├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── configs ├── cifar100 │ ├── erm.yaml │ ├── ft.yaml │ ├── ft_plus.yaml │ ├── plus.yaml │ ├── reinforce.yaml │ └── teacher.yaml ├── flowers102 │ ├── erm.yaml │ ├── ft.yaml │ ├── ft_plus.yaml │ ├── plus.yaml │ ├── reinforce.yaml │ └── teacher.yaml ├── food101 │ ├── erm.yaml │ ├── ft.yaml │ ├── ft_plus.yaml │ ├── plus.yaml │ ├── reinforce.yaml │ └── teacher.yaml └── imagenet │ ├── erm.yaml │ ├── kd.yaml │ ├── plus.yaml │ └── reinforce │ ├── cropflip.yaml │ ├── mixing.yaml │ ├── randaug.yaml │ └── randaug_mixing.yaml ├── data.py ├── dr ├── __init__.py ├── data.py ├── transforms.py └── utils.py ├── figures ├── DR_illustration_wide.pdf ├── DR_illustration_wide.png ├── imagenet_RRC+RARE_accuracy_annotated.pdf └── imagenet_RRC+RARE_accuracy_annotated.png ├── models.py ├── reinforce.py ├── requirements.txt ├── results ├── table_E1000.md ├── table_E150.md └── table_E300.md ├── tests └── test_reinforce.py ├── train.py ├── trainers.py ├── transforms.py └── utils.py /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to making participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies within all project spaces, and it also applies when 49 | an individual is representing the project or its community in public spaces. 50 | Examples of representing a project or community include using an official 51 | project e-mail address, posting via an official social media account, or acting 52 | as an appointed representative at an online or offline event. Representation of 53 | a project may be further defined and clarified by project maintainers. 54 | 55 | ## Enforcement 56 | 57 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 58 | reported by contacting the open source team at [opensource-conduct@group.apple.com](mailto:opensource-conduct@group.apple.com). All 59 | complaints will be reviewed and investigated and will result in a response that 60 | is deemed necessary and appropriate to the circumstances. The project team is 61 | obligated to maintain confidentiality with regard to the reporter of an incident. 62 | Further details of specific enforcement policies may be posted separately. 63 | 64 | Project maintainers who do not follow or enforce the Code of Conduct in good 65 | faith may face temporary or permanent repercussions as determined by other 66 | members of the project's leadership. 67 | 68 | ## Attribution 69 | 70 | This Code of Conduct is adapted from the [Contributor Covenant](https://www.contributor-covenant.org), version 1.4, 71 | available at [https://www.contributor-covenant.org/version/1/4/code-of-conduct.html](https://www.contributor-covenant.org/version/1/4/code-of-conduct.html) -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contribution Guide 2 | 3 | Thanks for your interest in contributing. This project was released to accompany a research paper for purposes of reproducability, and beyond its publication there are limited plans for future development of the repository. 4 | 5 | While we welcome new pull requests and issues please note that our response may be limited. Forks and out-of-tree improvements are strongly encouraged. 6 | 7 | ## Before you get started 8 | 9 | By submitting a pull request, you represent that you have the right to license your contribution to Apple and the community, and agree by submitting the patch that your contributions are licensed under the [LICENSE](LICENSE). 10 | 11 | We ask that all community members read and observe our [Code of Conduct](CODE_OF_CONDUCT.md). -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (C) 2022 Apple Inc. All Rights Reserved. 2 | 3 | IMPORTANT: This Apple software is supplied to you by Apple 4 | Inc. ("Apple") in consideration of your agreement to the following 5 | terms, and your use, installation, modification or redistribution of 6 | this Apple software constitutes acceptance of these terms. If you do 7 | not agree with these terms, please do not use, install, modify or 8 | redistribute this Apple software. 9 | 10 | In consideration of your agreement to abide by the following terms, and 11 | subject to these terms, Apple grants you a personal, non-exclusive 12 | license, under Apple's copyrights in this original Apple software (the 13 | "Apple Software"), to use, reproduce, modify and redistribute the Apple 14 | Software, with or without modifications, in source and/or binary forms; 15 | provided that if you redistribute the Apple Software in its entirety and 16 | without modifications, you must retain this notice and the following 17 | text and disclaimers in all such redistributions of the Apple Software. 18 | Neither the name, trademarks, service marks or logos of Apple Inc. may 19 | be used to endorse or promote products derived from the Apple Software 20 | without specific prior written permission from Apple. Except as 21 | expressly stated in this notice, no other rights or licenses, express or 22 | implied, are granted by Apple herein, including but not limited to any 23 | patent rights that may be infringed by your derivative works or by other 24 | works in which the Apple Software may be incorporated. 25 | 26 | The Apple Software is provided by Apple on an "AS IS" basis. APPLE 27 | MAKES NO WARRANTIES, EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION 28 | THE IMPLIED WARRANTIES OF NON-INFRINGEMENT, MERCHANTABILITY AND FITNESS 29 | FOR A PARTICULAR PURPOSE, REGARDING THE APPLE SOFTWARE OR ITS USE AND 30 | OPERATION ALONE OR IN COMBINATION WITH YOUR PRODUCTS. 31 | 32 | IN NO EVENT SHALL APPLE BE LIABLE FOR ANY SPECIAL, INDIRECT, INCIDENTAL 33 | OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 34 | SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 35 | INTERRUPTION) ARISING IN ANY WAY OUT OF THE USE, REPRODUCTION, 36 | MODIFICATION AND/OR DISTRIBUTION OF THE APPLE SOFTWARE, HOWEVER CAUSED 37 | AND WHETHER UNDER THEORY OF CONTRACT, TORT (INCLUDING NEGLIGENCE), 38 | STRICT LIABILITY OR OTHERWISE, EVEN IF APPLE HAS BEEN ADVISED OF THE 39 | POSSIBILITY OF SUCH DAMAGE. 40 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Dataset Reinforcement 2 | A light-weight implementation of Dataset Reinforcement, pretrained checkpoints, and reinforced datasets. 3 | 4 | **[Reinforce Data, Multiply Impact: Improved Model Accuracy and Robustness with Dataset Reinforcement.](https://arxiv.org/abs/2303.08983)** 5 | *, Faghri, F., Pouransari, H., Mehta, S., Farajtabar, M., Farhadi, A., Rastegari, M., & Tuzel, O., Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV), 2023.* 6 | 7 | 8 | **Update 2023/09/22**: Table 7-Average column corrected in ArXiv V3. Correct 9 | numbers: 30.4, 37.1, 37.9, 43.7, 39.6, 51.1. 10 | 11 | **Reinforced ImageNet, ImageNet+, improves accuracy at similar iterations/wall-clock** 12 | 13 |

14 | Reinforced ImageNet, ImageNet, improves accuracy at similar iterations/wall-clock. 15 |

16 | 17 | ImageNet validation accuracy of ResNet-50 is shown as 18 | a function of training duration with (1) ImageNet dataset, (2) knowledge 19 | distillation (KD), and (3) ImageNet+ dataset (ours). Each point is a full 20 | training with epochs varying from 50-1000. An epoch has the same number of 21 | iterations for ImageNet/ImageNet+. 22 | 23 | **Illustration of Dataset Reinforcement** 24 | 25 |

26 | Illustration of Dataset Reinforcement. 27 |

28 | 29 | Data augmentation and 30 | knowledge distillation are common approaches to 31 | improving accuracy. Dataset reinforcement combines the benefits of both by 32 | bringing the advantages of large models trained on large datasets to other 33 | datasets and models. Training of new models with a reinforced dataset is as 34 | fast as training on the original dataset for the same total iterations. 35 | Creating a reinforced dataset is a one-time process (e.g., ImageNet to 36 | ImageNet+) the cost of which is amortized over repeated uses. 37 | 38 | 39 | ## Requirements 40 | Install the requirements using: 41 | ```shell 42 | pip install -r requirements.txt 43 | ``` 44 | 45 | We support loading models from Timm library and CVNets library. 46 | 47 | To install CVNets library follow their [installation 48 | instructions](https://github.com/apple/ml-cvnets#installation). 49 | 50 | ## Reinforced Data 51 | 52 | The following is a list of reinforcements for ImageNet/CIFAR-100/Food-101/Flowers-102. We recommend ImageNet+-RA/RE based on the analysis in the paper. 53 | 54 | | Reinforce Data | Task ID | Size (GBs) | Comments | 55 | |------------------------|--------------------------------------------------------------------------------------------------------------------|------------|---------------| 56 | | ImageNet+-RRC | [rdata](https://docs-assets.developer.apple.com/ml-research/datasets/dr/imagenet_plus_rrc.tar.gz) | 33.4 | [NS=400] | 57 | | ImageNet+-+M* | [rdata](https://docs-assets.developer.apple.com/ml-research/datasets/dr/imagenet_plus_mixing.tar.gz) | 46.3 | [NS=400] | 58 | | **ImageNet+-+RA/RE** | [rdata](https://docs-assets.developer.apple.com/ml-research/datasets/dr/imagenet_plus_randaug.tar.gz) | 37.5 | [NS=400] | 59 | | ImageNet+-+M*+R* | [rdata](https://docs-assets.developer.apple.com/ml-research/datasets/dr/imagenet_plus_mixing_randaug.tar.gz) | 53.3 | [NS=400] | 60 | | ImageNet+-RRC-Small | [rdata](https://docs-assets.developer.apple.com/ml-research/datasets/dr/imagenet_plus_rrc_small.tar.gz) | 4.7 | [NS=100, K=5] | 61 | | ImageNet+-+M*-Small | [rdata](https://docs-assets.developer.apple.com/ml-research/datasets/dr/imagenet_plus_mixing_small.tar.gz) | 7.8 | [NS=100, K=5] | 62 | | ImageNet+-+RA/RE-Small | [rdata](https://docs-assets.developer.apple.com/ml-research/datasets/dr/imagenet_plus_randaug_small.tar.gz) | 5.6 | [NS=100, K=5] | 63 | | ImageNet+-+M*+R*-Small | [rdata](https://docs-assets.developer.apple.com/ml-research/datasets/dr/imagenet_plus_mixing_randaug_small.tar.gz) | 9.4 | [NS=100, K=5] | 64 | | ImageNet+-RRC-Mini | [rdata](https://docs-assets.developer.apple.com/ml-research/datasets/dr/imagenet_plus_rrc_mini.tar.gz) | 4.4 | [NS=50] | 65 | | ImageNet+-+M*-Mini | [rdata](https://docs-assets.developer.apple.com/ml-research/datasets/dr/imagenet_plus_mixing_mini.tar.gz) | 6.1 | [NS=50] | 66 | | ImageNet+-+RA/RE-Mini | [rdata](https://docs-assets.developer.apple.com/ml-research/datasets/dr/imagenet_plus_randaug_mini.tar.gz) | 4.9 | [NS=50] | 67 | | ImageNet+-+M*+R*-Mini | [rdata](https://docs-assets.developer.apple.com/ml-research/datasets/dr/imagenet_plus_mixing_randaug.tar.gz) | 7.0 | [NS=50] | 68 | | CIFAR-100 | [rdata](https://docs-assets.developer.apple.com/ml-research/datasets/dr/cifar100.tar.gz) | 2.5 | [NS=800] | 69 | | Food-101 | [rdata](https://docs-assets.developer.apple.com/ml-research/datasets/dr/food_101.tar.gz) | 4.2 | [NS=800] | 70 | | Flowers-102 | [rdata](https://docs-assets.developer.apple.com/ml-research/datasets/dr/flowers_102.tar.gz) | 0.5 | [NS=8000] | 71 | 72 | ## Pretrained Checkpoints 73 | 74 | ### CVNets Checkpoints 75 | 76 | We provide pretrained checkpoints for various models in CVNets. 77 | The accuracies can be verified using the CVNets library. 78 | 79 | - [150 Epochs Checkpoints](./results/table_E150.md) 80 | - [300 Epochs Checkpoints](./results/table_E300.md) 81 | - [1000 Epochs Checkpoints](./results/table_E1000.md) 82 | - [imagenet-cvnets.tar](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets.tar): All CVNets checkpoints trained on ImageNet (14.3GBs). 83 | - [imagenet-plus-cvnets.tar](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets.tar): All CVNets checkpoints trained on ImageNet+ (14.3GBs). 84 | 85 | Selected results trained for 1000 epochs: 86 | 87 | | Name | Mode | Params | ImageNet | ImageNet+ | ImageNet (EMA) | ImageNet+ (EMA) | Links | 88 | | :------------- | :--------- | :------------- | :----- | :----------- | :----- | :------ | ----- | 89 | | MobileNetV3 | large | 5.5M | 74.8 | 77.9 (**+3.1**) | 75.8 | 77.9 (**+2.1**) | [[best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/mobilenetv3_large_E1000/best.pt) [[ema_best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/mobilenetv3_large_E1000/ema_best.pt) [[config.yaml]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/mobilenetv3_large_E1000/config.yaml) [[metrics.jb]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/mobilenetv3_large_E1000/metrics.jb) | 90 | | ResNet | 50 | 25.6M | 80.0 | 82.0 (**+2.0**) | 80.1 | 82.0 (**+1.9**) | [[best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/resnet_50_E1000/best.pt) [[ema_best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/resnet_50_E1000/ema_best.pt) [[config.yaml]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/resnet_50_E1000/config.yaml) [[metrics.jb]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/resnet_50_E1000/metrics.jb) | 91 | | ViT | base | 86.7M | 76.8 | 85.1 (**+8.3**) | 80.8 | 85.1 (**+4.3**) | [[best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/vit_base_E1000/best.pt) [[ema_best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/vit_base_E1000/ema_best.pt) [[config.yaml]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/vit_base_E1000/config.yaml) [[metrics.jb]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/vit_base_E1000/metrics.jb) | 92 | | ViT-384 | base | 86.7M | 79.4 | 85.4 (**+6.0**) | 83.1 | 85.5 (**+2.4**) | [[best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/vit-384_base_E1000/best.pt) [[ema_best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/vit-384_base_E1000/ema_best.pt) [[config.yaml]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/vit-384_base_E1000/config.yaml) [[metrics.jb]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/vit-384_base_E1000/metrics.jb) | 93 | | Swin | tiny | 28.3M | 81.3 | 84.0 (**+2.7**) | 80.5 | 83.5 (**+3.0**) | [[best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/swin_tiny_E1000/best.pt) [[ema_best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/swin_tiny_E1000/ema_best.pt) [[config.yaml]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/swin_tiny_E1000/config.yaml) [[metrics.jb]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/swin_tiny_E1000/metrics.jb) | 94 | | Swin | small | 49.7M | 81.3 | 85.0 (**+3.7**) | 81.9 | 84.5 (**+2.6**) | [[best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/swin_small_E1000/best.pt) [[ema_best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/swin_small_E1000/ema_best.pt) [[config.yaml]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/swin_small_E1000/config.yaml) [[metrics.jb]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/swin_small_E1000/metrics.jb) | 95 | | Swin | base | 87.8M | 81.5 | 85.4 (**+3.9**) | 81.8 | 85.2 (**+3.4**) | [[best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/swin_base_E1000/best.pt) [[ema_best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/swin_base_E1000/ema_best.pt) [[config.yaml]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/swin_base_E1000/config.yaml) [[metrics.jb]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/swin_base_E1000/metrics.jb) | 96 | | Swin-384 | base | 87.8M | 83.6 | 85.8 (**+2.2**) | 83.8 | 85.5 (**+1.7**) | [[best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/swin-384_base_E1000/best.pt) [[ema_best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/swin-384_base_E1000/ema_best.pt) [[config.yaml]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/swin-384_base_E1000/config.yaml) [[metrics.jb]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/swin-384_base_E1000/metrics.jb) | 97 | 98 | ### Timm Checkpoints 99 | 100 | We provide pretrained checkpoints for ResNet50d from Timm library trained for 101 | 150 epochs using various reinforced datasets: 102 | 103 | - [imagenet-timm.tar](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-timm.tar): All Timm checkpoints trained on ImageNet and ImageNet+ (2.3GBs). 104 | 105 | | Model | Reinforce Data | Accuracy | Links | 106 | | ------------------------ | :------------------------ | :-------------: | :-----------------------------------------------------------------------------: | 107 | | ResNet50d [ERM] | N/A | 78.9 | [[best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-timm/imagenet/best.pt) [[config.yaml]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-timm/imagenet/config.yaml) [[metrics.jb]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-timm/imagenet/metrics.jb) | 108 | | ResNet50d | ImageNet+-RRC | 80.0 | [[best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-timm/imagenet_plus_rrc/best.pt) [[config.yaml]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-timm/imagenet_plus_rrc/config.yaml) [[metrics.jb]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-timm/imagenet_plus_rrc/metrics.jb) | 109 | | ResNet50d | ImageNet+-+M* | 80.5 | [[best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-timm/imagenet_plus_mixing/best.pt) [[config.yaml]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-timm/imagenet_plus_mixing/config.yaml) [[metrics.jb]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-timm/imagenet_plus_mixing/metrics.jb) | 110 | | ResNet50d | ImageNet+-+RA/RE | 80.4 | [[best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-timm/imagenet_plus_randaug/best.pt) [[config.yaml]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-timm/imagenet_plus_randaug/config.yaml) [[metrics.jb]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-timm/imagenet_plus_randaug/metrics.jb) | 111 | | ResNet50d | ImageNet+-+M*+R* | 80.2 | [[best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-timm/imagenet_plus_mixing_randaug/best.pt) [[config.yaml]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-timm/imagenet_plus_mixing_randaug/config.yaml) [[metrics.jb]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-timm/imagenet_plus_mixing_randaug/metrics.jb) | 112 | | ResNet50d | ImageNet+-RRC-Small | 80.0 | [[best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-timm/imagenet_plus_rrc_small/best.pt) [[config.yaml]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-timm/imagenet_plus_rrc_small/config.yaml) [[metrics.jb]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-timm/imagenet_plus_rrc_small/metrics.jb) | 113 | | ResNet50d | ImageNet+-+M*-Small | 80.6 | [[best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-timm/imagenet_plus_mixing_small/best.pt) [[config.yaml]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-timm/imagenet_plus_mixing_small/config.yaml) [[metrics.jb]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-timm/imagenet_plus_mixing_small/metrics.jb) | 114 | | ResNet50d | ImageNet+-+RA/RE-Small | 80.2 | [[best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-timm/imagenet_plus_randaug_small/best.pt) [[config.yaml]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-timm/imagenet_plus_randaug_small/config.yaml) [[metrics.jb]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-timm/imagenet_plus_randaug_small/metrics.jb) | 115 | | ResNet50d | ImageNet+-+M*+R*-Small | 80.1 | [[best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-timm/imagenet_plus_mixing_randaug_small/best.pt) [[config.yaml]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-timm/imagenet_plus_mixing_randaug_small/config.yaml) [[metrics.jb]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-timm/imagenet_plus_mixing_randaug_small/metrics.jb) | 116 | | ResNet50d | ImageNet+-RRC-Mini | 80.1 | [[best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-timm/imagenet_plus_rrc_mini/best.pt) [[config.yaml]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-timm/imagenet_plus_rrc_mini/config.yaml) [[metrics.jb]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-timm/imagenet_plus_rrc_mini/metrics.jb) | 117 | | ResNet50d | ImageNet+-+M*-Mini | 80.5 | [[best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-timm/imagenet_plus_mixing_mini/best.pt) [[config.yaml]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-timm/imagenet_plus_mixing_mini/config.yaml) [[metrics.jb]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-timm/imagenet_plus_mixing_mini/metrics.jb) | 118 | | ResNet50d | ImageNet+-+RA/RE-Mini | 80.4 | [[best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-timm/imagenet_plus_randaug_mini/best.pt) [[config.yaml]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-timm/imagenet_plus_randaug_mini/config.yaml) [[metrics.jb]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-timm/imagenet_plus_randaug_mini/metrics.jb) | 119 | | ResNet50d | ImageNet+-+M*+R*-Mini | 80.2 | [[best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-timm/imagenet_plus_mixing_randaug/best.pt) [[config.yaml]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-timm/imagenet_plus_mixing_randaug/config.yaml) [[metrics.jb]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-timm/imagenet_plus_mixing_randaug/metrics.jb) | 120 | 121 | 122 | ## Training 123 | 124 | We provide YAML configurations for training ResNet-50 in 125 | `CFG_FILE=configs/${DATASET}/${TRAINER}.yaml`, with the following options: 126 | - `DATASET`: `imagenet`, `cifar100`, `flowers102`, and `food101`. 127 | - `TRAINER`: standard training (`erm`), knowledge distillation (`kd`), and with reinforced data (`plus`). 128 | 129 | Follow the steps: 130 | - Choose the dataset and trainer from the choices above. 131 | - Download ImageNet data and set `data_path` in `$CFG_FILE`. 132 | - Download reinforcement metadata and set `reinforce.data_path` in `$CFG_FILE`. 133 | 134 | ```shell 135 | python train.py --config configs/imagenet/erm.yaml # ImageNet training without Reinforcements (ERM) 136 | python train.py --config configs/imagenet/kd.yaml # Knowledge Distillation 137 | python train.py --config configs/imagenet/plus.yaml # ImageNet+ training with reinforcements 138 | ``` 139 | 140 | Hyperparameters such as batch size for ImageNet training are optimized for 141 | running on a single node with 8xA100 40GB GPUs. For 142 | CIFAR-100/Flowers-102/Food-101, the configurations are optimized for training 143 | on a single GPU. 144 | 145 | ## Reinforce ImageNet 146 | 147 | Follow the steps: 148 | - Download ImageNet data and set `data_path` in `$CFG_FILE`. 149 | - If needed, change the teacher in `$CFG_FILE` to a smaller architecture. 150 | 151 | ```shell 152 | python reinforce.py --config configs/imagenet/reinforce/randaug.yaml 153 | ``` 154 | 155 | ## Reference 156 | 157 | If you found this code useful, please cite the following paper: 158 | 159 | @InProceedings{faghri2023reinforce, 160 | author = {Faghri, Fartash and Pouransari, Hadi and Mehta, Sachin and Farajtabar, Mehrdad and Farhadi, Ali and Rastegari, Mohammad and Tuzel, Oncel}, 161 | title = {Reinforce Data, Multiply Impact: Improved Model Accuracy and Robustness with Dataset Reinforcement}, 162 | booktitle = {Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)}, 163 | month = {October}, 164 | year = {2023}, 165 | } 166 | 167 | ## License 168 | This sample code is released under the [LICENSE](LICENSE) terms. 169 | -------------------------------------------------------------------------------- /configs/cifar100/erm.yaml: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023 Apple Inc. All rights reserved. 3 | # 4 | parameters: 5 | dataset: cifar100 6 | trainer: ERM 7 | arch: timm 8 | model: 9 | model_name: resnet50d 10 | pretrained: false 11 | num_classes: 100 12 | loss: 13 | label_smoothing: 0.1 14 | optim: 15 | name: sgd 16 | lr: 0.2 17 | momentum: 0.9 18 | weight_decay: 5.e-4 19 | warmup_length: 0 20 | no_decay_bn_filter_bias: false 21 | nesterov: false 22 | epochs: 1000 23 | save_freq: 100 24 | start_epoch: 0 25 | batch_size: 256 26 | print_freq: 100 27 | resume: '' 28 | evaluate: false 29 | pretrained: false 30 | dist_url: 'tcp://127.0.0.1:23333' 31 | dist_backend: 'nccl' 32 | # Single GPU training 33 | multiprocessing_distributed: false 34 | world_size: 1 35 | rank: 0 36 | workers: 8 37 | pin_memory: true 38 | persistent_workers: true 39 | seed: NULL 40 | gpu: 0 41 | download_data: false # Set to True to download 42 | data_path: '' 43 | artifact_path: '' 44 | image_augmentation: 45 | train: 46 | resize: 47 | size: 224 48 | random_crop: 49 | size: 224 50 | padding: 16 51 | rand_augment: # No horizontal flip when rand-augment is enabled 52 | enable: true 53 | to_tensor: 54 | enable: true 55 | normalize: 56 | mean: [0.507075159237, 0.4865488733149, 0.440917843367] 57 | std: [0.267334285879, 0.2564384629170, 0.276150471325] 58 | mixup: 59 | enable: true 60 | alpha: 0.2 61 | p: 1.0 62 | cutmix: 63 | enable: true 64 | alpha: 1.0 65 | p: 1.0 66 | val: 67 | resize: 68 | size: 224 69 | to_tensor: 70 | enable: true 71 | normalize: 72 | mean: [0.507075159237, 0.4865488733149, 0.440917843367] 73 | std: [0.267334285879, 0.2564384629170, 0.276150471325] 74 | -------------------------------------------------------------------------------- /configs/cifar100/ft.yaml: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023 Apple Inc. All rights reserved. 3 | # 4 | parameters: 5 | dataset: cifar100 6 | trainer: ERM 7 | arch: timm 8 | model: 9 | model_name: resnet50d 10 | pretrained: false 11 | num_classes: 100 12 | # TODO: load timm checkpoint 13 | loss: 14 | label_smoothing: 0.1 15 | optim: 16 | name: sgd 17 | lr: 0.002 18 | momentum: 0.9 19 | weight_decay: 5.e-4 20 | warmup_length: 0 21 | no_decay_bn_filter_bias: false 22 | nesterov: false 23 | epochs: 1000 24 | save_freq: 100 25 | start_epoch: 0 26 | batch_size: 256 27 | print_freq: 100 28 | resume: '' 29 | evaluate: false 30 | pretrained: false 31 | dist_url: 'tcp://127.0.0.1:23333' 32 | dist_backend: 'nccl' 33 | # Single GPU training 34 | multiprocessing_distributed: false 35 | world_size: 1 36 | rank: 0 37 | workers: 8 38 | pin_memory: true 39 | persistent_workers: true 40 | seed: NULL 41 | gpu: 0 42 | download_data: false # Set to True to download 43 | data_path: '' 44 | artifact_path: '' 45 | image_augmentation: 46 | train: 47 | resize: 48 | size: 224 49 | random_crop: 50 | size: 224 51 | padding: 16 52 | rand_augment: # No horizontal flip when rand-augment is enabled 53 | enable: true 54 | to_tensor: 55 | enable: true 56 | normalize: 57 | mean: [0.507075159237, 0.4865488733149, 0.440917843367] 58 | std: [0.267334285879, 0.2564384629170, 0.276150471325] 59 | mixup: 60 | enable: true 61 | alpha: 0.2 62 | p: 1.0 63 | cutmix: 64 | enable: true 65 | alpha: 1.0 66 | p: 1.0 67 | val: 68 | resize: 69 | size: 224 70 | to_tensor: 71 | enable: true 72 | normalize: 73 | mean: [0.507075159237, 0.4865488733149, 0.440917843367] 74 | std: [0.267334285879, 0.2564384629170, 0.276150471325] 75 | -------------------------------------------------------------------------------- /configs/cifar100/ft_plus.yaml: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023 Apple Inc. All rights reserved. 3 | # 4 | parameters: 5 | dataset: cifar100 6 | trainer: DR 7 | arch: timm 8 | model: 9 | model_name: resnet50d 10 | pretrained: false 11 | num_classes: 100 12 | # TODO: load timm checkpoint 13 | loss: 14 | label_smoothing: 0.1 15 | optim: 16 | name: sgd 17 | lr: 0.002 18 | momentum: 0.9 19 | weight_decay: 5.e-4 20 | warmup_length: 0 21 | no_decay_bn_filter_bias: false 22 | nesterov: false 23 | epochs: 1000 24 | save_freq: 100 25 | start_epoch: 0 26 | batch_size: 256 27 | print_freq: 100 28 | resume: '' 29 | evaluate: false 30 | pretrained: false 31 | dist_url: 'tcp://127.0.0.1:23333' 32 | dist_backend: 'nccl' 33 | # Single GPU training 34 | multiprocessing_distributed: false 35 | world_size: 1 36 | rank: 0 37 | workers: 8 38 | pin_memory: true 39 | persistent_workers: true 40 | seed: NULL 41 | gpu: 0 42 | download_data: false # Set to True to download 43 | data_path: '' 44 | artifact_path: '' 45 | image_augmentation: 46 | # Training transformations for non-reinforced samples 47 | train: 48 | resize: 49 | size: 224 50 | random_crop: 51 | size: 224 52 | padding: 16 53 | rand_augment: # No horizontal flip when rand-augment is enabled 54 | enable: true 55 | to_tensor: 56 | enable: true 57 | normalize: 58 | mean: [0.507075159237, 0.4865488733149, 0.440917843367] 59 | std: [0.267334285879, 0.2564384629170, 0.276150471325] 60 | val: 61 | resize: 62 | size: 224 63 | to_tensor: 64 | enable: true 65 | normalize: 66 | mean: [0.507075159237, 0.4865488733149, 0.440917843367] 67 | std: [0.267334285879, 0.2564384629170, 0.276150471325] 68 | reinforce: 69 | enable: true 70 | p: 0.99 71 | num_samples: NULL 72 | densify: smooth 73 | data_path: NULL 74 | -------------------------------------------------------------------------------- /configs/cifar100/plus.yaml: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023 Apple Inc. All rights reserved. 3 | # 4 | parameters: 5 | dataset: cifar100 6 | trainer: DR 7 | arch: timm 8 | model: 9 | model_name: resnet50d 10 | pretrained: false 11 | num_classes: 100 12 | loss: 13 | label_smoothing: 0.1 14 | optim: 15 | name: sgd 16 | lr: 0.2 17 | momentum: 0.9 18 | weight_decay: 5.e-4 19 | warmup_length: 0 20 | no_decay_bn_filter_bias: false 21 | nesterov: false 22 | epochs: 1000 23 | save_freq: 100 24 | start_epoch: 0 25 | batch_size: 256 26 | print_freq: 100 27 | resume: '' 28 | evaluate: false 29 | pretrained: false 30 | dist_url: 'tcp://127.0.0.1:23333' 31 | dist_backend: 'nccl' 32 | # Single GPU training 33 | multiprocessing_distributed: false 34 | world_size: 1 35 | rank: 0 36 | workers: 8 37 | pin_memory: true 38 | persistent_workers: true 39 | seed: NULL 40 | gpu: 0 41 | download_data: false # Set to True to download 42 | data_path: '' 43 | artifact_path: '' 44 | image_augmentation: 45 | # Training transformations for non-reinforced samples 46 | train: 47 | resize: 48 | size: 224 49 | random_crop: 50 | size: 224 51 | padding: 16 52 | rand_augment: # No horizontal flip when rand-augment is enabled 53 | enable: true 54 | to_tensor: 55 | enable: true 56 | normalize: 57 | mean: [0.507075159237, 0.4865488733149, 0.440917843367] 58 | std: [0.267334285879, 0.2564384629170, 0.276150471325] 59 | val: 60 | resize: 61 | size: 224 62 | to_tensor: 63 | enable: true 64 | normalize: 65 | mean: [0.507075159237, 0.4865488733149, 0.440917843367] 66 | std: [0.267334285879, 0.2564384629170, 0.276150471325] 67 | reinforce: 68 | enable: true 69 | p: 0.99 70 | num_samples: NULL 71 | densify: smooth 72 | data_path: NULL 73 | -------------------------------------------------------------------------------- /configs/cifar100/reinforce.yaml: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023 Apple Inc. All rights reserved. 3 | # 4 | parameters: 5 | dataset: cifar100 6 | teacher: 7 | ensemble: true 8 | # TODO: Load pretrained 9 | batch_size: 32 10 | print_freq: 10 11 | dist_url: 'tcp://127.0.0.1:23333' 12 | dist_backend: 'nccl' 13 | multiprocessing_distributed: true 14 | world_size: 1 15 | rank: 0 16 | workers: 88 17 | pin_memory: true 18 | persistent_workers: true 19 | seed: NULL 20 | gpu: NULL 21 | download_data: false # Set to True to download dataset 22 | data_path: '' 23 | artifact_path: '' 24 | reinforce: 25 | num_samples: 100 26 | num_candidates: NULL 27 | topk: 10 28 | compress: true 29 | joblib: false 30 | gzip: true 31 | image_augmentation: 32 | uint8: 33 | enable: true 34 | resize: 35 | size: 224 36 | random_crop: 37 | size: 224 38 | padding: 16 39 | random_horizontal_flip: 40 | enable: true 41 | p: 0.5 42 | rand_augment: 43 | enable: true 44 | p: 1.0 45 | to_tensor: 46 | enable: true 47 | normalize: 48 | mean: [0.507075159237, 0.4865488733149, 0.440917843367] 49 | std: [0.267334285879, 0.2564384629170, 0.276150471325] 50 | random_erase: 51 | enable: true 52 | p: 0.25 53 | -------------------------------------------------------------------------------- /configs/cifar100/teacher.yaml: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023 Apple Inc. All rights reserved. 3 | # 4 | parameters: 5 | dataset: cifar100 6 | trainer: ERM 7 | arch: timm 8 | model: 9 | model_name: resnet50d 10 | pretrained: false 11 | num_classes: 100 12 | # TODO: load timm pretrained checkpoint 13 | loss: 14 | label_smoothing: 0.1 15 | optim: 16 | name: sgd 17 | lr: 0.002 18 | momentum: 0.9 19 | weight_decay: 5.e-4 20 | warmup_length: 0 21 | no_decay_bn_filter_bias: false 22 | nesterov: false 23 | epochs: 1000 24 | save_freq: 100 25 | start_epoch: 0 26 | batch_size: 256 27 | print_freq: 100 28 | resume: '' 29 | evaluate: false 30 | pretrained: false 31 | dist_url: 'tcp://127.0.0.1:23333' 32 | dist_backend: 'nccl' 33 | # Single GPU training 34 | multiprocessing_distributed: false 35 | world_size: 1 36 | rank: 0 37 | workers: 8 38 | pin_memory: true 39 | persistent_workers: true 40 | seed: NULL 41 | gpu: 0 42 | download_data: false # Set to True to download 43 | data_path: '' 44 | artifact_path: '' 45 | image_augmentation: 46 | train: 47 | resize: 48 | size: 224 49 | random_crop: 50 | size: 224 51 | padding: 16 52 | rand_augment: # No horizontal flip when rand-augment is enabled 53 | enable: true 54 | to_tensor: 55 | enable: true 56 | normalize: 57 | mean: [0.507075159237, 0.4865488733149, 0.440917843367] 58 | std: [0.267334285879, 0.2564384629170, 0.276150471325] 59 | mixup: 60 | enable: true 61 | alpha: 1.0 # Strong mixup 62 | p: 1.0 63 | cutmix: 64 | enable: true 65 | alpha: 1.0 66 | p: 1.0 67 | val: 68 | resize: 69 | size: 224 70 | to_tensor: 71 | enable: true 72 | normalize: 73 | mean: [0.507075159237, 0.4865488733149, 0.440917843367] 74 | std: [0.267334285879, 0.2564384629170, 0.276150471325] 75 | -------------------------------------------------------------------------------- /configs/flowers102/erm.yaml: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023 Apple Inc. All rights reserved. 3 | # 4 | parameters: 5 | dataset: flowers102 6 | trainer: ERM 7 | arch: timm 8 | model: 9 | model_name: resnet50d 10 | pretrained: false 11 | num_classes: 102 12 | loss: 13 | label_smoothing: 0.1 14 | optim: 15 | name: sgd 16 | lr: 0.2 17 | momentum: 0.9 18 | weight_decay: 5.e-4 19 | warmup_length: 0 20 | no_decay_bn_filter_bias: false 21 | nesterov: false 22 | epochs: 10000 23 | save_freq: 100 24 | start_epoch: 0 25 | batch_size: 256 26 | print_freq: 100 27 | resume: '' 28 | evaluate: false 29 | pretrained: false 30 | dist_url: 'tcp://127.0.0.1:23333' 31 | dist_backend: 'nccl' 32 | # Single GPU training 33 | multiprocessing_distributed: false 34 | world_size: 1 35 | rank: 0 36 | workers: 8 37 | pin_memory: true 38 | persistent_workers: true 39 | seed: NULL 40 | gpu: 0 41 | download_data: false # Set to True to download 42 | data_path: '' 43 | artifact_path: '' 44 | image_augmentation: 45 | train: 46 | random_resized_crop: 47 | size: 224 48 | rand_augment: # No horizontal flip when rand-augment is enabled 49 | enable: true 50 | to_tensor: 51 | enable: true 52 | normalize: 53 | mean: [0.485, 0.456, 0.406] 54 | std: [0.229, 0.224, 0.225] 55 | mixup: 56 | enable: true 57 | alpha: 0.2 58 | p: 1.0 59 | cutmix: 60 | enable: true 61 | alpha: 1.0 62 | p: 1.0 63 | val: 64 | resize: 65 | size: 256 66 | center_crop: 67 | size: 224 68 | to_tensor: 69 | enable: true 70 | normalize: 71 | mean: [0.485, 0.456, 0.406] 72 | std: [0.229, 0.224, 0.225] 73 | -------------------------------------------------------------------------------- /configs/flowers102/ft.yaml: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023 Apple Inc. All rights reserved. 3 | # 4 | parameters: 5 | dataset: flowers102 6 | trainer: ERM 7 | arch: timm 8 | model: 9 | model_name: resnet50d 10 | pretrained: false 11 | num_classes: 102 12 | # TODO: load pretrained timm checkpoint 13 | loss: 14 | label_smoothing: 0.1 15 | optim: 16 | name: sgd 17 | lr: 0.002 18 | momentum: 0.9 19 | weight_decay: 5.e-4 20 | warmup_length: 0 21 | no_decay_bn_filter_bias: false 22 | nesterov: false 23 | epochs: 10000 24 | save_freq: 100 25 | start_epoch: 0 26 | batch_size: 256 27 | print_freq: 100 28 | resume: '' 29 | evaluate: false 30 | pretrained: false 31 | dist_url: 'tcp://127.0.0.1:23333' 32 | dist_backend: 'nccl' 33 | # Single GPU training 34 | multiprocessing_distributed: false 35 | world_size: 1 36 | rank: 0 37 | workers: 8 38 | pin_memory: true 39 | persistent_workers: true 40 | seed: NULL 41 | gpu: 0 42 | download_data: false # Set to True to download 43 | data_path: '' 44 | artifact_path: '' 45 | image_augmentation: 46 | train: 47 | random_resized_crop: 48 | size: 224 49 | rand_augment: # No horizontal flip when rand-augment is enabled 50 | enable: true 51 | to_tensor: 52 | enable: true 53 | normalize: 54 | mean: [0.485, 0.456, 0.406] 55 | std: [0.229, 0.224, 0.225] 56 | mixup: 57 | enable: true 58 | alpha: 0.2 59 | p: 1.0 60 | cutmix: 61 | enable: true 62 | alpha: 1.0 63 | p: 1.0 64 | val: 65 | resize: 66 | size: 256 67 | center_crop: 68 | size: 224 69 | to_tensor: 70 | enable: true 71 | normalize: 72 | mean: [0.485, 0.456, 0.406] 73 | std: [0.229, 0.224, 0.225] 74 | -------------------------------------------------------------------------------- /configs/flowers102/ft_plus.yaml: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023 Apple Inc. All rights reserved. 3 | # 4 | parameters: 5 | dataset: flowers102 6 | trainer: DR 7 | arch: timm 8 | model: 9 | model_name: resnet50d 10 | pretrained: false 11 | num_classes: 102 12 | # TODO: load pretrained timm checkpoint 13 | loss: 14 | label_smoothing: 0.1 15 | optim: 16 | name: sgd 17 | lr: 0.002 18 | momentum: 0.9 19 | weight_decay: 5.e-4 20 | warmup_length: 0 21 | no_decay_bn_filter_bias: false 22 | nesterov: false 23 | epochs: 10000 24 | save_freq: 100 25 | start_epoch: 0 26 | batch_size: 256 27 | print_freq: 100 28 | resume: '' 29 | evaluate: false 30 | pretrained: false 31 | dist_url: 'tcp://127.0.0.1:23333' 32 | dist_backend: 'nccl' 33 | # Single GPU training 34 | multiprocessing_distributed: false 35 | world_size: 1 36 | rank: 0 37 | workers: 8 38 | pin_memory: true 39 | persistent_workers: true 40 | seed: NULL 41 | gpu: 0 42 | download_data: false # Set to True to download 43 | data_path: '' 44 | artifact_path: '' 45 | image_augmentation: 46 | train: 47 | random_resized_crop: 48 | size: 224 49 | rand_augment: # No horizontal flip when rand-augment is enabled 50 | enable: true 51 | to_tensor: 52 | enable: true 53 | normalize: 54 | mean: [0.485, 0.456, 0.406] 55 | std: [0.229, 0.224, 0.225] 56 | val: 57 | resize: 58 | size: 256 59 | center_crop: 60 | size: 224 61 | to_tensor: 62 | enable: true 63 | normalize: 64 | mean: [0.485, 0.456, 0.406] 65 | std: [0.229, 0.224, 0.225] 66 | reinforce: 67 | enable: true 68 | p: 0.99 69 | num_samples: NULL 70 | densify: smooth 71 | data_path: NULL 72 | -------------------------------------------------------------------------------- /configs/flowers102/plus.yaml: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023 Apple Inc. All rights reserved. 3 | # 4 | parameters: 5 | dataset: flowers102 6 | trainer: DR 7 | arch: timm 8 | model: 9 | model_name: resnet50d 10 | pretrained: false 11 | num_classes: 102 12 | loss: 13 | label_smoothing: 0.1 14 | optim: 15 | name: sgd 16 | lr: 0.2 17 | momentum: 0.9 18 | weight_decay: 5.e-4 19 | warmup_length: 0 20 | no_decay_bn_filter_bias: false 21 | nesterov: false 22 | epochs: 10000 23 | save_freq: 100 24 | start_epoch: 0 25 | batch_size: 256 26 | print_freq: 100 27 | resume: '' 28 | evaluate: false 29 | pretrained: false 30 | dist_url: 'tcp://127.0.0.1:23333' 31 | dist_backend: 'nccl' 32 | # Single GPU training 33 | multiprocessing_distributed: false 34 | world_size: 1 35 | rank: 0 36 | workers: 8 37 | pin_memory: true 38 | persistent_workers: true 39 | seed: NULL 40 | gpu: 0 41 | download_data: false # Set to True to download 42 | data_path: '' 43 | artifact_path: '' 44 | image_augmentation: 45 | train: 46 | random_resized_crop: 47 | size: 224 48 | rand_augment: # No horizontal flip when rand-augment is enabled 49 | enable: true 50 | to_tensor: 51 | enable: true 52 | normalize: 53 | mean: [0.485, 0.456, 0.406] 54 | std: [0.229, 0.224, 0.225] 55 | val: 56 | resize: 57 | size: 256 58 | center_crop: 59 | size: 224 60 | to_tensor: 61 | enable: true 62 | normalize: 63 | mean: [0.485, 0.456, 0.406] 64 | std: [0.229, 0.224, 0.225] 65 | reinforce: 66 | enable: true 67 | p: 0.99 68 | num_samples: NULL 69 | densify: smooth 70 | data_path: NULL 71 | -------------------------------------------------------------------------------- /configs/flowers102/reinforce.yaml: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023 Apple Inc. All rights reserved. 3 | # 4 | parameters: 5 | dataset: flowers102 6 | teacher: 7 | ensemble: true 8 | # TODO: Load pretrained 9 | batch_size: 16 10 | print_freq: 10 11 | dist_url: 'tcp://127.0.0.1:23333' 12 | dist_backend: 'nccl' 13 | multiprocessing_distributed: true 14 | world_size: 1 15 | rank: 0 16 | workers: 88 17 | pin_memory: true 18 | persistent_workers: true 19 | seed: NULL 20 | gpu: NULL 21 | download_data: false # Set to True to download dataset 22 | data_path: '' 23 | artifact_path: '' 24 | reinforce: 25 | num_samples: 1000 26 | num_candidates: NULL 27 | topk: 10 28 | compress: true 29 | joblib: false 30 | gzip: true 31 | image_augmentation: 32 | uint8: 33 | enable: true 34 | random_resize_crop: 35 | enable: true 36 | size: 224 37 | random_horizontal_flip: 38 | enable: true 39 | p: 0.5 40 | rand_augment: 41 | enable: true 42 | p: 1.0 43 | to_tensor: 44 | enable: true 45 | normalize: 46 | mean: [0.485, 0.456, 0.406] 47 | std: [0.229, 0.224, 0.225] 48 | random_erase: 49 | enable: true 50 | p: 0.25 51 | -------------------------------------------------------------------------------- /configs/flowers102/teacher.yaml: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023 Apple Inc. All rights reserved. 3 | # 4 | parameters: 5 | dataset: flowers102 6 | trainer: ERM 7 | arch: timm 8 | model: 9 | model_name: resnet50d 10 | pretrained: false 11 | num_classes: 102 12 | # TODO: load timm pretrained checkpoint 13 | loss: 14 | label_smoothing: 0.1 15 | optim: 16 | name: sgd 17 | lr: 0.002 18 | momentum: 0.9 19 | weight_decay: 5.e-4 20 | warmup_length: 0 21 | no_decay_bn_filter_bias: false 22 | nesterov: false 23 | epochs: 10000 24 | save_freq: 100 25 | start_epoch: 0 26 | batch_size: 256 27 | print_freq: 100 28 | resume: '' 29 | evaluate: false 30 | pretrained: false 31 | dist_url: 'tcp://127.0.0.1:23333' 32 | dist_backend: 'nccl' 33 | # Single GPU training 34 | multiprocessing_distributed: false 35 | world_size: 1 36 | rank: 0 37 | workers: 8 38 | pin_memory: true 39 | persistent_workers: true 40 | seed: NULL 41 | gpu: 0 42 | download_data: false # Set to True to download 43 | data_path: '' 44 | artifact_path: '' 45 | image_augmentation: 46 | train: 47 | random_resized_crop: 48 | size: 224 49 | rand_augment: # No horizontal flip when rand-augment is enabled 50 | enable: true 51 | to_tensor: 52 | enable: true 53 | normalize: 54 | mean: [0.485, 0.456, 0.406] 55 | std: [0.229, 0.224, 0.225] 56 | mixup: 57 | enable: true 58 | alpha: 1.0 # Strong mixup 59 | p: 1.0 60 | cutmix: 61 | enable: true 62 | alpha: 1.0 63 | p: 1.0 64 | val: 65 | resize: 66 | size: 256 67 | center_crop: 68 | size: 224 69 | to_tensor: 70 | enable: true 71 | normalize: 72 | mean: [0.485, 0.456, 0.406] 73 | std: [0.229, 0.224, 0.225] 74 | -------------------------------------------------------------------------------- /configs/food101/erm.yaml: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023 Apple Inc. All rights reserved. 3 | # 4 | parameters: 5 | dataset: food101 6 | trainer: ERM 7 | arch: timm 8 | model: 9 | model_name: resnet50d 10 | pretrained: false 11 | num_classes: 101 12 | loss: 13 | label_smoothing: 0.1 14 | optim: 15 | name: sgd 16 | lr: 0.2 17 | momentum: 0.9 18 | weight_decay: 5.e-4 19 | warmup_length: 0 20 | no_decay_bn_filter_bias: false 21 | nesterov: false 22 | epochs: 1000 23 | save_freq: 100 24 | start_epoch: 0 25 | batch_size: 256 26 | print_freq: 100 27 | resume: '' 28 | evaluate: false 29 | pretrained: false 30 | dist_url: 'tcp://127.0.0.1:23333' 31 | dist_backend: 'nccl' 32 | # Single GPU training 33 | multiprocessing_distributed: false 34 | world_size: 1 35 | rank: 0 36 | workers: 8 37 | pin_memory: true 38 | persistent_workers: true 39 | seed: NULL 40 | gpu: 0 41 | download_data: false # Set to True to download 42 | data_path: '' 43 | artifact_path: '' 44 | image_augmentation: 45 | train: 46 | random_resized_crop: 47 | size: 224 48 | rand_augment: # No horizontal flip when rand-augment is enabled 49 | enable: true 50 | to_tensor: 51 | enable: true 52 | normalize: 53 | mean: [0.485, 0.456, 0.406] 54 | std: [0.229, 0.224, 0.225] 55 | mixup: 56 | enable: true 57 | alpha: 0.2 58 | p: 1.0 59 | cutmix: 60 | enable: true 61 | alpha: 1.0 62 | p: 1.0 63 | val: 64 | resize: 65 | size: 256 66 | center_crop: 67 | size: 224 68 | to_tensor: 69 | enable: true 70 | normalize: 71 | mean: [0.485, 0.456, 0.406] 72 | std: [0.229, 0.224, 0.225] 73 | -------------------------------------------------------------------------------- /configs/food101/ft.yaml: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023 Apple Inc. All rights reserved. 3 | # 4 | parameters: 5 | dataset: food101 6 | trainer: ERM 7 | arch: timm 8 | model: 9 | model_name: resnet50d 10 | pretrained: false 11 | num_classes: 101 12 | # TODO: load pretrained timm checkpoint 13 | loss: 14 | label_smoothing: 0.1 15 | optim: 16 | name: sgd 17 | lr: 0.002 18 | momentum: 0.9 19 | weight_decay: 5.e-4 20 | warmup_length: 0 21 | no_decay_bn_filter_bias: false 22 | nesterov: false 23 | epochs: 1000 24 | save_freq: 100 25 | start_epoch: 0 26 | batch_size: 256 27 | print_freq: 100 28 | resume: '' 29 | evaluate: false 30 | pretrained: false 31 | dist_url: 'tcp://127.0.0.1:23333' 32 | dist_backend: 'nccl' 33 | # Single GPU training 34 | multiprocessing_distributed: false 35 | world_size: 1 36 | rank: 0 37 | workers: 8 38 | pin_memory: true 39 | persistent_workers: true 40 | seed: NULL 41 | gpu: 0 42 | download_data: false # Set to True to download 43 | data_path: '' 44 | artifact_path: '' 45 | image_augmentation: 46 | train: 47 | random_resized_crop: 48 | size: 224 49 | rand_augment: # No horizontal flip when rand-augment is enabled 50 | enable: true 51 | to_tensor: 52 | enable: true 53 | normalize: 54 | mean: [0.485, 0.456, 0.406] 55 | std: [0.229, 0.224, 0.225] 56 | mixup: 57 | enable: true 58 | alpha: 0.2 59 | p: 1.0 60 | cutmix: 61 | enable: true 62 | alpha: 1.0 63 | p: 1.0 64 | val: 65 | resize: 66 | size: 256 67 | center_crop: 68 | size: 224 69 | to_tensor: 70 | enable: true 71 | normalize: 72 | mean: [0.485, 0.456, 0.406] 73 | std: [0.229, 0.224, 0.225] 74 | -------------------------------------------------------------------------------- /configs/food101/ft_plus.yaml: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023 Apple Inc. All rights reserved. 3 | # 4 | parameters: 5 | dataset: food101 6 | trainer: DR 7 | arch: timm 8 | model: 9 | model_name: resnet50d 10 | pretrained: false 11 | num_classes: 101 12 | # TODO: load pretrained timm checkpoint 13 | loss: 14 | label_smoothing: 0.1 15 | optim: 16 | name: sgd 17 | lr: 0.002 18 | momentum: 0.9 19 | weight_decay: 5.e-4 20 | warmup_length: 0 21 | no_decay_bn_filter_bias: false 22 | nesterov: false 23 | epochs: 1000 24 | save_freq: 100 25 | start_epoch: 0 26 | batch_size: 256 27 | print_freq: 100 28 | resume: '' 29 | evaluate: false 30 | pretrained: false 31 | dist_url: 'tcp://127.0.0.1:23333' 32 | dist_backend: 'nccl' 33 | # Single GPU training 34 | multiprocessing_distributed: false 35 | world_size: 1 36 | rank: 0 37 | workers: 8 38 | pin_memory: true 39 | persistent_workers: true 40 | seed: NULL 41 | gpu: 0 42 | download_data: false # Set to True to download 43 | data_path: '' 44 | artifact_path: '' 45 | image_augmentation: 46 | train: 47 | random_resized_crop: 48 | size: 224 49 | rand_augment: # No horizontal flip when rand-augment is enabled 50 | enable: true 51 | to_tensor: 52 | enable: true 53 | normalize: 54 | mean: [0.485, 0.456, 0.406] 55 | std: [0.229, 0.224, 0.225] 56 | val: 57 | resize: 58 | size: 256 59 | center_crop: 60 | size: 224 61 | to_tensor: 62 | enable: true 63 | normalize: 64 | mean: [0.485, 0.456, 0.406] 65 | std: [0.229, 0.224, 0.225] 66 | reinforce: 67 | enable: true 68 | p: 0.99 69 | num_samples: NULL 70 | densify: smooth 71 | data_path: NULL 72 | -------------------------------------------------------------------------------- /configs/food101/plus.yaml: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023 Apple Inc. All rights reserved. 3 | # 4 | parameters: 5 | dataset: food101 6 | trainer: DR 7 | arch: timm 8 | model: 9 | model_name: resnet50d 10 | pretrained: false 11 | num_classes: 101 12 | loss: 13 | label_smoothing: 0.1 14 | optim: 15 | name: sgd 16 | lr: 0.2 17 | momentum: 0.9 18 | weight_decay: 5.e-4 19 | warmup_length: 0 20 | no_decay_bn_filter_bias: false 21 | nesterov: false 22 | epochs: 1000 23 | save_freq: 100 24 | start_epoch: 0 25 | batch_size: 256 26 | print_freq: 100 27 | resume: '' 28 | evaluate: false 29 | pretrained: false 30 | dist_url: 'tcp://127.0.0.1:23333' 31 | dist_backend: 'nccl' 32 | # Single GPU training 33 | multiprocessing_distributed: false 34 | world_size: 1 35 | rank: 0 36 | workers: 8 37 | pin_memory: true 38 | persistent_workers: true 39 | seed: NULL 40 | gpu: 0 41 | download_data: false # Set to True to download 42 | data_path: '' 43 | artifact_path: '' 44 | image_augmentation: 45 | train: 46 | random_resized_crop: 47 | size: 224 48 | rand_augment: # No horizontal flip when rand-augment is enabled 49 | enable: true 50 | to_tensor: 51 | enable: true 52 | normalize: 53 | mean: [0.485, 0.456, 0.406] 54 | std: [0.229, 0.224, 0.225] 55 | val: 56 | resize: 57 | size: 256 58 | center_crop: 59 | size: 224 60 | to_tensor: 61 | enable: true 62 | normalize: 63 | mean: [0.485, 0.456, 0.406] 64 | std: [0.229, 0.224, 0.225] 65 | reinforce: 66 | enable: true 67 | p: 0.99 68 | num_samples: NULL 69 | densify: smooth 70 | data_path: NULL 71 | -------------------------------------------------------------------------------- /configs/food101/reinforce.yaml: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023 Apple Inc. All rights reserved. 3 | # 4 | parameters: 5 | dataset: food101 6 | teacher: 7 | ensemble: true 8 | # TODO: Load pretrained 9 | batch_size: 32 10 | print_freq: 10 11 | dist_url: 'tcp://127.0.0.1:23333' 12 | dist_backend: 'nccl' 13 | multiprocessing_distributed: true 14 | world_size: 1 15 | rank: 0 16 | workers: 88 17 | pin_memory: true 18 | persistent_workers: true 19 | seed: NULL 20 | gpu: NULL 21 | download_data: false # Set to True to download dataset 22 | data_path: '' 23 | artifact_path: '' 24 | reinforce: 25 | num_samples: 100 26 | num_candidates: NULL 27 | topk: 10 28 | compress: true 29 | joblib: false 30 | gzip: true 31 | image_augmentation: 32 | uint8: 33 | enable: true 34 | random_resize_crop: 35 | enable: true 36 | size: 224 37 | random_horizontal_flip: 38 | enable: true 39 | p: 0.5 40 | rand_augment: 41 | enable: true 42 | p: 1.0 43 | to_tensor: 44 | enable: true 45 | normalize: 46 | mean: [0.485, 0.456, 0.406] 47 | std: [0.229, 0.224, 0.225] 48 | random_erase: 49 | enable: true 50 | p: 0.25 51 | -------------------------------------------------------------------------------- /configs/food101/teacher.yaml: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023 Apple Inc. All rights reserved. 3 | # 4 | parameters: 5 | dataset: food101 6 | trainer: ERM 7 | arch: timm 8 | model: 9 | model_name: resnet50d 10 | pretrained: false 11 | num_classes: 101 12 | # TODO: load timm pretrained checkpoint 13 | loss: 14 | label_smoothing: 0.1 15 | optim: 16 | name: sgd 17 | lr: 0.002 18 | momentum: 0.9 19 | weight_decay: 5.e-4 20 | warmup_length: 0 21 | no_decay_bn_filter_bias: false 22 | nesterov: false 23 | epochs: 1000 24 | save_freq: 100 25 | start_epoch: 0 26 | batch_size: 256 27 | print_freq: 100 28 | resume: '' 29 | evaluate: false 30 | pretrained: false 31 | dist_url: 'tcp://127.0.0.1:23333' 32 | dist_backend: 'nccl' 33 | # Single GPU training 34 | multiprocessing_distributed: false 35 | world_size: 1 36 | rank: 0 37 | workers: 8 38 | pin_memory: true 39 | persistent_workers: true 40 | seed: NULL 41 | gpu: 0 42 | download_data: false # Set to True to download 43 | data_path: '' 44 | artifact_path: '' 45 | image_augmentation: 46 | train: 47 | random_resized_crop: 48 | size: 224 49 | rand_augment: # No horizontal flip when rand-augment is enabled 50 | enable: true 51 | to_tensor: 52 | enable: true 53 | normalize: 54 | mean: [0.485, 0.456, 0.406] 55 | std: [0.229, 0.224, 0.225] 56 | mixup: 57 | enable: true 58 | alpha: 1.0 # Strong mixup 59 | p: 1.0 60 | cutmix: 61 | enable: true 62 | alpha: 1.0 63 | p: 1.0 64 | val: 65 | resize: 66 | size: 256 67 | center_crop: 68 | size: 224 69 | to_tensor: 70 | enable: true 71 | normalize: 72 | mean: [0.485, 0.456, 0.406] 73 | std: [0.229, 0.224, 0.225] 74 | -------------------------------------------------------------------------------- /configs/imagenet/erm.yaml: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023 Apple Inc. All rights reserved. 3 | # 4 | parameters: 5 | dataset: imagenet 6 | trainer: ERM 7 | arch: timm 8 | model: 9 | model_name: resnet50d 10 | pretrained: false 11 | loss: 12 | label_smoothing: 0.1 13 | optim: 14 | name: sgd 15 | lr: 0.4 16 | momentum: 0.9 17 | weight_decay: 1.e-4 18 | warmup_length: 5 19 | epochs: 150 20 | save_freq: 50 21 | start_epoch: 0 22 | batch_size: 1024 23 | print_freq: 100 24 | resume: '' 25 | evaluate: false 26 | pretrained: false 27 | dist_url: 'tcp://127.0.0.1:23333' 28 | dist_backend: 'nccl' 29 | # Multi-GPU training 30 | multiprocessing_distributed: true 31 | world_size: 1 32 | rank: 0 33 | workers: 88 34 | pin_memory: true 35 | persistent_workers: true 36 | seed: NULL 37 | gpu: NULL 38 | download_data: false # Set to True to download 39 | data_path: '' 40 | artifact_path: '' 41 | image_augmentation: 42 | train: 43 | random_resized_crop: 44 | size: 224 45 | rand_augment: # No horizontal flip when rand-augment is enabled 46 | enable: true 47 | to_tensor: 48 | enable: true 49 | normalize: 50 | mean: [0.485, 0.456, 0.406] 51 | std: [0.229, 0.224, 0.225] 52 | random_erase: 53 | enable: true 54 | p: 0.25 55 | mixup: 56 | enable: true 57 | alpha: 0.2 58 | p: 1.0 59 | cutmix: 60 | enable: true 61 | alpha: 1.0 62 | p: 1.0 63 | val: 64 | resize: 65 | size: 256 66 | center_crop: 67 | size: 224 68 | to_tensor: 69 | enable: true 70 | normalize: 71 | mean: [0.485, 0.456, 0.406] 72 | std: [0.229, 0.224, 0.225] 73 | -------------------------------------------------------------------------------- /configs/imagenet/kd.yaml: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023 Apple Inc. All rights reserved. 3 | # 4 | parameters: 5 | dataset: imagenet 6 | trainer: KD 7 | teacher: 8 | timm_ensemble: true 9 | name: ig_resnext101_32x48d,ig_resnext101_32x32d,ig_resnext101_32x16d,ig_resnext101_32x8d 10 | validate: false 11 | student: 12 | arch: timm 13 | model: 14 | model_name: resnet50d 15 | pretrained: false 16 | loss: 17 | loss_type: kl 18 | lambda_cls: 0.0 19 | lambda_kd: 1.0 20 | temperature: 1.0 21 | optim: 22 | name: sgd 23 | lr: 0.4 24 | momentum: 0.9 25 | weight_decay: 1.e-5 # Smaller weight decay for KD 26 | warmup_length: 5 27 | epochs: 150 28 | save_freq: 50 29 | start_epoch: 0 30 | batch_size: 1024 31 | print_freq: 100 32 | resume: '' 33 | evaluate: false 34 | pretrained: false 35 | dist_url: 'tcp://127.0.0.1:23333' 36 | dist_backend: 'nccl' 37 | # Multi-GPU training 38 | multiprocessing_distributed: true 39 | world_size: 1 40 | rank: 0 41 | workers: 88 42 | pin_memory: true 43 | persistent_workers: true 44 | seed: NULL 45 | gpu: NULL 46 | download_data: false # Set to True to download 47 | data_path: '' 48 | artifact_path: '' 49 | image_augmentation: 50 | train: 51 | random_resized_crop: 52 | size: 224 53 | timm_resize_crop_norm: 54 | enable: true 55 | name: ig_resnext101_32x48d 56 | rand_augment: 57 | enable: true 58 | to_tensor: 59 | enable: true 60 | normalize: 61 | mean: [0.485, 0.456, 0.406] 62 | std: [0.229, 0.224, 0.225] 63 | random_erase: 64 | enable: true 65 | p: 0.25 66 | mixup: 67 | enable: true 68 | alpha: 1.0 # Stronger mixup for KD 69 | p: 1.0 70 | cutmix: 71 | enable: true 72 | alpha: 1.0 73 | p: 1.0 74 | val: 75 | resize: 76 | size: 256 77 | center_crop: 78 | size: 224 79 | to_tensor: 80 | enable: true 81 | normalize: 82 | mean: [0.485, 0.456, 0.406] 83 | std: [0.229, 0.224, 0.225] 84 | -------------------------------------------------------------------------------- /configs/imagenet/plus.yaml: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023 Apple Inc. All rights reserved. 3 | # 4 | parameters: 5 | dataset: imagenet 6 | trainer: DR 7 | arch: timm 8 | model: 9 | model_name: resnet50d 10 | pretrained: false 11 | loss: 12 | label_smoothing: 0.1 13 | optim: 14 | name: sgd 15 | lr: 0.4 16 | momentum: 0.9 17 | weight_decay: 1.e-4 18 | warmup_length: 5 19 | epochs: 150 20 | save_freq: 50 21 | start_epoch: 0 22 | batch_size: 1024 23 | print_freq: 100 24 | resume: '' 25 | evaluate: false 26 | pretrained: false 27 | dist_url: 'tcp://127.0.0.1:23333' 28 | dist_backend: 'nccl' 29 | # Multi-GPU training 30 | multiprocessing_distributed: true 31 | world_size: 1 32 | rank: 0 33 | workers: 88 34 | pin_memory: true 35 | persistent_workers: true 36 | seed: NULL 37 | gpu: NULL 38 | download_data: false # Set to True to download 39 | data_path: '' 40 | artifact_path: '' 41 | image_augmentation: 42 | # Training transformations for non-reinforced samples 43 | train: 44 | random_resized_crop: 45 | size: 224 46 | rand_augment: 47 | enable: true 48 | to_tensor: 49 | enable: true 50 | normalize: 51 | mean: [0.485, 0.456, 0.406] 52 | std: [0.229, 0.224, 0.225] 53 | random_erase: 54 | enable: true 55 | p: 0.25 56 | val: 57 | resize: 58 | size: 256 59 | center_crop: 60 | size: 224 61 | to_tensor: 62 | enable: true 63 | normalize: 64 | mean: [0.485, 0.456, 0.406] 65 | std: [0.229, 0.224, 0.225] 66 | # Reinforcement configs 67 | reinforce: 68 | enable: true 69 | p: 0.99 70 | num_samples: NULL 71 | densify: smooth 72 | data_path: NULL 73 | -------------------------------------------------------------------------------- /configs/imagenet/reinforce/cropflip.yaml: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023 Apple Inc. All rights reserved. 3 | # 4 | parameters: 5 | dataset: imagenet 6 | teacher: 7 | timm_ensemble: true 8 | name: 'ig_resnext101_32x8d,ig_resnext101_32x16d,ig_resnext101_32x32d,ig_resnext101_32x48d' 9 | batch_size: 32 10 | print_freq: 10 11 | dist_url: 'tcp://127.0.0.1:23333' 12 | dist_backend: 'nccl' 13 | multiprocessing_distributed: true 14 | world_size: 1 15 | rank: 0 16 | workers: 88 17 | pin_memory: true 18 | seed: NULL 19 | gpu: NULL 20 | download_data: false # Set to True to download dataset 21 | data_path: '' 22 | artifact_path: '' 23 | reinforce: 24 | num_samples: 50 25 | num_candidates: NULL 26 | topk: 10 27 | compress: true 28 | joblib: false 29 | gzip: true 30 | image_augmentation: 31 | uint8: 32 | enable: true 33 | random_resized_crop: 34 | enable: true 35 | size: 224 36 | random_horizontal_flip: 37 | enable: true 38 | p: 0.5 39 | to_tensor: 40 | enable: true 41 | normalize: 42 | mean: [0.485, 0.456, 0.406] 43 | std: [0.229, 0.224, 0.225] 44 | -------------------------------------------------------------------------------- /configs/imagenet/reinforce/mixing.yaml: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023 Apple Inc. All rights reserved. 3 | # 4 | parameters: 5 | dataset: imagenet 6 | teacher: 7 | timm_ensemble: true 8 | name: 'ig_resnext101_32x8d,ig_resnext101_32x16d,ig_resnext101_32x32d,ig_resnext101_32x48d' 9 | batch_size: 32 10 | print_freq: 10 11 | dist_url: 'tcp://127.0.0.1:23333' 12 | dist_backend: 'nccl' 13 | multiprocessing_distributed: true 14 | world_size: 1 15 | rank: 0 16 | workers: 88 17 | pin_memory: true 18 | seed: NULL 19 | gpu: NULL 20 | download_data: false # Set to True to download dataset 21 | data_path: '' 22 | artifact_path: '' 23 | reinforce: 24 | num_samples: 50 25 | num_candidates: NULL 26 | topk: 10 27 | compress: true 28 | joblib: false 29 | gzip: true 30 | image_augmentation: 31 | uint8: 32 | enable: true 33 | random_resized_crop: 34 | enable: true 35 | size: 224 36 | random_horizontal_flip: 37 | enable: true 38 | p: 0.5 39 | to_tensor: 40 | enable: true 41 | normalize: 42 | mean: [0.485, 0.456, 0.406] 43 | std: [0.229, 0.224, 0.225] 44 | mixup: 45 | enable: true 46 | alpha: 1.0 47 | div_by: 2.0 48 | p: 0.5 49 | cutmix: 50 | enable: true 51 | alpha: 1.0 52 | p: 0.5 53 | -------------------------------------------------------------------------------- /configs/imagenet/reinforce/randaug.yaml: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023 Apple Inc. All rights reserved. 3 | # 4 | parameters: 5 | dataset: imagenet 6 | teacher: 7 | timm_ensemble: true 8 | name: 'ig_resnext101_32x8d,ig_resnext101_32x16d,ig_resnext101_32x32d,ig_resnext101_32x48d' 9 | batch_size: 32 10 | print_freq: 10 11 | dist_url: 'tcp://127.0.0.1:23333' 12 | dist_backend: 'nccl' 13 | multiprocessing_distributed: true 14 | world_size: 1 15 | rank: 0 16 | workers: 88 17 | pin_memory: true 18 | seed: NULL 19 | gpu: NULL 20 | download_data: false # Set to True to download dataset 21 | data_path: '' 22 | artifact_path: '' 23 | reinforce: 24 | num_samples: 50 25 | num_candidates: NULL 26 | topk: 10 27 | compress: true 28 | joblib: false 29 | gzip: true 30 | image_augmentation: 31 | uint8: 32 | enable: true 33 | random_resized_crop: 34 | enable: true 35 | size: 224 36 | random_horizontal_flip: 37 | enable: true 38 | p: 0.5 39 | rand_augment: 40 | enable: true 41 | p: 1.0 42 | to_tensor: 43 | enable: true 44 | normalize: 45 | mean: [0.485, 0.456, 0.406] 46 | std: [0.229, 0.224, 0.225] 47 | random_erase: 48 | enable: true 49 | p: 0.25 50 | -------------------------------------------------------------------------------- /configs/imagenet/reinforce/randaug_mixing.yaml: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023 Apple Inc. All rights reserved. 3 | # 4 | parameters: 5 | dataset: imagenet 6 | teacher: 7 | timm_ensemble: true 8 | name: 'ig_resnext101_32x8d,ig_resnext101_32x16d,ig_resnext101_32x32d,ig_resnext101_32x48d' 9 | batch_size: 32 10 | print_freq: 10 11 | dist_url: 'tcp://127.0.0.1:23333' 12 | dist_backend: 'nccl' 13 | multiprocessing_distributed: true 14 | world_size: 1 15 | rank: 0 16 | workers: 88 17 | pin_memory: true 18 | seed: NULL 19 | gpu: NULL 20 | download_data: false # Set to True to download dataset 21 | data_path: '' 22 | artifact_path: '' 23 | reinforce: 24 | num_samples: 50 25 | num_candidates: NULL 26 | topk: 10 27 | compress: true 28 | joblib: false 29 | gzip: true 30 | image_augmentation: 31 | uint8: 32 | enable: true 33 | random_resized_crop: 34 | enable: true 35 | size: 224 36 | random_horizontal_flip: 37 | enable: true 38 | p: 0.5 39 | rand_augment: 40 | enable: true 41 | p: 1.0 42 | to_tensor: 43 | enable: true 44 | normalize: 45 | mean: [0.485, 0.456, 0.406] 46 | std: [0.229, 0.224, 0.225] 47 | random_erase: 48 | enable: true 49 | p: 0.25 50 | mixup: 51 | enable: true 52 | alpha: 1.0 53 | div_by: 2.0 54 | p: 0.5 55 | cutmix: 56 | enable: true 57 | alpha: 1.0 58 | p: 0.5 59 | -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023 Apple Inc. All rights reserved. 3 | # 4 | """Methods to initialize and create dataset loaders.""" 5 | import os 6 | import logging 7 | from typing import Any, Dict, Tuple 8 | 9 | from torch.utils.data import Dataset 10 | import torchvision.datasets as datasets 11 | import transforms 12 | 13 | 14 | def download_dataset(data_path: str, dataset: str) -> None: 15 | """Download dataset prior to spawning workers. 16 | 17 | Args: 18 | data_path: Path to the root of the dataset. 19 | dataset: The name of the dataset. 20 | """ 21 | if dataset == "imagenet": 22 | # ImageNet data requires manual download 23 | traindir = os.path.join(data_path, "training") 24 | valdir = os.path.join(data_path, "validation") 25 | assert os.path.isdir( 26 | traindir 27 | ), "Please download ImageNet training set to {}.".format(traindir) 28 | assert os.path.isdir( 29 | valdir 30 | ), "Please download ImageNet validation set to {}.".format(valdir) 31 | elif dataset == "cifar100": 32 | datasets.CIFAR100(root=data_path, train=True, download=True, transform=None) 33 | elif dataset == "flowers102": 34 | datasets.Flowers102( 35 | root=data_path, split="train", download=True, transform=None 36 | ) 37 | elif dataset == "food101": 38 | datasets.Food101(root=data_path, split="train", download=True, transform=None) 39 | 40 | 41 | def get_dataset_size(dataset: str) -> int: 42 | """Return dataset size to compute the number of iterations per epoch.""" 43 | if dataset == "imagenet": 44 | return 1281167 45 | elif dataset == "cifar100": 46 | return 50000 47 | elif dataset == "flowers102": 48 | return 1020 49 | elif dataset == "food101": 50 | return 75750 51 | 52 | 53 | def get_dataset_num_classes(dataset: str) -> int: 54 | """Return number of classes in a dataset.""" 55 | if dataset == "imagenet": 56 | return 1000 57 | elif dataset == "cifar100": 58 | return 100 59 | elif dataset == "flowers102": 60 | return 102 61 | elif dataset == "food101": 62 | return 101 63 | 64 | 65 | def get_dataset(config: Dict[str, Any]) -> Tuple[Dataset, Dataset]: 66 | """Return data loaders for training and validation sets.""" 67 | logging.info("Instantiating {} dataset.".format(config["dataset"])) 68 | 69 | dataset_name = config.get("dataset", None) 70 | if dataset_name is None: 71 | logging.error("Dataset name can't be None") 72 | dataset_name = dataset_name.lower() 73 | 74 | if dataset_name == "imagenet": 75 | return get_imagenet_dataset(config) 76 | elif dataset_name == "cifar100": 77 | return get_cifar100_dataset(config) 78 | elif dataset_name == "flowers102": 79 | return get_flowers102_dataset(config) 80 | elif dataset_name == "food101": 81 | return get_food101_dataset(config) 82 | else: 83 | raise NotImplementedError 84 | 85 | 86 | def get_cifar100_dataset(config) -> Tuple[Dataset, Dataset]: 87 | """Return training/test datasets for CIFAR-100 dataset. 88 | 89 | @TECHREPORT{Krizhevsky09learningmultiple, 90 | author = {Alex Krizhevsky}, 91 | title = {Learning multiple layers of features from tiny images}, 92 | institution = {}, 93 | year = {2009} 94 | } 95 | """ 96 | relative_path = config["data_path"] 97 | download_dataset = config.get("download_dataset", False) 98 | 99 | train_dataset = datasets.CIFAR100( 100 | root=relative_path, 101 | train=True, 102 | download=download_dataset, 103 | transform=transforms.compose_from_config(config["image_augmentation"]["train"]), 104 | ) 105 | 106 | val_dataset = datasets.CIFAR100( 107 | root=relative_path, 108 | train=False, 109 | download=False, 110 | transform=transforms.compose_from_config(config["image_augmentation"]["val"]), 111 | ) 112 | return train_dataset, val_dataset 113 | 114 | 115 | def get_imagenet_dataset(config) -> Tuple[Dataset, Dataset]: 116 | """Return training/validation datasets for ImageNet dataset.""" 117 | traindir = os.path.join(config["data_path"], "training") 118 | valdir = os.path.join(config["data_path"], "validation") 119 | 120 | train_dataset = datasets.ImageFolder( 121 | traindir, 122 | transforms.compose_from_config(config["image_augmentation"]["train"]), 123 | ) 124 | 125 | val_dataset = datasets.ImageFolder( 126 | valdir, 127 | transforms.compose_from_config(config["image_augmentation"]["val"]), 128 | ) 129 | return train_dataset, val_dataset 130 | 131 | 132 | def get_flowers102_dataset(config) -> Tuple[Dataset, Dataset]: 133 | """Return training/test datasets for Flowers-102 dataset. 134 | 135 | @InProceedings{Nilsback08, 136 | author = "Nilsback, M-E. and Zisserman, A.", 137 | title = "Automated Flower Classification over a Large Number of Classes", 138 | booktitle = "Proceedings of the Indian Conference on Computer Vision, Graphics 139 | and Image Processing", 140 | year = "2008", 141 | month = "Dec" 142 | } 143 | """ 144 | relative_path = config["data_path"] 145 | download_dataset = config.get("download_dataset", False) 146 | 147 | train_dataset = datasets.Flowers102( 148 | root=relative_path, 149 | split="train", 150 | download=download_dataset, 151 | transform=transforms.compose_from_config(config["image_augmentation"]["train"]), 152 | ) 153 | val_dataset = datasets.Flowers102( 154 | root=relative_path, 155 | split="test", 156 | download=False, 157 | transform=transforms.compose_from_config(config["image_augmentation"]["val"]), 158 | ) 159 | return train_dataset, val_dataset 160 | 161 | 162 | def get_food101_dataset(config) -> Tuple[Dataset, Dataset]: 163 | """Return training/test datasets for Food-101 dataset. 164 | 165 | @inproceedings{bossard14, 166 | title = {Food-101 -- Mining Discriminative Components with Random Forests}, 167 | author = {Bossard, Lukas and Guillaumin, Matthieu and Van Gool, Luc}, 168 | booktitle = {European Conference on Computer Vision}, 169 | year = {2014} 170 | } 171 | """ 172 | relative_path = config["data_path"] 173 | download_dataset = config.get("download_dataset", False) 174 | 175 | train_dataset = datasets.Food101( 176 | root=relative_path, 177 | split="train", 178 | download=download_dataset, 179 | transform=transforms.compose_from_config(config["image_augmentation"]["train"]), 180 | ) 181 | val_dataset = datasets.Food101( 182 | root=relative_path, 183 | split="test", 184 | download=False, 185 | transform=transforms.compose_from_config(config["image_augmentation"]["val"]), 186 | ) 187 | return train_dataset, val_dataset 188 | -------------------------------------------------------------------------------- /dr/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023 Apple Inc. All rights reserved. 3 | # 4 | 5 | -------------------------------------------------------------------------------- /dr/data.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023 Apple Inc. All rights reserved. 3 | # 4 | """Methods to initialize and create dataset loaders.""" 5 | import os 6 | import random 7 | from typing import Union, Tuple, Any, List, Dict 8 | 9 | import torch 10 | from torch import Tensor 11 | import torch.nn.parallel 12 | import torch.optim 13 | import torch.utils.data 14 | import torchvision 15 | from torch.utils.data import Dataset 16 | import torch.utils.data.distributed 17 | import dr.transforms as T_dr 18 | from dr.utils import densify 19 | from data import get_dataset_num_classes 20 | 21 | 22 | class ReinforceMetadata: 23 | """A class to load and return only the metadata of the reinforced dataset.""" 24 | 25 | def __init__(self, rdata_path: Union[str, List[str]]) -> None: 26 | """Iniatilize the metadata files and configuration.""" 27 | self.rdata_path = rdata_path if isinstance(rdata_path, list) else [rdata_path] 28 | assert any( 29 | [os.path.exists(rp) for rp in rdata_path] 30 | ), f"Please download reinforce metadata to {rdata_path}." 31 | self.rconfig = self.get_rconfig() 32 | 33 | def __getitem__(self, index: int) -> Tuple[int, List[List[float]]]: 34 | """Return reinforced metadata for a single data point at given index.""" 35 | p = random.randint(0, len(self.rdata_path) - 1) 36 | rdata = torch.load(os.path.join(self.rdata_path[p], "{}.pth.tar".format(index))) 37 | return rdata 38 | 39 | def get_rconfig(self) -> Dict[str, Any]: 40 | """Read configuration file for reinforcements.""" 41 | num_samples = 0 42 | for rdata in self.rdata_path: 43 | rconfig = torch.load(os.path.join(rdata, "config.pth.tar")) 44 | num_samples += rconfig["reinforce"]["num_samples"] 45 | rconfig["reinforce"]["num_samples"] = num_samples 46 | return rconfig 47 | 48 | 49 | class ReinforcedDataset(Dataset): 50 | """A class to reinforce a given dataset with metadata in rdata_path.""" 51 | 52 | def __init__( 53 | self, 54 | dataset: Dataset, 55 | rdata_path: Union[str, List[str]], 56 | config: Dict[str, Any], 57 | num_classes: int, 58 | ) -> None: 59 | """Initialize the metadata configuration and parameters.""" 60 | self.ds = dataset 61 | self.num_classes = num_classes 62 | self.densify_method = config.get("densify", "zeros") 63 | self.p = config.get("p", 1.0) or 1.0 64 | self.config = config 65 | 66 | # Use specified data augmentations only if sampling the original data 67 | self.transform_orig = dataset.transform 68 | dataset.transform = None 69 | 70 | self.r_metadata = ReinforceMetadata(rdata_path) 71 | rconfig = self.r_metadata.get_rconfig() 72 | 73 | # Initialize transformations from config 74 | self.transforms = T_dr.compose_from_config( 75 | rconfig["reinforce"]["image_augmentation"] 76 | ) 77 | 78 | def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor]: 79 | """Return the sample from the dataset with given index.""" 80 | # With probability self.p sample reinforced data, otherwise sample original 81 | p = random.random() 82 | 83 | if p < self.p: 84 | # Load reinforcement meta data for the image 85 | rdata = self.r_metadata[index] 86 | 87 | # Choose a random reinforcement 88 | assert rdata[0] == index, "Index does not match the metadata index." 89 | rdata = rdata[1] # tuple (id, rdata) 90 | i = random.randint(0, len(rdata) - 1) 91 | rdata_sample = rdata[i] 92 | params, target = rdata_sample["params"], rdata_sample["prob"] 93 | if isinstance(params, list): 94 | params = self.transforms.decompress(params) 95 | 96 | # Load image 97 | img, _ = self.ds[index] 98 | 99 | # Load the image pair if reinforcement has mixup/cutmix 100 | img2 = None 101 | if "cutmix" in params or "mixup" in params: 102 | img2, _ = self.ds[params["id2"]] 103 | 104 | # Reapply augmentation 105 | img, reparams = self.transforms.reapply(img, params, img2) 106 | for k, v in reparams.items(): 107 | assert v == params[k], "params changed." 108 | target = densify(target, self.num_classes, self.densify_method) 109 | else: 110 | # With probability 1-self.p sample the original data 111 | img, target = self.ds[index] 112 | if self.transform_orig: 113 | img = self.transform_orig(img) 114 | target = torch.nn.functional.one_hot( 115 | torch.tensor(target), num_classes=self.num_classes 116 | ).float() 117 | return img.detach(), target.detach() 118 | 119 | def __len__(self) -> int: 120 | """Return the number of samples in the dataset.""" 121 | return len(self.ds) 122 | 123 | 124 | class DatasetWithParameters: 125 | """A wrapper to the PyTorch datasets that transformation parameters.""" 126 | 127 | def __init__( 128 | self, 129 | dataset: torchvision.datasets.VisionDataset, 130 | transform: T_dr.Compose, 131 | num_samples: int, 132 | ) -> None: 133 | """Initialize and set the number of random crops per sample.""" 134 | self.ds = dataset 135 | self.num_samples = num_samples 136 | self.transform = transform 137 | self.ds.transform = None 138 | 139 | def __getitem__(self, index: int) -> Tuple[Tensor, int, Tensor]: 140 | """Return multiple random transformations of a sample at given index. 141 | 142 | Args: 143 | index: An integer that is the unique ID of a sample in the dataset. 144 | 145 | Returns: 146 | A Tuple of (inputs, target, params) of shape: 147 | sample_all: [num_samples,]+sample.shape 148 | target: int 149 | params: A dictionary of parameters, each value is a Tensor of shape 150 | [num,_samples, ...] 151 | """ 152 | sample, target = self.ds[index] 153 | sample_all, params_all = T_dr.before_collate_apply( 154 | sample, self.transform, self.num_samples 155 | ) 156 | return sample_all, target, params_all 157 | 158 | def __len__(self) -> int: 159 | """Return the number of samples in the dataset.""" 160 | return len(self.ds) 161 | 162 | 163 | class IndexedDataset(torch.utils.data.Dataset): 164 | """A wrapper to PyTorch datasets that returns an index.""" 165 | 166 | def __init__(self, dataset: torch.utils.data.Dataset) -> None: 167 | """Set the dataset.""" 168 | self.ds = dataset 169 | 170 | def __getitem__(self, index: int) -> Tuple[int, ...]: 171 | """Return a sample and the given index.""" 172 | data = self.ds[index] 173 | return (index, *data) 174 | 175 | def __len__(self) -> int: 176 | """Return the number of samples in the dataset.""" 177 | return len(self.ds) 178 | -------------------------------------------------------------------------------- /dr/transforms.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023 Apple Inc. All rights reserved. 3 | # 4 | 5 | """Extending transformations from torchvision to be reproducible.""" 6 | 7 | from collections import defaultdict 8 | from typing import List, OrderedDict, Union, Tuple, Optional, Any, Dict 9 | 10 | import torch 11 | from torch import Tensor 12 | import torch.nn.parallel 13 | import torch.optim 14 | import torch.utils.data 15 | import torch.utils.data.distributed 16 | import torchvision.transforms as T 17 | from torchvision.transforms import functional as F 18 | from torchvision.transforms.autoaugment import _apply_op 19 | 20 | import transforms 21 | from transforms import clean_config 22 | 23 | 24 | NO_PARAM = () 25 | NO_PARAM_TYPE = Tuple 26 | 27 | 28 | class Compressible: 29 | """Base class for reproducible transformations with compressible parameters.""" 30 | 31 | @staticmethod 32 | def compress_params(params: Any) -> Any: 33 | """Return compressed parameters.""" 34 | return params 35 | 36 | @staticmethod 37 | def decompress_params(params: Any) -> Any: 38 | """Return decompressed parameters.""" 39 | return params 40 | 41 | 42 | class Resize(T.Resize, Compressible): 43 | """Extending PyTorch's Resize to reapply a given transformation.""" 44 | 45 | def forward( 46 | self, img: Tensor, params: Optional[torch.Size] = None 47 | ) -> Tuple[Tensor, torch.Size]: 48 | """Transform an image randomly or reapply based on given parameters.""" 49 | img = super().forward(img) 50 | return img, self.size 51 | 52 | 53 | class CenterCrop(T.CenterCrop, Compressible): 54 | """Extending PyTorch's CenterCrop to reapply a given transformation.""" 55 | 56 | def forward( 57 | self, img: Tensor, params: Optional[NO_PARAM_TYPE] = None 58 | ) -> Tuple[Tensor, NO_PARAM_TYPE]: 59 | """Transform an image randomly or reapply based on given parameters.""" 60 | img = super().forward(img) 61 | # TODO: can we remove contiguous? 62 | img = img.contiguous() 63 | return img, NO_PARAM 64 | 65 | 66 | class RandomCrop(T.RandomCrop, Compressible): 67 | """Extending PyTorch's RandomCrop to reapply a given transformation.""" 68 | 69 | def __init__(self, *args, **kwargs) -> None: 70 | """Initialize super and set last parameters to None.""" 71 | super().__init__(*args, **kwargs) 72 | self.params = None 73 | 74 | def get_params(self, *args, **kwargs) -> Tuple[int, int, int, int]: 75 | """Return self.params or new transformation params if self.params not set.""" 76 | if self.params is None: 77 | self.params = super().get_params(*args, **kwargs) 78 | return self.params 79 | 80 | def forward( 81 | self, img: Tensor, params: Optional[Tuple[int, int]] = None 82 | ) -> Tuple[Tensor, Tuple[int, int]]: 83 | """Transform an image randomly or reapply based on given parameters.""" 84 | self.params = None 85 | if params is not None: 86 | # Add the constant value of size 87 | self.params = (params[0], params[1], self.size[0], self.size[1]) 88 | img = super().forward(img) 89 | params = self.params 90 | return img, params[:2] # Return only [top, left], ie random parameters 91 | 92 | 93 | class RandomResizedCrop(T.RandomResizedCrop, Compressible): 94 | """Extending PyTorch's RandomResizedCrop to reapply a given transformation.""" 95 | 96 | def __init__(self, *args, **kwargs) -> None: 97 | """Initialize super and set last parameters to None.""" 98 | super().__init__(*args, **kwargs) 99 | self.params = None 100 | 101 | def get_params(self, *args, **kwargs) -> Tuple[int, int, int, int]: 102 | """Return self.params or new transformation params if self.params not set.""" 103 | if self.params is None: 104 | self.params = super().get_params(*args, **kwargs) 105 | return self.params 106 | 107 | def forward( 108 | self, img: Tensor, params: Optional[Tuple[int, int, int, int]] = None 109 | ) -> Tuple[Tensor, Tuple[int, int, int, int]]: 110 | """Transform an image randomly or reapply based on given parameters.""" 111 | self.params = None 112 | if params is not None: 113 | self.params = params 114 | img = super().forward(img) 115 | params = self.params 116 | return img, params 117 | 118 | 119 | class RandomHorizontalFlip(T.RandomHorizontalFlip, Compressible): 120 | """Extending PyTorch's RandomHorizontalFlip to reapply a given transformation.""" 121 | 122 | def forward( 123 | self, img: Tensor, params: Optional[bool] = None 124 | ) -> Tuple[Tensor, bool]: 125 | """Transform an image randomly or reapply based on given parameters. 126 | 127 | Args: 128 | img (PIL Image or Tensor): Image to be flipped. 129 | 130 | Returns: 131 | PIL Image or Tensor: Randomly flipped image. 132 | """ 133 | if params is None: 134 | # Randomly skip only if params=None 135 | params = torch.rand(1).item() < self.p 136 | if params: 137 | img = F.hflip(img) 138 | return img, params 139 | 140 | 141 | class RandAugment(T.RandAugment, Compressible): 142 | """Extending PyTorch's RandAugment to reapply a given transformation.""" 143 | 144 | op_names = [ 145 | "Identity", 146 | "ShearX", 147 | "ShearY", 148 | "TranslateX", 149 | "TranslateY", 150 | "Rotate", 151 | "Brightness", 152 | "Color", 153 | "Contrast", 154 | "Sharpness", 155 | "Posterize", 156 | "Solarize", 157 | "AutoContrast", 158 | "Equalize", 159 | ] 160 | 161 | def __init__(self, p: float = 1.0, *args, **kwargs) -> None: 162 | """Initialize RandAugment with probability p of augmentation. 163 | 164 | Args: 165 | p: The probability of applying transformation. A float in [0, 1.0]. 166 | """ 167 | super().__init__(*args, **kwargs) 168 | self.p = p 169 | 170 | def forward( 171 | self, img: Tensor, params: Optional[List[Tuple[str, float]]] = None, **kwargs 172 | ) -> Tuple[Tensor, List[Tuple[str, float]]]: 173 | """Transform an image randomly or reapply based on given parameters. 174 | 175 | Args: 176 | img (PIL Image or Tensor): Image to be transformed. 177 | 178 | Returns: 179 | PIL Image or Tensor: Transformed image. 180 | """ 181 | fill = self.fill 182 | channels, height, width = F.get_dimensions(img) 183 | if isinstance(img, torch.Tensor): 184 | if isinstance(fill, (int, float)): 185 | fill = [float(fill)] * channels 186 | elif fill is not None: 187 | fill = [float(f) for f in fill] 188 | 189 | op_meta = self._augmentation_space(self.num_magnitude_bins, (height, width)) 190 | if params is None: 191 | # Randomly skip only if params=None 192 | if torch.rand(1) > self.p: 193 | return img, None 194 | 195 | params = [] 196 | for _ in range(self.num_ops): 197 | op_index = int(torch.randint(len(op_meta), (1,)).item()) 198 | op_name = list(op_meta.keys())[op_index] 199 | magnitudes, signed = op_meta[op_name] 200 | magnitude = ( 201 | float(magnitudes[self.magnitude].item()) 202 | if magnitudes.ndim > 0 203 | else 0.0 204 | ) 205 | if signed and torch.randint(2, (1,)): 206 | magnitude *= -1.0 207 | params += [(op_name, magnitude)] 208 | 209 | for i in range(self.num_ops): 210 | op_name, magnitude = params[i] 211 | img = _apply_op( 212 | img, op_name, magnitude, interpolation=self.interpolation, fill=fill 213 | ) 214 | 215 | return img, params 216 | 217 | @staticmethod 218 | def compress_params(params: List[Tuple[str, float]]) -> List[Tuple[int, float]]: 219 | """Return compressed parameters.""" 220 | if params is None: 221 | return None 222 | pc = [] 223 | for p in params: 224 | pc += [(RandAugment.op_names.index(p[0]), p[1])] 225 | return pc 226 | 227 | @staticmethod 228 | def decompress_params(params: List[Tuple[int, float]]) -> List[Tuple[str, float]]: 229 | """Return decompressed parameters.""" 230 | if params is None: 231 | return None 232 | pc = [] 233 | for p in params: 234 | pc += [(RandAugment.op_names[p[0]], p[1])] 235 | return pc 236 | 237 | 238 | class RandomErasing(T.RandomErasing, Compressible): 239 | """Extending PyTorch's RandomErasing to reapply a given transformation.""" 240 | 241 | def forward( 242 | self, img: Tensor, params: Optional[Tuple] = None, **kwargs 243 | ) -> Tuple[Tensor, Tuple]: 244 | """Transform an image randomly or reapply based on given parameters. 245 | 246 | Args: 247 | img (Tensor): Tensor image to be erased. 248 | 249 | Returns: 250 | img (Tensor): Erased Tensor image. 251 | """ 252 | if params is None: 253 | # Randomly skip only if params=None 254 | if torch.rand(1) > self.p: 255 | return img, None 256 | x, y, h, w, _ = self.get_params(img, scale=self.scale, ratio=self.ratio) 257 | else: 258 | x, y, h, w = params 259 | # In early experiments F.erase used in pytorch's RE was very slow 260 | # TODO: verify that F.erase is still slower than assigning zeros 261 | if x != -1: 262 | img[:, x : x + h, y : y + w] = 0 263 | return img, (x, y, h, w) 264 | 265 | 266 | class Normalize(T.Normalize, Compressible): 267 | """PyTorch's Normalize with an extra dummy transformation parameter.""" 268 | 269 | def forward( 270 | self, *args, params: Optional[NO_PARAM_TYPE] = None, **kwargs 271 | ) -> Tuple[Tensor, Tuple]: 272 | """Return normalized input and NO_PARAM as parameters.""" 273 | x = super().forward(*args, **kwargs) 274 | return x, NO_PARAM 275 | 276 | 277 | class MixUp(transforms.MixUp, Compressible): 278 | """Extending MixUp to reapply a given transformation.""" 279 | 280 | def __init__(self, *args, **kwargs) -> None: 281 | """Initialize super and set last parameters to None.""" 282 | super().__init__(*args, **kwargs) 283 | self.params = None 284 | 285 | def get_params(self, *args, **kwargs) -> float: 286 | """Return self.params or new transformation params if self.params not set.""" 287 | if self.params is None: 288 | self.params = super().get_params(*args, **kwargs) 289 | return self.params 290 | 291 | def forward( 292 | self, 293 | x: Tensor, 294 | x2: Tensor, 295 | y: Optional[Tensor] = None, 296 | y2: Optional[Tensor] = None, 297 | params: Dict[str, float] = None, 298 | ) -> Tuple[Tuple[Tensor, Tensor], Dict[str, float]]: 299 | """Transform an image randomly or reapply based on given parameters.""" 300 | self.params = None 301 | if params is not None: 302 | self.params = params 303 | x, y = super().forward(x, x2, y, y2) 304 | params = self.params 305 | return (x, y), params 306 | 307 | 308 | class CutMix(transforms.CutMix, Compressible): 309 | """Extending CutMix to reapply a given transformation.""" 310 | 311 | def __init__(self, *args, **kwargs) -> None: 312 | """Initialize super and set last parameters to None.""" 313 | super().__init__(*args, **kwargs) 314 | self.params = None 315 | 316 | def get_params(self, *args, **kwargs) -> Tuple[float, Tuple[int, int, int, int]]: 317 | """Return self.params or new transformation params if self.params not set.""" 318 | if self.params is None: 319 | self.params = super().get_params(*args, **kwargs) 320 | return self.params 321 | 322 | def forward( 323 | self, 324 | x: Tensor, 325 | x2: Tensor, 326 | y: Optional[Tensor] = None, 327 | y2: Optional[Tensor] = None, 328 | params: Dict[str, float] = None, 329 | ) -> Tuple[ 330 | Tuple[Tensor, Tensor], Dict[str, Union[float, Tuple[int, int, int, int]]] 331 | ]: 332 | """Transform an image randomly or reapply based on given parameters.""" 333 | self.params = None 334 | if params is not None: 335 | self.params = params 336 | x, y = super().forward(x, x2, y, y2) 337 | params = self.params 338 | return (x, y), params 339 | 340 | @staticmethod 341 | def compress_params(params: Any) -> Any: 342 | """Return compressed parameters.""" 343 | if params is None: 344 | return None 345 | return [params[0]]+list(params[1]) 346 | 347 | @staticmethod 348 | def decompress_params(params: Any) -> Any: 349 | """Return decompressed parameters.""" 350 | if params is None: 351 | return None 352 | return params[0], tuple(params[1:]) 353 | 354 | 355 | class ToUint8(torch.nn.Module, Compressible): 356 | """Convert float32 Tensor in range [0, 1] to uint8 [0, 255].""" 357 | 358 | def forward(self, img: Tensor, **kwargs) -> Tuple[Tensor, NO_PARAM_TYPE]: 359 | """Return uint8(img) and NO_PARAM as parameters.""" 360 | if not isinstance(img, torch.Tensor): 361 | return img, NO_PARAM 362 | return (img * 255).to(torch.uint8), NO_PARAM 363 | 364 | 365 | class ToTensor(torch.nn.Module, Compressible): 366 | """Convert PIL to torch.Tensor or if Tensor uint8 [0, 255] to float32 [0, 1].""" 367 | 368 | def forward(self, img: Tensor, **kwargs) -> Tuple[Tensor, NO_PARAM_TYPE]: 369 | """Return tensor(img) and NO_PARAM as parameters.""" 370 | if isinstance(img, torch.Tensor): 371 | """Return float32(img) and NO_PARAM as parameters.""" 372 | return (img / 255.0).to(torch.float32), NO_PARAM 373 | return F.to_tensor(img), NO_PARAM 374 | 375 | 376 | # Transformations are composed according to the order below, not the order in config 377 | TRANSFORMATION_TO_NAME = OrderedDict( 378 | [ 379 | ("uint8", ToUint8), 380 | ("resize", Resize), 381 | ("center_crop", CenterCrop), 382 | ("random_crop", RandomCrop), 383 | ("random_resized_crop", RandomResizedCrop), 384 | ("random_horizontal_flip", RandomHorizontalFlip), 385 | ("rand_augment", RandAugment), 386 | ("to_tensor", ToTensor), 387 | ("random_erase", RandomErasing), # TODO: fix the order of RE with transforms 388 | ("normalize", Normalize), 389 | ("mixup", MixUp), 390 | ("cutmix", CutMix), 391 | ] 392 | ) 393 | # Only in datagen 394 | BEFORE_COLLATE_TRANSFORMS = [ 395 | "uint8", 396 | "resize", 397 | "center_crop", 398 | "random_crop", 399 | "random_resized_crop", 400 | "to_tensor", 401 | ] 402 | NO_PARAM_TRANSFORMS = [ 403 | "uint8", 404 | "center_crop", 405 | "to_tensor", 406 | "normalize", 407 | ] 408 | 409 | 410 | class Compose: 411 | """Compose a list of reproducible data transformations.""" 412 | 413 | def __init__(self, transforms: List[Tuple[str, Compressible]]) -> None: 414 | """Initialize transformations.""" 415 | self.transforms = transforms 416 | 417 | def has_random_resized_crop(self) -> bool: 418 | """Return True if RandomResizedCrop is one of the transformations.""" 419 | return any([t.__class__ == RandomResizedCrop for _, t in self.transforms]) 420 | 421 | def __call__( 422 | self, 423 | img: Tensor, 424 | img2: Tensor = None, 425 | after_collate: Optional[bool] = False, 426 | ) -> Tuple[Tensor, Dict[str, Any]]: 427 | """Apply a transformation to two images and return augmentation parameters. 428 | 429 | Args: 430 | img: A tensor to be transformed. 431 | params: Transformation parameters to be reapplied. 432 | img2: Second tensor to be used for mixing transformations. 433 | 434 | The value of `params` can be None or empty in 3 cases: 435 | 1) `params=None` in apply(): The value should be generated randomly, 436 | 2) `params=None` in reapply(): Transformation was randomly skipped during 437 | generation time, 438 | 3) `params=()`: Trasformation has no random parameters. 439 | 440 | Returns: 441 | A Tuple of a transformed image and a dictionary with transformation 442 | parameters. 443 | """ 444 | params = dict() 445 | for t_name, t in self.transforms: 446 | if after_collate and ( 447 | t_name in BEFORE_COLLATE_TRANSFORMS 448 | and t_name != "uint8" 449 | and t_name != "to_tensor" 450 | ): 451 | # Skip transformations applied in data loader 452 | pass 453 | elif t_name == "cutmix" or t_name == "mixup": 454 | # Mix images 455 | if img2 is not None: 456 | (img, _), p = t(img, img2) 457 | params[t_name] = p 458 | else: 459 | # Apply an augmentation to both images, skip img2 if no mixing 460 | img, p = t(img) 461 | if img2 is not None: 462 | img2, p2 = t(img2) 463 | p = (p, p2) 464 | params[t_name] = p 465 | return img, params 466 | 467 | def reapply( 468 | self, img: Tensor, params: Dict[str, Any], img2: Tensor = None 469 | ) -> Tuple[Tensor, Dict[str, Any]]: 470 | """Reapply transformations to an image given augmentation parameters. 471 | 472 | Args: 473 | img: A tensor to be transformed. 474 | params: Transformation parameters to be reapplied. 475 | img2: Second tensor to be used for mixing transformations. 476 | 477 | The value of `params` can be None or empty in 3 cases: 478 | 1) `params=None` in apply(): The value should be generated randomly, 479 | 2) `params=None` in reapply(): Transformation was randomly skipped during 480 | generation time, 481 | 3) `params=()`: Trasformation has no random parameters. 482 | 483 | Returns: 484 | A Tuple of a transformed image and a dictionary with transformation 485 | parameters. 486 | """ 487 | for t_name, t in self.transforms: 488 | if t_name in params: 489 | if t_name == "cutmix" or t_name == "mixup": 490 | # Remix images 491 | if params[t_name] is not None: 492 | (img, _), _ = t(img, img2, params=params[t_name]) 493 | else: 494 | # Reapply an augmentation to both images, skip img2 if no 495 | # mixing 496 | if params[t_name][0] is not None: 497 | img, _ = t(img, params=params[t_name][0]) 498 | if img2 is not None and params[t_name][1] is not None: 499 | img2, _ = t(img2, params=params[t_name][1]) 500 | return img, params 501 | 502 | def compress(self, params: Dict[str, Any]) -> List[Any]: 503 | """Compress augmentation parameters.""" 504 | params_compressed = [] 505 | 506 | no_pair = True 507 | # Save second pair id if mixup or cutmix enabled 508 | t_names = [t[0] for t in self.transforms] 509 | if "mixup" in t_names or "cutmix" in t_names: 510 | if ( 511 | params.get("mixup", None) is not None 512 | or params.get("cutmix", None) is not None 513 | ): 514 | params_compressed += [params["id2"]] 515 | no_pair = False 516 | else: 517 | params_compressed += [None] 518 | 519 | # Save transformation parameters 520 | for t_name, t in self.transforms: 521 | p = params[t_name] 522 | if t_name in NO_PARAM_TRANSFORMS: 523 | pass 524 | elif t_name == "mixup" or t_name == "cutmix": 525 | params_compressed += [t.compress_params(p)] 526 | else: 527 | if no_pair: 528 | params_compressed += [t.compress_params(p[0])] 529 | else: 530 | params_compressed += [ 531 | [t.compress_params(p[0]), t.compress_params(p[1])] 532 | ] 533 | return params_compressed 534 | 535 | def decompress(self, params_compressed: List[Any]) -> Dict[str, Any]: 536 | """Decompress augmentation parameters.""" 537 | params = {} 538 | 539 | # Read second pair id if mixup or cutmix enabled 540 | t_names = [t[0] for t in self.transforms] 541 | no_pair = None 542 | if "mixup" in t_names or "cutmix" in t_names: 543 | no_pair = params_compressed[0] 544 | if no_pair is not None: 545 | params["id2"] = no_pair 546 | params_compressed = params_compressed[1:] 547 | 548 | # Read parameters for transformations with random parameters 549 | with_param_transforms = [(t_name, t) 550 | for t_name, t in self.transforms 551 | if t_name not in NO_PARAM_TRANSFORMS] 552 | for p, (t_name, t) in zip(params_compressed, with_param_transforms): 553 | if p is None: 554 | pass 555 | elif t_name == "mixup" or t_name == "cutmix": 556 | params[t_name] = t.decompress_params(p) 557 | else: 558 | if no_pair is not None and len(p) > 1: 559 | params[t_name] = ( 560 | t.decompress_params(p[0]), 561 | t.decompress_params(p[1]), 562 | ) 563 | else: 564 | params[t_name] = (t.decompress_params(p),) 565 | 566 | # Fill non-random transformations 567 | for t_name, t in self.transforms: 568 | if t_name in NO_PARAM_TRANSFORMS: 569 | params[t_name] = (NO_PARAM, NO_PARAM) 570 | return params 571 | 572 | 573 | def compose_from_config(config: Dict[str, Any]) -> Compose: 574 | """Initialize transformations given the dataset name and configurations. 575 | 576 | Args: 577 | config: A dictionary of augmentation parameters. 578 | """ 579 | config = clean_config(config) 580 | transforms = [] 581 | for t_name, t_class in TRANSFORMATION_TO_NAME.items(): 582 | if t_name in config: 583 | # TODO: warn for every key in config_tr that was not used 584 | transforms += [(t_name, t_class(**config[t_name]))] 585 | return Compose(transforms) 586 | 587 | 588 | def before_collate_config( 589 | config: Dict[str, Dict[str, Any]] 590 | ) -> Dict[str, Dict[str, Any]]: 591 | """Return configs with resize/crop transformations to pass to data loader. 592 | 593 | Only transformations that cannot be applied after data collate are 594 | composed. For example, RandomResizedCrop has to be applied before collate 595 | To create tensors of similar shapes. 596 | 597 | Args: 598 | config: A dictionary of augmentation parameters. 599 | """ 600 | return {k: v for k, v in config.items() if k in BEFORE_COLLATE_TRANSFORMS} 601 | 602 | 603 | def after_collate_config( 604 | config: Dict[str, Dict[str, Any]] 605 | ) -> Dict[str, Dict[str, Any]]: 606 | """Return configs after excluding transformations from `befor_collate_config`.""" 607 | return {k: v for k, v in config.items() if k not in BEFORE_COLLATE_TRANSFORMS} 608 | 609 | 610 | def before_collate_apply( 611 | sample: Tensor, transform: Compose, num_samples: int 612 | ) -> Tuple[Tensor, Tensor]: 613 | """Return multiple samples applying the transformations. 614 | 615 | Args: 616 | sample: A single sample to be randomly transformed. 617 | transform: A list of transformations to be applied. 618 | num_samples: The number of random transformations to be generated. 619 | 620 | Returns: 621 | Random transformations of the input. Shape: [num_samples,]+sample.shape 622 | """ 623 | sample_all = [] 624 | params_all = defaultdict(list) 625 | for _ in range(num_samples): 626 | # [height, width, channels] 627 | # -> ([height_new, width_new, channels], Dict(str, Tuple)) 628 | sample_new, params = transform(sample) 629 | sample_all.append(sample_new) 630 | for k, v in params.items(): 631 | params_all[k].append(v) 632 | 633 | sample_all = torch.stack(sample_all, axis=0) 634 | for k in params_all.keys(): 635 | params_all[k] = torch.tensor(params_all[k]) 636 | return sample_all, params_all 637 | -------------------------------------------------------------------------------- /dr/utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023 Apple Inc. All rights reserved. 3 | # 4 | """Utilities for training.""" 5 | from typing import List, Tuple, Optional 6 | 7 | import torch 8 | import numpy as np 9 | 10 | 11 | def sparsify(p: np.array, k: int) -> Tuple[List[float], List[float]]: 12 | """Sparsify a probabilities vector to k top values.""" 13 | ids = np.argpartition(p, -k)[-k:] 14 | return p[ids].tolist(), ids.tolist() 15 | 16 | 17 | def densify( 18 | sp: Tuple[List[float], List[float]], 19 | num_classes: int, 20 | method: Optional[str] = "smooth", 21 | ) -> torch.Tensor: 22 | """Densify a sparse probability vector.""" 23 | if not isinstance(sp, list) and not isinstance(sp, tuple): 24 | return torch.tensor(sp) 25 | sp, ids = torch.tensor(sp[0]), sp[1] 26 | r = 1.0 - sp.sum() # Max with 0 is needed if sp.sum is close to 1.0 27 | # assert r >= 0.0, f"Sum of sparse probabilities ({r}) should be less than 1.0." 28 | if method == "zeros" or r < 0.0: 29 | p = torch.zeros(num_classes) 30 | p[ids] = torch.nn.functional.softmax(sp, dim=0) 31 | elif method == "smooth": 32 | p = torch.ones(num_classes) * r / (num_classes - len(sp)) 33 | p[ids] = sp 34 | return p 35 | -------------------------------------------------------------------------------- /figures/DR_illustration_wide.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-dr/a21113c779714b68bcc8b6518f2a4ad149b29f00/figures/DR_illustration_wide.pdf -------------------------------------------------------------------------------- /figures/DR_illustration_wide.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-dr/a21113c779714b68bcc8b6518f2a4ad149b29f00/figures/DR_illustration_wide.png -------------------------------------------------------------------------------- /figures/imagenet_RRC+RARE_accuracy_annotated.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-dr/a21113c779714b68bcc8b6518f2a4ad149b29f00/figures/imagenet_RRC+RARE_accuracy_annotated.pdf -------------------------------------------------------------------------------- /figures/imagenet_RRC+RARE_accuracy_annotated.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-dr/a21113c779714b68bcc8b6518f2a4ad149b29f00/figures/imagenet_RRC+RARE_accuracy_annotated.png -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023 Apple Inc. All rights reserved. 3 | # 4 | 5 | """Methods to create, load, and ensemble models.""" 6 | import os 7 | from typing import Any, Dict, Generator, List, Optional, Tuple, Union 8 | import yaml 9 | import logging 10 | import torch 11 | import torch.nn as nn 12 | from torch.nn import functional as F 13 | import torchvision.models as models 14 | 15 | 16 | 17 | def move_to_device(model: torch.nn.Module, config: Dict[str, Any]) -> torch.nn.Module: 18 | """Wrap model with DDP/DP if distributed, convert to CUDA if GPU set, else CPU.""" 19 | if not torch.cuda.is_available(): 20 | logging.info("using CPU, this will be slow") 21 | elif config["distributed"]: 22 | ngpus_per_node = torch.cuda.device_count() 23 | # For multiprocessing distributed, DistributedDataParallel constructor 24 | # should always set the single device scope, otherwise, 25 | # DistributedDataParallel will use all available devices. 26 | if config["gpu"] is not None: 27 | torch.cuda.set_device(config["gpu"]) 28 | model.cuda(config["gpu"]) 29 | # When using a single GPU per process and per 30 | # DistributedDataParallel, we need to divide the batch size 31 | # ourselves based on the total number of GPUs we have 32 | config["batch_size"] = int(config["batch_size"] / ngpus_per_node) 33 | config["workers"] = int( 34 | (config["workers"] + ngpus_per_node - 1) / ngpus_per_node 35 | ) 36 | model = torch.nn.parallel.DistributedDataParallel( 37 | model, device_ids=[config["gpu"]] 38 | ) 39 | else: 40 | model.cuda() 41 | # DistributedDataParallel will divide and allocate batch_size to 42 | # all available GPUs if device_ids are not set 43 | model = torch.nn.parallel.DistributedDataParallel(model) 44 | elif config["gpu"] is not None: 45 | torch.cuda.set_device(config["gpu"]) 46 | model = model.cuda(config["gpu"]) 47 | else: 48 | # DataParallel will divide and allocate batch_size to all available 49 | # GPUs 50 | model = torch.nn.DataParallel(model).cuda() 51 | return model 52 | 53 | 54 | def load_model(gpu: torch.device, config: Dict[str, Any]) -> torch.nn.Module: 55 | """Load a pretrained model or an ensemble of pretrained models.""" 56 | if config.get("ensemble", False): 57 | # Load an ensemble from a checkpoint path 58 | config 59 | device = None 60 | if gpu is not None: 61 | device = "cuda:{}".format(gpu) 62 | # Load models 63 | members = torch.nn.ModuleList(load_ensemble(checkpoints_path, device)) 64 | model = ClassificationEnsembleNet(members) 65 | elif config.get("timm_ensemble", False): 66 | import timm 67 | # Load an ensemble of Timm models 68 | model_names = config.get("name", None) 69 | if gpu is not None: 70 | torch.cuda.set_device(gpu) 71 | # Load pretrained models 72 | members = torch.nn.ModuleList() 73 | if not isinstance(model_names, list): 74 | model_names = model_names.split(",") 75 | for m in model_names: 76 | members += [timm.create_model(m, pretrained=True)] 77 | # Create Ensemble 78 | model = ClassificationEnsembleNet(members) 79 | elif config.get("checkpoint", None) is not None: 80 | # Load a single pretrained model 81 | checkpoint_path = config["checkpoint"] 82 | arch = config["arch"] 83 | model = load_from_local(checkpoint_path, arch, gpu) 84 | else: 85 | # Use default pretrained model from pytorch. 86 | model = models.__dict__[config["arch"]](pretrained=True) 87 | return model 88 | 89 | 90 | def load_from_local(checkpoint_path: str, arch: str, gpu: int) -> torch.nn.Module: 91 | """Load model from local path and move to GPU if gpu set.""" 92 | teacher_model = models.__dict__[arch]() 93 | 94 | # Load from checkpoint 95 | if gpu is None: 96 | checkpoint = torch.load(checkpoint_path) 97 | else: 98 | # Map model to be loaded to specified single gpu. 99 | loc = "cuda:{}".format(gpu) 100 | checkpoint = torch.load(checkpoint_path, map_location=loc) 101 | 102 | # Strip module. from checkpoint 103 | ckpt = {k.replace("module.", ""): v for k, v in checkpoint["state_dict"].items()} 104 | teacher_model.load_state_dict(ckpt) 105 | logging.info("Loaded checkpoint {} for teacher".format(checkpoint_path)) 106 | return teacher_model 107 | 108 | 109 | def create_model(arch: str, config: dict) -> torch.nn.Module: 110 | """Create models from CVNets/Timm/Torch.""" 111 | if arch == "cvnets": 112 | import cvnets 113 | from cvnets import modeling_arguments 114 | import argparse 115 | # TODO: cvnets does not yet support easy model creation outside the library 116 | parser = argparse.ArgumentParser(description="") 117 | parser = modeling_arguments(parser) 118 | opts = parser.parse_args() 119 | config_dot = dict(convert_dict_to_dotted(config)) 120 | for k, v in config_dot.items(): 121 | if hasattr(opts, k): 122 | setattr(opts, k, v) 123 | setattr(opts, 'dataset.category', 'classification') 124 | model = cvnets.get_model(config_dot) 125 | elif arch == "timm": 126 | import timm 127 | model = timm.create_model(**config["model"]) 128 | else: 129 | model = models.__dict__[arch]() 130 | logging.info(model) 131 | return model 132 | 133 | 134 | class ClassificationEnsembleNet(nn.Module): 135 | """Ensemble model for classification based on averaging.""" 136 | 137 | def __init__(self, members: torch.nn.ModuleList) -> None: 138 | """Init ensemble.""" 139 | super().__init__() 140 | self.members = members 141 | 142 | def forward( 143 | self, x: torch.Tensor, return_prob: bool = False, temperature: float = 1.0 144 | ) -> torch.Tensor: 145 | """Reduce function for classification using averaging.""" 146 | output = 0 147 | for a_network in self.members: 148 | logits = a_network(x) 149 | prob = F.softmax(logits / temperature, dim=1) 150 | output = output + prob 151 | 152 | if return_prob: 153 | return output / float(len(self.members)) 154 | return (output / float(len(self.members))).log() 155 | 156 | 157 | def init_model_from_ckpt( 158 | model: torch.nn.Module, 159 | ckpt_path: str, 160 | device: torch.device, 161 | strict_keys: Optional[bool] = False, 162 | ) -> torch.nn.Module: 163 | """Init a model from an already trained model. 164 | 165 | Args: 166 | model: the pytorch model object to be loaded. 167 | ckpt_path: path to a model checkpoint. 168 | device: the device to load data to. Note that 169 | the model could be saved from a different device. 170 | Here we transfer the paramters to the current given device. So, 171 | a model could be trained and saved on GPU, and be loaded on CPU, 172 | for example. 173 | strict_keys: If True keys in state_dict of both models should be 174 | identical. 175 | """ 176 | ckpt = torch.load(ckpt_path, map_location=device) 177 | pretrained_dict = ckpt["state_dict"] 178 | # For incorrectly saved DataParallel models 179 | pretrained_dict = {k.replace("module.", ""): v for k, v in pretrained_dict.items()} 180 | model.load_state_dict(pretrained_dict, strict=strict_keys) 181 | 182 | return model 183 | 184 | 185 | def load_ensemble(checkpoints_path: str, device: torch.device) -> List[torch.nn.Module]: 186 | """Traverse all subdirs and load checkpoints.""" 187 | models = list() 188 | for root, dirs, files in os.walk(checkpoints_path): 189 | dirs.sort() 190 | # a directory is a legitimate checkpoint directory if the root has a config.yaml 191 | if "config.yaml" in files: 192 | with open(os.path.join(root, "config.yaml")) as f: 193 | model_config = yaml.safe_load(f).get("parameters") 194 | arch = model_config["arch"] 195 | if ( 196 | model_config.get("model", {}) 197 | .get("classification", {}) 198 | .get("pretrained", None) 199 | is not None 200 | ): 201 | model_config["model"]["classification"]["pretrained"] = None 202 | model = create_model(arch, model_config) 203 | ckpt_path = get_path_to_checkpoint(root) 204 | model = init_model_from_ckpt(model, ckpt_path, device) 205 | models.append(model) 206 | return models 207 | 208 | 209 | def get_path_to_checkpoint(artifact_dir: str, epoch=None) -> str: 210 | """Find checkpoint file path in an artifact directory. 211 | 212 | Args: 213 | artifact_dir: path to an experiment artifact directory, 214 | to laod checkpoints from there. 215 | epoch: If given tries to load that checkpoint, otherwise 216 | loads the latest. This function assumes checkpoints are saved 217 | as `checkpoint_epoch.tar' 218 | """ 219 | ckpts_path = os.path.join(artifact_dir, "checkpoints") 220 | ckpts_list = os.listdir(ckpts_path) 221 | ckpts_dict = { 222 | int(ckpt.split("_")[1].split(".")[0]): os.path.join(ckpts_path, ckpt) 223 | for ckpt in ckpts_list 224 | } 225 | if len(ckpts_list) == 0: 226 | msg = "No checkpoint exists!" 227 | raise ValueError(msg) 228 | if epoch is not None: 229 | if epoch not in ckpts_dict.keys(): 230 | msg = "Could not find checkpoint for epoch {} !" 231 | raise ValueError(msg.format(epoch)) 232 | else: 233 | epoch = max(ckpts_dict.keys()) 234 | return ckpts_dict[epoch] 235 | 236 | 237 | def convert_dict_to_dotted( 238 | c: Dict[str, Any], 239 | prefix: str = "", 240 | ) -> Generator[Union[Any, Tuple[str, Any]], None, None]: 241 | """Convert a nested dictionary of configs to flat dotted notation.""" 242 | if isinstance(c, dict): 243 | prefix += "." if prefix != "" else "" 244 | for k, v in c.items(): 245 | for x in convert_dict_to_dotted(v, prefix + k): 246 | yield x 247 | else: 248 | yield (prefix, c) 249 | -------------------------------------------------------------------------------- /reinforce.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023 Apple Inc. All rights reserved. 3 | # 4 | """Reinforce a dataset given a pretrained teacher and save the metadata.""" 5 | 6 | import os 7 | import time 8 | import argparse 9 | from typing import Any, Dict, Tuple 10 | import yaml 11 | import random 12 | import warnings 13 | import copy 14 | import joblib 15 | 16 | import torch 17 | import torch.nn.parallel 18 | import torch.distributed as dist 19 | import torch.backends.cudnn as cudnn 20 | import torch.multiprocessing as mp 21 | import torch.optim 22 | import torch.utils.data 23 | import torch.utils.data.distributed 24 | import dr.transforms as T_dr 25 | from dr.data import IndexedDataset, DatasetWithParameters 26 | from dr.utils import sparsify 27 | import internal 28 | import internal.hooks 29 | import internal.data 30 | from models import load_model 31 | 32 | from utils import ProgressMeter, AverageMeter 33 | from data import ( 34 | download_dataset, 35 | get_dataset, 36 | get_dataset_num_classes, 37 | ) 38 | 39 | import logging 40 | 41 | logging.basicConfig(format="%(asctime)s %(message)s", level=logging.INFO) 42 | 43 | 44 | def parse_args() -> argparse.Namespace: 45 | """Parse command-line arguments.""" 46 | parser = argparse.ArgumentParser(description="PyTorch Reinforcing") 47 | parser.add_argument( 48 | "--config", type=str, required=False, help="Path to a yaml config file." 49 | ) 50 | args = parser.parse_args() 51 | return args 52 | 53 | 54 | def main(args) -> None: 55 | """Reinforce a model with the configurations specified in given arguments.""" 56 | if args.config is not None: 57 | # Read parameters from yaml config for local run 58 | yaml_path = args.config 59 | with open(yaml_path, "r") as file: 60 | config = yaml.safe_load(file).get("parameters") 61 | 62 | dataset = config.get("dataset", "imagenet") 63 | config["num_classes"] = get_dataset_num_classes(dataset) 64 | 65 | if config["seed"] is not None: 66 | random.seed(config["seed"]) 67 | torch.manual_seed(config["seed"]) 68 | cudnn.deterministic = True 69 | cudnn.benchmark = False 70 | warnings.warn( 71 | "You have chosen to seed training. " 72 | "This will turn on the CUDNN deterministic setting, " 73 | "which can slow down your training considerably! " 74 | "You may see unexpected behavior when restarting " 75 | "from checkpoints." 76 | ) 77 | 78 | ngpus_per_node = torch.cuda.device_count() 79 | if config["gpu"] is not None: 80 | warnings.warn( 81 | "You have chosen a specific GPU. This will completely " 82 | "disable data parallelism." 83 | ) 84 | ngpus_per_node = 1 85 | 86 | if config["dist_url"] == "env://" and config["world_size"] == -1: 87 | config["world_size"] = int(os.environ["WORLD_SIZE"]) 88 | 89 | config["distributed"] = ( 90 | config["world_size"] > 1 or config["multiprocessing_distributed"] 91 | ) 92 | 93 | if config["download_data"]: 94 | config["data_path"] = download_dataset(config["data_path"], config["dataset"]) 95 | if config["multiprocessing_distributed"]: 96 | # Since we have ngpus_per_node processes per node, the total world_size 97 | # needs to be adjusted accordingly 98 | config["world_size"] = ngpus_per_node * config["world_size"] 99 | # Use torch.multiprocessing.spawn to launch distributed processes: the 100 | # main_worker process function 101 | mp.spawn( 102 | main_worker, 103 | nprocs=ngpus_per_node, 104 | args=(ngpus_per_node, copy.deepcopy(config)), 105 | ) 106 | 107 | else: 108 | # Simply call main_worker function 109 | main_worker(config["gpu"], ngpus_per_node, copy.deepcopy(config)) 110 | 111 | rdata_ext_path = os.path.join("/mnt/reinforced_data") 112 | fnames = next(os.walk(rdata_ext_path))[2] 113 | max_id = max( 114 | [ 115 | int(f.replace(".pth.tar", "").replace(".jb", "")) 116 | for f in fnames 117 | if (".pth.tar" in f or ".jb" in f) and f != "config.pth.tar" 118 | ] 119 | ) 120 | assert max_id + 1 == len(fnames), "Saved {} files but max id is {}.".format( 121 | len(fnames), max_id + 1 122 | ) 123 | logging.info("Saving new data.") 124 | torch.save(config, os.path.join(rdata_ext_path, "config.pth.tar")) 125 | artifact_path = config["artifact_path"] 126 | gzip = config["reinforce"]["gzip"] 127 | cmd = "tar c{}f {} -C {} reinforced_data".format( 128 | "z" if gzip else "", 129 | os.path.join( 130 | artifact_path, "reinforced_data.tar{}".format(".gz" if gzip else "") 131 | ), 132 | "/mnt/", 133 | ) 134 | os.system(cmd) 135 | 136 | 137 | def main_worker(gpu: int, ngpus_per_node: int, config: Dict[str, Any]) -> None: 138 | """Reinforce data with a single process. In distributed training, run on one GPU.""" 139 | config["gpu"] = gpu 140 | 141 | if config["gpu"] is not None: 142 | logging.info("Use GPU: {} for datagen".format(config["gpu"])) 143 | 144 | if config["distributed"]: 145 | if config["dist_url"] == "env://" and config["rank"] == -1: 146 | config["rank"] = int(os.environ["RANK"]) 147 | if config["multiprocessing_distributed"]: 148 | # For multiprocessing distributed training, rank needs to be the 149 | # global rank among all the processes 150 | config["rank"] = config["rank"] * ngpus_per_node + gpu 151 | dist.init_process_group( 152 | backend=config["dist_backend"], 153 | init_method=config["dist_url"], 154 | world_size=config["world_size"], 155 | rank=config["rank"], 156 | ) 157 | 158 | # Disable logging on all workers except rank=0 159 | if config["multiprocessing_distributed"] and config["rank"] % ngpus_per_node != 0: 160 | logging.basicConfig( 161 | format="%(asctime)s %(message)s", level=logging.WARNING, force=True 162 | ) 163 | 164 | model = load_model(gpu=config["gpu"], config=config["teacher"]) 165 | if config["gpu"] is not None: 166 | torch.cuda.set_device(config["gpu"]) 167 | model = model.cuda(config["gpu"]) 168 | else: 169 | model.cuda() 170 | 171 | # Set model to eval 172 | model.eval() 173 | 174 | # Print number of model parameters 175 | logging.info( 176 | "Number of model parameters: {}".format( 177 | sum(p.numel() for p in model.parameters() if p.requires_grad) 178 | ) 179 | ) 180 | 181 | # Data loading code 182 | logging.info("Instantiating dataset.") 183 | config["image_augmentation"] = {"train": {}, "val": {}} 184 | train_dataset, _ = get_dataset(config) 185 | num_samples = config["reinforce"]["num_samples"] 186 | # Set transforms to None and wrap the dataset 187 | train_dataset.transform = None 188 | tr = T_dr.compose_from_config( 189 | T_dr.before_collate_config(config["reinforce"]["image_augmentation"]) 190 | ) 191 | train_dataset = DatasetWithParameters( 192 | train_dataset, transform=tr, num_samples=num_samples 193 | ) 194 | 195 | train_dataset = IndexedDataset(train_dataset) 196 | if config["distributed"]: 197 | # When using a single GPU per process and per 198 | # DistributedDataParallel, we need to divide the batch size 199 | # ourselves based on the total number of GPUs we have 200 | config["batch_size"] = int(config["batch_size"] / ngpus_per_node) 201 | config["workers"] = int( 202 | (config["workers"] + ngpus_per_node - 1) / ngpus_per_node 203 | ) 204 | # Shuffle=True is better for mixing reinforcements 205 | train_sampler = torch.utils.data.distributed.DistributedSampler( 206 | train_dataset, shuffle=True, drop_last=False 207 | ) 208 | train_sampler.set_epoch(0) 209 | else: 210 | train_sampler = None 211 | 212 | train_loader = torch.utils.data.DataLoader( 213 | train_dataset, 214 | batch_size=config["batch_size"], 215 | shuffle=(train_sampler is None), 216 | drop_last=False, 217 | num_workers=config["workers"], 218 | pin_memory=config["pin_memory"], 219 | sampler=train_sampler, 220 | ) 221 | 222 | reinforce(train_loader, model, config) 223 | 224 | 225 | def reinforce( 226 | train_loader: torch.utils.data.DataLoader, 227 | model: torch.nn.Module, 228 | config: Dict[str, Any], 229 | ) -> None: 230 | """Generate reinforcements and save in individual files per training sample.""" 231 | batch_time = AverageMeter("Time", ":6.3f") 232 | progress = ProgressMeter(len(train_loader), [batch_time]) 233 | 234 | num_samples = config["reinforce"]["num_samples"] 235 | transforms = T_dr.compose_from_config(config["reinforce"]["image_augmentation"]) 236 | 237 | rdata_ext_path = os.path.join("/mnt/reinforced_data") 238 | os.makedirs(rdata_ext_path, exist_ok=True) 239 | 240 | with torch.no_grad(): 241 | end = time.time() 242 | for batch_i, data in enumerate(train_loader): 243 | ids, images, target, coords = data 244 | if config["gpu"] is not None: 245 | images = images.cuda(config["gpu"], non_blocking=True) 246 | target = target.cuda(config["gpu"], non_blocking=True) 247 | 248 | # Apply transformations 249 | images_aug, params_aug = transform_batch( 250 | ids, images, target, coords, transforms, config 251 | ) 252 | images_aug = images_aug.reshape((-1,) + images.shape[2:]) 253 | 254 | # Compute output 255 | prob = model(images_aug, return_prob=True) 256 | prob = prob.reshape((images.shape[0], num_samples, -1)) 257 | prob = prob.cpu().numpy() 258 | for j in range(images.shape[0]): 259 | new_samples = [ 260 | { 261 | "params": params_aug[j][k], 262 | "prob": sparsify(prob[j][k], config["reinforce"]["topk"]), 263 | } 264 | for k in range(num_samples) 265 | ] 266 | fname = os.path.join(rdata_ext_path, "{}.pth.tar".format(int(ids[j]))) 267 | if not os.path.exists(fname): 268 | if not config["reinforce"]["joblib"]: 269 | # protocol=4 gives ~1.3x compression vs proto=1 270 | torch.save((int(ids[j]), new_samples), fname, pickle_protocol=4) 271 | else: 272 | # joblib gives ~2.2x compression vs torch.save(proto=4) 273 | # but 1.6x slower to save and 7x slower to load 274 | joblib.dump((int(ids[j]), new_samples), fname, compress=True) 275 | 276 | # Measure elapsed time 277 | batch_time.update(time.time() - end) 278 | end = time.time() 279 | 280 | if batch_i % config["print_freq"] == 0: 281 | progress.display(batch_i) 282 | 283 | time.sleep(0.01) # Preventing broken pipe error 284 | 285 | 286 | def transform_batch( 287 | ids: torch.Tensor, 288 | images: torch.Tensor, 289 | target: torch.Tensor, 290 | coords: torch.Tensor, 291 | transforms: T_dr.Compose, 292 | config: Dict[str, Any], 293 | ) -> Tuple[torch.Tensor, torch.Tensor]: 294 | """Apply image transformations to a batch and return parameters. 295 | 296 | Args: 297 | ids: Image indices. Shape: [batch_size, 1] 298 | images: Image crops. If random-resized-crop is enabled, `imgaes` has 299 | multiple random crop. 300 | Shape without RRC: [batch_size, n_channels, crop_height, crop_width], 301 | Shape with RRC: [batch_size, n_samples, n_channels, crop_height, crop_width] 302 | target: Ground-truth labels. Shape: [batch_size, 1] 303 | coords: Random-resized-crop coordinates. Shape: [batch_size, n_samples, 4] 304 | transforms: A list of transformations to be applied randomly and stored. 305 | config: A dictionary of configurations with `reinforce` key. 306 | 307 | Returns: 308 | A tuple of transformed images and transformation parameters. 309 | Shape [0]: [batch_size, n_samples, n_channels, crop_height, crop_width] 310 | Shape [1]: A list of size [batch_size, n_samples]. 311 | """ 312 | num_samples = config["reinforce"]["num_samples"] 313 | images_aug = [] 314 | params_aug = [] 315 | for i in range(images.shape[0]): 316 | images_aug += [[]] 317 | params_aug += [[]] 318 | for j in range(num_samples): 319 | # Choose a sample with pre-collate transformation parameters 320 | img = images[i, j] 321 | 322 | # Sample an image pair 323 | index = random.randint(0, images.shape[0] - 1) 324 | j2 = random.randint(0, images.shape[1] - 1) 325 | img2 = images[index, j2] 326 | id2 = int(ids[index]) 327 | 328 | # Apply remaining transformations after collate 329 | img, params = transforms(img, img2, after_collate=True) 330 | 331 | # Update parameters with before collate transformations 332 | params0 = { 333 | k: (tuple(v[i, j].tolist()), tuple(v[index, j2].tolist())) 334 | for k, v in coords.items() 335 | } 336 | params.update(params0) 337 | 338 | if "mixup" in params or "cutmix" in params: 339 | params["id2"] = id2 340 | if config["reinforce"]["compress"]: 341 | params = transforms.compress(params) 342 | images_aug[-1] += [img] 343 | params_aug[-1] += [params] 344 | images_aug[-1] = torch.stack(images_aug[-1], axis=0) 345 | images_aug = torch.stack(images_aug, axis=0) 346 | return images_aug, params_aug 347 | 348 | 349 | if __name__ == "__main__": 350 | args = parse_args() 351 | main(args) 352 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | scikit-image 3 | torch 4 | torchvision 5 | timm 6 | -------------------------------------------------------------------------------- /results/table_E150.md: -------------------------------------------------------------------------------- 1 | | Name | Mode | Parameters | ImageNet | ImageNet+ | ImageNet | ImageNet+ | ImageNet Links | ImageNet+ Links | 2 | |:-------------|:---------|:-------------|:------------------------------------------------|:-----------------------------------------------------|:------------------------------------------------|:-----------------------------------------------------|:---|:---| 3 | | MobileNetV1 | 0.25 | 0.5M | 55.2 | 55.4 | 55.4 | 55.4 | [[best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/mobilenetv1_0.25_E150/best.pt) [[ema_best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/mobilenetv1_0.25_E150/ema_best.pt) [[config.yaml]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/mobilenetv1_0.25_E150/config.yaml) [[metrics.jb]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/mobilenetv1_0.25_E150/metrics.jb) | [[best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/mobilenetv1_0.25_E150/best.pt) [[ema_best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/mobilenetv1_0.25_E150/ema_best.pt) [[config.yaml]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/mobilenetv1_0.25_E150/config.yaml) [[metrics.jb]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/mobilenetv1_0.25_E150/metrics.jb) | 4 | | MobileNetV1 | 0.5 | 1.3M | 66.2 | 67.1 | 66.4 | 67.1 | [[best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/mobilenetv1_0.5_E150/best.pt) [[ema_best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/mobilenetv1_0.5_E150/ema_best.pt) [[config.yaml]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/mobilenetv1_0.5_E150/config.yaml) [[metrics.jb]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/mobilenetv1_0.5_E150/metrics.jb) | [[best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/mobilenetv1_0.5_E150/best.pt) [[ema_best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/mobilenetv1_0.5_E150/ema_best.pt) [[config.yaml]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/mobilenetv1_0.5_E150/config.yaml) [[metrics.jb]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/mobilenetv1_0.5_E150/metrics.jb) | 5 | | MobileNetV1 | 1.0 | 4.3M | 73.5 | 75.1 | 73.6 | 75.1 | [[best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/mobilenetv1_1.0_E150/best.pt) [[ema_best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/mobilenetv1_1.0_E150/ema_best.pt) [[config.yaml]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/mobilenetv1_1.0_E150/config.yaml) [[metrics.jb]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/mobilenetv1_1.0_E150/metrics.jb) | [[best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/mobilenetv1_1.0_E150/best.pt) [[ema_best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/mobilenetv1_1.0_E150/ema_best.pt) [[config.yaml]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/mobilenetv1_1.0_E150/config.yaml) [[metrics.jb]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/mobilenetv1_1.0_E150/metrics.jb) | 6 | | MobileNetV2 | 0.25 | 1.5M | 54.7 | 54.2 | 54.7 | 54.2 | [[best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/mobilenetv2_0.25_E150/best.pt) [[ema_best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/mobilenetv2_0.25_E150/ema_best.pt) [[config.yaml]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/mobilenetv2_0.25_E150/config.yaml) [[metrics.jb]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/mobilenetv2_0.25_E150/metrics.jb) | [[best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/mobilenetv2_0.25_E150/best.pt) [[ema_best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/mobilenetv2_0.25_E150/ema_best.pt) [[config.yaml]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/mobilenetv2_0.25_E150/config.yaml) [[metrics.jb]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/mobilenetv2_0.25_E150/metrics.jb) | 7 | | MobileNetV2 | 0.5 | 2.0M | 65.7 | 65.7 | 65.7 | 65.7 | [[best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/mobilenetv2_0.5_E150/best.pt) [[ema_best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/mobilenetv2_0.5_E150/ema_best.pt) [[config.yaml]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/mobilenetv2_0.5_E150/config.yaml) [[metrics.jb]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/mobilenetv2_0.5_E150/metrics.jb) | [[best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/mobilenetv2_0.5_E150/best.pt) [[ema_best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/mobilenetv2_0.5_E150/ema_best.pt) [[config.yaml]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/mobilenetv2_0.5_E150/config.yaml) [[metrics.jb]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/mobilenetv2_0.5_E150/metrics.jb) | 8 | | MobileNetV2 | 1.0 | 3.5M | 72.8 | 73.8 | 72.8 | 73.9 | [[best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/mobilenetv2_1.0_E150/best.pt) [[ema_best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/mobilenetv2_1.0_E150/ema_best.pt) [[config.yaml]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/mobilenetv2_1.0_E150/config.yaml) [[metrics.jb]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/mobilenetv2_1.0_E150/metrics.jb) | [[best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/mobilenetv2_1.0_E150/best.pt) [[ema_best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/mobilenetv2_1.0_E150/ema_best.pt) [[config.yaml]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/mobilenetv2_1.0_E150/config.yaml) [[metrics.jb]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/mobilenetv2_1.0_E150/metrics.jb) | 9 | | MobileNetV3 | small | 2.5M | 66.6 | 67.7 | 66.7 | 67.7 | [[best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/mobilenetv3_small_E150/best.pt) [[ema_best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/mobilenetv3_small_E150/ema_best.pt) [[config.yaml]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/mobilenetv3_small_E150/config.yaml) [[metrics.jb]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/mobilenetv3_small_E150/metrics.jb) | [[best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/mobilenetv3_small_E150/best.pt) [[ema_best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/mobilenetv3_small_E150/ema_best.pt) [[config.yaml]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/mobilenetv3_small_E150/config.yaml) [[metrics.jb]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/mobilenetv3_small_E150/metrics.jb) | 10 | | MobileNetV3 | large | 5.5M | 74.7 | 76.5 | 74.8 | 76.5 | [[best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/mobilenetv3_large_E150/best.pt) [[ema_best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/mobilenetv3_large_E150/ema_best.pt) [[config.yaml]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/mobilenetv3_large_E150/config.yaml) [[metrics.jb]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/mobilenetv3_large_E150/metrics.jb) | [[best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/mobilenetv3_large_E150/best.pt) [[ema_best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/mobilenetv3_large_E150/ema_best.pt) [[config.yaml]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/mobilenetv3_large_E150/config.yaml) [[metrics.jb]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/mobilenetv3_large_E150/metrics.jb) | 11 | | MobileViT | xx_small | 1.3M | 66.0 | 67.4 | 66.7 | 67.9 | [[best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/mobilevit_xx_small_E150/best.pt) [[ema_best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/mobilevit_xx_small_E150/ema_best.pt) [[config.yaml]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/mobilevit_xx_small_E150/config.yaml) [[metrics.jb]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/mobilevit_xx_small_E150/metrics.jb) | [[best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/mobilevit_xx_small_E150/best.pt) [[ema_best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/mobilevit_xx_small_E150/ema_best.pt) [[config.yaml]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/mobilevit_xx_small_E150/config.yaml) [[metrics.jb]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/mobilevit_xx_small_E150/metrics.jb) | 12 | | MobileViT | x_small | 2.3M | 72.6 | 74.0 | 73.3 | 74.7 | [[best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/mobilevit_x_small_E150/best.pt) [[ema_best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/mobilevit_x_small_E150/ema_best.pt) [[config.yaml]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/mobilevit_x_small_E150/config.yaml) [[metrics.jb]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/mobilevit_x_small_E150/metrics.jb) | [[best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/mobilevit_x_small_E150/best.pt) [[ema_best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/mobilevit_x_small_E150/ema_best.pt) [[config.yaml]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/mobilevit_x_small_E150/config.yaml) [[metrics.jb]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/mobilevit_x_small_E150/metrics.jb) | 13 | | MobileViT | small | 5.6M | 76.3 | 78.3 | 76.7 | 78.6 | [[best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/mobilevit_small_E150/best.pt) [[ema_best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/mobilevit_small_E150/ema_best.pt) [[config.yaml]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/mobilevit_small_E150/config.yaml) [[metrics.jb]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/mobilevit_small_E150/metrics.jb) | [[best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/mobilevit_small_E150/best.pt) [[ema_best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/mobilevit_small_E150/ema_best.pt) [[config.yaml]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/mobilevit_small_E150/config.yaml) [[metrics.jb]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/mobilevit_small_E150/metrics.jb) | 14 | | ResNet | 18 | 11.7M | 69.9 | 73.2 | 69.8 | 73.2 | [[best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/resnet_18_E150/best.pt) [[ema_best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/resnet_18_E150/ema_best.pt) [[config.yaml]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/resnet_18_E150/config.yaml) [[metrics.jb]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/resnet_18_E150/metrics.jb) | [[best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/resnet_18_E150/best.pt) [[ema_best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/resnet_18_E150/ema_best.pt) [[config.yaml]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/resnet_18_E150/config.yaml) [[metrics.jb]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/resnet_18_E150/metrics.jb) | 15 | | ResNet | 34 | 21.8M | 74.6 | 76.9 | 74.7 | 76.9 | [[best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/resnet_34_E150/best.pt) [[ema_best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/resnet_34_E150/ema_best.pt) [[config.yaml]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/resnet_34_E150/config.yaml) [[metrics.jb]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/resnet_34_E150/metrics.jb) | [[best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/resnet_34_E150/best.pt) [[ema_best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/resnet_34_E150/ema_best.pt) [[config.yaml]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/resnet_34_E150/config.yaml) [[metrics.jb]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/resnet_34_E150/metrics.jb) | 16 | | ResNet | 50 | 25.6M | 79.0 | 80.3 | 79.1 | 80.3 | [[best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/resnet_50_E150/best.pt) [[ema_best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/resnet_50_E150/ema_best.pt) [[config.yaml]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/resnet_50_E150/config.yaml) [[metrics.jb]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/resnet_50_E150/metrics.jb) | [[best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/resnet_50_E150/best.pt) [[ema_best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/resnet_50_E150/ema_best.pt) [[config.yaml]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/resnet_50_E150/config.yaml) [[metrics.jb]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/resnet_50_E150/metrics.jb) | 17 | | ResNet | 101 | 44.6M | 80.5 | 81.8 | 80.5 | 81.9 | [[best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/resnet_101_E150/best.pt) [[ema_best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/resnet_101_E150/ema_best.pt) [[config.yaml]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/resnet_101_E150/config.yaml) [[metrics.jb]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/resnet_101_E150/metrics.jb) | [[best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/resnet_101_E150/best.pt) [[ema_best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/resnet_101_E150/ema_best.pt) [[config.yaml]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/resnet_101_E150/config.yaml) [[metrics.jb]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/resnet_101_E150/metrics.jb) | 18 | | ResNet | 152 | 60.3M | 81.3 | 82.2 | 81.3 | 82.3 | [[best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/resnet_152_E150/best.pt) [[ema_best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/resnet_152_E150/ema_best.pt) [[config.yaml]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/resnet_152_E150/config.yaml) [[metrics.jb]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/resnet_152_E150/metrics.jb) | [[best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/resnet_152_E150/best.pt) [[ema_best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/resnet_152_E150/ema_best.pt) [[config.yaml]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/resnet_152_E150/config.yaml) [[metrics.jb]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/resnet_152_E150/metrics.jb) | 19 | | EfficientNet | b2 | 9.3M | 79.5 | 81.5 | 79.5 | 81.6 | [[best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/efficientnet_b2_E150/best.pt) [[ema_best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/efficientnet_b2_E150/ema_best.pt) [[config.yaml]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/efficientnet_b2_E150/config.yaml) [[metrics.jb]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/efficientnet_b2_E150/metrics.jb) | [[best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/efficientnet_b2_E150/best.pt) [[ema_best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/efficientnet_b2_E150/ema_best.pt) [[config.yaml]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/efficientnet_b2_E150/config.yaml) [[metrics.jb]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/efficientnet_b2_E150/metrics.jb) | 20 | | EfficientNet | b3 | 12.4M | 80.9 | 82.4 | 80.8 | 82.5 | [[best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/efficientnet_b3_E150/best.pt) [[ema_best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/efficientnet_b3_E150/ema_best.pt) [[config.yaml]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/efficientnet_b3_E150/config.yaml) [[metrics.jb]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/efficientnet_b3_E150/metrics.jb) | [[best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/efficientnet_b3_E150/best.pt) [[ema_best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/efficientnet_b3_E150/ema_best.pt) [[config.yaml]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/efficientnet_b3_E150/config.yaml) [[metrics.jb]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/efficientnet_b3_E150/metrics.jb) | 21 | | EfficientNet | b4 | 19.7M | 82.7 | 83.6 | 82.7 | 83.7 | [[best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/efficientnet_b4_E150/best.pt) [[ema_best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/efficientnet_b4_E150/ema_best.pt) [[config.yaml]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/efficientnet_b4_E150/config.yaml) [[metrics.jb]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/efficientnet_b4_E150/metrics.jb) | [[best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/efficientnet_b4_E150/best.pt) [[ema_best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/efficientnet_b4_E150/ema_best.pt) [[config.yaml]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/efficientnet_b4_E150/config.yaml) [[metrics.jb]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/efficientnet_b4_E150/metrics.jb) | 22 | | ViT | tiny | 5.6M | 72.1 | 74.3 | 72.1 | 74.4 | [[best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/vit_tiny_E150/best.pt) [[ema_best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/vit_tiny_E150/ema_best.pt) [[config.yaml]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/vit_tiny_E150/config.yaml) [[metrics.jb]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/vit_tiny_E150/metrics.jb) | [[best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/vit_tiny_E150/best.pt) [[ema_best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/vit_tiny_E150/ema_best.pt) [[config.yaml]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/vit_tiny_E150/config.yaml) [[metrics.jb]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/vit_tiny_E150/metrics.jb) | 23 | | ViT | small | 21.9M | 78.4 | 79.8 | 78.7 | 79.9 | [[best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/vit_small_E150/best.pt) [[ema_best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/vit_small_E150/ema_best.pt) [[config.yaml]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/vit_small_E150/config.yaml) [[metrics.jb]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/vit_small_E150/metrics.jb) | [[best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/vit_small_E150/best.pt) [[ema_best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/vit_small_E150/ema_best.pt) [[config.yaml]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/vit_small_E150/config.yaml) [[metrics.jb]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/vit_small_E150/metrics.jb) | 24 | | ViT | base | 86.7M | 79.5 | 81.7 | 80.6 | 81.7 | [[best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/vit_base_E150/best.pt) [[ema_best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/vit_base_E150/ema_best.pt) [[config.yaml]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/vit_base_E150/config.yaml) [[metrics.jb]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/vit_base_E150/metrics.jb) | [[best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/vit_base_E150/best.pt) [[ema_best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/vit_base_E150/ema_best.pt) [[config.yaml]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/vit_base_E150/config.yaml) [[metrics.jb]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/vit_base_E150/metrics.jb) | 25 | | ViT-384 | base | 86.7M | 80.5 | 83.0 | 81.9 | 83.1 | [[best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/vit-384_base_E150/best.pt) [[ema_best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/vit-384_base_E150/ema_best.pt) [[config.yaml]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/vit-384_base_E150/config.yaml) [[metrics.jb]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/vit-384_base_E150/metrics.jb) | [[best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/vit-384_base_E150/best.pt) [[ema_best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/vit-384_base_E150/ema_best.pt) [[config.yaml]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/vit-384_base_E150/config.yaml) [[metrics.jb]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/vit-384_base_E150/metrics.jb) | 26 | | Swin | tiny | 28.3M | 80.5 | 82.1 | 80.3 | 81.9 | [[best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/swin_tiny_E150/best.pt) [[ema_best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/swin_tiny_E150/ema_best.pt) [[config.yaml]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/swin_tiny_E150/config.yaml) [[metrics.jb]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/swin_tiny_E150/metrics.jb) | [[best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/swin_tiny_E150/best.pt) [[ema_best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/swin_tiny_E150/ema_best.pt) [[config.yaml]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/swin_tiny_E150/config.yaml) [[metrics.jb]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/swin_tiny_E150/metrics.jb) | 27 | | Swin | small | 49.7M | 82.2 | 83.6 | 81.9 | 83.3 | [[best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/swin_small_E150/best.pt) [[ema_best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/swin_small_E150/ema_best.pt) [[config.yaml]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/swin_small_E150/config.yaml) [[metrics.jb]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/swin_small_E150/metrics.jb) | [[best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/swin_small_E150/best.pt) [[ema_best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/swin_small_E150/ema_best.pt) [[config.yaml]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/swin_small_E150/config.yaml) [[metrics.jb]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/swin_small_E150/metrics.jb) | 28 | | Swin | base | 87.8M | 82.7 | 83.9 | 82.2 | 83.7 | [[best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/swin_base_E150/best.pt) [[ema_best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/swin_base_E150/ema_best.pt) [[config.yaml]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/swin_base_E150/config.yaml) [[metrics.jb]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/swin_base_E150/metrics.jb) | [[best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/swin_base_E150/best.pt) [[ema_best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/swin_base_E150/ema_best.pt) [[config.yaml]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/swin_base_E150/config.yaml) [[metrics.jb]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/swin_base_E150/metrics.jb) | 29 | | Swin-384 | base | 87.8M | 82.6 | 83.2 | 82.4 | 83.0 | [[best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/swin-384_base_E150/best.pt) [[ema_best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/swin-384_base_E150/ema_best.pt) [[config.yaml]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/swin-384_base_E150/config.yaml) [[metrics.jb]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/swin-384_base_E150/metrics.jb) | [[best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/swin-384_base_E150/best.pt) [[ema_best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/swin-384_base_E150/ema_best.pt) [[config.yaml]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/swin-384_base_E150/config.yaml) [[metrics.jb]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/swin-384_base_E150/metrics.jb) | 30 | -------------------------------------------------------------------------------- /tests/test_reinforce.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023 Apple Inc. All rights reserved. 3 | # 4 | 5 | """Tests for dataset generation.""" 6 | 7 | from typing import Any, Dict 8 | import pytest 9 | import torch 10 | 11 | import dr.transforms as T_dr 12 | from reinforce import transform_batch 13 | 14 | # from torchvision.transforms import InterpolationMode, functional as F 15 | 16 | 17 | def single_transform_config_test( 18 | config: Dict[str, Any], 19 | num_images: int, 20 | num_samples: int, 21 | height: int, 22 | width: int, 23 | crop_height: int, 24 | crop_width: int, 25 | compress: bool, 26 | single_image_test: bool, 27 | ) -> None: 28 | """Test applying and reapplying transformations to a batch of images.""" 29 | ids = torch.arange(num_images) 30 | images_orig = torch.rand(size=(num_images, 3, height, width)) 31 | images = torch.zeros(num_images, num_samples, 3, crop_height, crop_width) 32 | target = torch.randint(low=0, high=10, size=(num_images,)) 33 | # Simulate data loader 34 | transforms = T_dr.compose_from_config( 35 | T_dr.before_collate_config(config["reinforce"]["image_augmentation"]) 36 | ) 37 | 38 | images = [] 39 | coords = [] 40 | for i in range(num_images): 41 | sample_all, params_all = T_dr.before_collate_apply( 42 | images_orig[i], transforms, num_samples 43 | ) 44 | images += [sample_all] 45 | coords += [params_all] 46 | images = torch.utils.data.default_collate(images) 47 | coords = torch.utils.data.default_collate(coords) 48 | 49 | # Apply after collate transformations 50 | transforms = T_dr.compose_from_config(config["reinforce"]["image_augmentation"]) 51 | images_aug, params_aug = transform_batch( 52 | ids, images, target, coords, transforms, config 53 | ) 54 | assert images_aug.shape == ( 55 | num_images, 56 | num_samples, 57 | 3, 58 | crop_height, 59 | crop_width, 60 | ), "Incorrect shape of transformed image." 61 | transforms = T_dr.compose_from_config(config["reinforce"]["image_augmentation"]) 62 | 63 | # Single test 64 | # TODO: support transforms(img) and reapply(img, params) without img2 65 | if single_image_test: 66 | img, param = transforms(images_orig[0], images_orig[0]) 67 | img2, param2 = transforms.reapply(images_orig[0], param) 68 | torch.testing.assert_close( 69 | actual=img, 70 | expected=img2, 71 | ) 72 | 73 | if num_images > 1 and num_samples > 1: 74 | assert ( 75 | len(set([str(q) for p in params_aug for q in p])) > 1 76 | ), "Parameters are not random." 77 | 78 | # Test reapply 79 | for i in range(num_images): 80 | for j in range(num_samples): 81 | if compress: 82 | params_aug[i][j] = transforms.decompress(params_aug[i][j]) 83 | img2 = None 84 | if "mixup" in params_aug[i][j] or "cutmix" in params_aug[i][j]: 85 | img2 = images_orig[params_aug[i][j]["id2"]] 86 | out, _ = transforms.reapply(images_orig[i], params_aug[i][j], img2) 87 | torch.testing.assert_close(actual=out, expected=images_aug[i][j]) 88 | 89 | 90 | @pytest.mark.parametrize("compress", [False, True]) 91 | def test_random_resized_crop(compress: bool) -> None: 92 | """Test RRC with other transformations.""" 93 | num_images = 2 94 | num_samples = 4 95 | height = 10 96 | width = 10 97 | crop_height = 5 98 | crop_width = 5 99 | single_image_test = True 100 | config = { 101 | "reinforce": { 102 | "num_samples": num_samples, 103 | "compress": compress, 104 | "image_augmentation": { 105 | "uint8": {"enable": True}, 106 | "random_resized_crop": { 107 | "enable": True, 108 | "size": [crop_height, crop_width], 109 | }, 110 | "random_horizontal_flip": {"enable": True, "p": 0.5}, 111 | "rand_augment": {"enable": True, "p": 0.5}, 112 | "to_tensor": {"enable": True}, 113 | "normalize": { 114 | "enable": True, 115 | "mean": [0.485, 0.456, 0.406], 116 | "std": [0.229, 0.224, 0.225], 117 | }, 118 | "random_erase": {"enable": True, "p": 0.25}, 119 | }, 120 | } 121 | } 122 | single_transform_config_test( 123 | config, 124 | num_images, 125 | num_samples, 126 | height, 127 | width, 128 | crop_height, 129 | crop_width, 130 | compress, 131 | single_image_test, 132 | ) 133 | 134 | 135 | @pytest.mark.parametrize("compress", [False, True]) 136 | def test_center_crop(compress: bool) -> None: 137 | """Test center-crop with other transformations.""" 138 | num_images = 2 139 | num_samples = 4 140 | height = 10 141 | width = 10 142 | crop_height = 5 143 | crop_width = 5 144 | single_image_test = True 145 | config = { 146 | "reinforce": { 147 | "num_samples": num_samples, 148 | "compress": compress, 149 | "image_augmentation": { 150 | "uint8": {"enable": True}, 151 | "center_crop": {"enable": True, "size": [crop_height, crop_width]}, 152 | "random_horizontal_flip": {"enable": True, "p": 0.5}, 153 | "rand_augment": {"enable": True, "p": 0.5}, 154 | "to_tensor": {"enable": True}, 155 | "normalize": { 156 | "enable": True, 157 | "mean": [0.485, 0.456, 0.406], 158 | "std": [0.229, 0.224, 0.225], 159 | }, 160 | "random_erase": {"enable": True, "p": 0.25}, 161 | }, 162 | } 163 | } 164 | single_transform_config_test( 165 | config, 166 | num_images, 167 | num_samples, 168 | height, 169 | width, 170 | crop_height, 171 | crop_width, 172 | compress, 173 | single_image_test, 174 | ) 175 | 176 | 177 | @pytest.mark.parametrize("compress", [False, True]) 178 | def test_mixing(compress: bool) -> None: 179 | """Test MixUp/CutMix with other transformations.""" 180 | num_images = 10 181 | num_samples = 20 182 | height = 256 183 | width = 256 184 | crop_height = 224 185 | crop_width = 224 186 | single_image_test = False 187 | config = { 188 | "reinforce": { 189 | "num_samples": num_samples, 190 | "compress": compress, 191 | "image_augmentation": { 192 | "uint8": {"enable": True}, 193 | "random_resized_crop": { 194 | "enable": True, 195 | "size": [crop_height, crop_width], 196 | }, 197 | "random_horizontal_flip": {"enable": True, "p": 0.5}, 198 | "rand_augment": {"enable": True, "p": 0.5}, 199 | "to_tensor": {"enable": True}, 200 | "normalize": { 201 | "enable": True, 202 | "mean": [0.485, 0.456, 0.406], 203 | "std": [0.229, 0.224, 0.225], 204 | }, 205 | "random_erase": {"enable": True, "p": 0.25}, 206 | "mixup": {"enable": True, "alpha": 1.0, "p": 0.5}, 207 | "cutmix": {"enable": True, "alpha": 1.0, "p": 0.5}, 208 | }, 209 | } 210 | } 211 | single_transform_config_test( 212 | config, 213 | num_images, 214 | num_samples, 215 | height, 216 | width, 217 | crop_height, 218 | crop_width, 219 | compress, 220 | single_image_test, 221 | ) 222 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023 Apple Inc. All rights reserved. 3 | # 4 | """Modification of Pytorch ImageNet training code to handle additional datasets.""" 5 | import argparse 6 | import os 7 | import random 8 | import shutil 9 | import warnings 10 | import yaml 11 | import copy 12 | from typing import List, Dict, Any 13 | 14 | import torch 15 | import torch.nn.parallel 16 | import torch.backends.cudnn as cudnn 17 | import torch.distributed as dist 18 | import torch.optim 19 | import torch.multiprocessing as mp 20 | import torch.utils.data 21 | import torch.utils.data.distributed 22 | 23 | from trainers import get_trainer 24 | from utils import CosineLR 25 | from data import ( 26 | download_dataset, 27 | get_dataset, 28 | get_dataset_num_classes, 29 | get_dataset_size, 30 | ) 31 | 32 | from dr.data import ReinforcedDataset 33 | 34 | import logging 35 | 36 | logging.basicConfig(format="%(asctime)s %(message)s", level=logging.INFO) 37 | 38 | best_acc1 = 0 39 | 40 | 41 | def parse_args() -> argparse.Namespace: 42 | """Parse command-line arguments.""" 43 | parser = argparse.ArgumentParser(description="PyTorch Training") 44 | parser.add_argument( 45 | "--config", type=str, required=False, help="Path to a yaml config file." 46 | ) 47 | args = parser.parse_args() 48 | return args 49 | 50 | 51 | def get_trainable_parameters( 52 | model: torch.nn.Module, 53 | weight_decay: float, 54 | no_decay_bn_filter_bias: bool, 55 | *args, 56 | **kwargs 57 | ) -> List[Dict[str, List[torch.nn.Parameter]]]: 58 | """Return trainable model parameters excluding biases and normalization layers. 59 | 60 | Args: 61 | model: The Torch model to be trained. 62 | weight_decay: The weight decay coefficient. 63 | no_decay_bn_filter_bias: If True exclude biases and normalization layer params. 64 | 65 | Returns: 66 | A list with two dictionaries for parameters with and without weight decay. 67 | """ 68 | with_decay = [] 69 | without_decay = [] 70 | for _, param in model.named_parameters(): 71 | if param.requires_grad and len(param.shape) == 1 and no_decay_bn_filter_bias: 72 | # biases and normalization layer parameters are of len 1 73 | without_decay.append(param) 74 | elif param.requires_grad: 75 | with_decay.append(param) 76 | param_list = [{"params": with_decay, "weight_decay": weight_decay}] 77 | if len(without_decay) > 0: 78 | param_list.append({"params": without_decay, "weight_decay": 0.0}) 79 | return param_list 80 | 81 | 82 | def get_optimizer( 83 | model: torch.nn.Module, config: Dict[str, Any] 84 | ) -> torch.optim.Optimizer: 85 | """Initialize an optimizer with parameters of a model. 86 | 87 | Args: 88 | model: The model to be trained. 89 | config: A dictionary of optimizer hyperparameters and configurations. The 90 | configuration should at least have a `name`. 91 | 92 | Returns: 93 | A Torch optimizer. 94 | """ 95 | optim_name = config["name"] 96 | 97 | params = get_trainable_parameters( 98 | model, 99 | weight_decay=config.get("weight_decay", 0.0), 100 | no_decay_bn_filter_bias=config.get("no_decay_bn_filter_bias", False), 101 | ) 102 | 103 | if optim_name == "sgd": 104 | optimizer = torch.optim.SGD( 105 | params=params, 106 | lr=config["lr"], 107 | momentum=config["momentum"], 108 | nesterov=config.get("nesterov", True), 109 | ) 110 | elif optim_name == "adamw": 111 | optimizer = torch.optim.AdamW( 112 | params=params, 113 | lr=config["lr"], 114 | betas=(config["beta1"], config["beta2"]), 115 | ) 116 | else: 117 | raise NotImplementedError 118 | return optimizer 119 | 120 | 121 | def main(args) -> None: 122 | """Train a model with the configurations specified in given arguments.""" 123 | if args.config is not None: 124 | # Read parameters from yaml config for local run 125 | yaml_path = args.config 126 | with open(yaml_path, "r") as file: 127 | config = yaml.safe_load(file).get("parameters") 128 | 129 | dataset = config.get("dataset", "imagenet") 130 | config["num_classes"] = get_dataset_num_classes(dataset) 131 | 132 | # Print args before training 133 | logging.info("Training args: {}".format(config)) 134 | 135 | if config["seed"] is not None: 136 | random.seed(config["seed"]) 137 | torch.manual_seed(config["seed"]) 138 | cudnn.deterministic = True 139 | cudnn.benchmark = False 140 | warnings.warn( 141 | "You have chosen to seed training. " 142 | "This will turn on the CUDNN deterministic setting, " 143 | "which can slow down your training considerably! " 144 | "You may see unexpected behavior when restarting " 145 | "from checkpoints." 146 | ) 147 | 148 | ngpus_per_node = torch.cuda.device_count() 149 | if config["gpu"] is not None: 150 | warnings.warn( 151 | "You have chosen a specific GPU. This will completely " 152 | "disable data parallelism." 153 | ) 154 | ngpus_per_node = 1 155 | 156 | if config["dist_url"] == "env://" and config["world_size"] == -1: 157 | config["world_size"] = int(os.environ["WORLD_SIZE"]) 158 | 159 | config["distributed"] = ( 160 | config["world_size"] > 1 or config["multiprocessing_distributed"] 161 | ) 162 | 163 | if config["download_data"]: 164 | config["data_path"] = download_dataset(config["data_path"], config["dataset"]) 165 | 166 | if config["multiprocessing_distributed"]: 167 | # Since we have ngpus_per_node processes per node, the total world_size 168 | # needs to be adjusted accordingly 169 | config["world_size"] = ngpus_per_node * config["world_size"] 170 | # Use torch.multiprocessing.spawn to launch distributed processes: the 171 | # main_worker process function 172 | mp.spawn( 173 | main_worker, 174 | nprocs=ngpus_per_node, 175 | args=(ngpus_per_node, copy.deepcopy(config)), 176 | ) 177 | else: 178 | # Simply call main_worker function 179 | main_worker(config["gpu"], ngpus_per_node, copy.deepcopy(config)) 180 | 181 | 182 | def main_worker(gpu: int, ngpus_per_node: int, config: Dict[str, Any]) -> None: 183 | """Train a model with a single process. In distributed training, run on one GPU.""" 184 | global best_acc1 185 | config["gpu"] = gpu 186 | 187 | if config["gpu"] is not None: 188 | logging.info("Use GPU: {} for training".format(config["gpu"])) 189 | 190 | if config["distributed"]: 191 | if config["dist_url"] == "env://" and config["rank"] == -1: 192 | config["rank"] = int(os.environ["RANK"]) 193 | if config["multiprocessing_distributed"]: 194 | # For multiprocessing distributed training, rank needs to be the 195 | # global rank among all the processes 196 | config["rank"] = config["rank"] * ngpus_per_node + gpu 197 | dist.init_process_group( 198 | backend=config["dist_backend"], 199 | init_method=config["dist_url"], 200 | world_size=config["world_size"], 201 | rank=config["rank"], 202 | ) 203 | # Disable logging on all workers except rank=0 204 | if config["multiprocessing_distributed"] and config["rank"] % ngpus_per_node != 0: 205 | logging.basicConfig( 206 | format="%(asctime)s %(message)s", level=logging.WARNING, force=True 207 | ) 208 | 209 | # Initialize a trainer class that handles different models of training such as 210 | # standard training (ERM), Knowledge distillation, and Dataset Reinforcement 211 | trainer = get_trainer(config) 212 | 213 | # Create the model to train. Also create and load the teacher model for KD 214 | model = trainer.get_model() 215 | 216 | # Print number of model parameters 217 | logging.info( 218 | "Number of model parameters: {}".format( 219 | sum(p.numel() for p in model.parameters() if p.requires_grad) 220 | ) 221 | ) 222 | 223 | # Define loss function (criterion) and optimizer 224 | criterion = trainer.get_criterion() 225 | logging.info("Criterion: {}".format(criterion)) 226 | 227 | # Define optimizer 228 | optimizer = get_optimizer(model, config["optim"]) 229 | logging.info("Optimizer: {}".format(optimizer)) 230 | 231 | dataset_size = get_dataset_size(config["dataset"]) 232 | # Compute warmup and total iterations on each gpu 233 | warmup_length = ( 234 | config["optim"]["warmup_length"] 235 | * dataset_size 236 | // config["batch_size"] 237 | // ngpus_per_node 238 | ) 239 | total_steps = ( 240 | config["epochs"] * dataset_size // config["batch_size"] // ngpus_per_node 241 | ) 242 | lr_scheduler = CosineLR( 243 | optimizer, 244 | warmup_length=warmup_length, 245 | total_steps=total_steps, 246 | lr=config["optim"]["lr"], 247 | end_lr=config["optim"].get("end_lr", 0.0), 248 | ) 249 | 250 | # Resume from a checkpoint 251 | if config["resume"]: 252 | if os.path.isfile(config["resume"]): 253 | logging.info("=> loading checkpoint '{}'".format(config["resume"])) 254 | if config["gpu"] is None: 255 | checkpoint = torch.load(config["resume"]) 256 | else: 257 | # Map model to be loaded to specified single gpu. 258 | loc = "cuda:{}".format(config["gpu"]) 259 | checkpoint = torch.load(config["resume"], map_location=loc) 260 | config["start_epoch"] = checkpoint["epoch"] 261 | best_acc1 = checkpoint["best_acc1"] 262 | if hasattr(model, "module"): 263 | model.module.load_state_dict(checkpoint["state_dict"]) 264 | else: 265 | model.load_state_dict(checkpoint["state_dict"]) 266 | optimizer.load_state_dict(checkpoint["optimizer"]) 267 | lr_scheduler.load_state_dict(checkpoint["scheduler"]) 268 | logging.info( 269 | "=> loaded checkpoint '{}' (epoch {})".format( 270 | config["resume"], checkpoint["epoch"] 271 | ) 272 | ) 273 | else: 274 | logging.info("=> no checkpoint found at '{}'".format(config["resume"])) 275 | 276 | # Data loading code 277 | logging.info("Instantiating dataset.") 278 | train_dataset, val_dataset = get_dataset(config) 279 | logging.info("Training Dataset: {}".format(train_dataset)) 280 | logging.info("Validation Dataset: {}".format(val_dataset)) 281 | 282 | if config["trainer"] == "DR": 283 | # Reinforce dataset 284 | train_dataset = ReinforcedDataset( 285 | train_dataset, 286 | rdata_path=config["reinforce"]["data_path"], 287 | config=config["reinforce"], 288 | num_classes=config["num_classes"], 289 | ) 290 | 291 | if config["distributed"]: 292 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) 293 | val_sampler = torch.utils.data.distributed.DistributedSampler( 294 | val_dataset, shuffle=False, drop_last=True 295 | ) 296 | else: 297 | train_sampler = None 298 | val_sampler = None 299 | 300 | train_loader = torch.utils.data.DataLoader( 301 | train_dataset, 302 | batch_size=config["batch_size"], 303 | shuffle=(train_sampler is None), 304 | num_workers=config["workers"], 305 | pin_memory=config["pin_memory"], 306 | persistent_workers=config['persistent_workers'], 307 | sampler=train_sampler, 308 | ) 309 | 310 | val_loader = torch.utils.data.DataLoader( 311 | val_dataset, 312 | batch_size=config["batch_size"], 313 | shuffle=False, 314 | num_workers=config["workers"], 315 | pin_memory=config["pin_memory"], 316 | persistent_workers=config['persistent_workers'], 317 | sampler=val_sampler, 318 | ) 319 | 320 | if config["evaluate"]: 321 | # Evaluate a pretrained model without training 322 | trainer.validate(val_loader, model, criterion, config) 323 | return 324 | 325 | # Evaluate a pretrained teacher before training 326 | trainer.validate_pretrained(val_loader, model, criterion, config) 327 | 328 | # Start training 329 | for epoch in range(config["start_epoch"], config["epochs"]): 330 | if config["distributed"]: 331 | train_sampler.set_epoch(epoch) 332 | 333 | # Train for one epoch 334 | train_metrics = trainer.train( 335 | train_loader, 336 | model, 337 | criterion, 338 | optimizer, 339 | epoch, 340 | config, 341 | lr_scheduler, 342 | ) 343 | 344 | # Evaluate on validation set 345 | val_metrics = trainer.validate(val_loader, model, criterion, config) 346 | val_acc1 = val_metrics["val_accuracy@top1"] 347 | 348 | # remember best acc@1 and save checkpoint 349 | is_best = val_acc1 > best_acc1 350 | best_acc1 = max(val_acc1, best_acc1) 351 | 352 | if not config["multiprocessing_distributed"] or ( 353 | config["multiprocessing_distributed"] 354 | and config["rank"] % ngpus_per_node == 0 355 | ): 356 | metrics = dict(train_metrics) 357 | metrics.update(val_metrics) 358 | if isinstance(model, torch.nn.DataParallel) or isinstance( 359 | model, torch.nn.parallel.DistributedDataParallel 360 | ): 361 | model_state_dict = model.module.state_dict() 362 | else: 363 | model_state_dict = model.state_dict() 364 | is_save_epoch = (epoch + 1) % config.get("save_freq", 1) == 0 365 | is_save_epoch = is_save_epoch or ((epoch + 1) == config["epochs"]) 366 | save_checkpoint( 367 | { 368 | "epoch": epoch + 1, 369 | "config": config, 370 | "state_dict": model_state_dict, 371 | "best_acc1": best_acc1, 372 | "optimizer": optimizer.state_dict(), 373 | "scheduler": lr_scheduler.state_dict(), 374 | }, 375 | is_best, 376 | is_save_epoch, 377 | config["artifact_path"], 378 | ) 379 | 380 | 381 | def save_checkpoint(state, is_best, is_save_epoch, artifact_path) -> None: 382 | """Save checkpoint and update the best model.""" 383 | fname = os.path.join(artifact_path, "checkpoint.pth.tar") 384 | torch.save(state, fname) 385 | 386 | if is_save_epoch: 387 | checkpoint_fname = os.path.join( 388 | artifact_path, "checkpoint_{}.pth.tar".format(state["epoch"]) 389 | ) 390 | shutil.copyfile(fname, checkpoint_fname) 391 | if is_best: 392 | best_model_fname = os.path.join(artifact_path, "model_best.pth.tar") 393 | shutil.copyfile(fname, best_model_fname) 394 | 395 | 396 | if __name__ == "__main__": 397 | args = parse_args() 398 | main(args) 399 | -------------------------------------------------------------------------------- /trainers.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023 Apple Inc. All rights reserved. 3 | # 4 | """Training methods for ERM, Knowledge Distillation, and Dataset Reinforcement.""" 5 | from abc import ABC 6 | import time 7 | import logging 8 | from typing import Callable, Dict, Any, Type 9 | 10 | import torch 11 | from torch import Tensor 12 | import torch.nn as nn 13 | from torch.utils.data import DataLoader, Subset 14 | from torch.nn import functional as F 15 | 16 | from utils import AverageMeter, CosineLR, ProgressMeter, Summary, accuracy 17 | from models import move_to_device, load_model, create_model 18 | from transforms import MixingTransforms 19 | 20 | 21 | class Trainer(ABC): 22 | """Abstract class for various training methodologies.""" 23 | 24 | def get_model(self) -> nn.Module: 25 | """Create and initialize the model to train using self.config.""" 26 | raise NotImplementedError("Implement `get_model` to initialize a model.") 27 | 28 | def get_criterion(self) -> nn.Module: 29 | """Return the training criterion.""" 30 | raise NotImplementedError("Implement `get_criterion`.") 31 | 32 | def train( 33 | self, 34 | train_loader: DataLoader, 35 | model: torch.nn.Module, 36 | criterion: torch.nn.Module, 37 | optimizer: torch.optim.Optimizer, 38 | epoch: int, 39 | config: Dict[str, Any], 40 | lr_scheduler: Type[CosineLR], 41 | ) -> Dict[str, Any]: 42 | """Train a model for a single epoch and return training metrics dictionary.""" 43 | raise NotImplementedError("Implement `train` method.") 44 | 45 | def validate_pretrained(self, *args, **kwargs) -> None: 46 | """Validate pretrained teacher model.""" 47 | pass 48 | 49 | def validate(self, *args, **kwargs) -> Dict[str, Any]: 50 | """Validate the model that is being trained and return a metrics dictionary.""" 51 | return validate(*args, **kwargs) 52 | 53 | 54 | def get_trainer(config: Dict[str, Any]) -> Trainer: 55 | """Initialize a trainer given a configuration dictionary.""" 56 | trainer_type = config["trainer"] 57 | if trainer_type == "ERM": 58 | return ERMTrainer(config) 59 | elif trainer_type == "KD": 60 | return KDTrainer(config) 61 | elif trainer_type == "DR": 62 | return ReinforcedTrainer(config) 63 | raise NotImplementedError("Trainer not implemented.") 64 | 65 | 66 | class ERMTrainer(Trainer): 67 | """Trainer class for Empirical Risk Minimization (ERM) with cross-entropy.""" 68 | 69 | def __init__(self, config: Dict[str, Any]) -> None: 70 | """Initialize ERMTrainer.""" 71 | self.config = config 72 | self.label_smoothing = config["loss"].get("label_smoothing", 0.0) 73 | 74 | def get_model(self) -> torch.nn.Module: 75 | """Create and initialize the model to train using self.config.""" 76 | arch = self.config["arch"] 77 | model = create_model(arch, self.config) 78 | model = move_to_device(model, self.config) 79 | return model 80 | 81 | def get_criterion(self) -> torch.nn.Module: 82 | """Return the training criterion.""" 83 | criterion = nn.CrossEntropyLoss(label_smoothing=self.label_smoothing).cuda( 84 | self.config["gpu"] 85 | ) 86 | return criterion 87 | 88 | def train( 89 | self, 90 | train_loader: DataLoader, 91 | model: torch.nn.Module, 92 | criterion: torch.nn.Module, 93 | optimizer: torch.optim.Optimizer, 94 | epoch: int, 95 | config: Dict[str, Any], 96 | lr_scheduler: Type[CosineLR], 97 | ) -> Dict[str, Any]: 98 | """Train a model for a single epoch and return training metrics dictionary.""" 99 | batch_time = AverageMeter("Time", ":6.3f") 100 | data_time = AverageMeter("Data", ":6.3f") 101 | losses = AverageMeter("Loss", ":.6f") 102 | top1 = AverageMeter("Acc@1", ":6.2f") 103 | top5 = AverageMeter("Acc@5", ":6.2f") 104 | lrs = AverageMeter("Lr", ":.4f") 105 | conf = AverageMeter("Confidence", ":.5f") 106 | progress = ProgressMeter( 107 | len(train_loader), 108 | [batch_time, data_time, losses, top1, top5, lrs, conf], 109 | prefix="Epoch: [{}]".format(epoch), 110 | ) 111 | 112 | # switch to train mode 113 | model.train() 114 | 115 | mixing_transforms = MixingTransforms( 116 | config["image_augmentation"], config["num_classes"] 117 | ) 118 | 119 | end = time.time() 120 | for i, (images, target) in enumerate(train_loader): 121 | # measure data loading time 122 | data_time.update(time.time() - end) 123 | 124 | if config["gpu"] is not None: 125 | images = images.cuda(config["gpu"], non_blocking=True) 126 | target = target.cuda(config["gpu"], non_blocking=True) 127 | 128 | # apply mixup / cutmix 129 | mix_images, mix_target = mixing_transforms(images, target) 130 | 131 | # compute output 132 | output = model(mix_images) 133 | 134 | # classification loss 135 | loss = criterion(output, mix_target) 136 | 137 | # measure accuracy and record loss 138 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 139 | losses.update(loss.item(), images.size(0)) 140 | top1.update(acc1, images.size(0)) 141 | top5.update(acc5, images.size(0)) 142 | lrs.update(lr_scheduler.get_last_lr()[0]) 143 | 144 | # measure confidence 145 | prob = torch.nn.functional.softmax(output, dim=1) 146 | conf.update(prob.max(1).values.mean().item(), images.size(0)) 147 | 148 | # compute gradient and do SGD step 149 | optimizer.zero_grad() 150 | loss.backward() 151 | optimizer.step() 152 | 153 | # measure elapsed time 154 | batch_time.update(time.time() - end) 155 | end = time.time() 156 | 157 | lr_scheduler.step() 158 | 159 | if i % config["print_freq"] == 0: 160 | progress.display(i) 161 | 162 | if config["distributed"]: 163 | top1.all_reduce() 164 | top5.all_reduce() 165 | 166 | metrics = { 167 | "train_accuracy@top1": top1.avg, 168 | "train_accuracy@top5": top5.avg, 169 | "train_loss": losses.avg, 170 | "lr": lrs.avg, 171 | "train_confidence": conf.avg, 172 | } 173 | return metrics 174 | 175 | 176 | class ReinforcedTrainer(ERMTrainer): 177 | """Trainer with a reinforced dataset. Same as ERM Trainer with KL loss.""" 178 | 179 | def get_criterion(self) -> Callable[[Tensor, Tensor], Tensor]: 180 | """Return KL loss instead of cross-entropy.""" 181 | return lambda output, target: F.kl_div( 182 | F.log_softmax(output, dim=1), target, reduction="batchmean" 183 | ) 184 | 185 | 186 | class KDTrainer(ERMTrainer): 187 | """Trainer for Knowledge Distillation.""" 188 | 189 | def __init__(self, config: Dict[str, Any]) -> None: 190 | """Initialize trainer and set hyperparameters of KD.""" 191 | # Loss config 192 | self.lambda_kd = config["loss"].get("lambda_kd", 1.0) 193 | self.lambda_cls = config["loss"].get("lambda_cls", 0.0) 194 | self.temperature = config["loss"].get("temperature", 1.0) 195 | assert self.temperature > 0, "Softmax with temperature=0 is undefined." 196 | self.label_smoothing = config["loss"].get("label_smoothing", 0.0) 197 | 198 | self.config = config 199 | self.teacher_model = None 200 | 201 | def get_model(self) -> torch.nn.Module: 202 | """Create and initialize student and teacher models.""" 203 | config = self.config 204 | 205 | # Instantiate student model for training. 206 | student_arch = config["student"]["arch"] 207 | model = create_model(student_arch, config["student"]) 208 | model = move_to_device(model, self.config) 209 | 210 | # Instantiate teacher model 211 | teacher_model = load_model(config["gpu"], config["teacher"]) 212 | 213 | if config["gpu"] is not None: 214 | torch.cuda.set_device(config["gpu"]) 215 | teacher_model = teacher_model.cuda(config["gpu"]) 216 | else: 217 | teacher_model.cuda() 218 | # Set teacher to eval mode 219 | teacher_model.eval() 220 | self.teacher_model = teacher_model 221 | 222 | return model 223 | 224 | def validate_pretrained( 225 | self, 226 | val_loader: DataLoader, 227 | model: torch.nn.Module, 228 | criterion: torch.nn.Module, 229 | config: Dict[str, Any], 230 | ) -> None: 231 | """Validate teacher accuracy before training.""" 232 | teacher_model = self.teacher_model 233 | do_validate = config.get("teacher", {}).get("validate", True) 234 | if teacher_model is not None and do_validate: 235 | logging.info( 236 | "Validation loader resizes to standard 256x256 resolution" 237 | " which is necessarily the optimal resolution for the teacher." 238 | ) 239 | val_metrics = validate(val_loader, teacher_model, criterion, config) 240 | logging.info( 241 | "Teacher accuracy@top1: {}, @top5: {}".format( 242 | val_metrics["val_accuracy@top1"], val_metrics["val_accuracy@top5"] 243 | ) 244 | ) 245 | 246 | def train( 247 | self, 248 | train_loader: DataLoader, 249 | model: torch.nn.Module, 250 | criterion: torch.nn.Module, 251 | optimizer: torch.optim.Optimizer, 252 | epoch: int, 253 | config: Dict[str, Any], 254 | lr_scheduler: Type[CosineLR], 255 | ) -> Dict[str, Any]: 256 | """Train a model for a single epoch and return training metrics dictionary.""" 257 | batch_time = AverageMeter("Time", ":6.3f") 258 | data_time = AverageMeter("Data", ":6.3f") 259 | losses = AverageMeter("Loss", ":.6f") 260 | kd_losses = AverageMeter("KD Loss", ":.6f") 261 | overall_losses = AverageMeter("Loss", ":.6f") 262 | top1 = AverageMeter("Acc@1", ":6.2f") 263 | top5 = AverageMeter("Acc@5", ":6.2f") 264 | lrs = AverageMeter("Lr", ":.4f") 265 | progress = ProgressMeter( 266 | len(train_loader), 267 | [batch_time, data_time, losses, kd_losses, overall_losses, top1, top5, lrs], 268 | prefix="Epoch: [{}]".format(epoch), 269 | ) 270 | 271 | # Switch to train mode 272 | model.train() 273 | 274 | mixing_transforms = MixingTransforms( 275 | config["image_augmentation"], config["num_classes"] 276 | ) 277 | 278 | end = time.time() 279 | for i, (images, target) in enumerate(train_loader): 280 | # Measure data loading time 281 | data_time.update(time.time() - end) 282 | 283 | if config["gpu"] is not None: 284 | images = images.cuda(config["gpu"], non_blocking=True) 285 | target = target.cuda(config["gpu"], non_blocking=True) 286 | 287 | # Apply mixup / cutmix 288 | mix_images, mix_target = mixing_transforms(images, target) 289 | 290 | # Compute output for differing resolution. Support only 224 student 291 | mix_images_small = mix_images 292 | if mix_images.shape[-1] != 224: 293 | mix_images_small = F.interpolate( 294 | mix_images, size=(224, 224), mode="bilinear" 295 | ) 296 | output = model(mix_images_small) 297 | 298 | # Classification loss 299 | loss = criterion(output, mix_target) 300 | losses.update(loss.item(), images.size(0)) 301 | 302 | # Distillation loss 303 | # Get teacher's output for this input 304 | with torch.no_grad(): 305 | teacher_probs = self.teacher_model( 306 | mix_images, return_prob=True, temperature=self.temperature 307 | ).detach() 308 | kd_loss = F.kl_div( 309 | F.log_softmax(output / self.temperature, dim=1), 310 | teacher_probs, 311 | reduction="batchmean", 312 | ) * (self.temperature**2) 313 | kd_losses.update(kd_loss.item(), images.size(0)) 314 | 315 | # Overall loss is a combination of kd loss and classification loss 316 | loss = self.lambda_cls * loss + self.lambda_kd * kd_loss 317 | 318 | # measure accuracy and record loss 319 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 320 | overall_losses.update(loss.item(), images.size(0)) 321 | top1.update(acc1, images.size(0)) 322 | top5.update(acc5, images.size(0)) 323 | lrs.update(lr_scheduler.get_last_lr()[0]) 324 | 325 | # compute gradient and do SGD step 326 | optimizer.zero_grad() 327 | loss.backward() 328 | optimizer.step() 329 | 330 | lr_scheduler.step() 331 | 332 | # measure elapsed time 333 | batch_time.update(time.time() - end) 334 | end = time.time() 335 | 336 | if i % config["print_freq"] == 0: 337 | progress.display(i) 338 | 339 | if config["distributed"]: 340 | top1.all_reduce() 341 | top5.all_reduce() 342 | 343 | metrics = { 344 | "train_accuracy@top1": top1.avg, 345 | "train_accuracy@top5": top5.avg, 346 | "train_loss_ce": losses.avg, 347 | "train_loss_kd": kd_losses.avg, 348 | "train_loss_total": overall_losses.avg, 349 | "lr": lrs.avg, 350 | } 351 | return metrics 352 | 353 | 354 | def validate( 355 | val_loader: DataLoader, 356 | model: torch.nn.Module, 357 | criterion: torch.nn.Module, 358 | config: Dict[str, Any], 359 | ) -> Dict[str, Any]: 360 | """Validate the model that is being trained and return a metrics dictionary.""" 361 | 362 | def run_validate(loader: DataLoader, base_progress: int = 0) -> None: 363 | with torch.no_grad(): 364 | end = time.time() 365 | for i, (images, target) in enumerate(val_loader): 366 | i = base_progress + i 367 | if config["gpu"] is not None: 368 | images = images.cuda(config["gpu"], non_blocking=True) 369 | target = target.cuda(config["gpu"], non_blocking=True) 370 | 371 | # compute output 372 | output = model(images) 373 | # for validation, compute standard CE loss without label smoothing 374 | loss = F.cross_entropy(output, target) 375 | 376 | # measure accuracy and record loss 377 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 378 | losses.update(loss.item(), images.size(0)) 379 | top1.update(acc1, images.size(0)) 380 | top5.update(acc5, images.size(0)) 381 | 382 | # measure confidence 383 | prob = torch.nn.functional.softmax(output, dim=1) 384 | conf.update(prob.max(1).values.mean().item(), images.size(0)) 385 | 386 | # measure elapsed time 387 | batch_time.update(time.time() - end) 388 | end = time.time() 389 | 390 | if i % config["print_freq"] == 0: 391 | progress.display(i) 392 | 393 | batch_time = AverageMeter("Time", ":6.3f", Summary.NONE) 394 | losses = AverageMeter("Loss", ":.6f", Summary.NONE) 395 | top1 = AverageMeter("Acc@1", ":6.2f", Summary.AVERAGE) 396 | top5 = AverageMeter("Acc@5", ":6.2f", Summary.AVERAGE) 397 | conf = AverageMeter("Confidence", ":.5f", Summary.AVERAGE) 398 | progress = ProgressMeter( 399 | len(val_loader) 400 | + ( 401 | config["distributed"] 402 | and ( 403 | len(val_loader.sampler) * config["world_size"] < len(val_loader.dataset) 404 | ) 405 | ), 406 | [batch_time, losses, top1, top5], 407 | prefix="Test: ", 408 | ) 409 | 410 | # switch to evaluate mode 411 | model.eval() 412 | 413 | # run validation using all nodes in a distributed env and aggregate results 414 | run_validate(val_loader) 415 | if config["distributed"]: 416 | top1.all_reduce() 417 | top5.all_reduce() 418 | 419 | if config["distributed"] and ( 420 | len(val_loader.sampler) * config["world_size"] < len(val_loader.dataset) 421 | ): 422 | aux_val_dataset = Subset( 423 | val_loader.dataset, 424 | range( 425 | len(val_loader.sampler) * config["world_size"], len(val_loader.dataset) 426 | ), 427 | ) 428 | aux_val_loader = torch.utils.data.DataLoader( 429 | aux_val_dataset, 430 | batch_size=config["batch_size"], 431 | shuffle=False, 432 | num_workers=config["workers"], 433 | pin_memory=True, 434 | ) 435 | run_validate(aux_val_loader, len(val_loader)) 436 | 437 | progress.display_summary() 438 | 439 | metrics = { 440 | "val_loss": losses.avg, 441 | "val_accuracy@top1": top1.avg, 442 | "val_accuracy@top5": top5.avg, 443 | "val_confidence": conf.avg, 444 | } 445 | return metrics 446 | -------------------------------------------------------------------------------- /transforms.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023 Apple Inc. All rights reserved. 3 | # 4 | 5 | """Simplified composition of PyTorch transformations from a configuration dictionary.""" 6 | 7 | import math 8 | import random 9 | from typing import Any, Dict, Optional, OrderedDict, Tuple 10 | import numpy as np 11 | 12 | import timm 13 | from timm.data.transforms import str_to_interp_mode 14 | import torch 15 | from torch import Tensor 16 | import torchvision.transforms as T 17 | from torch.nn import functional as F 18 | 19 | 20 | INTERPOLATION_MODE_MAP = { 21 | "nearest": T.InterpolationMode.NEAREST, 22 | "bilinear": T.InterpolationMode.BILINEAR, 23 | "bicubic": T.InterpolationMode.BICUBIC, 24 | "cubic": T.InterpolationMode.BICUBIC, 25 | "box": T.InterpolationMode.BOX, 26 | "hamming": T.InterpolationMode.HAMMING, 27 | "lanczos": T.InterpolationMode.LANCZOS, 28 | } 29 | 30 | 31 | class AutoAugment(T.AutoAugment): 32 | """Extend PyTorch's AutoAugment to init from a policy and an interpolation name.""" 33 | 34 | def __init__( 35 | self, policy: str = "imagenet", interpolation: str = "bilinear", *args, **kwargs 36 | ) -> None: 37 | """Init from an policy and interpolation name.""" 38 | if "cifar" in policy.lower(): 39 | policy = T.AutoAugmentPolicy.CIFAR10 40 | elif "svhn" in policy.lower(): 41 | policy = T.AutoAugmentPolicy.SVHN 42 | else: 43 | policy = T.AutoAugmentPolicy.IMAGENET 44 | interpolation = INTERPOLATION_MODE_MAP[interpolation] 45 | super().__init__(*args, policy=policy, interpolation=interpolation, **kwargs) 46 | 47 | 48 | class RandAugment(T.RandAugment): 49 | """Extend PyTorch's RandAugment to init from an interpolation name.""" 50 | 51 | def __init__(self, interpolation: str = "bilinear", *args, **kwargs) -> None: 52 | """Init from an interpolation name.""" 53 | interpolation = INTERPOLATION_MODE_MAP[interpolation] 54 | super().__init__(*args, interpolation=interpolation, **kwargs) 55 | 56 | 57 | class TrivialAugmentWide(T.TrivialAugmentWide): 58 | """Extend PyTorch's TrivialAugmentWide to init from an interpolation name.""" 59 | 60 | def __init__(self, interpolation: str = "bilinear", *args, **kwargs) -> None: 61 | """Init from an interpolation name.""" 62 | interpolation = INTERPOLATION_MODE_MAP[interpolation] 63 | super().__init__(*args, interpolation=interpolation, **kwargs) 64 | 65 | 66 | # Transformations are composed according to the order in this dict, not the order in 67 | # yaml config 68 | TRANSFORMATION_TO_NAME = OrderedDict( 69 | [ 70 | ("resize", T.Resize), 71 | ("center_crop", T.CenterCrop), 72 | ("random_crop", T.RandomCrop), 73 | ("random_resized_crop", T.RandomResizedCrop), 74 | ("random_horizontal_flip", T.RandomHorizontalFlip), 75 | ("rand_augment", RandAugment), 76 | ("auto_augment", AutoAugment), 77 | ("trivial_augment_wide", TrivialAugmentWide), 78 | ("to_tensor", T.ToTensor), 79 | ("random_erase", T.RandomErasing), 80 | ("normalize", T.Normalize), 81 | ] 82 | ) 83 | 84 | 85 | def timm_resize_crop_norm(config: Dict[str, Any]) -> torch.nn.Module: 86 | """Set Resize/RandomCrop/Normalization parameters from configs of a Timm teacher.""" 87 | teacher_name = config["timm_resize_crop_norm"]["name"] 88 | cfg = timm.models.get_pretrained_cfg(teacher_name).to_dict() 89 | if "test_input_size" in cfg: 90 | img_size = list(cfg["test_input_size"])[-1] 91 | else: 92 | img_size = list(cfg["input_size"])[-1] 93 | # Crop ratio and image size for optimal performance of a Timm model 94 | crop_pct = cfg["crop_pct"] 95 | scale_size = int(math.floor(img_size / crop_pct)) 96 | interpolation = cfg["interpolation"] 97 | config["resize"] = { 98 | "size": scale_size, 99 | "interpolation": str_to_interp_mode(interpolation), 100 | } 101 | config["random_crop"] = { 102 | "size": img_size, 103 | "pad_if_needed": True, 104 | } 105 | config["normalize"] = {"mean": cfg["mean"], "std": cfg["std"]} 106 | return config 107 | 108 | 109 | def clean_config(config: Dict[str, Dict[str, Any]]) -> Dict[str, Dict[str, Any]]: 110 | """Return a clone of configs and remove unnecessary keys from configurations.""" 111 | new_config = {} 112 | for k, v in config.items(): 113 | vv = dict(v) 114 | if vv.pop("enable", True): 115 | new_config[k] = vv 116 | return new_config 117 | 118 | 119 | def compose_from_config(config_tr: Dict[str, Any]) -> torch.nn.Module: 120 | """Initialize transformations given the dataset name and configurations. 121 | 122 | Args: 123 | config_tr: A dictionary of transformation parameters. 124 | 125 | Returns a composition of transformations. 126 | """ 127 | config_tr = clean_config(config_tr) 128 | if "timm_resize_crop_norm" in config_tr: 129 | config_tr = timm_resize_crop_norm(config_tr) 130 | transforms = [] 131 | for t_name, t_class in TRANSFORMATION_TO_NAME.items(): 132 | if t_name in config_tr: 133 | # TODO: warn for every key in config_tr that was not used 134 | transforms += [t_class(**config_tr[t_name])] 135 | return T.Compose(transforms) 136 | 137 | 138 | class MixUp(torch.nn.Module): 139 | r"""MixUp image transformation. 140 | 141 | For an input x the 142 | output is :math:`\lambda x + (1-\lambda) x_p` , where :math:`x_p` is a 143 | random permutation of `x` along the batch dimension, and lam is a random 144 | number between 0 and 1. 145 | See https://arxiv.org/abs/1710.09412 for more details. 146 | """ 147 | 148 | def __init__( 149 | self, alpha: float = 1.0, p: float = 1.0, div_by: float = 1.0, *args, **kwargs 150 | ) -> None: 151 | """Initialize MixUp transformation. 152 | 153 | Args: 154 | alpha: A positive real number that determines the sampling 155 | distribution. Each mixed sample is a convex combination of two 156 | examples from the batch with mixing coefficient lambda. 157 | lambda is sampled from a symmetric Beta distribution with 158 | parameter alpha. When alpha=0 no mixing happens. Defaults to 1.0. 159 | p: Mixing is applied with probability `p`. Defaults to 1.0. 160 | div_by: Divide the lambda by a constant. Set to 2.0 to make sure mixing is 161 | biased towards the first input. Defaults to 1.0. 162 | """ 163 | super().__init__(*args, **kwargs) 164 | assert alpha >= 0 165 | assert p >= 0 and p <= 1.0 166 | assert div_by >= 1.0 167 | self.alpha = alpha 168 | self.p = p 169 | self.div_by = div_by 170 | 171 | def get_params(self, alpha: float, div_by: float) -> float: 172 | """Return MixUp random parameters.""" 173 | # Skip mixing by probability 1-self.p 174 | if alpha == 0 or torch.rand(1) > self.p: 175 | return None 176 | 177 | lam = np.random.beta(alpha, alpha) / div_by 178 | return lam 179 | 180 | def forward( 181 | self, 182 | x: Tensor, 183 | x2: Optional[Tensor] = None, 184 | y: Optional[Tensor] = None, 185 | y2: Optional[Tensor] = None, 186 | ) -> Tuple[Tensor, Tensor]: 187 | r"""Apply pixel-space mixing to a batch of examples. 188 | 189 | Args: 190 | x: A tensor with a batch of samples. Shape: [batch_size, ...]. 191 | x2: A tensor with exactly one matching sample for any input in `x`. Shape: 192 | [batch_size, ...]. 193 | y: A tensor of target labels. Shape: [batch_size, ...]. 194 | y2: A tensor of target labels for paired samples. Shape: [batch_size, ...]. 195 | 196 | Returns: 197 | Mixed x tensor, y labels, and dictionary of mixing parameter {'lam': lam}. 198 | """ 199 | alpha = self.alpha 200 | # Randomly sample lambda if not provided 201 | params = self.get_params(alpha, self.div_by) 202 | if params is None: 203 | return x, y 204 | lam = params 205 | 206 | # Randomly sample second input from the same mini-batch if not provided 207 | if x2 is None: 208 | batch_size = int(x.size()[0]) 209 | index = torch.randperm(batch_size, device=x.device) 210 | x2 = x[index, :] 211 | y2 = y[index, :] if y is not None else None 212 | 213 | # Mix inputs and labels 214 | mixed_x = lam * x + (1 - lam) * x2 215 | mixed_y = y 216 | if y is not None: 217 | mixed_y = lam * y + (1 - lam) * y2 218 | 219 | return mixed_x, mixed_y 220 | 221 | 222 | class CutMix(torch.nn.Module): 223 | r"""CutMix image transformation. 224 | 225 | Please see the full paper for more details: 226 | https://arxiv.org/pdf/1905.04899.pdf 227 | """ 228 | 229 | def __init__(self, alpha: float = 1.0, p: float = 1.0, *args, **kwargs) -> None: 230 | """Initialize CutMix transformation. 231 | 232 | Args: 233 | alpha: The alpha parameter to the Beta for producing a mixing lambda. 234 | """ 235 | super().__init__(*args, **kwargs) 236 | assert alpha >= 0 237 | assert p >= 0 and p <= 1.0 238 | self.alpha = alpha 239 | self.p = p 240 | 241 | @staticmethod 242 | def rand_bbox(size: torch.Size, lam: float) -> Tuple[int, int, int, int]: 243 | """Return a random bbox coordinates. 244 | 245 | Args: 246 | size: model input tensor shape in this format: (...,H,W) 247 | lam: lambda sampling parameter in CutMix method. See equation 1 248 | in the original paper: https://arxiv.org/pdf/1905.04899.pdf 249 | 250 | Returns: 251 | The output bbox format is a tuple: (x1, y1, x2, y2), where (x1, 252 | y1) and (x2,y2) are the coordinates of the top-left and bottom-right 253 | corners of the bbox in the pixel-space. 254 | """ 255 | assert lam >= 0 and lam <= 1.0 256 | h = size[-2] 257 | w = size[-1] 258 | cut_rat = np.sqrt(1.0 - lam) 259 | cut_h = int(h * cut_rat) 260 | cut_w = int(w * cut_rat) 261 | 262 | # uniform 263 | cx = np.random.randint(h) 264 | cy = np.random.randint(w) 265 | 266 | bbx1 = np.clip(cx - cut_h // 2, 0, h) 267 | bby1 = np.clip(cy - cut_w // 2, 0, w) 268 | bbx2 = np.clip(cx + cut_h // 2, 0, h) 269 | bby2 = np.clip(cy + cut_w // 2, 0, w) 270 | 271 | return (bbx1, bby1, bbx2, bby2) 272 | 273 | def get_params( 274 | self, size: torch.Size, alpha: float 275 | ) -> Tuple[float, Tuple[int, int, int, int]]: 276 | """Return CutMix random parameters.""" 277 | # Skip mixing by probability 1-self.p 278 | if alpha == 0 or torch.rand(1) > self.p: 279 | return None 280 | 281 | lam = np.random.beta(alpha, alpha) 282 | # Compute mask 283 | bbx1, bby1, bbx2, bby2 = self.rand_bbox(size, lam) 284 | return lam, (bbx1, bby1, bbx2, bby2) 285 | 286 | def forward( 287 | self, 288 | x: Tensor, 289 | x2: Optional[Tensor] = None, 290 | y: Optional[Tensor] = None, 291 | y2: Optional[Tensor] = None, 292 | ) -> Tuple[Tensor, Tensor]: 293 | """Mix images by replacing random patches from one to the other. 294 | 295 | Args: 296 | x: A tensor with a batch of samples. Shape: [batch_size, ...]. 297 | x2: A tensor with exactly one matching sample for any input in `x`. Shape: 298 | [batch_size, ...]. 299 | y: A tensor of target labels. Shape: [batch_size, ...]. 300 | y2: A tensor of target labels for paired samples. Shape: [batch_size, ...]. 301 | params: Dictionary of {'lam': lam_val} to reproduce a mixing. 302 | 303 | """ 304 | alpha = self.alpha 305 | 306 | # Randomly sample lambda and bbox coordinates if not provided 307 | params = self.get_params(x.shape, alpha) 308 | if params is None: 309 | return x, y 310 | lam, (bbx1, bby1, bbx2, bby2) = params 311 | 312 | # Randomly sample second input from the same mini-batch if not provided 313 | if x2 is None: 314 | batch_size = int(x.size()[0]) 315 | index = torch.randperm(batch_size, device=x.device) 316 | x2 = x[index, :] 317 | y2 = y[index, :] if y is not None else None 318 | 319 | # Mix inputs and labels 320 | mixed_x = x.detach().clone() 321 | mixed_x[:, bbx1:bbx2, bby1:bby2] = x2[:, bbx1:bbx2, bby1:bby2] 322 | mixed_y = y 323 | if y is not None: 324 | # Adjust lambda 325 | lam = 1.0 - ((bbx2 - bbx1) * (bby2 - bby1) / (x.size()[-1] * x.size()[-2])) 326 | mixed_y = lam * y + (1 - lam) * y2 327 | 328 | return mixed_x, mixed_y 329 | 330 | 331 | class MixingTransforms: 332 | """Randomly apply only one of MixUp or CutMix. Used for standard training.""" 333 | 334 | def __init__(self, config_tr: Dict[str, Any], num_classes: int) -> None: 335 | """Initialize mixup and/or cutmix.""" 336 | config_tr = clean_config(config_tr) 337 | self.mixing_transforms = [] 338 | if "mixup" in config_tr: 339 | self.mixing_transforms += [MixUp(**config_tr["mixup"])] 340 | if "cutmix" in config_tr: 341 | self.mixing_transforms += [CutMix(**config_tr["cutmix"])] 342 | self.num_classes = num_classes 343 | 344 | def __call__(self, images: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]: 345 | """Apply only one of MixUp or CutMix.""" 346 | if len(self.mixing_transforms) > 0: 347 | one_hot_label = F.one_hot(target, num_classes=self.num_classes) 348 | mix_f = random.choice(self.mixing_transforms) 349 | images, target = mix_f(x=images, y=one_hot_label) 350 | return images, target 351 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023 Apple Inc. All rights reserved. 3 | # 4 | """Utilities for training.""" 5 | from enum import Enum 6 | from typing import Dict, Any, Iterable, List, Optional 7 | 8 | import torch 9 | from torch import Tensor 10 | import numpy as np 11 | import logging 12 | 13 | import torch.distributed as dist 14 | 15 | 16 | class Summary(Enum): 17 | """Meter value types.""" 18 | 19 | NONE = 0 20 | AVERAGE = 1 21 | SUM = 2 22 | COUNT = 3 23 | 24 | 25 | class AverageMeter(object): 26 | """Computes and stores the average and current value.""" 27 | 28 | def __init__(self, name, fmt=":f", summary_type=Summary.AVERAGE): 29 | self.name = name 30 | self.fmt = fmt 31 | self.summary_type = summary_type 32 | self.reset() 33 | 34 | def reset(self): 35 | self.val = 0 36 | self.avg = 0 37 | self.sum = 0 38 | self.count = 0 39 | 40 | def update(self, val, n=1): 41 | self.val = val 42 | self.sum += val * n 43 | self.count += n 44 | self.avg = self.sum / self.count 45 | 46 | def all_reduce(self): 47 | if torch.cuda.is_available(): 48 | device = torch.device("cuda") 49 | elif torch.backends.mps.is_available(): 50 | device = torch.device("mps") 51 | else: 52 | device = torch.device("cpu") 53 | total = torch.tensor([self.sum, self.count], dtype=torch.float32, device=device) 54 | dist.all_reduce(total, dist.ReduceOp.SUM, async_op=False) 55 | self.sum, self.count = total.tolist() 56 | self.avg = self.sum / self.count 57 | 58 | def __str__(self): 59 | fmtstr = "{name} {val" + self.fmt + "} ({avg" + self.fmt + "})" 60 | return fmtstr.format(**self.__dict__) 61 | 62 | def summary(self): 63 | fmtstr = "" 64 | if self.summary_type is Summary.NONE: 65 | fmtstr = "" 66 | elif self.summary_type is Summary.AVERAGE: 67 | fmtstr = "{name} {avg:.3f}" 68 | elif self.summary_type is Summary.SUM: 69 | fmtstr = "{name} {sum:.3f}" 70 | elif self.summary_type is Summary.COUNT: 71 | fmtstr = "{name} {count:.3f}" 72 | else: 73 | raise ValueError("invalid summary type %r" % self.summary_type) 74 | 75 | return fmtstr.format(**self.__dict__) 76 | 77 | 78 | class ProgressMeter(object): 79 | def __init__(self, num_batches, meters, prefix=""): 80 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches) 81 | self.meters = meters 82 | self.prefix = prefix 83 | 84 | def display(self, batch): 85 | entries = [self.prefix + self.batch_fmtstr.format(batch)] 86 | entries += [str(meter) for meter in self.meters] 87 | logging.info("\t".join(entries)) 88 | 89 | def display_summary(self): 90 | entries = [" *"] 91 | entries += [meter.summary() for meter in self.meters] 92 | logging.info(" ".join(entries)) 93 | 94 | def _get_batch_fmtstr(self, num_batches): 95 | num_digits = len(str(num_batches // 1)) 96 | fmt = "{:" + str(num_digits) + "d}" 97 | return "[" + fmt + "/" + fmt.format(num_batches) + "]" 98 | 99 | 100 | def accuracy( 101 | output: Tensor, target: Tensor, topk: Optional[Iterable[int]] = (1,) 102 | ) -> List[float]: 103 | """Compute the accuracy over the k top predictions for the specified values of k.""" 104 | with torch.no_grad(): 105 | maxk = max(topk) 106 | batch_size = target.size(0) 107 | 108 | if len(target.shape) > 1 and target.shape[1] > 1: 109 | # soft labels 110 | _, target = target.max(dim=1) 111 | 112 | _, pred = output.topk(maxk, 1, True, True) 113 | pred = pred.t() 114 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 115 | 116 | res = [] 117 | for k in topk: 118 | correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) 119 | res.append(correct_k.mul_(100.0 / batch_size).item()) 120 | return res 121 | 122 | 123 | def assign_learning_rate(optimizer: torch.optim.Optimizer, new_lr: float) -> None: 124 | """Update lr parameter of an optimizer. 125 | 126 | Args: 127 | optimizer: A torch optimizer. 128 | new_lr: updated value of learning rate. 129 | """ 130 | for param_group in optimizer.param_groups: 131 | param_group["lr"] = new_lr 132 | 133 | 134 | def _warmup_lr(base_lr: float, warmup_length: int, n_iter: int) -> float: 135 | """Get updated lr after applying initial warmup. 136 | 137 | Args: 138 | base_lr: Nominal learning rate. 139 | warmup_length: Number of total iterations for initial warmup. 140 | n_iter: Current iteration number. 141 | 142 | Returns: 143 | Warmup-updated learning rate. 144 | """ 145 | return base_lr * (n_iter + 1) / warmup_length 146 | 147 | 148 | class CosineLR: 149 | """LR adjustment callable with cosine schedule. 150 | 151 | Args: 152 | optimizer: A torch optimizer. 153 | warmup_length: Number of iterations for initial warmup. 154 | total_steps: Total number of iterations. 155 | lr: Nominal learning rate value. 156 | 157 | Returns: 158 | A callable to adjust learning rate per iteration. 159 | """ 160 | 161 | def __init__( 162 | self, 163 | optimizer: torch.optim.Optimizer, 164 | warmup_length: int, 165 | total_steps: int, 166 | lr: float, 167 | end_lr: float = 0.0, 168 | **kwargs 169 | ) -> None: 170 | """Set parameters of cosine learning rate with warmup.""" 171 | assert lr > end_lr, ( 172 | "End LR should be less than the LR. Got:" " lr={} and last_lr={}" 173 | ).format(lr, end_lr) 174 | self.optimizer = optimizer 175 | self.warmup_length = warmup_length 176 | self.total_steps = total_steps 177 | self.lr = lr 178 | self.last_lr = 0 179 | self.end_lr = end_lr 180 | self.last_n_iter = 0 181 | 182 | def step(self) -> float: 183 | """Return updated learning rate for the next iteration.""" 184 | self.last_n_iter += 1 185 | n_iter = self.last_n_iter 186 | 187 | if n_iter < self.warmup_length: 188 | new_lr = _warmup_lr(self.lr, self.warmup_length, n_iter) 189 | else: 190 | e = n_iter - self.warmup_length + 1 191 | es = self.total_steps - self.warmup_length 192 | 193 | new_lr = self.end_lr + 0.5 * (self.lr - self.end_lr) * ( 194 | 1 + np.cos(np.pi * e / es) 195 | ) 196 | 197 | assign_learning_rate(self.optimizer, new_lr) 198 | 199 | self.last_lr = new_lr 200 | 201 | def get_last_lr(self) -> List[float]: 202 | """Return the value of the last learning rate.""" 203 | return [self.last_lr] 204 | 205 | def state_dict(self) -> Dict[str, Any]: 206 | """Return the state dictionary to recover optimization in training restart.""" 207 | return { 208 | "warmup_length": self.warmup_length, 209 | "total_steps": self.total_steps, 210 | "lr": self.lr, 211 | "last_lr": self.last_lr, 212 | "last_n_iter": self.last_n_iter, 213 | } 214 | 215 | def load_state_dict(self, state: Dict[str, Any]) -> None: 216 | """Restore scheduler state.""" 217 | self.warmup_length = state["warmup_length"] 218 | self.total_steps = state["total_steps"] 219 | self.lr = state["lr"] 220 | self.last_lr = state["last_lr"] 221 | self.last_n_iter = state["last_n_iter"] 222 | --------------------------------------------------------------------------------