├── CODE_OF_CONDUCT.md
├── CONTRIBUTING.md
├── LICENSE
├── README.md
├── configs
├── cifar100
│ ├── erm.yaml
│ ├── ft.yaml
│ ├── ft_plus.yaml
│ ├── plus.yaml
│ ├── reinforce.yaml
│ └── teacher.yaml
├── flowers102
│ ├── erm.yaml
│ ├── ft.yaml
│ ├── ft_plus.yaml
│ ├── plus.yaml
│ ├── reinforce.yaml
│ └── teacher.yaml
├── food101
│ ├── erm.yaml
│ ├── ft.yaml
│ ├── ft_plus.yaml
│ ├── plus.yaml
│ ├── reinforce.yaml
│ └── teacher.yaml
└── imagenet
│ ├── erm.yaml
│ ├── kd.yaml
│ ├── plus.yaml
│ └── reinforce
│ ├── cropflip.yaml
│ ├── mixing.yaml
│ ├── randaug.yaml
│ └── randaug_mixing.yaml
├── data.py
├── dr
├── __init__.py
├── data.py
├── transforms.py
└── utils.py
├── figures
├── DR_illustration_wide.pdf
├── DR_illustration_wide.png
├── imagenet_RRC+RARE_accuracy_annotated.pdf
└── imagenet_RRC+RARE_accuracy_annotated.png
├── models.py
├── reinforce.py
├── requirements.txt
├── results
├── table_E1000.md
├── table_E150.md
└── table_E300.md
├── tests
└── test_reinforce.py
├── train.py
├── trainers.py
├── transforms.py
└── utils.py
/CODE_OF_CONDUCT.md:
--------------------------------------------------------------------------------
1 | # Code of Conduct
2 |
3 | ## Our Pledge
4 |
5 | In the interest of fostering an open and welcoming environment, we as
6 | contributors and maintainers pledge to making participation in our project and
7 | our community a harassment-free experience for everyone, regardless of age, body
8 | size, disability, ethnicity, sex characteristics, gender identity and expression,
9 | level of experience, education, socio-economic status, nationality, personal
10 | appearance, race, religion, or sexual identity and orientation.
11 |
12 | ## Our Standards
13 |
14 | Examples of behavior that contributes to creating a positive environment
15 | include:
16 |
17 | * Using welcoming and inclusive language
18 | * Being respectful of differing viewpoints and experiences
19 | * Gracefully accepting constructive criticism
20 | * Focusing on what is best for the community
21 | * Showing empathy towards other community members
22 |
23 | Examples of unacceptable behavior by participants include:
24 |
25 | * The use of sexualized language or imagery and unwelcome sexual attention or
26 | advances
27 | * Trolling, insulting/derogatory comments, and personal or political attacks
28 | * Public or private harassment
29 | * Publishing others' private information, such as a physical or electronic
30 | address, without explicit permission
31 | * Other conduct which could reasonably be considered inappropriate in a
32 | professional setting
33 |
34 | ## Our Responsibilities
35 |
36 | Project maintainers are responsible for clarifying the standards of acceptable
37 | behavior and are expected to take appropriate and fair corrective action in
38 | response to any instances of unacceptable behavior.
39 |
40 | Project maintainers have the right and responsibility to remove, edit, or
41 | reject comments, commits, code, wiki edits, issues, and other contributions
42 | that are not aligned to this Code of Conduct, or to ban temporarily or
43 | permanently any contributor for other behaviors that they deem inappropriate,
44 | threatening, offensive, or harmful.
45 |
46 | ## Scope
47 |
48 | This Code of Conduct applies within all project spaces, and it also applies when
49 | an individual is representing the project or its community in public spaces.
50 | Examples of representing a project or community include using an official
51 | project e-mail address, posting via an official social media account, or acting
52 | as an appointed representative at an online or offline event. Representation of
53 | a project may be further defined and clarified by project maintainers.
54 |
55 | ## Enforcement
56 |
57 | Instances of abusive, harassing, or otherwise unacceptable behavior may be
58 | reported by contacting the open source team at [opensource-conduct@group.apple.com](mailto:opensource-conduct@group.apple.com). All
59 | complaints will be reviewed and investigated and will result in a response that
60 | is deemed necessary and appropriate to the circumstances. The project team is
61 | obligated to maintain confidentiality with regard to the reporter of an incident.
62 | Further details of specific enforcement policies may be posted separately.
63 |
64 | Project maintainers who do not follow or enforce the Code of Conduct in good
65 | faith may face temporary or permanent repercussions as determined by other
66 | members of the project's leadership.
67 |
68 | ## Attribution
69 |
70 | This Code of Conduct is adapted from the [Contributor Covenant](https://www.contributor-covenant.org), version 1.4,
71 | available at [https://www.contributor-covenant.org/version/1/4/code-of-conduct.html](https://www.contributor-covenant.org/version/1/4/code-of-conduct.html)
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # Contribution Guide
2 |
3 | Thanks for your interest in contributing. This project was released to accompany a research paper for purposes of reproducability, and beyond its publication there are limited plans for future development of the repository.
4 |
5 | While we welcome new pull requests and issues please note that our response may be limited. Forks and out-of-tree improvements are strongly encouraged.
6 |
7 | ## Before you get started
8 |
9 | By submitting a pull request, you represent that you have the right to license your contribution to Apple and the community, and agree by submitting the patch that your contributions are licensed under the [LICENSE](LICENSE).
10 |
11 | We ask that all community members read and observe our [Code of Conduct](CODE_OF_CONDUCT.md).
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Copyright (C) 2022 Apple Inc. All Rights Reserved.
2 |
3 | IMPORTANT: This Apple software is supplied to you by Apple
4 | Inc. ("Apple") in consideration of your agreement to the following
5 | terms, and your use, installation, modification or redistribution of
6 | this Apple software constitutes acceptance of these terms. If you do
7 | not agree with these terms, please do not use, install, modify or
8 | redistribute this Apple software.
9 |
10 | In consideration of your agreement to abide by the following terms, and
11 | subject to these terms, Apple grants you a personal, non-exclusive
12 | license, under Apple's copyrights in this original Apple software (the
13 | "Apple Software"), to use, reproduce, modify and redistribute the Apple
14 | Software, with or without modifications, in source and/or binary forms;
15 | provided that if you redistribute the Apple Software in its entirety and
16 | without modifications, you must retain this notice and the following
17 | text and disclaimers in all such redistributions of the Apple Software.
18 | Neither the name, trademarks, service marks or logos of Apple Inc. may
19 | be used to endorse or promote products derived from the Apple Software
20 | without specific prior written permission from Apple. Except as
21 | expressly stated in this notice, no other rights or licenses, express or
22 | implied, are granted by Apple herein, including but not limited to any
23 | patent rights that may be infringed by your derivative works or by other
24 | works in which the Apple Software may be incorporated.
25 |
26 | The Apple Software is provided by Apple on an "AS IS" basis. APPLE
27 | MAKES NO WARRANTIES, EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION
28 | THE IMPLIED WARRANTIES OF NON-INFRINGEMENT, MERCHANTABILITY AND FITNESS
29 | FOR A PARTICULAR PURPOSE, REGARDING THE APPLE SOFTWARE OR ITS USE AND
30 | OPERATION ALONE OR IN COMBINATION WITH YOUR PRODUCTS.
31 |
32 | IN NO EVENT SHALL APPLE BE LIABLE FOR ANY SPECIAL, INDIRECT, INCIDENTAL
33 | OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
34 | SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
35 | INTERRUPTION) ARISING IN ANY WAY OUT OF THE USE, REPRODUCTION,
36 | MODIFICATION AND/OR DISTRIBUTION OF THE APPLE SOFTWARE, HOWEVER CAUSED
37 | AND WHETHER UNDER THEORY OF CONTRACT, TORT (INCLUDING NEGLIGENCE),
38 | STRICT LIABILITY OR OTHERWISE, EVEN IF APPLE HAS BEEN ADVISED OF THE
39 | POSSIBILITY OF SUCH DAMAGE.
40 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Dataset Reinforcement
2 | A light-weight implementation of Dataset Reinforcement, pretrained checkpoints, and reinforced datasets.
3 |
4 | **[Reinforce Data, Multiply Impact: Improved Model Accuracy and Robustness with Dataset Reinforcement.](https://arxiv.org/abs/2303.08983)**
5 | *, Faghri, F., Pouransari, H., Mehta, S., Farajtabar, M., Farhadi, A., Rastegari, M., & Tuzel, O., Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV), 2023.*
6 |
7 |
8 | **Update 2023/09/22**: Table 7-Average column corrected in ArXiv V3. Correct
9 | numbers: 30.4, 37.1, 37.9, 43.7, 39.6, 51.1.
10 |
11 | **Reinforced ImageNet, ImageNet+, improves accuracy at similar iterations/wall-clock**
12 |
13 |
14 |
15 |
16 |
17 | ImageNet validation accuracy of ResNet-50 is shown as
18 | a function of training duration with (1) ImageNet dataset, (2) knowledge
19 | distillation (KD), and (3) ImageNet+ dataset (ours). Each point is a full
20 | training with epochs varying from 50-1000. An epoch has the same number of
21 | iterations for ImageNet/ImageNet+.
22 |
23 | **Illustration of Dataset Reinforcement**
24 |
25 |
26 |
27 |
28 |
29 | Data augmentation and
30 | knowledge distillation are common approaches to
31 | improving accuracy. Dataset reinforcement combines the benefits of both by
32 | bringing the advantages of large models trained on large datasets to other
33 | datasets and models. Training of new models with a reinforced dataset is as
34 | fast as training on the original dataset for the same total iterations.
35 | Creating a reinforced dataset is a one-time process (e.g., ImageNet to
36 | ImageNet+) the cost of which is amortized over repeated uses.
37 |
38 |
39 | ## Requirements
40 | Install the requirements using:
41 | ```shell
42 | pip install -r requirements.txt
43 | ```
44 |
45 | We support loading models from Timm library and CVNets library.
46 |
47 | To install CVNets library follow their [installation
48 | instructions](https://github.com/apple/ml-cvnets#installation).
49 |
50 | ## Reinforced Data
51 |
52 | The following is a list of reinforcements for ImageNet/CIFAR-100/Food-101/Flowers-102. We recommend ImageNet+-RA/RE based on the analysis in the paper.
53 |
54 | | Reinforce Data | Task ID | Size (GBs) | Comments |
55 | |------------------------|--------------------------------------------------------------------------------------------------------------------|------------|---------------|
56 | | ImageNet+-RRC | [rdata](https://docs-assets.developer.apple.com/ml-research/datasets/dr/imagenet_plus_rrc.tar.gz) | 33.4 | [NS=400] |
57 | | ImageNet+-+M* | [rdata](https://docs-assets.developer.apple.com/ml-research/datasets/dr/imagenet_plus_mixing.tar.gz) | 46.3 | [NS=400] |
58 | | **ImageNet+-+RA/RE** | [rdata](https://docs-assets.developer.apple.com/ml-research/datasets/dr/imagenet_plus_randaug.tar.gz) | 37.5 | [NS=400] |
59 | | ImageNet+-+M*+R* | [rdata](https://docs-assets.developer.apple.com/ml-research/datasets/dr/imagenet_plus_mixing_randaug.tar.gz) | 53.3 | [NS=400] |
60 | | ImageNet+-RRC-Small | [rdata](https://docs-assets.developer.apple.com/ml-research/datasets/dr/imagenet_plus_rrc_small.tar.gz) | 4.7 | [NS=100, K=5] |
61 | | ImageNet+-+M*-Small | [rdata](https://docs-assets.developer.apple.com/ml-research/datasets/dr/imagenet_plus_mixing_small.tar.gz) | 7.8 | [NS=100, K=5] |
62 | | ImageNet+-+RA/RE-Small | [rdata](https://docs-assets.developer.apple.com/ml-research/datasets/dr/imagenet_plus_randaug_small.tar.gz) | 5.6 | [NS=100, K=5] |
63 | | ImageNet+-+M*+R*-Small | [rdata](https://docs-assets.developer.apple.com/ml-research/datasets/dr/imagenet_plus_mixing_randaug_small.tar.gz) | 9.4 | [NS=100, K=5] |
64 | | ImageNet+-RRC-Mini | [rdata](https://docs-assets.developer.apple.com/ml-research/datasets/dr/imagenet_plus_rrc_mini.tar.gz) | 4.4 | [NS=50] |
65 | | ImageNet+-+M*-Mini | [rdata](https://docs-assets.developer.apple.com/ml-research/datasets/dr/imagenet_plus_mixing_mini.tar.gz) | 6.1 | [NS=50] |
66 | | ImageNet+-+RA/RE-Mini | [rdata](https://docs-assets.developer.apple.com/ml-research/datasets/dr/imagenet_plus_randaug_mini.tar.gz) | 4.9 | [NS=50] |
67 | | ImageNet+-+M*+R*-Mini | [rdata](https://docs-assets.developer.apple.com/ml-research/datasets/dr/imagenet_plus_mixing_randaug.tar.gz) | 7.0 | [NS=50] |
68 | | CIFAR-100 | [rdata](https://docs-assets.developer.apple.com/ml-research/datasets/dr/cifar100.tar.gz) | 2.5 | [NS=800] |
69 | | Food-101 | [rdata](https://docs-assets.developer.apple.com/ml-research/datasets/dr/food_101.tar.gz) | 4.2 | [NS=800] |
70 | | Flowers-102 | [rdata](https://docs-assets.developer.apple.com/ml-research/datasets/dr/flowers_102.tar.gz) | 0.5 | [NS=8000] |
71 |
72 | ## Pretrained Checkpoints
73 |
74 | ### CVNets Checkpoints
75 |
76 | We provide pretrained checkpoints for various models in CVNets.
77 | The accuracies can be verified using the CVNets library.
78 |
79 | - [150 Epochs Checkpoints](./results/table_E150.md)
80 | - [300 Epochs Checkpoints](./results/table_E300.md)
81 | - [1000 Epochs Checkpoints](./results/table_E1000.md)
82 | - [imagenet-cvnets.tar](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets.tar): All CVNets checkpoints trained on ImageNet (14.3GBs).
83 | - [imagenet-plus-cvnets.tar](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets.tar): All CVNets checkpoints trained on ImageNet+ (14.3GBs).
84 |
85 | Selected results trained for 1000 epochs:
86 |
87 | | Name | Mode | Params | ImageNet | ImageNet+ | ImageNet (EMA) | ImageNet+ (EMA) | Links |
88 | | :------------- | :--------- | :------------- | :----- | :----------- | :----- | :------ | ----- |
89 | | MobileNetV3 | large | 5.5M | 74.8 | 77.9 (**+3.1**) | 75.8 | 77.9 (**+2.1**) | [[best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/mobilenetv3_large_E1000/best.pt) [[ema_best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/mobilenetv3_large_E1000/ema_best.pt) [[config.yaml]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/mobilenetv3_large_E1000/config.yaml) [[metrics.jb]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/mobilenetv3_large_E1000/metrics.jb) |
90 | | ResNet | 50 | 25.6M | 80.0 | 82.0 (**+2.0**) | 80.1 | 82.0 (**+1.9**) | [[best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/resnet_50_E1000/best.pt) [[ema_best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/resnet_50_E1000/ema_best.pt) [[config.yaml]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/resnet_50_E1000/config.yaml) [[metrics.jb]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/resnet_50_E1000/metrics.jb) |
91 | | ViT | base | 86.7M | 76.8 | 85.1 (**+8.3**) | 80.8 | 85.1 (**+4.3**) | [[best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/vit_base_E1000/best.pt) [[ema_best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/vit_base_E1000/ema_best.pt) [[config.yaml]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/vit_base_E1000/config.yaml) [[metrics.jb]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/vit_base_E1000/metrics.jb) |
92 | | ViT-384 | base | 86.7M | 79.4 | 85.4 (**+6.0**) | 83.1 | 85.5 (**+2.4**) | [[best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/vit-384_base_E1000/best.pt) [[ema_best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/vit-384_base_E1000/ema_best.pt) [[config.yaml]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/vit-384_base_E1000/config.yaml) [[metrics.jb]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/vit-384_base_E1000/metrics.jb) |
93 | | Swin | tiny | 28.3M | 81.3 | 84.0 (**+2.7**) | 80.5 | 83.5 (**+3.0**) | [[best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/swin_tiny_E1000/best.pt) [[ema_best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/swin_tiny_E1000/ema_best.pt) [[config.yaml]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/swin_tiny_E1000/config.yaml) [[metrics.jb]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/swin_tiny_E1000/metrics.jb) |
94 | | Swin | small | 49.7M | 81.3 | 85.0 (**+3.7**) | 81.9 | 84.5 (**+2.6**) | [[best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/swin_small_E1000/best.pt) [[ema_best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/swin_small_E1000/ema_best.pt) [[config.yaml]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/swin_small_E1000/config.yaml) [[metrics.jb]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/swin_small_E1000/metrics.jb) |
95 | | Swin | base | 87.8M | 81.5 | 85.4 (**+3.9**) | 81.8 | 85.2 (**+3.4**) | [[best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/swin_base_E1000/best.pt) [[ema_best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/swin_base_E1000/ema_best.pt) [[config.yaml]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/swin_base_E1000/config.yaml) [[metrics.jb]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/swin_base_E1000/metrics.jb) |
96 | | Swin-384 | base | 87.8M | 83.6 | 85.8 (**+2.2**) | 83.8 | 85.5 (**+1.7**) | [[best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/swin-384_base_E1000/best.pt) [[ema_best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/swin-384_base_E1000/ema_best.pt) [[config.yaml]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/swin-384_base_E1000/config.yaml) [[metrics.jb]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/swin-384_base_E1000/metrics.jb) |
97 |
98 | ### Timm Checkpoints
99 |
100 | We provide pretrained checkpoints for ResNet50d from Timm library trained for
101 | 150 epochs using various reinforced datasets:
102 |
103 | - [imagenet-timm.tar](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-timm.tar): All Timm checkpoints trained on ImageNet and ImageNet+ (2.3GBs).
104 |
105 | | Model | Reinforce Data | Accuracy | Links |
106 | | ------------------------ | :------------------------ | :-------------: | :-----------------------------------------------------------------------------: |
107 | | ResNet50d [ERM] | N/A | 78.9 | [[best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-timm/imagenet/best.pt) [[config.yaml]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-timm/imagenet/config.yaml) [[metrics.jb]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-timm/imagenet/metrics.jb) |
108 | | ResNet50d | ImageNet+-RRC | 80.0 | [[best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-timm/imagenet_plus_rrc/best.pt) [[config.yaml]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-timm/imagenet_plus_rrc/config.yaml) [[metrics.jb]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-timm/imagenet_plus_rrc/metrics.jb) |
109 | | ResNet50d | ImageNet+-+M* | 80.5 | [[best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-timm/imagenet_plus_mixing/best.pt) [[config.yaml]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-timm/imagenet_plus_mixing/config.yaml) [[metrics.jb]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-timm/imagenet_plus_mixing/metrics.jb) |
110 | | ResNet50d | ImageNet+-+RA/RE | 80.4 | [[best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-timm/imagenet_plus_randaug/best.pt) [[config.yaml]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-timm/imagenet_plus_randaug/config.yaml) [[metrics.jb]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-timm/imagenet_plus_randaug/metrics.jb) |
111 | | ResNet50d | ImageNet+-+M*+R* | 80.2 | [[best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-timm/imagenet_plus_mixing_randaug/best.pt) [[config.yaml]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-timm/imagenet_plus_mixing_randaug/config.yaml) [[metrics.jb]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-timm/imagenet_plus_mixing_randaug/metrics.jb) |
112 | | ResNet50d | ImageNet+-RRC-Small | 80.0 | [[best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-timm/imagenet_plus_rrc_small/best.pt) [[config.yaml]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-timm/imagenet_plus_rrc_small/config.yaml) [[metrics.jb]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-timm/imagenet_plus_rrc_small/metrics.jb) |
113 | | ResNet50d | ImageNet+-+M*-Small | 80.6 | [[best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-timm/imagenet_plus_mixing_small/best.pt) [[config.yaml]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-timm/imagenet_plus_mixing_small/config.yaml) [[metrics.jb]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-timm/imagenet_plus_mixing_small/metrics.jb) |
114 | | ResNet50d | ImageNet+-+RA/RE-Small | 80.2 | [[best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-timm/imagenet_plus_randaug_small/best.pt) [[config.yaml]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-timm/imagenet_plus_randaug_small/config.yaml) [[metrics.jb]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-timm/imagenet_plus_randaug_small/metrics.jb) |
115 | | ResNet50d | ImageNet+-+M*+R*-Small | 80.1 | [[best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-timm/imagenet_plus_mixing_randaug_small/best.pt) [[config.yaml]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-timm/imagenet_plus_mixing_randaug_small/config.yaml) [[metrics.jb]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-timm/imagenet_plus_mixing_randaug_small/metrics.jb) |
116 | | ResNet50d | ImageNet+-RRC-Mini | 80.1 | [[best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-timm/imagenet_plus_rrc_mini/best.pt) [[config.yaml]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-timm/imagenet_plus_rrc_mini/config.yaml) [[metrics.jb]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-timm/imagenet_plus_rrc_mini/metrics.jb) |
117 | | ResNet50d | ImageNet+-+M*-Mini | 80.5 | [[best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-timm/imagenet_plus_mixing_mini/best.pt) [[config.yaml]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-timm/imagenet_plus_mixing_mini/config.yaml) [[metrics.jb]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-timm/imagenet_plus_mixing_mini/metrics.jb) |
118 | | ResNet50d | ImageNet+-+RA/RE-Mini | 80.4 | [[best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-timm/imagenet_plus_randaug_mini/best.pt) [[config.yaml]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-timm/imagenet_plus_randaug_mini/config.yaml) [[metrics.jb]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-timm/imagenet_plus_randaug_mini/metrics.jb) |
119 | | ResNet50d | ImageNet+-+M*+R*-Mini | 80.2 | [[best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-timm/imagenet_plus_mixing_randaug/best.pt) [[config.yaml]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-timm/imagenet_plus_mixing_randaug/config.yaml) [[metrics.jb]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-timm/imagenet_plus_mixing_randaug/metrics.jb) |
120 |
121 |
122 | ## Training
123 |
124 | We provide YAML configurations for training ResNet-50 in
125 | `CFG_FILE=configs/${DATASET}/${TRAINER}.yaml`, with the following options:
126 | - `DATASET`: `imagenet`, `cifar100`, `flowers102`, and `food101`.
127 | - `TRAINER`: standard training (`erm`), knowledge distillation (`kd`), and with reinforced data (`plus`).
128 |
129 | Follow the steps:
130 | - Choose the dataset and trainer from the choices above.
131 | - Download ImageNet data and set `data_path` in `$CFG_FILE`.
132 | - Download reinforcement metadata and set `reinforce.data_path` in `$CFG_FILE`.
133 |
134 | ```shell
135 | python train.py --config configs/imagenet/erm.yaml # ImageNet training without Reinforcements (ERM)
136 | python train.py --config configs/imagenet/kd.yaml # Knowledge Distillation
137 | python train.py --config configs/imagenet/plus.yaml # ImageNet+ training with reinforcements
138 | ```
139 |
140 | Hyperparameters such as batch size for ImageNet training are optimized for
141 | running on a single node with 8xA100 40GB GPUs. For
142 | CIFAR-100/Flowers-102/Food-101, the configurations are optimized for training
143 | on a single GPU.
144 |
145 | ## Reinforce ImageNet
146 |
147 | Follow the steps:
148 | - Download ImageNet data and set `data_path` in `$CFG_FILE`.
149 | - If needed, change the teacher in `$CFG_FILE` to a smaller architecture.
150 |
151 | ```shell
152 | python reinforce.py --config configs/imagenet/reinforce/randaug.yaml
153 | ```
154 |
155 | ## Reference
156 |
157 | If you found this code useful, please cite the following paper:
158 |
159 | @InProceedings{faghri2023reinforce,
160 | author = {Faghri, Fartash and Pouransari, Hadi and Mehta, Sachin and Farajtabar, Mehrdad and Farhadi, Ali and Rastegari, Mohammad and Tuzel, Oncel},
161 | title = {Reinforce Data, Multiply Impact: Improved Model Accuracy and Robustness with Dataset Reinforcement},
162 | booktitle = {Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)},
163 | month = {October},
164 | year = {2023},
165 | }
166 |
167 | ## License
168 | This sample code is released under the [LICENSE](LICENSE) terms.
169 |
--------------------------------------------------------------------------------
/configs/cifar100/erm.yaml:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023 Apple Inc. All rights reserved.
3 | #
4 | parameters:
5 | dataset: cifar100
6 | trainer: ERM
7 | arch: timm
8 | model:
9 | model_name: resnet50d
10 | pretrained: false
11 | num_classes: 100
12 | loss:
13 | label_smoothing: 0.1
14 | optim:
15 | name: sgd
16 | lr: 0.2
17 | momentum: 0.9
18 | weight_decay: 5.e-4
19 | warmup_length: 0
20 | no_decay_bn_filter_bias: false
21 | nesterov: false
22 | epochs: 1000
23 | save_freq: 100
24 | start_epoch: 0
25 | batch_size: 256
26 | print_freq: 100
27 | resume: ''
28 | evaluate: false
29 | pretrained: false
30 | dist_url: 'tcp://127.0.0.1:23333'
31 | dist_backend: 'nccl'
32 | # Single GPU training
33 | multiprocessing_distributed: false
34 | world_size: 1
35 | rank: 0
36 | workers: 8
37 | pin_memory: true
38 | persistent_workers: true
39 | seed: NULL
40 | gpu: 0
41 | download_data: false # Set to True to download
42 | data_path: ''
43 | artifact_path: ''
44 | image_augmentation:
45 | train:
46 | resize:
47 | size: 224
48 | random_crop:
49 | size: 224
50 | padding: 16
51 | rand_augment: # No horizontal flip when rand-augment is enabled
52 | enable: true
53 | to_tensor:
54 | enable: true
55 | normalize:
56 | mean: [0.507075159237, 0.4865488733149, 0.440917843367]
57 | std: [0.267334285879, 0.2564384629170, 0.276150471325]
58 | mixup:
59 | enable: true
60 | alpha: 0.2
61 | p: 1.0
62 | cutmix:
63 | enable: true
64 | alpha: 1.0
65 | p: 1.0
66 | val:
67 | resize:
68 | size: 224
69 | to_tensor:
70 | enable: true
71 | normalize:
72 | mean: [0.507075159237, 0.4865488733149, 0.440917843367]
73 | std: [0.267334285879, 0.2564384629170, 0.276150471325]
74 |
--------------------------------------------------------------------------------
/configs/cifar100/ft.yaml:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023 Apple Inc. All rights reserved.
3 | #
4 | parameters:
5 | dataset: cifar100
6 | trainer: ERM
7 | arch: timm
8 | model:
9 | model_name: resnet50d
10 | pretrained: false
11 | num_classes: 100
12 | # TODO: load timm checkpoint
13 | loss:
14 | label_smoothing: 0.1
15 | optim:
16 | name: sgd
17 | lr: 0.002
18 | momentum: 0.9
19 | weight_decay: 5.e-4
20 | warmup_length: 0
21 | no_decay_bn_filter_bias: false
22 | nesterov: false
23 | epochs: 1000
24 | save_freq: 100
25 | start_epoch: 0
26 | batch_size: 256
27 | print_freq: 100
28 | resume: ''
29 | evaluate: false
30 | pretrained: false
31 | dist_url: 'tcp://127.0.0.1:23333'
32 | dist_backend: 'nccl'
33 | # Single GPU training
34 | multiprocessing_distributed: false
35 | world_size: 1
36 | rank: 0
37 | workers: 8
38 | pin_memory: true
39 | persistent_workers: true
40 | seed: NULL
41 | gpu: 0
42 | download_data: false # Set to True to download
43 | data_path: ''
44 | artifact_path: ''
45 | image_augmentation:
46 | train:
47 | resize:
48 | size: 224
49 | random_crop:
50 | size: 224
51 | padding: 16
52 | rand_augment: # No horizontal flip when rand-augment is enabled
53 | enable: true
54 | to_tensor:
55 | enable: true
56 | normalize:
57 | mean: [0.507075159237, 0.4865488733149, 0.440917843367]
58 | std: [0.267334285879, 0.2564384629170, 0.276150471325]
59 | mixup:
60 | enable: true
61 | alpha: 0.2
62 | p: 1.0
63 | cutmix:
64 | enable: true
65 | alpha: 1.0
66 | p: 1.0
67 | val:
68 | resize:
69 | size: 224
70 | to_tensor:
71 | enable: true
72 | normalize:
73 | mean: [0.507075159237, 0.4865488733149, 0.440917843367]
74 | std: [0.267334285879, 0.2564384629170, 0.276150471325]
75 |
--------------------------------------------------------------------------------
/configs/cifar100/ft_plus.yaml:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023 Apple Inc. All rights reserved.
3 | #
4 | parameters:
5 | dataset: cifar100
6 | trainer: DR
7 | arch: timm
8 | model:
9 | model_name: resnet50d
10 | pretrained: false
11 | num_classes: 100
12 | # TODO: load timm checkpoint
13 | loss:
14 | label_smoothing: 0.1
15 | optim:
16 | name: sgd
17 | lr: 0.002
18 | momentum: 0.9
19 | weight_decay: 5.e-4
20 | warmup_length: 0
21 | no_decay_bn_filter_bias: false
22 | nesterov: false
23 | epochs: 1000
24 | save_freq: 100
25 | start_epoch: 0
26 | batch_size: 256
27 | print_freq: 100
28 | resume: ''
29 | evaluate: false
30 | pretrained: false
31 | dist_url: 'tcp://127.0.0.1:23333'
32 | dist_backend: 'nccl'
33 | # Single GPU training
34 | multiprocessing_distributed: false
35 | world_size: 1
36 | rank: 0
37 | workers: 8
38 | pin_memory: true
39 | persistent_workers: true
40 | seed: NULL
41 | gpu: 0
42 | download_data: false # Set to True to download
43 | data_path: ''
44 | artifact_path: ''
45 | image_augmentation:
46 | # Training transformations for non-reinforced samples
47 | train:
48 | resize:
49 | size: 224
50 | random_crop:
51 | size: 224
52 | padding: 16
53 | rand_augment: # No horizontal flip when rand-augment is enabled
54 | enable: true
55 | to_tensor:
56 | enable: true
57 | normalize:
58 | mean: [0.507075159237, 0.4865488733149, 0.440917843367]
59 | std: [0.267334285879, 0.2564384629170, 0.276150471325]
60 | val:
61 | resize:
62 | size: 224
63 | to_tensor:
64 | enable: true
65 | normalize:
66 | mean: [0.507075159237, 0.4865488733149, 0.440917843367]
67 | std: [0.267334285879, 0.2564384629170, 0.276150471325]
68 | reinforce:
69 | enable: true
70 | p: 0.99
71 | num_samples: NULL
72 | densify: smooth
73 | data_path: NULL
74 |
--------------------------------------------------------------------------------
/configs/cifar100/plus.yaml:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023 Apple Inc. All rights reserved.
3 | #
4 | parameters:
5 | dataset: cifar100
6 | trainer: DR
7 | arch: timm
8 | model:
9 | model_name: resnet50d
10 | pretrained: false
11 | num_classes: 100
12 | loss:
13 | label_smoothing: 0.1
14 | optim:
15 | name: sgd
16 | lr: 0.2
17 | momentum: 0.9
18 | weight_decay: 5.e-4
19 | warmup_length: 0
20 | no_decay_bn_filter_bias: false
21 | nesterov: false
22 | epochs: 1000
23 | save_freq: 100
24 | start_epoch: 0
25 | batch_size: 256
26 | print_freq: 100
27 | resume: ''
28 | evaluate: false
29 | pretrained: false
30 | dist_url: 'tcp://127.0.0.1:23333'
31 | dist_backend: 'nccl'
32 | # Single GPU training
33 | multiprocessing_distributed: false
34 | world_size: 1
35 | rank: 0
36 | workers: 8
37 | pin_memory: true
38 | persistent_workers: true
39 | seed: NULL
40 | gpu: 0
41 | download_data: false # Set to True to download
42 | data_path: ''
43 | artifact_path: ''
44 | image_augmentation:
45 | # Training transformations for non-reinforced samples
46 | train:
47 | resize:
48 | size: 224
49 | random_crop:
50 | size: 224
51 | padding: 16
52 | rand_augment: # No horizontal flip when rand-augment is enabled
53 | enable: true
54 | to_tensor:
55 | enable: true
56 | normalize:
57 | mean: [0.507075159237, 0.4865488733149, 0.440917843367]
58 | std: [0.267334285879, 0.2564384629170, 0.276150471325]
59 | val:
60 | resize:
61 | size: 224
62 | to_tensor:
63 | enable: true
64 | normalize:
65 | mean: [0.507075159237, 0.4865488733149, 0.440917843367]
66 | std: [0.267334285879, 0.2564384629170, 0.276150471325]
67 | reinforce:
68 | enable: true
69 | p: 0.99
70 | num_samples: NULL
71 | densify: smooth
72 | data_path: NULL
73 |
--------------------------------------------------------------------------------
/configs/cifar100/reinforce.yaml:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023 Apple Inc. All rights reserved.
3 | #
4 | parameters:
5 | dataset: cifar100
6 | teacher:
7 | ensemble: true
8 | # TODO: Load pretrained
9 | batch_size: 32
10 | print_freq: 10
11 | dist_url: 'tcp://127.0.0.1:23333'
12 | dist_backend: 'nccl'
13 | multiprocessing_distributed: true
14 | world_size: 1
15 | rank: 0
16 | workers: 88
17 | pin_memory: true
18 | persistent_workers: true
19 | seed: NULL
20 | gpu: NULL
21 | download_data: false # Set to True to download dataset
22 | data_path: ''
23 | artifact_path: ''
24 | reinforce:
25 | num_samples: 100
26 | num_candidates: NULL
27 | topk: 10
28 | compress: true
29 | joblib: false
30 | gzip: true
31 | image_augmentation:
32 | uint8:
33 | enable: true
34 | resize:
35 | size: 224
36 | random_crop:
37 | size: 224
38 | padding: 16
39 | random_horizontal_flip:
40 | enable: true
41 | p: 0.5
42 | rand_augment:
43 | enable: true
44 | p: 1.0
45 | to_tensor:
46 | enable: true
47 | normalize:
48 | mean: [0.507075159237, 0.4865488733149, 0.440917843367]
49 | std: [0.267334285879, 0.2564384629170, 0.276150471325]
50 | random_erase:
51 | enable: true
52 | p: 0.25
53 |
--------------------------------------------------------------------------------
/configs/cifar100/teacher.yaml:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023 Apple Inc. All rights reserved.
3 | #
4 | parameters:
5 | dataset: cifar100
6 | trainer: ERM
7 | arch: timm
8 | model:
9 | model_name: resnet50d
10 | pretrained: false
11 | num_classes: 100
12 | # TODO: load timm pretrained checkpoint
13 | loss:
14 | label_smoothing: 0.1
15 | optim:
16 | name: sgd
17 | lr: 0.002
18 | momentum: 0.9
19 | weight_decay: 5.e-4
20 | warmup_length: 0
21 | no_decay_bn_filter_bias: false
22 | nesterov: false
23 | epochs: 1000
24 | save_freq: 100
25 | start_epoch: 0
26 | batch_size: 256
27 | print_freq: 100
28 | resume: ''
29 | evaluate: false
30 | pretrained: false
31 | dist_url: 'tcp://127.0.0.1:23333'
32 | dist_backend: 'nccl'
33 | # Single GPU training
34 | multiprocessing_distributed: false
35 | world_size: 1
36 | rank: 0
37 | workers: 8
38 | pin_memory: true
39 | persistent_workers: true
40 | seed: NULL
41 | gpu: 0
42 | download_data: false # Set to True to download
43 | data_path: ''
44 | artifact_path: ''
45 | image_augmentation:
46 | train:
47 | resize:
48 | size: 224
49 | random_crop:
50 | size: 224
51 | padding: 16
52 | rand_augment: # No horizontal flip when rand-augment is enabled
53 | enable: true
54 | to_tensor:
55 | enable: true
56 | normalize:
57 | mean: [0.507075159237, 0.4865488733149, 0.440917843367]
58 | std: [0.267334285879, 0.2564384629170, 0.276150471325]
59 | mixup:
60 | enable: true
61 | alpha: 1.0 # Strong mixup
62 | p: 1.0
63 | cutmix:
64 | enable: true
65 | alpha: 1.0
66 | p: 1.0
67 | val:
68 | resize:
69 | size: 224
70 | to_tensor:
71 | enable: true
72 | normalize:
73 | mean: [0.507075159237, 0.4865488733149, 0.440917843367]
74 | std: [0.267334285879, 0.2564384629170, 0.276150471325]
75 |
--------------------------------------------------------------------------------
/configs/flowers102/erm.yaml:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023 Apple Inc. All rights reserved.
3 | #
4 | parameters:
5 | dataset: flowers102
6 | trainer: ERM
7 | arch: timm
8 | model:
9 | model_name: resnet50d
10 | pretrained: false
11 | num_classes: 102
12 | loss:
13 | label_smoothing: 0.1
14 | optim:
15 | name: sgd
16 | lr: 0.2
17 | momentum: 0.9
18 | weight_decay: 5.e-4
19 | warmup_length: 0
20 | no_decay_bn_filter_bias: false
21 | nesterov: false
22 | epochs: 10000
23 | save_freq: 100
24 | start_epoch: 0
25 | batch_size: 256
26 | print_freq: 100
27 | resume: ''
28 | evaluate: false
29 | pretrained: false
30 | dist_url: 'tcp://127.0.0.1:23333'
31 | dist_backend: 'nccl'
32 | # Single GPU training
33 | multiprocessing_distributed: false
34 | world_size: 1
35 | rank: 0
36 | workers: 8
37 | pin_memory: true
38 | persistent_workers: true
39 | seed: NULL
40 | gpu: 0
41 | download_data: false # Set to True to download
42 | data_path: ''
43 | artifact_path: ''
44 | image_augmentation:
45 | train:
46 | random_resized_crop:
47 | size: 224
48 | rand_augment: # No horizontal flip when rand-augment is enabled
49 | enable: true
50 | to_tensor:
51 | enable: true
52 | normalize:
53 | mean: [0.485, 0.456, 0.406]
54 | std: [0.229, 0.224, 0.225]
55 | mixup:
56 | enable: true
57 | alpha: 0.2
58 | p: 1.0
59 | cutmix:
60 | enable: true
61 | alpha: 1.0
62 | p: 1.0
63 | val:
64 | resize:
65 | size: 256
66 | center_crop:
67 | size: 224
68 | to_tensor:
69 | enable: true
70 | normalize:
71 | mean: [0.485, 0.456, 0.406]
72 | std: [0.229, 0.224, 0.225]
73 |
--------------------------------------------------------------------------------
/configs/flowers102/ft.yaml:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023 Apple Inc. All rights reserved.
3 | #
4 | parameters:
5 | dataset: flowers102
6 | trainer: ERM
7 | arch: timm
8 | model:
9 | model_name: resnet50d
10 | pretrained: false
11 | num_classes: 102
12 | # TODO: load pretrained timm checkpoint
13 | loss:
14 | label_smoothing: 0.1
15 | optim:
16 | name: sgd
17 | lr: 0.002
18 | momentum: 0.9
19 | weight_decay: 5.e-4
20 | warmup_length: 0
21 | no_decay_bn_filter_bias: false
22 | nesterov: false
23 | epochs: 10000
24 | save_freq: 100
25 | start_epoch: 0
26 | batch_size: 256
27 | print_freq: 100
28 | resume: ''
29 | evaluate: false
30 | pretrained: false
31 | dist_url: 'tcp://127.0.0.1:23333'
32 | dist_backend: 'nccl'
33 | # Single GPU training
34 | multiprocessing_distributed: false
35 | world_size: 1
36 | rank: 0
37 | workers: 8
38 | pin_memory: true
39 | persistent_workers: true
40 | seed: NULL
41 | gpu: 0
42 | download_data: false # Set to True to download
43 | data_path: ''
44 | artifact_path: ''
45 | image_augmentation:
46 | train:
47 | random_resized_crop:
48 | size: 224
49 | rand_augment: # No horizontal flip when rand-augment is enabled
50 | enable: true
51 | to_tensor:
52 | enable: true
53 | normalize:
54 | mean: [0.485, 0.456, 0.406]
55 | std: [0.229, 0.224, 0.225]
56 | mixup:
57 | enable: true
58 | alpha: 0.2
59 | p: 1.0
60 | cutmix:
61 | enable: true
62 | alpha: 1.0
63 | p: 1.0
64 | val:
65 | resize:
66 | size: 256
67 | center_crop:
68 | size: 224
69 | to_tensor:
70 | enable: true
71 | normalize:
72 | mean: [0.485, 0.456, 0.406]
73 | std: [0.229, 0.224, 0.225]
74 |
--------------------------------------------------------------------------------
/configs/flowers102/ft_plus.yaml:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023 Apple Inc. All rights reserved.
3 | #
4 | parameters:
5 | dataset: flowers102
6 | trainer: DR
7 | arch: timm
8 | model:
9 | model_name: resnet50d
10 | pretrained: false
11 | num_classes: 102
12 | # TODO: load pretrained timm checkpoint
13 | loss:
14 | label_smoothing: 0.1
15 | optim:
16 | name: sgd
17 | lr: 0.002
18 | momentum: 0.9
19 | weight_decay: 5.e-4
20 | warmup_length: 0
21 | no_decay_bn_filter_bias: false
22 | nesterov: false
23 | epochs: 10000
24 | save_freq: 100
25 | start_epoch: 0
26 | batch_size: 256
27 | print_freq: 100
28 | resume: ''
29 | evaluate: false
30 | pretrained: false
31 | dist_url: 'tcp://127.0.0.1:23333'
32 | dist_backend: 'nccl'
33 | # Single GPU training
34 | multiprocessing_distributed: false
35 | world_size: 1
36 | rank: 0
37 | workers: 8
38 | pin_memory: true
39 | persistent_workers: true
40 | seed: NULL
41 | gpu: 0
42 | download_data: false # Set to True to download
43 | data_path: ''
44 | artifact_path: ''
45 | image_augmentation:
46 | train:
47 | random_resized_crop:
48 | size: 224
49 | rand_augment: # No horizontal flip when rand-augment is enabled
50 | enable: true
51 | to_tensor:
52 | enable: true
53 | normalize:
54 | mean: [0.485, 0.456, 0.406]
55 | std: [0.229, 0.224, 0.225]
56 | val:
57 | resize:
58 | size: 256
59 | center_crop:
60 | size: 224
61 | to_tensor:
62 | enable: true
63 | normalize:
64 | mean: [0.485, 0.456, 0.406]
65 | std: [0.229, 0.224, 0.225]
66 | reinforce:
67 | enable: true
68 | p: 0.99
69 | num_samples: NULL
70 | densify: smooth
71 | data_path: NULL
72 |
--------------------------------------------------------------------------------
/configs/flowers102/plus.yaml:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023 Apple Inc. All rights reserved.
3 | #
4 | parameters:
5 | dataset: flowers102
6 | trainer: DR
7 | arch: timm
8 | model:
9 | model_name: resnet50d
10 | pretrained: false
11 | num_classes: 102
12 | loss:
13 | label_smoothing: 0.1
14 | optim:
15 | name: sgd
16 | lr: 0.2
17 | momentum: 0.9
18 | weight_decay: 5.e-4
19 | warmup_length: 0
20 | no_decay_bn_filter_bias: false
21 | nesterov: false
22 | epochs: 10000
23 | save_freq: 100
24 | start_epoch: 0
25 | batch_size: 256
26 | print_freq: 100
27 | resume: ''
28 | evaluate: false
29 | pretrained: false
30 | dist_url: 'tcp://127.0.0.1:23333'
31 | dist_backend: 'nccl'
32 | # Single GPU training
33 | multiprocessing_distributed: false
34 | world_size: 1
35 | rank: 0
36 | workers: 8
37 | pin_memory: true
38 | persistent_workers: true
39 | seed: NULL
40 | gpu: 0
41 | download_data: false # Set to True to download
42 | data_path: ''
43 | artifact_path: ''
44 | image_augmentation:
45 | train:
46 | random_resized_crop:
47 | size: 224
48 | rand_augment: # No horizontal flip when rand-augment is enabled
49 | enable: true
50 | to_tensor:
51 | enable: true
52 | normalize:
53 | mean: [0.485, 0.456, 0.406]
54 | std: [0.229, 0.224, 0.225]
55 | val:
56 | resize:
57 | size: 256
58 | center_crop:
59 | size: 224
60 | to_tensor:
61 | enable: true
62 | normalize:
63 | mean: [0.485, 0.456, 0.406]
64 | std: [0.229, 0.224, 0.225]
65 | reinforce:
66 | enable: true
67 | p: 0.99
68 | num_samples: NULL
69 | densify: smooth
70 | data_path: NULL
71 |
--------------------------------------------------------------------------------
/configs/flowers102/reinforce.yaml:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023 Apple Inc. All rights reserved.
3 | #
4 | parameters:
5 | dataset: flowers102
6 | teacher:
7 | ensemble: true
8 | # TODO: Load pretrained
9 | batch_size: 16
10 | print_freq: 10
11 | dist_url: 'tcp://127.0.0.1:23333'
12 | dist_backend: 'nccl'
13 | multiprocessing_distributed: true
14 | world_size: 1
15 | rank: 0
16 | workers: 88
17 | pin_memory: true
18 | persistent_workers: true
19 | seed: NULL
20 | gpu: NULL
21 | download_data: false # Set to True to download dataset
22 | data_path: ''
23 | artifact_path: ''
24 | reinforce:
25 | num_samples: 1000
26 | num_candidates: NULL
27 | topk: 10
28 | compress: true
29 | joblib: false
30 | gzip: true
31 | image_augmentation:
32 | uint8:
33 | enable: true
34 | random_resize_crop:
35 | enable: true
36 | size: 224
37 | random_horizontal_flip:
38 | enable: true
39 | p: 0.5
40 | rand_augment:
41 | enable: true
42 | p: 1.0
43 | to_tensor:
44 | enable: true
45 | normalize:
46 | mean: [0.485, 0.456, 0.406]
47 | std: [0.229, 0.224, 0.225]
48 | random_erase:
49 | enable: true
50 | p: 0.25
51 |
--------------------------------------------------------------------------------
/configs/flowers102/teacher.yaml:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023 Apple Inc. All rights reserved.
3 | #
4 | parameters:
5 | dataset: flowers102
6 | trainer: ERM
7 | arch: timm
8 | model:
9 | model_name: resnet50d
10 | pretrained: false
11 | num_classes: 102
12 | # TODO: load timm pretrained checkpoint
13 | loss:
14 | label_smoothing: 0.1
15 | optim:
16 | name: sgd
17 | lr: 0.002
18 | momentum: 0.9
19 | weight_decay: 5.e-4
20 | warmup_length: 0
21 | no_decay_bn_filter_bias: false
22 | nesterov: false
23 | epochs: 10000
24 | save_freq: 100
25 | start_epoch: 0
26 | batch_size: 256
27 | print_freq: 100
28 | resume: ''
29 | evaluate: false
30 | pretrained: false
31 | dist_url: 'tcp://127.0.0.1:23333'
32 | dist_backend: 'nccl'
33 | # Single GPU training
34 | multiprocessing_distributed: false
35 | world_size: 1
36 | rank: 0
37 | workers: 8
38 | pin_memory: true
39 | persistent_workers: true
40 | seed: NULL
41 | gpu: 0
42 | download_data: false # Set to True to download
43 | data_path: ''
44 | artifact_path: ''
45 | image_augmentation:
46 | train:
47 | random_resized_crop:
48 | size: 224
49 | rand_augment: # No horizontal flip when rand-augment is enabled
50 | enable: true
51 | to_tensor:
52 | enable: true
53 | normalize:
54 | mean: [0.485, 0.456, 0.406]
55 | std: [0.229, 0.224, 0.225]
56 | mixup:
57 | enable: true
58 | alpha: 1.0 # Strong mixup
59 | p: 1.0
60 | cutmix:
61 | enable: true
62 | alpha: 1.0
63 | p: 1.0
64 | val:
65 | resize:
66 | size: 256
67 | center_crop:
68 | size: 224
69 | to_tensor:
70 | enable: true
71 | normalize:
72 | mean: [0.485, 0.456, 0.406]
73 | std: [0.229, 0.224, 0.225]
74 |
--------------------------------------------------------------------------------
/configs/food101/erm.yaml:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023 Apple Inc. All rights reserved.
3 | #
4 | parameters:
5 | dataset: food101
6 | trainer: ERM
7 | arch: timm
8 | model:
9 | model_name: resnet50d
10 | pretrained: false
11 | num_classes: 101
12 | loss:
13 | label_smoothing: 0.1
14 | optim:
15 | name: sgd
16 | lr: 0.2
17 | momentum: 0.9
18 | weight_decay: 5.e-4
19 | warmup_length: 0
20 | no_decay_bn_filter_bias: false
21 | nesterov: false
22 | epochs: 1000
23 | save_freq: 100
24 | start_epoch: 0
25 | batch_size: 256
26 | print_freq: 100
27 | resume: ''
28 | evaluate: false
29 | pretrained: false
30 | dist_url: 'tcp://127.0.0.1:23333'
31 | dist_backend: 'nccl'
32 | # Single GPU training
33 | multiprocessing_distributed: false
34 | world_size: 1
35 | rank: 0
36 | workers: 8
37 | pin_memory: true
38 | persistent_workers: true
39 | seed: NULL
40 | gpu: 0
41 | download_data: false # Set to True to download
42 | data_path: ''
43 | artifact_path: ''
44 | image_augmentation:
45 | train:
46 | random_resized_crop:
47 | size: 224
48 | rand_augment: # No horizontal flip when rand-augment is enabled
49 | enable: true
50 | to_tensor:
51 | enable: true
52 | normalize:
53 | mean: [0.485, 0.456, 0.406]
54 | std: [0.229, 0.224, 0.225]
55 | mixup:
56 | enable: true
57 | alpha: 0.2
58 | p: 1.0
59 | cutmix:
60 | enable: true
61 | alpha: 1.0
62 | p: 1.0
63 | val:
64 | resize:
65 | size: 256
66 | center_crop:
67 | size: 224
68 | to_tensor:
69 | enable: true
70 | normalize:
71 | mean: [0.485, 0.456, 0.406]
72 | std: [0.229, 0.224, 0.225]
73 |
--------------------------------------------------------------------------------
/configs/food101/ft.yaml:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023 Apple Inc. All rights reserved.
3 | #
4 | parameters:
5 | dataset: food101
6 | trainer: ERM
7 | arch: timm
8 | model:
9 | model_name: resnet50d
10 | pretrained: false
11 | num_classes: 101
12 | # TODO: load pretrained timm checkpoint
13 | loss:
14 | label_smoothing: 0.1
15 | optim:
16 | name: sgd
17 | lr: 0.002
18 | momentum: 0.9
19 | weight_decay: 5.e-4
20 | warmup_length: 0
21 | no_decay_bn_filter_bias: false
22 | nesterov: false
23 | epochs: 1000
24 | save_freq: 100
25 | start_epoch: 0
26 | batch_size: 256
27 | print_freq: 100
28 | resume: ''
29 | evaluate: false
30 | pretrained: false
31 | dist_url: 'tcp://127.0.0.1:23333'
32 | dist_backend: 'nccl'
33 | # Single GPU training
34 | multiprocessing_distributed: false
35 | world_size: 1
36 | rank: 0
37 | workers: 8
38 | pin_memory: true
39 | persistent_workers: true
40 | seed: NULL
41 | gpu: 0
42 | download_data: false # Set to True to download
43 | data_path: ''
44 | artifact_path: ''
45 | image_augmentation:
46 | train:
47 | random_resized_crop:
48 | size: 224
49 | rand_augment: # No horizontal flip when rand-augment is enabled
50 | enable: true
51 | to_tensor:
52 | enable: true
53 | normalize:
54 | mean: [0.485, 0.456, 0.406]
55 | std: [0.229, 0.224, 0.225]
56 | mixup:
57 | enable: true
58 | alpha: 0.2
59 | p: 1.0
60 | cutmix:
61 | enable: true
62 | alpha: 1.0
63 | p: 1.0
64 | val:
65 | resize:
66 | size: 256
67 | center_crop:
68 | size: 224
69 | to_tensor:
70 | enable: true
71 | normalize:
72 | mean: [0.485, 0.456, 0.406]
73 | std: [0.229, 0.224, 0.225]
74 |
--------------------------------------------------------------------------------
/configs/food101/ft_plus.yaml:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023 Apple Inc. All rights reserved.
3 | #
4 | parameters:
5 | dataset: food101
6 | trainer: DR
7 | arch: timm
8 | model:
9 | model_name: resnet50d
10 | pretrained: false
11 | num_classes: 101
12 | # TODO: load pretrained timm checkpoint
13 | loss:
14 | label_smoothing: 0.1
15 | optim:
16 | name: sgd
17 | lr: 0.002
18 | momentum: 0.9
19 | weight_decay: 5.e-4
20 | warmup_length: 0
21 | no_decay_bn_filter_bias: false
22 | nesterov: false
23 | epochs: 1000
24 | save_freq: 100
25 | start_epoch: 0
26 | batch_size: 256
27 | print_freq: 100
28 | resume: ''
29 | evaluate: false
30 | pretrained: false
31 | dist_url: 'tcp://127.0.0.1:23333'
32 | dist_backend: 'nccl'
33 | # Single GPU training
34 | multiprocessing_distributed: false
35 | world_size: 1
36 | rank: 0
37 | workers: 8
38 | pin_memory: true
39 | persistent_workers: true
40 | seed: NULL
41 | gpu: 0
42 | download_data: false # Set to True to download
43 | data_path: ''
44 | artifact_path: ''
45 | image_augmentation:
46 | train:
47 | random_resized_crop:
48 | size: 224
49 | rand_augment: # No horizontal flip when rand-augment is enabled
50 | enable: true
51 | to_tensor:
52 | enable: true
53 | normalize:
54 | mean: [0.485, 0.456, 0.406]
55 | std: [0.229, 0.224, 0.225]
56 | val:
57 | resize:
58 | size: 256
59 | center_crop:
60 | size: 224
61 | to_tensor:
62 | enable: true
63 | normalize:
64 | mean: [0.485, 0.456, 0.406]
65 | std: [0.229, 0.224, 0.225]
66 | reinforce:
67 | enable: true
68 | p: 0.99
69 | num_samples: NULL
70 | densify: smooth
71 | data_path: NULL
72 |
--------------------------------------------------------------------------------
/configs/food101/plus.yaml:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023 Apple Inc. All rights reserved.
3 | #
4 | parameters:
5 | dataset: food101
6 | trainer: DR
7 | arch: timm
8 | model:
9 | model_name: resnet50d
10 | pretrained: false
11 | num_classes: 101
12 | loss:
13 | label_smoothing: 0.1
14 | optim:
15 | name: sgd
16 | lr: 0.2
17 | momentum: 0.9
18 | weight_decay: 5.e-4
19 | warmup_length: 0
20 | no_decay_bn_filter_bias: false
21 | nesterov: false
22 | epochs: 1000
23 | save_freq: 100
24 | start_epoch: 0
25 | batch_size: 256
26 | print_freq: 100
27 | resume: ''
28 | evaluate: false
29 | pretrained: false
30 | dist_url: 'tcp://127.0.0.1:23333'
31 | dist_backend: 'nccl'
32 | # Single GPU training
33 | multiprocessing_distributed: false
34 | world_size: 1
35 | rank: 0
36 | workers: 8
37 | pin_memory: true
38 | persistent_workers: true
39 | seed: NULL
40 | gpu: 0
41 | download_data: false # Set to True to download
42 | data_path: ''
43 | artifact_path: ''
44 | image_augmentation:
45 | train:
46 | random_resized_crop:
47 | size: 224
48 | rand_augment: # No horizontal flip when rand-augment is enabled
49 | enable: true
50 | to_tensor:
51 | enable: true
52 | normalize:
53 | mean: [0.485, 0.456, 0.406]
54 | std: [0.229, 0.224, 0.225]
55 | val:
56 | resize:
57 | size: 256
58 | center_crop:
59 | size: 224
60 | to_tensor:
61 | enable: true
62 | normalize:
63 | mean: [0.485, 0.456, 0.406]
64 | std: [0.229, 0.224, 0.225]
65 | reinforce:
66 | enable: true
67 | p: 0.99
68 | num_samples: NULL
69 | densify: smooth
70 | data_path: NULL
71 |
--------------------------------------------------------------------------------
/configs/food101/reinforce.yaml:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023 Apple Inc. All rights reserved.
3 | #
4 | parameters:
5 | dataset: food101
6 | teacher:
7 | ensemble: true
8 | # TODO: Load pretrained
9 | batch_size: 32
10 | print_freq: 10
11 | dist_url: 'tcp://127.0.0.1:23333'
12 | dist_backend: 'nccl'
13 | multiprocessing_distributed: true
14 | world_size: 1
15 | rank: 0
16 | workers: 88
17 | pin_memory: true
18 | persistent_workers: true
19 | seed: NULL
20 | gpu: NULL
21 | download_data: false # Set to True to download dataset
22 | data_path: ''
23 | artifact_path: ''
24 | reinforce:
25 | num_samples: 100
26 | num_candidates: NULL
27 | topk: 10
28 | compress: true
29 | joblib: false
30 | gzip: true
31 | image_augmentation:
32 | uint8:
33 | enable: true
34 | random_resize_crop:
35 | enable: true
36 | size: 224
37 | random_horizontal_flip:
38 | enable: true
39 | p: 0.5
40 | rand_augment:
41 | enable: true
42 | p: 1.0
43 | to_tensor:
44 | enable: true
45 | normalize:
46 | mean: [0.485, 0.456, 0.406]
47 | std: [0.229, 0.224, 0.225]
48 | random_erase:
49 | enable: true
50 | p: 0.25
51 |
--------------------------------------------------------------------------------
/configs/food101/teacher.yaml:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023 Apple Inc. All rights reserved.
3 | #
4 | parameters:
5 | dataset: food101
6 | trainer: ERM
7 | arch: timm
8 | model:
9 | model_name: resnet50d
10 | pretrained: false
11 | num_classes: 101
12 | # TODO: load timm pretrained checkpoint
13 | loss:
14 | label_smoothing: 0.1
15 | optim:
16 | name: sgd
17 | lr: 0.002
18 | momentum: 0.9
19 | weight_decay: 5.e-4
20 | warmup_length: 0
21 | no_decay_bn_filter_bias: false
22 | nesterov: false
23 | epochs: 1000
24 | save_freq: 100
25 | start_epoch: 0
26 | batch_size: 256
27 | print_freq: 100
28 | resume: ''
29 | evaluate: false
30 | pretrained: false
31 | dist_url: 'tcp://127.0.0.1:23333'
32 | dist_backend: 'nccl'
33 | # Single GPU training
34 | multiprocessing_distributed: false
35 | world_size: 1
36 | rank: 0
37 | workers: 8
38 | pin_memory: true
39 | persistent_workers: true
40 | seed: NULL
41 | gpu: 0
42 | download_data: false # Set to True to download
43 | data_path: ''
44 | artifact_path: ''
45 | image_augmentation:
46 | train:
47 | random_resized_crop:
48 | size: 224
49 | rand_augment: # No horizontal flip when rand-augment is enabled
50 | enable: true
51 | to_tensor:
52 | enable: true
53 | normalize:
54 | mean: [0.485, 0.456, 0.406]
55 | std: [0.229, 0.224, 0.225]
56 | mixup:
57 | enable: true
58 | alpha: 1.0 # Strong mixup
59 | p: 1.0
60 | cutmix:
61 | enable: true
62 | alpha: 1.0
63 | p: 1.0
64 | val:
65 | resize:
66 | size: 256
67 | center_crop:
68 | size: 224
69 | to_tensor:
70 | enable: true
71 | normalize:
72 | mean: [0.485, 0.456, 0.406]
73 | std: [0.229, 0.224, 0.225]
74 |
--------------------------------------------------------------------------------
/configs/imagenet/erm.yaml:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023 Apple Inc. All rights reserved.
3 | #
4 | parameters:
5 | dataset: imagenet
6 | trainer: ERM
7 | arch: timm
8 | model:
9 | model_name: resnet50d
10 | pretrained: false
11 | loss:
12 | label_smoothing: 0.1
13 | optim:
14 | name: sgd
15 | lr: 0.4
16 | momentum: 0.9
17 | weight_decay: 1.e-4
18 | warmup_length: 5
19 | epochs: 150
20 | save_freq: 50
21 | start_epoch: 0
22 | batch_size: 1024
23 | print_freq: 100
24 | resume: ''
25 | evaluate: false
26 | pretrained: false
27 | dist_url: 'tcp://127.0.0.1:23333'
28 | dist_backend: 'nccl'
29 | # Multi-GPU training
30 | multiprocessing_distributed: true
31 | world_size: 1
32 | rank: 0
33 | workers: 88
34 | pin_memory: true
35 | persistent_workers: true
36 | seed: NULL
37 | gpu: NULL
38 | download_data: false # Set to True to download
39 | data_path: ''
40 | artifact_path: ''
41 | image_augmentation:
42 | train:
43 | random_resized_crop:
44 | size: 224
45 | rand_augment: # No horizontal flip when rand-augment is enabled
46 | enable: true
47 | to_tensor:
48 | enable: true
49 | normalize:
50 | mean: [0.485, 0.456, 0.406]
51 | std: [0.229, 0.224, 0.225]
52 | random_erase:
53 | enable: true
54 | p: 0.25
55 | mixup:
56 | enable: true
57 | alpha: 0.2
58 | p: 1.0
59 | cutmix:
60 | enable: true
61 | alpha: 1.0
62 | p: 1.0
63 | val:
64 | resize:
65 | size: 256
66 | center_crop:
67 | size: 224
68 | to_tensor:
69 | enable: true
70 | normalize:
71 | mean: [0.485, 0.456, 0.406]
72 | std: [0.229, 0.224, 0.225]
73 |
--------------------------------------------------------------------------------
/configs/imagenet/kd.yaml:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023 Apple Inc. All rights reserved.
3 | #
4 | parameters:
5 | dataset: imagenet
6 | trainer: KD
7 | teacher:
8 | timm_ensemble: true
9 | name: ig_resnext101_32x48d,ig_resnext101_32x32d,ig_resnext101_32x16d,ig_resnext101_32x8d
10 | validate: false
11 | student:
12 | arch: timm
13 | model:
14 | model_name: resnet50d
15 | pretrained: false
16 | loss:
17 | loss_type: kl
18 | lambda_cls: 0.0
19 | lambda_kd: 1.0
20 | temperature: 1.0
21 | optim:
22 | name: sgd
23 | lr: 0.4
24 | momentum: 0.9
25 | weight_decay: 1.e-5 # Smaller weight decay for KD
26 | warmup_length: 5
27 | epochs: 150
28 | save_freq: 50
29 | start_epoch: 0
30 | batch_size: 1024
31 | print_freq: 100
32 | resume: ''
33 | evaluate: false
34 | pretrained: false
35 | dist_url: 'tcp://127.0.0.1:23333'
36 | dist_backend: 'nccl'
37 | # Multi-GPU training
38 | multiprocessing_distributed: true
39 | world_size: 1
40 | rank: 0
41 | workers: 88
42 | pin_memory: true
43 | persistent_workers: true
44 | seed: NULL
45 | gpu: NULL
46 | download_data: false # Set to True to download
47 | data_path: ''
48 | artifact_path: ''
49 | image_augmentation:
50 | train:
51 | random_resized_crop:
52 | size: 224
53 | timm_resize_crop_norm:
54 | enable: true
55 | name: ig_resnext101_32x48d
56 | rand_augment:
57 | enable: true
58 | to_tensor:
59 | enable: true
60 | normalize:
61 | mean: [0.485, 0.456, 0.406]
62 | std: [0.229, 0.224, 0.225]
63 | random_erase:
64 | enable: true
65 | p: 0.25
66 | mixup:
67 | enable: true
68 | alpha: 1.0 # Stronger mixup for KD
69 | p: 1.0
70 | cutmix:
71 | enable: true
72 | alpha: 1.0
73 | p: 1.0
74 | val:
75 | resize:
76 | size: 256
77 | center_crop:
78 | size: 224
79 | to_tensor:
80 | enable: true
81 | normalize:
82 | mean: [0.485, 0.456, 0.406]
83 | std: [0.229, 0.224, 0.225]
84 |
--------------------------------------------------------------------------------
/configs/imagenet/plus.yaml:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023 Apple Inc. All rights reserved.
3 | #
4 | parameters:
5 | dataset: imagenet
6 | trainer: DR
7 | arch: timm
8 | model:
9 | model_name: resnet50d
10 | pretrained: false
11 | loss:
12 | label_smoothing: 0.1
13 | optim:
14 | name: sgd
15 | lr: 0.4
16 | momentum: 0.9
17 | weight_decay: 1.e-4
18 | warmup_length: 5
19 | epochs: 150
20 | save_freq: 50
21 | start_epoch: 0
22 | batch_size: 1024
23 | print_freq: 100
24 | resume: ''
25 | evaluate: false
26 | pretrained: false
27 | dist_url: 'tcp://127.0.0.1:23333'
28 | dist_backend: 'nccl'
29 | # Multi-GPU training
30 | multiprocessing_distributed: true
31 | world_size: 1
32 | rank: 0
33 | workers: 88
34 | pin_memory: true
35 | persistent_workers: true
36 | seed: NULL
37 | gpu: NULL
38 | download_data: false # Set to True to download
39 | data_path: ''
40 | artifact_path: ''
41 | image_augmentation:
42 | # Training transformations for non-reinforced samples
43 | train:
44 | random_resized_crop:
45 | size: 224
46 | rand_augment:
47 | enable: true
48 | to_tensor:
49 | enable: true
50 | normalize:
51 | mean: [0.485, 0.456, 0.406]
52 | std: [0.229, 0.224, 0.225]
53 | random_erase:
54 | enable: true
55 | p: 0.25
56 | val:
57 | resize:
58 | size: 256
59 | center_crop:
60 | size: 224
61 | to_tensor:
62 | enable: true
63 | normalize:
64 | mean: [0.485, 0.456, 0.406]
65 | std: [0.229, 0.224, 0.225]
66 | # Reinforcement configs
67 | reinforce:
68 | enable: true
69 | p: 0.99
70 | num_samples: NULL
71 | densify: smooth
72 | data_path: NULL
73 |
--------------------------------------------------------------------------------
/configs/imagenet/reinforce/cropflip.yaml:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023 Apple Inc. All rights reserved.
3 | #
4 | parameters:
5 | dataset: imagenet
6 | teacher:
7 | timm_ensemble: true
8 | name: 'ig_resnext101_32x8d,ig_resnext101_32x16d,ig_resnext101_32x32d,ig_resnext101_32x48d'
9 | batch_size: 32
10 | print_freq: 10
11 | dist_url: 'tcp://127.0.0.1:23333'
12 | dist_backend: 'nccl'
13 | multiprocessing_distributed: true
14 | world_size: 1
15 | rank: 0
16 | workers: 88
17 | pin_memory: true
18 | seed: NULL
19 | gpu: NULL
20 | download_data: false # Set to True to download dataset
21 | data_path: ''
22 | artifact_path: ''
23 | reinforce:
24 | num_samples: 50
25 | num_candidates: NULL
26 | topk: 10
27 | compress: true
28 | joblib: false
29 | gzip: true
30 | image_augmentation:
31 | uint8:
32 | enable: true
33 | random_resized_crop:
34 | enable: true
35 | size: 224
36 | random_horizontal_flip:
37 | enable: true
38 | p: 0.5
39 | to_tensor:
40 | enable: true
41 | normalize:
42 | mean: [0.485, 0.456, 0.406]
43 | std: [0.229, 0.224, 0.225]
44 |
--------------------------------------------------------------------------------
/configs/imagenet/reinforce/mixing.yaml:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023 Apple Inc. All rights reserved.
3 | #
4 | parameters:
5 | dataset: imagenet
6 | teacher:
7 | timm_ensemble: true
8 | name: 'ig_resnext101_32x8d,ig_resnext101_32x16d,ig_resnext101_32x32d,ig_resnext101_32x48d'
9 | batch_size: 32
10 | print_freq: 10
11 | dist_url: 'tcp://127.0.0.1:23333'
12 | dist_backend: 'nccl'
13 | multiprocessing_distributed: true
14 | world_size: 1
15 | rank: 0
16 | workers: 88
17 | pin_memory: true
18 | seed: NULL
19 | gpu: NULL
20 | download_data: false # Set to True to download dataset
21 | data_path: ''
22 | artifact_path: ''
23 | reinforce:
24 | num_samples: 50
25 | num_candidates: NULL
26 | topk: 10
27 | compress: true
28 | joblib: false
29 | gzip: true
30 | image_augmentation:
31 | uint8:
32 | enable: true
33 | random_resized_crop:
34 | enable: true
35 | size: 224
36 | random_horizontal_flip:
37 | enable: true
38 | p: 0.5
39 | to_tensor:
40 | enable: true
41 | normalize:
42 | mean: [0.485, 0.456, 0.406]
43 | std: [0.229, 0.224, 0.225]
44 | mixup:
45 | enable: true
46 | alpha: 1.0
47 | div_by: 2.0
48 | p: 0.5
49 | cutmix:
50 | enable: true
51 | alpha: 1.0
52 | p: 0.5
53 |
--------------------------------------------------------------------------------
/configs/imagenet/reinforce/randaug.yaml:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023 Apple Inc. All rights reserved.
3 | #
4 | parameters:
5 | dataset: imagenet
6 | teacher:
7 | timm_ensemble: true
8 | name: 'ig_resnext101_32x8d,ig_resnext101_32x16d,ig_resnext101_32x32d,ig_resnext101_32x48d'
9 | batch_size: 32
10 | print_freq: 10
11 | dist_url: 'tcp://127.0.0.1:23333'
12 | dist_backend: 'nccl'
13 | multiprocessing_distributed: true
14 | world_size: 1
15 | rank: 0
16 | workers: 88
17 | pin_memory: true
18 | seed: NULL
19 | gpu: NULL
20 | download_data: false # Set to True to download dataset
21 | data_path: ''
22 | artifact_path: ''
23 | reinforce:
24 | num_samples: 50
25 | num_candidates: NULL
26 | topk: 10
27 | compress: true
28 | joblib: false
29 | gzip: true
30 | image_augmentation:
31 | uint8:
32 | enable: true
33 | random_resized_crop:
34 | enable: true
35 | size: 224
36 | random_horizontal_flip:
37 | enable: true
38 | p: 0.5
39 | rand_augment:
40 | enable: true
41 | p: 1.0
42 | to_tensor:
43 | enable: true
44 | normalize:
45 | mean: [0.485, 0.456, 0.406]
46 | std: [0.229, 0.224, 0.225]
47 | random_erase:
48 | enable: true
49 | p: 0.25
50 |
--------------------------------------------------------------------------------
/configs/imagenet/reinforce/randaug_mixing.yaml:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023 Apple Inc. All rights reserved.
3 | #
4 | parameters:
5 | dataset: imagenet
6 | teacher:
7 | timm_ensemble: true
8 | name: 'ig_resnext101_32x8d,ig_resnext101_32x16d,ig_resnext101_32x32d,ig_resnext101_32x48d'
9 | batch_size: 32
10 | print_freq: 10
11 | dist_url: 'tcp://127.0.0.1:23333'
12 | dist_backend: 'nccl'
13 | multiprocessing_distributed: true
14 | world_size: 1
15 | rank: 0
16 | workers: 88
17 | pin_memory: true
18 | seed: NULL
19 | gpu: NULL
20 | download_data: false # Set to True to download dataset
21 | data_path: ''
22 | artifact_path: ''
23 | reinforce:
24 | num_samples: 50
25 | num_candidates: NULL
26 | topk: 10
27 | compress: true
28 | joblib: false
29 | gzip: true
30 | image_augmentation:
31 | uint8:
32 | enable: true
33 | random_resized_crop:
34 | enable: true
35 | size: 224
36 | random_horizontal_flip:
37 | enable: true
38 | p: 0.5
39 | rand_augment:
40 | enable: true
41 | p: 1.0
42 | to_tensor:
43 | enable: true
44 | normalize:
45 | mean: [0.485, 0.456, 0.406]
46 | std: [0.229, 0.224, 0.225]
47 | random_erase:
48 | enable: true
49 | p: 0.25
50 | mixup:
51 | enable: true
52 | alpha: 1.0
53 | div_by: 2.0
54 | p: 0.5
55 | cutmix:
56 | enable: true
57 | alpha: 1.0
58 | p: 0.5
59 |
--------------------------------------------------------------------------------
/data.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023 Apple Inc. All rights reserved.
3 | #
4 | """Methods to initialize and create dataset loaders."""
5 | import os
6 | import logging
7 | from typing import Any, Dict, Tuple
8 |
9 | from torch.utils.data import Dataset
10 | import torchvision.datasets as datasets
11 | import transforms
12 |
13 |
14 | def download_dataset(data_path: str, dataset: str) -> None:
15 | """Download dataset prior to spawning workers.
16 |
17 | Args:
18 | data_path: Path to the root of the dataset.
19 | dataset: The name of the dataset.
20 | """
21 | if dataset == "imagenet":
22 | # ImageNet data requires manual download
23 | traindir = os.path.join(data_path, "training")
24 | valdir = os.path.join(data_path, "validation")
25 | assert os.path.isdir(
26 | traindir
27 | ), "Please download ImageNet training set to {}.".format(traindir)
28 | assert os.path.isdir(
29 | valdir
30 | ), "Please download ImageNet validation set to {}.".format(valdir)
31 | elif dataset == "cifar100":
32 | datasets.CIFAR100(root=data_path, train=True, download=True, transform=None)
33 | elif dataset == "flowers102":
34 | datasets.Flowers102(
35 | root=data_path, split="train", download=True, transform=None
36 | )
37 | elif dataset == "food101":
38 | datasets.Food101(root=data_path, split="train", download=True, transform=None)
39 |
40 |
41 | def get_dataset_size(dataset: str) -> int:
42 | """Return dataset size to compute the number of iterations per epoch."""
43 | if dataset == "imagenet":
44 | return 1281167
45 | elif dataset == "cifar100":
46 | return 50000
47 | elif dataset == "flowers102":
48 | return 1020
49 | elif dataset == "food101":
50 | return 75750
51 |
52 |
53 | def get_dataset_num_classes(dataset: str) -> int:
54 | """Return number of classes in a dataset."""
55 | if dataset == "imagenet":
56 | return 1000
57 | elif dataset == "cifar100":
58 | return 100
59 | elif dataset == "flowers102":
60 | return 102
61 | elif dataset == "food101":
62 | return 101
63 |
64 |
65 | def get_dataset(config: Dict[str, Any]) -> Tuple[Dataset, Dataset]:
66 | """Return data loaders for training and validation sets."""
67 | logging.info("Instantiating {} dataset.".format(config["dataset"]))
68 |
69 | dataset_name = config.get("dataset", None)
70 | if dataset_name is None:
71 | logging.error("Dataset name can't be None")
72 | dataset_name = dataset_name.lower()
73 |
74 | if dataset_name == "imagenet":
75 | return get_imagenet_dataset(config)
76 | elif dataset_name == "cifar100":
77 | return get_cifar100_dataset(config)
78 | elif dataset_name == "flowers102":
79 | return get_flowers102_dataset(config)
80 | elif dataset_name == "food101":
81 | return get_food101_dataset(config)
82 | else:
83 | raise NotImplementedError
84 |
85 |
86 | def get_cifar100_dataset(config) -> Tuple[Dataset, Dataset]:
87 | """Return training/test datasets for CIFAR-100 dataset.
88 |
89 | @TECHREPORT{Krizhevsky09learningmultiple,
90 | author = {Alex Krizhevsky},
91 | title = {Learning multiple layers of features from tiny images},
92 | institution = {},
93 | year = {2009}
94 | }
95 | """
96 | relative_path = config["data_path"]
97 | download_dataset = config.get("download_dataset", False)
98 |
99 | train_dataset = datasets.CIFAR100(
100 | root=relative_path,
101 | train=True,
102 | download=download_dataset,
103 | transform=transforms.compose_from_config(config["image_augmentation"]["train"]),
104 | )
105 |
106 | val_dataset = datasets.CIFAR100(
107 | root=relative_path,
108 | train=False,
109 | download=False,
110 | transform=transforms.compose_from_config(config["image_augmentation"]["val"]),
111 | )
112 | return train_dataset, val_dataset
113 |
114 |
115 | def get_imagenet_dataset(config) -> Tuple[Dataset, Dataset]:
116 | """Return training/validation datasets for ImageNet dataset."""
117 | traindir = os.path.join(config["data_path"], "training")
118 | valdir = os.path.join(config["data_path"], "validation")
119 |
120 | train_dataset = datasets.ImageFolder(
121 | traindir,
122 | transforms.compose_from_config(config["image_augmentation"]["train"]),
123 | )
124 |
125 | val_dataset = datasets.ImageFolder(
126 | valdir,
127 | transforms.compose_from_config(config["image_augmentation"]["val"]),
128 | )
129 | return train_dataset, val_dataset
130 |
131 |
132 | def get_flowers102_dataset(config) -> Tuple[Dataset, Dataset]:
133 | """Return training/test datasets for Flowers-102 dataset.
134 |
135 | @InProceedings{Nilsback08,
136 | author = "Nilsback, M-E. and Zisserman, A.",
137 | title = "Automated Flower Classification over a Large Number of Classes",
138 | booktitle = "Proceedings of the Indian Conference on Computer Vision, Graphics
139 | and Image Processing",
140 | year = "2008",
141 | month = "Dec"
142 | }
143 | """
144 | relative_path = config["data_path"]
145 | download_dataset = config.get("download_dataset", False)
146 |
147 | train_dataset = datasets.Flowers102(
148 | root=relative_path,
149 | split="train",
150 | download=download_dataset,
151 | transform=transforms.compose_from_config(config["image_augmentation"]["train"]),
152 | )
153 | val_dataset = datasets.Flowers102(
154 | root=relative_path,
155 | split="test",
156 | download=False,
157 | transform=transforms.compose_from_config(config["image_augmentation"]["val"]),
158 | )
159 | return train_dataset, val_dataset
160 |
161 |
162 | def get_food101_dataset(config) -> Tuple[Dataset, Dataset]:
163 | """Return training/test datasets for Food-101 dataset.
164 |
165 | @inproceedings{bossard14,
166 | title = {Food-101 -- Mining Discriminative Components with Random Forests},
167 | author = {Bossard, Lukas and Guillaumin, Matthieu and Van Gool, Luc},
168 | booktitle = {European Conference on Computer Vision},
169 | year = {2014}
170 | }
171 | """
172 | relative_path = config["data_path"]
173 | download_dataset = config.get("download_dataset", False)
174 |
175 | train_dataset = datasets.Food101(
176 | root=relative_path,
177 | split="train",
178 | download=download_dataset,
179 | transform=transforms.compose_from_config(config["image_augmentation"]["train"]),
180 | )
181 | val_dataset = datasets.Food101(
182 | root=relative_path,
183 | split="test",
184 | download=False,
185 | transform=transforms.compose_from_config(config["image_augmentation"]["val"]),
186 | )
187 | return train_dataset, val_dataset
188 |
--------------------------------------------------------------------------------
/dr/__init__.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023 Apple Inc. All rights reserved.
3 | #
4 |
5 |
--------------------------------------------------------------------------------
/dr/data.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023 Apple Inc. All rights reserved.
3 | #
4 | """Methods to initialize and create dataset loaders."""
5 | import os
6 | import random
7 | from typing import Union, Tuple, Any, List, Dict
8 |
9 | import torch
10 | from torch import Tensor
11 | import torch.nn.parallel
12 | import torch.optim
13 | import torch.utils.data
14 | import torchvision
15 | from torch.utils.data import Dataset
16 | import torch.utils.data.distributed
17 | import dr.transforms as T_dr
18 | from dr.utils import densify
19 | from data import get_dataset_num_classes
20 |
21 |
22 | class ReinforceMetadata:
23 | """A class to load and return only the metadata of the reinforced dataset."""
24 |
25 | def __init__(self, rdata_path: Union[str, List[str]]) -> None:
26 | """Iniatilize the metadata files and configuration."""
27 | self.rdata_path = rdata_path if isinstance(rdata_path, list) else [rdata_path]
28 | assert any(
29 | [os.path.exists(rp) for rp in rdata_path]
30 | ), f"Please download reinforce metadata to {rdata_path}."
31 | self.rconfig = self.get_rconfig()
32 |
33 | def __getitem__(self, index: int) -> Tuple[int, List[List[float]]]:
34 | """Return reinforced metadata for a single data point at given index."""
35 | p = random.randint(0, len(self.rdata_path) - 1)
36 | rdata = torch.load(os.path.join(self.rdata_path[p], "{}.pth.tar".format(index)))
37 | return rdata
38 |
39 | def get_rconfig(self) -> Dict[str, Any]:
40 | """Read configuration file for reinforcements."""
41 | num_samples = 0
42 | for rdata in self.rdata_path:
43 | rconfig = torch.load(os.path.join(rdata, "config.pth.tar"))
44 | num_samples += rconfig["reinforce"]["num_samples"]
45 | rconfig["reinforce"]["num_samples"] = num_samples
46 | return rconfig
47 |
48 |
49 | class ReinforcedDataset(Dataset):
50 | """A class to reinforce a given dataset with metadata in rdata_path."""
51 |
52 | def __init__(
53 | self,
54 | dataset: Dataset,
55 | rdata_path: Union[str, List[str]],
56 | config: Dict[str, Any],
57 | num_classes: int,
58 | ) -> None:
59 | """Initialize the metadata configuration and parameters."""
60 | self.ds = dataset
61 | self.num_classes = num_classes
62 | self.densify_method = config.get("densify", "zeros")
63 | self.p = config.get("p", 1.0) or 1.0
64 | self.config = config
65 |
66 | # Use specified data augmentations only if sampling the original data
67 | self.transform_orig = dataset.transform
68 | dataset.transform = None
69 |
70 | self.r_metadata = ReinforceMetadata(rdata_path)
71 | rconfig = self.r_metadata.get_rconfig()
72 |
73 | # Initialize transformations from config
74 | self.transforms = T_dr.compose_from_config(
75 | rconfig["reinforce"]["image_augmentation"]
76 | )
77 |
78 | def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor]:
79 | """Return the sample from the dataset with given index."""
80 | # With probability self.p sample reinforced data, otherwise sample original
81 | p = random.random()
82 |
83 | if p < self.p:
84 | # Load reinforcement meta data for the image
85 | rdata = self.r_metadata[index]
86 |
87 | # Choose a random reinforcement
88 | assert rdata[0] == index, "Index does not match the metadata index."
89 | rdata = rdata[1] # tuple (id, rdata)
90 | i = random.randint(0, len(rdata) - 1)
91 | rdata_sample = rdata[i]
92 | params, target = rdata_sample["params"], rdata_sample["prob"]
93 | if isinstance(params, list):
94 | params = self.transforms.decompress(params)
95 |
96 | # Load image
97 | img, _ = self.ds[index]
98 |
99 | # Load the image pair if reinforcement has mixup/cutmix
100 | img2 = None
101 | if "cutmix" in params or "mixup" in params:
102 | img2, _ = self.ds[params["id2"]]
103 |
104 | # Reapply augmentation
105 | img, reparams = self.transforms.reapply(img, params, img2)
106 | for k, v in reparams.items():
107 | assert v == params[k], "params changed."
108 | target = densify(target, self.num_classes, self.densify_method)
109 | else:
110 | # With probability 1-self.p sample the original data
111 | img, target = self.ds[index]
112 | if self.transform_orig:
113 | img = self.transform_orig(img)
114 | target = torch.nn.functional.one_hot(
115 | torch.tensor(target), num_classes=self.num_classes
116 | ).float()
117 | return img.detach(), target.detach()
118 |
119 | def __len__(self) -> int:
120 | """Return the number of samples in the dataset."""
121 | return len(self.ds)
122 |
123 |
124 | class DatasetWithParameters:
125 | """A wrapper to the PyTorch datasets that transformation parameters."""
126 |
127 | def __init__(
128 | self,
129 | dataset: torchvision.datasets.VisionDataset,
130 | transform: T_dr.Compose,
131 | num_samples: int,
132 | ) -> None:
133 | """Initialize and set the number of random crops per sample."""
134 | self.ds = dataset
135 | self.num_samples = num_samples
136 | self.transform = transform
137 | self.ds.transform = None
138 |
139 | def __getitem__(self, index: int) -> Tuple[Tensor, int, Tensor]:
140 | """Return multiple random transformations of a sample at given index.
141 |
142 | Args:
143 | index: An integer that is the unique ID of a sample in the dataset.
144 |
145 | Returns:
146 | A Tuple of (inputs, target, params) of shape:
147 | sample_all: [num_samples,]+sample.shape
148 | target: int
149 | params: A dictionary of parameters, each value is a Tensor of shape
150 | [num,_samples, ...]
151 | """
152 | sample, target = self.ds[index]
153 | sample_all, params_all = T_dr.before_collate_apply(
154 | sample, self.transform, self.num_samples
155 | )
156 | return sample_all, target, params_all
157 |
158 | def __len__(self) -> int:
159 | """Return the number of samples in the dataset."""
160 | return len(self.ds)
161 |
162 |
163 | class IndexedDataset(torch.utils.data.Dataset):
164 | """A wrapper to PyTorch datasets that returns an index."""
165 |
166 | def __init__(self, dataset: torch.utils.data.Dataset) -> None:
167 | """Set the dataset."""
168 | self.ds = dataset
169 |
170 | def __getitem__(self, index: int) -> Tuple[int, ...]:
171 | """Return a sample and the given index."""
172 | data = self.ds[index]
173 | return (index, *data)
174 |
175 | def __len__(self) -> int:
176 | """Return the number of samples in the dataset."""
177 | return len(self.ds)
178 |
--------------------------------------------------------------------------------
/dr/transforms.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023 Apple Inc. All rights reserved.
3 | #
4 |
5 | """Extending transformations from torchvision to be reproducible."""
6 |
7 | from collections import defaultdict
8 | from typing import List, OrderedDict, Union, Tuple, Optional, Any, Dict
9 |
10 | import torch
11 | from torch import Tensor
12 | import torch.nn.parallel
13 | import torch.optim
14 | import torch.utils.data
15 | import torch.utils.data.distributed
16 | import torchvision.transforms as T
17 | from torchvision.transforms import functional as F
18 | from torchvision.transforms.autoaugment import _apply_op
19 |
20 | import transforms
21 | from transforms import clean_config
22 |
23 |
24 | NO_PARAM = ()
25 | NO_PARAM_TYPE = Tuple
26 |
27 |
28 | class Compressible:
29 | """Base class for reproducible transformations with compressible parameters."""
30 |
31 | @staticmethod
32 | def compress_params(params: Any) -> Any:
33 | """Return compressed parameters."""
34 | return params
35 |
36 | @staticmethod
37 | def decompress_params(params: Any) -> Any:
38 | """Return decompressed parameters."""
39 | return params
40 |
41 |
42 | class Resize(T.Resize, Compressible):
43 | """Extending PyTorch's Resize to reapply a given transformation."""
44 |
45 | def forward(
46 | self, img: Tensor, params: Optional[torch.Size] = None
47 | ) -> Tuple[Tensor, torch.Size]:
48 | """Transform an image randomly or reapply based on given parameters."""
49 | img = super().forward(img)
50 | return img, self.size
51 |
52 |
53 | class CenterCrop(T.CenterCrop, Compressible):
54 | """Extending PyTorch's CenterCrop to reapply a given transformation."""
55 |
56 | def forward(
57 | self, img: Tensor, params: Optional[NO_PARAM_TYPE] = None
58 | ) -> Tuple[Tensor, NO_PARAM_TYPE]:
59 | """Transform an image randomly or reapply based on given parameters."""
60 | img = super().forward(img)
61 | # TODO: can we remove contiguous?
62 | img = img.contiguous()
63 | return img, NO_PARAM
64 |
65 |
66 | class RandomCrop(T.RandomCrop, Compressible):
67 | """Extending PyTorch's RandomCrop to reapply a given transformation."""
68 |
69 | def __init__(self, *args, **kwargs) -> None:
70 | """Initialize super and set last parameters to None."""
71 | super().__init__(*args, **kwargs)
72 | self.params = None
73 |
74 | def get_params(self, *args, **kwargs) -> Tuple[int, int, int, int]:
75 | """Return self.params or new transformation params if self.params not set."""
76 | if self.params is None:
77 | self.params = super().get_params(*args, **kwargs)
78 | return self.params
79 |
80 | def forward(
81 | self, img: Tensor, params: Optional[Tuple[int, int]] = None
82 | ) -> Tuple[Tensor, Tuple[int, int]]:
83 | """Transform an image randomly or reapply based on given parameters."""
84 | self.params = None
85 | if params is not None:
86 | # Add the constant value of size
87 | self.params = (params[0], params[1], self.size[0], self.size[1])
88 | img = super().forward(img)
89 | params = self.params
90 | return img, params[:2] # Return only [top, left], ie random parameters
91 |
92 |
93 | class RandomResizedCrop(T.RandomResizedCrop, Compressible):
94 | """Extending PyTorch's RandomResizedCrop to reapply a given transformation."""
95 |
96 | def __init__(self, *args, **kwargs) -> None:
97 | """Initialize super and set last parameters to None."""
98 | super().__init__(*args, **kwargs)
99 | self.params = None
100 |
101 | def get_params(self, *args, **kwargs) -> Tuple[int, int, int, int]:
102 | """Return self.params or new transformation params if self.params not set."""
103 | if self.params is None:
104 | self.params = super().get_params(*args, **kwargs)
105 | return self.params
106 |
107 | def forward(
108 | self, img: Tensor, params: Optional[Tuple[int, int, int, int]] = None
109 | ) -> Tuple[Tensor, Tuple[int, int, int, int]]:
110 | """Transform an image randomly or reapply based on given parameters."""
111 | self.params = None
112 | if params is not None:
113 | self.params = params
114 | img = super().forward(img)
115 | params = self.params
116 | return img, params
117 |
118 |
119 | class RandomHorizontalFlip(T.RandomHorizontalFlip, Compressible):
120 | """Extending PyTorch's RandomHorizontalFlip to reapply a given transformation."""
121 |
122 | def forward(
123 | self, img: Tensor, params: Optional[bool] = None
124 | ) -> Tuple[Tensor, bool]:
125 | """Transform an image randomly or reapply based on given parameters.
126 |
127 | Args:
128 | img (PIL Image or Tensor): Image to be flipped.
129 |
130 | Returns:
131 | PIL Image or Tensor: Randomly flipped image.
132 | """
133 | if params is None:
134 | # Randomly skip only if params=None
135 | params = torch.rand(1).item() < self.p
136 | if params:
137 | img = F.hflip(img)
138 | return img, params
139 |
140 |
141 | class RandAugment(T.RandAugment, Compressible):
142 | """Extending PyTorch's RandAugment to reapply a given transformation."""
143 |
144 | op_names = [
145 | "Identity",
146 | "ShearX",
147 | "ShearY",
148 | "TranslateX",
149 | "TranslateY",
150 | "Rotate",
151 | "Brightness",
152 | "Color",
153 | "Contrast",
154 | "Sharpness",
155 | "Posterize",
156 | "Solarize",
157 | "AutoContrast",
158 | "Equalize",
159 | ]
160 |
161 | def __init__(self, p: float = 1.0, *args, **kwargs) -> None:
162 | """Initialize RandAugment with probability p of augmentation.
163 |
164 | Args:
165 | p: The probability of applying transformation. A float in [0, 1.0].
166 | """
167 | super().__init__(*args, **kwargs)
168 | self.p = p
169 |
170 | def forward(
171 | self, img: Tensor, params: Optional[List[Tuple[str, float]]] = None, **kwargs
172 | ) -> Tuple[Tensor, List[Tuple[str, float]]]:
173 | """Transform an image randomly or reapply based on given parameters.
174 |
175 | Args:
176 | img (PIL Image or Tensor): Image to be transformed.
177 |
178 | Returns:
179 | PIL Image or Tensor: Transformed image.
180 | """
181 | fill = self.fill
182 | channels, height, width = F.get_dimensions(img)
183 | if isinstance(img, torch.Tensor):
184 | if isinstance(fill, (int, float)):
185 | fill = [float(fill)] * channels
186 | elif fill is not None:
187 | fill = [float(f) for f in fill]
188 |
189 | op_meta = self._augmentation_space(self.num_magnitude_bins, (height, width))
190 | if params is None:
191 | # Randomly skip only if params=None
192 | if torch.rand(1) > self.p:
193 | return img, None
194 |
195 | params = []
196 | for _ in range(self.num_ops):
197 | op_index = int(torch.randint(len(op_meta), (1,)).item())
198 | op_name = list(op_meta.keys())[op_index]
199 | magnitudes, signed = op_meta[op_name]
200 | magnitude = (
201 | float(magnitudes[self.magnitude].item())
202 | if magnitudes.ndim > 0
203 | else 0.0
204 | )
205 | if signed and torch.randint(2, (1,)):
206 | magnitude *= -1.0
207 | params += [(op_name, magnitude)]
208 |
209 | for i in range(self.num_ops):
210 | op_name, magnitude = params[i]
211 | img = _apply_op(
212 | img, op_name, magnitude, interpolation=self.interpolation, fill=fill
213 | )
214 |
215 | return img, params
216 |
217 | @staticmethod
218 | def compress_params(params: List[Tuple[str, float]]) -> List[Tuple[int, float]]:
219 | """Return compressed parameters."""
220 | if params is None:
221 | return None
222 | pc = []
223 | for p in params:
224 | pc += [(RandAugment.op_names.index(p[0]), p[1])]
225 | return pc
226 |
227 | @staticmethod
228 | def decompress_params(params: List[Tuple[int, float]]) -> List[Tuple[str, float]]:
229 | """Return decompressed parameters."""
230 | if params is None:
231 | return None
232 | pc = []
233 | for p in params:
234 | pc += [(RandAugment.op_names[p[0]], p[1])]
235 | return pc
236 |
237 |
238 | class RandomErasing(T.RandomErasing, Compressible):
239 | """Extending PyTorch's RandomErasing to reapply a given transformation."""
240 |
241 | def forward(
242 | self, img: Tensor, params: Optional[Tuple] = None, **kwargs
243 | ) -> Tuple[Tensor, Tuple]:
244 | """Transform an image randomly or reapply based on given parameters.
245 |
246 | Args:
247 | img (Tensor): Tensor image to be erased.
248 |
249 | Returns:
250 | img (Tensor): Erased Tensor image.
251 | """
252 | if params is None:
253 | # Randomly skip only if params=None
254 | if torch.rand(1) > self.p:
255 | return img, None
256 | x, y, h, w, _ = self.get_params(img, scale=self.scale, ratio=self.ratio)
257 | else:
258 | x, y, h, w = params
259 | # In early experiments F.erase used in pytorch's RE was very slow
260 | # TODO: verify that F.erase is still slower than assigning zeros
261 | if x != -1:
262 | img[:, x : x + h, y : y + w] = 0
263 | return img, (x, y, h, w)
264 |
265 |
266 | class Normalize(T.Normalize, Compressible):
267 | """PyTorch's Normalize with an extra dummy transformation parameter."""
268 |
269 | def forward(
270 | self, *args, params: Optional[NO_PARAM_TYPE] = None, **kwargs
271 | ) -> Tuple[Tensor, Tuple]:
272 | """Return normalized input and NO_PARAM as parameters."""
273 | x = super().forward(*args, **kwargs)
274 | return x, NO_PARAM
275 |
276 |
277 | class MixUp(transforms.MixUp, Compressible):
278 | """Extending MixUp to reapply a given transformation."""
279 |
280 | def __init__(self, *args, **kwargs) -> None:
281 | """Initialize super and set last parameters to None."""
282 | super().__init__(*args, **kwargs)
283 | self.params = None
284 |
285 | def get_params(self, *args, **kwargs) -> float:
286 | """Return self.params or new transformation params if self.params not set."""
287 | if self.params is None:
288 | self.params = super().get_params(*args, **kwargs)
289 | return self.params
290 |
291 | def forward(
292 | self,
293 | x: Tensor,
294 | x2: Tensor,
295 | y: Optional[Tensor] = None,
296 | y2: Optional[Tensor] = None,
297 | params: Dict[str, float] = None,
298 | ) -> Tuple[Tuple[Tensor, Tensor], Dict[str, float]]:
299 | """Transform an image randomly or reapply based on given parameters."""
300 | self.params = None
301 | if params is not None:
302 | self.params = params
303 | x, y = super().forward(x, x2, y, y2)
304 | params = self.params
305 | return (x, y), params
306 |
307 |
308 | class CutMix(transforms.CutMix, Compressible):
309 | """Extending CutMix to reapply a given transformation."""
310 |
311 | def __init__(self, *args, **kwargs) -> None:
312 | """Initialize super and set last parameters to None."""
313 | super().__init__(*args, **kwargs)
314 | self.params = None
315 |
316 | def get_params(self, *args, **kwargs) -> Tuple[float, Tuple[int, int, int, int]]:
317 | """Return self.params or new transformation params if self.params not set."""
318 | if self.params is None:
319 | self.params = super().get_params(*args, **kwargs)
320 | return self.params
321 |
322 | def forward(
323 | self,
324 | x: Tensor,
325 | x2: Tensor,
326 | y: Optional[Tensor] = None,
327 | y2: Optional[Tensor] = None,
328 | params: Dict[str, float] = None,
329 | ) -> Tuple[
330 | Tuple[Tensor, Tensor], Dict[str, Union[float, Tuple[int, int, int, int]]]
331 | ]:
332 | """Transform an image randomly or reapply based on given parameters."""
333 | self.params = None
334 | if params is not None:
335 | self.params = params
336 | x, y = super().forward(x, x2, y, y2)
337 | params = self.params
338 | return (x, y), params
339 |
340 | @staticmethod
341 | def compress_params(params: Any) -> Any:
342 | """Return compressed parameters."""
343 | if params is None:
344 | return None
345 | return [params[0]]+list(params[1])
346 |
347 | @staticmethod
348 | def decompress_params(params: Any) -> Any:
349 | """Return decompressed parameters."""
350 | if params is None:
351 | return None
352 | return params[0], tuple(params[1:])
353 |
354 |
355 | class ToUint8(torch.nn.Module, Compressible):
356 | """Convert float32 Tensor in range [0, 1] to uint8 [0, 255]."""
357 |
358 | def forward(self, img: Tensor, **kwargs) -> Tuple[Tensor, NO_PARAM_TYPE]:
359 | """Return uint8(img) and NO_PARAM as parameters."""
360 | if not isinstance(img, torch.Tensor):
361 | return img, NO_PARAM
362 | return (img * 255).to(torch.uint8), NO_PARAM
363 |
364 |
365 | class ToTensor(torch.nn.Module, Compressible):
366 | """Convert PIL to torch.Tensor or if Tensor uint8 [0, 255] to float32 [0, 1]."""
367 |
368 | def forward(self, img: Tensor, **kwargs) -> Tuple[Tensor, NO_PARAM_TYPE]:
369 | """Return tensor(img) and NO_PARAM as parameters."""
370 | if isinstance(img, torch.Tensor):
371 | """Return float32(img) and NO_PARAM as parameters."""
372 | return (img / 255.0).to(torch.float32), NO_PARAM
373 | return F.to_tensor(img), NO_PARAM
374 |
375 |
376 | # Transformations are composed according to the order below, not the order in config
377 | TRANSFORMATION_TO_NAME = OrderedDict(
378 | [
379 | ("uint8", ToUint8),
380 | ("resize", Resize),
381 | ("center_crop", CenterCrop),
382 | ("random_crop", RandomCrop),
383 | ("random_resized_crop", RandomResizedCrop),
384 | ("random_horizontal_flip", RandomHorizontalFlip),
385 | ("rand_augment", RandAugment),
386 | ("to_tensor", ToTensor),
387 | ("random_erase", RandomErasing), # TODO: fix the order of RE with transforms
388 | ("normalize", Normalize),
389 | ("mixup", MixUp),
390 | ("cutmix", CutMix),
391 | ]
392 | )
393 | # Only in datagen
394 | BEFORE_COLLATE_TRANSFORMS = [
395 | "uint8",
396 | "resize",
397 | "center_crop",
398 | "random_crop",
399 | "random_resized_crop",
400 | "to_tensor",
401 | ]
402 | NO_PARAM_TRANSFORMS = [
403 | "uint8",
404 | "center_crop",
405 | "to_tensor",
406 | "normalize",
407 | ]
408 |
409 |
410 | class Compose:
411 | """Compose a list of reproducible data transformations."""
412 |
413 | def __init__(self, transforms: List[Tuple[str, Compressible]]) -> None:
414 | """Initialize transformations."""
415 | self.transforms = transforms
416 |
417 | def has_random_resized_crop(self) -> bool:
418 | """Return True if RandomResizedCrop is one of the transformations."""
419 | return any([t.__class__ == RandomResizedCrop for _, t in self.transforms])
420 |
421 | def __call__(
422 | self,
423 | img: Tensor,
424 | img2: Tensor = None,
425 | after_collate: Optional[bool] = False,
426 | ) -> Tuple[Tensor, Dict[str, Any]]:
427 | """Apply a transformation to two images and return augmentation parameters.
428 |
429 | Args:
430 | img: A tensor to be transformed.
431 | params: Transformation parameters to be reapplied.
432 | img2: Second tensor to be used for mixing transformations.
433 |
434 | The value of `params` can be None or empty in 3 cases:
435 | 1) `params=None` in apply(): The value should be generated randomly,
436 | 2) `params=None` in reapply(): Transformation was randomly skipped during
437 | generation time,
438 | 3) `params=()`: Trasformation has no random parameters.
439 |
440 | Returns:
441 | A Tuple of a transformed image and a dictionary with transformation
442 | parameters.
443 | """
444 | params = dict()
445 | for t_name, t in self.transforms:
446 | if after_collate and (
447 | t_name in BEFORE_COLLATE_TRANSFORMS
448 | and t_name != "uint8"
449 | and t_name != "to_tensor"
450 | ):
451 | # Skip transformations applied in data loader
452 | pass
453 | elif t_name == "cutmix" or t_name == "mixup":
454 | # Mix images
455 | if img2 is not None:
456 | (img, _), p = t(img, img2)
457 | params[t_name] = p
458 | else:
459 | # Apply an augmentation to both images, skip img2 if no mixing
460 | img, p = t(img)
461 | if img2 is not None:
462 | img2, p2 = t(img2)
463 | p = (p, p2)
464 | params[t_name] = p
465 | return img, params
466 |
467 | def reapply(
468 | self, img: Tensor, params: Dict[str, Any], img2: Tensor = None
469 | ) -> Tuple[Tensor, Dict[str, Any]]:
470 | """Reapply transformations to an image given augmentation parameters.
471 |
472 | Args:
473 | img: A tensor to be transformed.
474 | params: Transformation parameters to be reapplied.
475 | img2: Second tensor to be used for mixing transformations.
476 |
477 | The value of `params` can be None or empty in 3 cases:
478 | 1) `params=None` in apply(): The value should be generated randomly,
479 | 2) `params=None` in reapply(): Transformation was randomly skipped during
480 | generation time,
481 | 3) `params=()`: Trasformation has no random parameters.
482 |
483 | Returns:
484 | A Tuple of a transformed image and a dictionary with transformation
485 | parameters.
486 | """
487 | for t_name, t in self.transforms:
488 | if t_name in params:
489 | if t_name == "cutmix" or t_name == "mixup":
490 | # Remix images
491 | if params[t_name] is not None:
492 | (img, _), _ = t(img, img2, params=params[t_name])
493 | else:
494 | # Reapply an augmentation to both images, skip img2 if no
495 | # mixing
496 | if params[t_name][0] is not None:
497 | img, _ = t(img, params=params[t_name][0])
498 | if img2 is not None and params[t_name][1] is not None:
499 | img2, _ = t(img2, params=params[t_name][1])
500 | return img, params
501 |
502 | def compress(self, params: Dict[str, Any]) -> List[Any]:
503 | """Compress augmentation parameters."""
504 | params_compressed = []
505 |
506 | no_pair = True
507 | # Save second pair id if mixup or cutmix enabled
508 | t_names = [t[0] for t in self.transforms]
509 | if "mixup" in t_names or "cutmix" in t_names:
510 | if (
511 | params.get("mixup", None) is not None
512 | or params.get("cutmix", None) is not None
513 | ):
514 | params_compressed += [params["id2"]]
515 | no_pair = False
516 | else:
517 | params_compressed += [None]
518 |
519 | # Save transformation parameters
520 | for t_name, t in self.transforms:
521 | p = params[t_name]
522 | if t_name in NO_PARAM_TRANSFORMS:
523 | pass
524 | elif t_name == "mixup" or t_name == "cutmix":
525 | params_compressed += [t.compress_params(p)]
526 | else:
527 | if no_pair:
528 | params_compressed += [t.compress_params(p[0])]
529 | else:
530 | params_compressed += [
531 | [t.compress_params(p[0]), t.compress_params(p[1])]
532 | ]
533 | return params_compressed
534 |
535 | def decompress(self, params_compressed: List[Any]) -> Dict[str, Any]:
536 | """Decompress augmentation parameters."""
537 | params = {}
538 |
539 | # Read second pair id if mixup or cutmix enabled
540 | t_names = [t[0] for t in self.transforms]
541 | no_pair = None
542 | if "mixup" in t_names or "cutmix" in t_names:
543 | no_pair = params_compressed[0]
544 | if no_pair is not None:
545 | params["id2"] = no_pair
546 | params_compressed = params_compressed[1:]
547 |
548 | # Read parameters for transformations with random parameters
549 | with_param_transforms = [(t_name, t)
550 | for t_name, t in self.transforms
551 | if t_name not in NO_PARAM_TRANSFORMS]
552 | for p, (t_name, t) in zip(params_compressed, with_param_transforms):
553 | if p is None:
554 | pass
555 | elif t_name == "mixup" or t_name == "cutmix":
556 | params[t_name] = t.decompress_params(p)
557 | else:
558 | if no_pair is not None and len(p) > 1:
559 | params[t_name] = (
560 | t.decompress_params(p[0]),
561 | t.decompress_params(p[1]),
562 | )
563 | else:
564 | params[t_name] = (t.decompress_params(p),)
565 |
566 | # Fill non-random transformations
567 | for t_name, t in self.transforms:
568 | if t_name in NO_PARAM_TRANSFORMS:
569 | params[t_name] = (NO_PARAM, NO_PARAM)
570 | return params
571 |
572 |
573 | def compose_from_config(config: Dict[str, Any]) -> Compose:
574 | """Initialize transformations given the dataset name and configurations.
575 |
576 | Args:
577 | config: A dictionary of augmentation parameters.
578 | """
579 | config = clean_config(config)
580 | transforms = []
581 | for t_name, t_class in TRANSFORMATION_TO_NAME.items():
582 | if t_name in config:
583 | # TODO: warn for every key in config_tr that was not used
584 | transforms += [(t_name, t_class(**config[t_name]))]
585 | return Compose(transforms)
586 |
587 |
588 | def before_collate_config(
589 | config: Dict[str, Dict[str, Any]]
590 | ) -> Dict[str, Dict[str, Any]]:
591 | """Return configs with resize/crop transformations to pass to data loader.
592 |
593 | Only transformations that cannot be applied after data collate are
594 | composed. For example, RandomResizedCrop has to be applied before collate
595 | To create tensors of similar shapes.
596 |
597 | Args:
598 | config: A dictionary of augmentation parameters.
599 | """
600 | return {k: v for k, v in config.items() if k in BEFORE_COLLATE_TRANSFORMS}
601 |
602 |
603 | def after_collate_config(
604 | config: Dict[str, Dict[str, Any]]
605 | ) -> Dict[str, Dict[str, Any]]:
606 | """Return configs after excluding transformations from `befor_collate_config`."""
607 | return {k: v for k, v in config.items() if k not in BEFORE_COLLATE_TRANSFORMS}
608 |
609 |
610 | def before_collate_apply(
611 | sample: Tensor, transform: Compose, num_samples: int
612 | ) -> Tuple[Tensor, Tensor]:
613 | """Return multiple samples applying the transformations.
614 |
615 | Args:
616 | sample: A single sample to be randomly transformed.
617 | transform: A list of transformations to be applied.
618 | num_samples: The number of random transformations to be generated.
619 |
620 | Returns:
621 | Random transformations of the input. Shape: [num_samples,]+sample.shape
622 | """
623 | sample_all = []
624 | params_all = defaultdict(list)
625 | for _ in range(num_samples):
626 | # [height, width, channels]
627 | # -> ([height_new, width_new, channels], Dict(str, Tuple))
628 | sample_new, params = transform(sample)
629 | sample_all.append(sample_new)
630 | for k, v in params.items():
631 | params_all[k].append(v)
632 |
633 | sample_all = torch.stack(sample_all, axis=0)
634 | for k in params_all.keys():
635 | params_all[k] = torch.tensor(params_all[k])
636 | return sample_all, params_all
637 |
--------------------------------------------------------------------------------
/dr/utils.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023 Apple Inc. All rights reserved.
3 | #
4 | """Utilities for training."""
5 | from typing import List, Tuple, Optional
6 |
7 | import torch
8 | import numpy as np
9 |
10 |
11 | def sparsify(p: np.array, k: int) -> Tuple[List[float], List[float]]:
12 | """Sparsify a probabilities vector to k top values."""
13 | ids = np.argpartition(p, -k)[-k:]
14 | return p[ids].tolist(), ids.tolist()
15 |
16 |
17 | def densify(
18 | sp: Tuple[List[float], List[float]],
19 | num_classes: int,
20 | method: Optional[str] = "smooth",
21 | ) -> torch.Tensor:
22 | """Densify a sparse probability vector."""
23 | if not isinstance(sp, list) and not isinstance(sp, tuple):
24 | return torch.tensor(sp)
25 | sp, ids = torch.tensor(sp[0]), sp[1]
26 | r = 1.0 - sp.sum() # Max with 0 is needed if sp.sum is close to 1.0
27 | # assert r >= 0.0, f"Sum of sparse probabilities ({r}) should be less than 1.0."
28 | if method == "zeros" or r < 0.0:
29 | p = torch.zeros(num_classes)
30 | p[ids] = torch.nn.functional.softmax(sp, dim=0)
31 | elif method == "smooth":
32 | p = torch.ones(num_classes) * r / (num_classes - len(sp))
33 | p[ids] = sp
34 | return p
35 |
--------------------------------------------------------------------------------
/figures/DR_illustration_wide.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/apple/ml-dr/a21113c779714b68bcc8b6518f2a4ad149b29f00/figures/DR_illustration_wide.pdf
--------------------------------------------------------------------------------
/figures/DR_illustration_wide.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/apple/ml-dr/a21113c779714b68bcc8b6518f2a4ad149b29f00/figures/DR_illustration_wide.png
--------------------------------------------------------------------------------
/figures/imagenet_RRC+RARE_accuracy_annotated.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/apple/ml-dr/a21113c779714b68bcc8b6518f2a4ad149b29f00/figures/imagenet_RRC+RARE_accuracy_annotated.pdf
--------------------------------------------------------------------------------
/figures/imagenet_RRC+RARE_accuracy_annotated.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/apple/ml-dr/a21113c779714b68bcc8b6518f2a4ad149b29f00/figures/imagenet_RRC+RARE_accuracy_annotated.png
--------------------------------------------------------------------------------
/models.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023 Apple Inc. All rights reserved.
3 | #
4 |
5 | """Methods to create, load, and ensemble models."""
6 | import os
7 | from typing import Any, Dict, Generator, List, Optional, Tuple, Union
8 | import yaml
9 | import logging
10 | import torch
11 | import torch.nn as nn
12 | from torch.nn import functional as F
13 | import torchvision.models as models
14 |
15 |
16 |
17 | def move_to_device(model: torch.nn.Module, config: Dict[str, Any]) -> torch.nn.Module:
18 | """Wrap model with DDP/DP if distributed, convert to CUDA if GPU set, else CPU."""
19 | if not torch.cuda.is_available():
20 | logging.info("using CPU, this will be slow")
21 | elif config["distributed"]:
22 | ngpus_per_node = torch.cuda.device_count()
23 | # For multiprocessing distributed, DistributedDataParallel constructor
24 | # should always set the single device scope, otherwise,
25 | # DistributedDataParallel will use all available devices.
26 | if config["gpu"] is not None:
27 | torch.cuda.set_device(config["gpu"])
28 | model.cuda(config["gpu"])
29 | # When using a single GPU per process and per
30 | # DistributedDataParallel, we need to divide the batch size
31 | # ourselves based on the total number of GPUs we have
32 | config["batch_size"] = int(config["batch_size"] / ngpus_per_node)
33 | config["workers"] = int(
34 | (config["workers"] + ngpus_per_node - 1) / ngpus_per_node
35 | )
36 | model = torch.nn.parallel.DistributedDataParallel(
37 | model, device_ids=[config["gpu"]]
38 | )
39 | else:
40 | model.cuda()
41 | # DistributedDataParallel will divide and allocate batch_size to
42 | # all available GPUs if device_ids are not set
43 | model = torch.nn.parallel.DistributedDataParallel(model)
44 | elif config["gpu"] is not None:
45 | torch.cuda.set_device(config["gpu"])
46 | model = model.cuda(config["gpu"])
47 | else:
48 | # DataParallel will divide and allocate batch_size to all available
49 | # GPUs
50 | model = torch.nn.DataParallel(model).cuda()
51 | return model
52 |
53 |
54 | def load_model(gpu: torch.device, config: Dict[str, Any]) -> torch.nn.Module:
55 | """Load a pretrained model or an ensemble of pretrained models."""
56 | if config.get("ensemble", False):
57 | # Load an ensemble from a checkpoint path
58 | config
59 | device = None
60 | if gpu is not None:
61 | device = "cuda:{}".format(gpu)
62 | # Load models
63 | members = torch.nn.ModuleList(load_ensemble(checkpoints_path, device))
64 | model = ClassificationEnsembleNet(members)
65 | elif config.get("timm_ensemble", False):
66 | import timm
67 | # Load an ensemble of Timm models
68 | model_names = config.get("name", None)
69 | if gpu is not None:
70 | torch.cuda.set_device(gpu)
71 | # Load pretrained models
72 | members = torch.nn.ModuleList()
73 | if not isinstance(model_names, list):
74 | model_names = model_names.split(",")
75 | for m in model_names:
76 | members += [timm.create_model(m, pretrained=True)]
77 | # Create Ensemble
78 | model = ClassificationEnsembleNet(members)
79 | elif config.get("checkpoint", None) is not None:
80 | # Load a single pretrained model
81 | checkpoint_path = config["checkpoint"]
82 | arch = config["arch"]
83 | model = load_from_local(checkpoint_path, arch, gpu)
84 | else:
85 | # Use default pretrained model from pytorch.
86 | model = models.__dict__[config["arch"]](pretrained=True)
87 | return model
88 |
89 |
90 | def load_from_local(checkpoint_path: str, arch: str, gpu: int) -> torch.nn.Module:
91 | """Load model from local path and move to GPU if gpu set."""
92 | teacher_model = models.__dict__[arch]()
93 |
94 | # Load from checkpoint
95 | if gpu is None:
96 | checkpoint = torch.load(checkpoint_path)
97 | else:
98 | # Map model to be loaded to specified single gpu.
99 | loc = "cuda:{}".format(gpu)
100 | checkpoint = torch.load(checkpoint_path, map_location=loc)
101 |
102 | # Strip module. from checkpoint
103 | ckpt = {k.replace("module.", ""): v for k, v in checkpoint["state_dict"].items()}
104 | teacher_model.load_state_dict(ckpt)
105 | logging.info("Loaded checkpoint {} for teacher".format(checkpoint_path))
106 | return teacher_model
107 |
108 |
109 | def create_model(arch: str, config: dict) -> torch.nn.Module:
110 | """Create models from CVNets/Timm/Torch."""
111 | if arch == "cvnets":
112 | import cvnets
113 | from cvnets import modeling_arguments
114 | import argparse
115 | # TODO: cvnets does not yet support easy model creation outside the library
116 | parser = argparse.ArgumentParser(description="")
117 | parser = modeling_arguments(parser)
118 | opts = parser.parse_args()
119 | config_dot = dict(convert_dict_to_dotted(config))
120 | for k, v in config_dot.items():
121 | if hasattr(opts, k):
122 | setattr(opts, k, v)
123 | setattr(opts, 'dataset.category', 'classification')
124 | model = cvnets.get_model(config_dot)
125 | elif arch == "timm":
126 | import timm
127 | model = timm.create_model(**config["model"])
128 | else:
129 | model = models.__dict__[arch]()
130 | logging.info(model)
131 | return model
132 |
133 |
134 | class ClassificationEnsembleNet(nn.Module):
135 | """Ensemble model for classification based on averaging."""
136 |
137 | def __init__(self, members: torch.nn.ModuleList) -> None:
138 | """Init ensemble."""
139 | super().__init__()
140 | self.members = members
141 |
142 | def forward(
143 | self, x: torch.Tensor, return_prob: bool = False, temperature: float = 1.0
144 | ) -> torch.Tensor:
145 | """Reduce function for classification using averaging."""
146 | output = 0
147 | for a_network in self.members:
148 | logits = a_network(x)
149 | prob = F.softmax(logits / temperature, dim=1)
150 | output = output + prob
151 |
152 | if return_prob:
153 | return output / float(len(self.members))
154 | return (output / float(len(self.members))).log()
155 |
156 |
157 | def init_model_from_ckpt(
158 | model: torch.nn.Module,
159 | ckpt_path: str,
160 | device: torch.device,
161 | strict_keys: Optional[bool] = False,
162 | ) -> torch.nn.Module:
163 | """Init a model from an already trained model.
164 |
165 | Args:
166 | model: the pytorch model object to be loaded.
167 | ckpt_path: path to a model checkpoint.
168 | device: the device to load data to. Note that
169 | the model could be saved from a different device.
170 | Here we transfer the paramters to the current given device. So,
171 | a model could be trained and saved on GPU, and be loaded on CPU,
172 | for example.
173 | strict_keys: If True keys in state_dict of both models should be
174 | identical.
175 | """
176 | ckpt = torch.load(ckpt_path, map_location=device)
177 | pretrained_dict = ckpt["state_dict"]
178 | # For incorrectly saved DataParallel models
179 | pretrained_dict = {k.replace("module.", ""): v for k, v in pretrained_dict.items()}
180 | model.load_state_dict(pretrained_dict, strict=strict_keys)
181 |
182 | return model
183 |
184 |
185 | def load_ensemble(checkpoints_path: str, device: torch.device) -> List[torch.nn.Module]:
186 | """Traverse all subdirs and load checkpoints."""
187 | models = list()
188 | for root, dirs, files in os.walk(checkpoints_path):
189 | dirs.sort()
190 | # a directory is a legitimate checkpoint directory if the root has a config.yaml
191 | if "config.yaml" in files:
192 | with open(os.path.join(root, "config.yaml")) as f:
193 | model_config = yaml.safe_load(f).get("parameters")
194 | arch = model_config["arch"]
195 | if (
196 | model_config.get("model", {})
197 | .get("classification", {})
198 | .get("pretrained", None)
199 | is not None
200 | ):
201 | model_config["model"]["classification"]["pretrained"] = None
202 | model = create_model(arch, model_config)
203 | ckpt_path = get_path_to_checkpoint(root)
204 | model = init_model_from_ckpt(model, ckpt_path, device)
205 | models.append(model)
206 | return models
207 |
208 |
209 | def get_path_to_checkpoint(artifact_dir: str, epoch=None) -> str:
210 | """Find checkpoint file path in an artifact directory.
211 |
212 | Args:
213 | artifact_dir: path to an experiment artifact directory,
214 | to laod checkpoints from there.
215 | epoch: If given tries to load that checkpoint, otherwise
216 | loads the latest. This function assumes checkpoints are saved
217 | as `checkpoint_epoch.tar'
218 | """
219 | ckpts_path = os.path.join(artifact_dir, "checkpoints")
220 | ckpts_list = os.listdir(ckpts_path)
221 | ckpts_dict = {
222 | int(ckpt.split("_")[1].split(".")[0]): os.path.join(ckpts_path, ckpt)
223 | for ckpt in ckpts_list
224 | }
225 | if len(ckpts_list) == 0:
226 | msg = "No checkpoint exists!"
227 | raise ValueError(msg)
228 | if epoch is not None:
229 | if epoch not in ckpts_dict.keys():
230 | msg = "Could not find checkpoint for epoch {} !"
231 | raise ValueError(msg.format(epoch))
232 | else:
233 | epoch = max(ckpts_dict.keys())
234 | return ckpts_dict[epoch]
235 |
236 |
237 | def convert_dict_to_dotted(
238 | c: Dict[str, Any],
239 | prefix: str = "",
240 | ) -> Generator[Union[Any, Tuple[str, Any]], None, None]:
241 | """Convert a nested dictionary of configs to flat dotted notation."""
242 | if isinstance(c, dict):
243 | prefix += "." if prefix != "" else ""
244 | for k, v in c.items():
245 | for x in convert_dict_to_dotted(v, prefix + k):
246 | yield x
247 | else:
248 | yield (prefix, c)
249 |
--------------------------------------------------------------------------------
/reinforce.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023 Apple Inc. All rights reserved.
3 | #
4 | """Reinforce a dataset given a pretrained teacher and save the metadata."""
5 |
6 | import os
7 | import time
8 | import argparse
9 | from typing import Any, Dict, Tuple
10 | import yaml
11 | import random
12 | import warnings
13 | import copy
14 | import joblib
15 |
16 | import torch
17 | import torch.nn.parallel
18 | import torch.distributed as dist
19 | import torch.backends.cudnn as cudnn
20 | import torch.multiprocessing as mp
21 | import torch.optim
22 | import torch.utils.data
23 | import torch.utils.data.distributed
24 | import dr.transforms as T_dr
25 | from dr.data import IndexedDataset, DatasetWithParameters
26 | from dr.utils import sparsify
27 | import internal
28 | import internal.hooks
29 | import internal.data
30 | from models import load_model
31 |
32 | from utils import ProgressMeter, AverageMeter
33 | from data import (
34 | download_dataset,
35 | get_dataset,
36 | get_dataset_num_classes,
37 | )
38 |
39 | import logging
40 |
41 | logging.basicConfig(format="%(asctime)s %(message)s", level=logging.INFO)
42 |
43 |
44 | def parse_args() -> argparse.Namespace:
45 | """Parse command-line arguments."""
46 | parser = argparse.ArgumentParser(description="PyTorch Reinforcing")
47 | parser.add_argument(
48 | "--config", type=str, required=False, help="Path to a yaml config file."
49 | )
50 | args = parser.parse_args()
51 | return args
52 |
53 |
54 | def main(args) -> None:
55 | """Reinforce a model with the configurations specified in given arguments."""
56 | if args.config is not None:
57 | # Read parameters from yaml config for local run
58 | yaml_path = args.config
59 | with open(yaml_path, "r") as file:
60 | config = yaml.safe_load(file).get("parameters")
61 |
62 | dataset = config.get("dataset", "imagenet")
63 | config["num_classes"] = get_dataset_num_classes(dataset)
64 |
65 | if config["seed"] is not None:
66 | random.seed(config["seed"])
67 | torch.manual_seed(config["seed"])
68 | cudnn.deterministic = True
69 | cudnn.benchmark = False
70 | warnings.warn(
71 | "You have chosen to seed training. "
72 | "This will turn on the CUDNN deterministic setting, "
73 | "which can slow down your training considerably! "
74 | "You may see unexpected behavior when restarting "
75 | "from checkpoints."
76 | )
77 |
78 | ngpus_per_node = torch.cuda.device_count()
79 | if config["gpu"] is not None:
80 | warnings.warn(
81 | "You have chosen a specific GPU. This will completely "
82 | "disable data parallelism."
83 | )
84 | ngpus_per_node = 1
85 |
86 | if config["dist_url"] == "env://" and config["world_size"] == -1:
87 | config["world_size"] = int(os.environ["WORLD_SIZE"])
88 |
89 | config["distributed"] = (
90 | config["world_size"] > 1 or config["multiprocessing_distributed"]
91 | )
92 |
93 | if config["download_data"]:
94 | config["data_path"] = download_dataset(config["data_path"], config["dataset"])
95 | if config["multiprocessing_distributed"]:
96 | # Since we have ngpus_per_node processes per node, the total world_size
97 | # needs to be adjusted accordingly
98 | config["world_size"] = ngpus_per_node * config["world_size"]
99 | # Use torch.multiprocessing.spawn to launch distributed processes: the
100 | # main_worker process function
101 | mp.spawn(
102 | main_worker,
103 | nprocs=ngpus_per_node,
104 | args=(ngpus_per_node, copy.deepcopy(config)),
105 | )
106 |
107 | else:
108 | # Simply call main_worker function
109 | main_worker(config["gpu"], ngpus_per_node, copy.deepcopy(config))
110 |
111 | rdata_ext_path = os.path.join("/mnt/reinforced_data")
112 | fnames = next(os.walk(rdata_ext_path))[2]
113 | max_id = max(
114 | [
115 | int(f.replace(".pth.tar", "").replace(".jb", ""))
116 | for f in fnames
117 | if (".pth.tar" in f or ".jb" in f) and f != "config.pth.tar"
118 | ]
119 | )
120 | assert max_id + 1 == len(fnames), "Saved {} files but max id is {}.".format(
121 | len(fnames), max_id + 1
122 | )
123 | logging.info("Saving new data.")
124 | torch.save(config, os.path.join(rdata_ext_path, "config.pth.tar"))
125 | artifact_path = config["artifact_path"]
126 | gzip = config["reinforce"]["gzip"]
127 | cmd = "tar c{}f {} -C {} reinforced_data".format(
128 | "z" if gzip else "",
129 | os.path.join(
130 | artifact_path, "reinforced_data.tar{}".format(".gz" if gzip else "")
131 | ),
132 | "/mnt/",
133 | )
134 | os.system(cmd)
135 |
136 |
137 | def main_worker(gpu: int, ngpus_per_node: int, config: Dict[str, Any]) -> None:
138 | """Reinforce data with a single process. In distributed training, run on one GPU."""
139 | config["gpu"] = gpu
140 |
141 | if config["gpu"] is not None:
142 | logging.info("Use GPU: {} for datagen".format(config["gpu"]))
143 |
144 | if config["distributed"]:
145 | if config["dist_url"] == "env://" and config["rank"] == -1:
146 | config["rank"] = int(os.environ["RANK"])
147 | if config["multiprocessing_distributed"]:
148 | # For multiprocessing distributed training, rank needs to be the
149 | # global rank among all the processes
150 | config["rank"] = config["rank"] * ngpus_per_node + gpu
151 | dist.init_process_group(
152 | backend=config["dist_backend"],
153 | init_method=config["dist_url"],
154 | world_size=config["world_size"],
155 | rank=config["rank"],
156 | )
157 |
158 | # Disable logging on all workers except rank=0
159 | if config["multiprocessing_distributed"] and config["rank"] % ngpus_per_node != 0:
160 | logging.basicConfig(
161 | format="%(asctime)s %(message)s", level=logging.WARNING, force=True
162 | )
163 |
164 | model = load_model(gpu=config["gpu"], config=config["teacher"])
165 | if config["gpu"] is not None:
166 | torch.cuda.set_device(config["gpu"])
167 | model = model.cuda(config["gpu"])
168 | else:
169 | model.cuda()
170 |
171 | # Set model to eval
172 | model.eval()
173 |
174 | # Print number of model parameters
175 | logging.info(
176 | "Number of model parameters: {}".format(
177 | sum(p.numel() for p in model.parameters() if p.requires_grad)
178 | )
179 | )
180 |
181 | # Data loading code
182 | logging.info("Instantiating dataset.")
183 | config["image_augmentation"] = {"train": {}, "val": {}}
184 | train_dataset, _ = get_dataset(config)
185 | num_samples = config["reinforce"]["num_samples"]
186 | # Set transforms to None and wrap the dataset
187 | train_dataset.transform = None
188 | tr = T_dr.compose_from_config(
189 | T_dr.before_collate_config(config["reinforce"]["image_augmentation"])
190 | )
191 | train_dataset = DatasetWithParameters(
192 | train_dataset, transform=tr, num_samples=num_samples
193 | )
194 |
195 | train_dataset = IndexedDataset(train_dataset)
196 | if config["distributed"]:
197 | # When using a single GPU per process and per
198 | # DistributedDataParallel, we need to divide the batch size
199 | # ourselves based on the total number of GPUs we have
200 | config["batch_size"] = int(config["batch_size"] / ngpus_per_node)
201 | config["workers"] = int(
202 | (config["workers"] + ngpus_per_node - 1) / ngpus_per_node
203 | )
204 | # Shuffle=True is better for mixing reinforcements
205 | train_sampler = torch.utils.data.distributed.DistributedSampler(
206 | train_dataset, shuffle=True, drop_last=False
207 | )
208 | train_sampler.set_epoch(0)
209 | else:
210 | train_sampler = None
211 |
212 | train_loader = torch.utils.data.DataLoader(
213 | train_dataset,
214 | batch_size=config["batch_size"],
215 | shuffle=(train_sampler is None),
216 | drop_last=False,
217 | num_workers=config["workers"],
218 | pin_memory=config["pin_memory"],
219 | sampler=train_sampler,
220 | )
221 |
222 | reinforce(train_loader, model, config)
223 |
224 |
225 | def reinforce(
226 | train_loader: torch.utils.data.DataLoader,
227 | model: torch.nn.Module,
228 | config: Dict[str, Any],
229 | ) -> None:
230 | """Generate reinforcements and save in individual files per training sample."""
231 | batch_time = AverageMeter("Time", ":6.3f")
232 | progress = ProgressMeter(len(train_loader), [batch_time])
233 |
234 | num_samples = config["reinforce"]["num_samples"]
235 | transforms = T_dr.compose_from_config(config["reinforce"]["image_augmentation"])
236 |
237 | rdata_ext_path = os.path.join("/mnt/reinforced_data")
238 | os.makedirs(rdata_ext_path, exist_ok=True)
239 |
240 | with torch.no_grad():
241 | end = time.time()
242 | for batch_i, data in enumerate(train_loader):
243 | ids, images, target, coords = data
244 | if config["gpu"] is not None:
245 | images = images.cuda(config["gpu"], non_blocking=True)
246 | target = target.cuda(config["gpu"], non_blocking=True)
247 |
248 | # Apply transformations
249 | images_aug, params_aug = transform_batch(
250 | ids, images, target, coords, transforms, config
251 | )
252 | images_aug = images_aug.reshape((-1,) + images.shape[2:])
253 |
254 | # Compute output
255 | prob = model(images_aug, return_prob=True)
256 | prob = prob.reshape((images.shape[0], num_samples, -1))
257 | prob = prob.cpu().numpy()
258 | for j in range(images.shape[0]):
259 | new_samples = [
260 | {
261 | "params": params_aug[j][k],
262 | "prob": sparsify(prob[j][k], config["reinforce"]["topk"]),
263 | }
264 | for k in range(num_samples)
265 | ]
266 | fname = os.path.join(rdata_ext_path, "{}.pth.tar".format(int(ids[j])))
267 | if not os.path.exists(fname):
268 | if not config["reinforce"]["joblib"]:
269 | # protocol=4 gives ~1.3x compression vs proto=1
270 | torch.save((int(ids[j]), new_samples), fname, pickle_protocol=4)
271 | else:
272 | # joblib gives ~2.2x compression vs torch.save(proto=4)
273 | # but 1.6x slower to save and 7x slower to load
274 | joblib.dump((int(ids[j]), new_samples), fname, compress=True)
275 |
276 | # Measure elapsed time
277 | batch_time.update(time.time() - end)
278 | end = time.time()
279 |
280 | if batch_i % config["print_freq"] == 0:
281 | progress.display(batch_i)
282 |
283 | time.sleep(0.01) # Preventing broken pipe error
284 |
285 |
286 | def transform_batch(
287 | ids: torch.Tensor,
288 | images: torch.Tensor,
289 | target: torch.Tensor,
290 | coords: torch.Tensor,
291 | transforms: T_dr.Compose,
292 | config: Dict[str, Any],
293 | ) -> Tuple[torch.Tensor, torch.Tensor]:
294 | """Apply image transformations to a batch and return parameters.
295 |
296 | Args:
297 | ids: Image indices. Shape: [batch_size, 1]
298 | images: Image crops. If random-resized-crop is enabled, `imgaes` has
299 | multiple random crop.
300 | Shape without RRC: [batch_size, n_channels, crop_height, crop_width],
301 | Shape with RRC: [batch_size, n_samples, n_channels, crop_height, crop_width]
302 | target: Ground-truth labels. Shape: [batch_size, 1]
303 | coords: Random-resized-crop coordinates. Shape: [batch_size, n_samples, 4]
304 | transforms: A list of transformations to be applied randomly and stored.
305 | config: A dictionary of configurations with `reinforce` key.
306 |
307 | Returns:
308 | A tuple of transformed images and transformation parameters.
309 | Shape [0]: [batch_size, n_samples, n_channels, crop_height, crop_width]
310 | Shape [1]: A list of size [batch_size, n_samples].
311 | """
312 | num_samples = config["reinforce"]["num_samples"]
313 | images_aug = []
314 | params_aug = []
315 | for i in range(images.shape[0]):
316 | images_aug += [[]]
317 | params_aug += [[]]
318 | for j in range(num_samples):
319 | # Choose a sample with pre-collate transformation parameters
320 | img = images[i, j]
321 |
322 | # Sample an image pair
323 | index = random.randint(0, images.shape[0] - 1)
324 | j2 = random.randint(0, images.shape[1] - 1)
325 | img2 = images[index, j2]
326 | id2 = int(ids[index])
327 |
328 | # Apply remaining transformations after collate
329 | img, params = transforms(img, img2, after_collate=True)
330 |
331 | # Update parameters with before collate transformations
332 | params0 = {
333 | k: (tuple(v[i, j].tolist()), tuple(v[index, j2].tolist()))
334 | for k, v in coords.items()
335 | }
336 | params.update(params0)
337 |
338 | if "mixup" in params or "cutmix" in params:
339 | params["id2"] = id2
340 | if config["reinforce"]["compress"]:
341 | params = transforms.compress(params)
342 | images_aug[-1] += [img]
343 | params_aug[-1] += [params]
344 | images_aug[-1] = torch.stack(images_aug[-1], axis=0)
345 | images_aug = torch.stack(images_aug, axis=0)
346 | return images_aug, params_aug
347 |
348 |
349 | if __name__ == "__main__":
350 | args = parse_args()
351 | main(args)
352 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy
2 | scikit-image
3 | torch
4 | torchvision
5 | timm
6 |
--------------------------------------------------------------------------------
/results/table_E150.md:
--------------------------------------------------------------------------------
1 | | Name | Mode | Parameters | ImageNet | ImageNet+ | ImageNet | ImageNet+ | ImageNet Links | ImageNet+ Links |
2 | |:-------------|:---------|:-------------|:------------------------------------------------|:-----------------------------------------------------|:------------------------------------------------|:-----------------------------------------------------|:---|:---|
3 | | MobileNetV1 | 0.25 | 0.5M | 55.2 | 55.4 | 55.4 | 55.4 | [[best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/mobilenetv1_0.25_E150/best.pt) [[ema_best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/mobilenetv1_0.25_E150/ema_best.pt) [[config.yaml]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/mobilenetv1_0.25_E150/config.yaml) [[metrics.jb]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/mobilenetv1_0.25_E150/metrics.jb) | [[best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/mobilenetv1_0.25_E150/best.pt) [[ema_best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/mobilenetv1_0.25_E150/ema_best.pt) [[config.yaml]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/mobilenetv1_0.25_E150/config.yaml) [[metrics.jb]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/mobilenetv1_0.25_E150/metrics.jb) |
4 | | MobileNetV1 | 0.5 | 1.3M | 66.2 | 67.1 | 66.4 | 67.1 | [[best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/mobilenetv1_0.5_E150/best.pt) [[ema_best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/mobilenetv1_0.5_E150/ema_best.pt) [[config.yaml]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/mobilenetv1_0.5_E150/config.yaml) [[metrics.jb]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/mobilenetv1_0.5_E150/metrics.jb) | [[best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/mobilenetv1_0.5_E150/best.pt) [[ema_best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/mobilenetv1_0.5_E150/ema_best.pt) [[config.yaml]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/mobilenetv1_0.5_E150/config.yaml) [[metrics.jb]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/mobilenetv1_0.5_E150/metrics.jb) |
5 | | MobileNetV1 | 1.0 | 4.3M | 73.5 | 75.1 | 73.6 | 75.1 | [[best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/mobilenetv1_1.0_E150/best.pt) [[ema_best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/mobilenetv1_1.0_E150/ema_best.pt) [[config.yaml]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/mobilenetv1_1.0_E150/config.yaml) [[metrics.jb]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/mobilenetv1_1.0_E150/metrics.jb) | [[best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/mobilenetv1_1.0_E150/best.pt) [[ema_best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/mobilenetv1_1.0_E150/ema_best.pt) [[config.yaml]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/mobilenetv1_1.0_E150/config.yaml) [[metrics.jb]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/mobilenetv1_1.0_E150/metrics.jb) |
6 | | MobileNetV2 | 0.25 | 1.5M | 54.7 | 54.2 | 54.7 | 54.2 | [[best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/mobilenetv2_0.25_E150/best.pt) [[ema_best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/mobilenetv2_0.25_E150/ema_best.pt) [[config.yaml]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/mobilenetv2_0.25_E150/config.yaml) [[metrics.jb]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/mobilenetv2_0.25_E150/metrics.jb) | [[best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/mobilenetv2_0.25_E150/best.pt) [[ema_best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/mobilenetv2_0.25_E150/ema_best.pt) [[config.yaml]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/mobilenetv2_0.25_E150/config.yaml) [[metrics.jb]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/mobilenetv2_0.25_E150/metrics.jb) |
7 | | MobileNetV2 | 0.5 | 2.0M | 65.7 | 65.7 | 65.7 | 65.7 | [[best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/mobilenetv2_0.5_E150/best.pt) [[ema_best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/mobilenetv2_0.5_E150/ema_best.pt) [[config.yaml]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/mobilenetv2_0.5_E150/config.yaml) [[metrics.jb]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/mobilenetv2_0.5_E150/metrics.jb) | [[best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/mobilenetv2_0.5_E150/best.pt) [[ema_best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/mobilenetv2_0.5_E150/ema_best.pt) [[config.yaml]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/mobilenetv2_0.5_E150/config.yaml) [[metrics.jb]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/mobilenetv2_0.5_E150/metrics.jb) |
8 | | MobileNetV2 | 1.0 | 3.5M | 72.8 | 73.8 | 72.8 | 73.9 | [[best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/mobilenetv2_1.0_E150/best.pt) [[ema_best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/mobilenetv2_1.0_E150/ema_best.pt) [[config.yaml]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/mobilenetv2_1.0_E150/config.yaml) [[metrics.jb]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/mobilenetv2_1.0_E150/metrics.jb) | [[best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/mobilenetv2_1.0_E150/best.pt) [[ema_best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/mobilenetv2_1.0_E150/ema_best.pt) [[config.yaml]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/mobilenetv2_1.0_E150/config.yaml) [[metrics.jb]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/mobilenetv2_1.0_E150/metrics.jb) |
9 | | MobileNetV3 | small | 2.5M | 66.6 | 67.7 | 66.7 | 67.7 | [[best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/mobilenetv3_small_E150/best.pt) [[ema_best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/mobilenetv3_small_E150/ema_best.pt) [[config.yaml]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/mobilenetv3_small_E150/config.yaml) [[metrics.jb]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/mobilenetv3_small_E150/metrics.jb) | [[best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/mobilenetv3_small_E150/best.pt) [[ema_best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/mobilenetv3_small_E150/ema_best.pt) [[config.yaml]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/mobilenetv3_small_E150/config.yaml) [[metrics.jb]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/mobilenetv3_small_E150/metrics.jb) |
10 | | MobileNetV3 | large | 5.5M | 74.7 | 76.5 | 74.8 | 76.5 | [[best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/mobilenetv3_large_E150/best.pt) [[ema_best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/mobilenetv3_large_E150/ema_best.pt) [[config.yaml]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/mobilenetv3_large_E150/config.yaml) [[metrics.jb]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/mobilenetv3_large_E150/metrics.jb) | [[best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/mobilenetv3_large_E150/best.pt) [[ema_best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/mobilenetv3_large_E150/ema_best.pt) [[config.yaml]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/mobilenetv3_large_E150/config.yaml) [[metrics.jb]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/mobilenetv3_large_E150/metrics.jb) |
11 | | MobileViT | xx_small | 1.3M | 66.0 | 67.4 | 66.7 | 67.9 | [[best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/mobilevit_xx_small_E150/best.pt) [[ema_best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/mobilevit_xx_small_E150/ema_best.pt) [[config.yaml]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/mobilevit_xx_small_E150/config.yaml) [[metrics.jb]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/mobilevit_xx_small_E150/metrics.jb) | [[best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/mobilevit_xx_small_E150/best.pt) [[ema_best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/mobilevit_xx_small_E150/ema_best.pt) [[config.yaml]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/mobilevit_xx_small_E150/config.yaml) [[metrics.jb]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/mobilevit_xx_small_E150/metrics.jb) |
12 | | MobileViT | x_small | 2.3M | 72.6 | 74.0 | 73.3 | 74.7 | [[best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/mobilevit_x_small_E150/best.pt) [[ema_best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/mobilevit_x_small_E150/ema_best.pt) [[config.yaml]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/mobilevit_x_small_E150/config.yaml) [[metrics.jb]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/mobilevit_x_small_E150/metrics.jb) | [[best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/mobilevit_x_small_E150/best.pt) [[ema_best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/mobilevit_x_small_E150/ema_best.pt) [[config.yaml]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/mobilevit_x_small_E150/config.yaml) [[metrics.jb]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/mobilevit_x_small_E150/metrics.jb) |
13 | | MobileViT | small | 5.6M | 76.3 | 78.3 | 76.7 | 78.6 | [[best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/mobilevit_small_E150/best.pt) [[ema_best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/mobilevit_small_E150/ema_best.pt) [[config.yaml]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/mobilevit_small_E150/config.yaml) [[metrics.jb]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/mobilevit_small_E150/metrics.jb) | [[best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/mobilevit_small_E150/best.pt) [[ema_best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/mobilevit_small_E150/ema_best.pt) [[config.yaml]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/mobilevit_small_E150/config.yaml) [[metrics.jb]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/mobilevit_small_E150/metrics.jb) |
14 | | ResNet | 18 | 11.7M | 69.9 | 73.2 | 69.8 | 73.2 | [[best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/resnet_18_E150/best.pt) [[ema_best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/resnet_18_E150/ema_best.pt) [[config.yaml]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/resnet_18_E150/config.yaml) [[metrics.jb]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/resnet_18_E150/metrics.jb) | [[best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/resnet_18_E150/best.pt) [[ema_best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/resnet_18_E150/ema_best.pt) [[config.yaml]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/resnet_18_E150/config.yaml) [[metrics.jb]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/resnet_18_E150/metrics.jb) |
15 | | ResNet | 34 | 21.8M | 74.6 | 76.9 | 74.7 | 76.9 | [[best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/resnet_34_E150/best.pt) [[ema_best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/resnet_34_E150/ema_best.pt) [[config.yaml]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/resnet_34_E150/config.yaml) [[metrics.jb]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/resnet_34_E150/metrics.jb) | [[best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/resnet_34_E150/best.pt) [[ema_best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/resnet_34_E150/ema_best.pt) [[config.yaml]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/resnet_34_E150/config.yaml) [[metrics.jb]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/resnet_34_E150/metrics.jb) |
16 | | ResNet | 50 | 25.6M | 79.0 | 80.3 | 79.1 | 80.3 | [[best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/resnet_50_E150/best.pt) [[ema_best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/resnet_50_E150/ema_best.pt) [[config.yaml]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/resnet_50_E150/config.yaml) [[metrics.jb]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/resnet_50_E150/metrics.jb) | [[best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/resnet_50_E150/best.pt) [[ema_best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/resnet_50_E150/ema_best.pt) [[config.yaml]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/resnet_50_E150/config.yaml) [[metrics.jb]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/resnet_50_E150/metrics.jb) |
17 | | ResNet | 101 | 44.6M | 80.5 | 81.8 | 80.5 | 81.9 | [[best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/resnet_101_E150/best.pt) [[ema_best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/resnet_101_E150/ema_best.pt) [[config.yaml]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/resnet_101_E150/config.yaml) [[metrics.jb]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/resnet_101_E150/metrics.jb) | [[best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/resnet_101_E150/best.pt) [[ema_best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/resnet_101_E150/ema_best.pt) [[config.yaml]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/resnet_101_E150/config.yaml) [[metrics.jb]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/resnet_101_E150/metrics.jb) |
18 | | ResNet | 152 | 60.3M | 81.3 | 82.2 | 81.3 | 82.3 | [[best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/resnet_152_E150/best.pt) [[ema_best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/resnet_152_E150/ema_best.pt) [[config.yaml]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/resnet_152_E150/config.yaml) [[metrics.jb]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/resnet_152_E150/metrics.jb) | [[best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/resnet_152_E150/best.pt) [[ema_best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/resnet_152_E150/ema_best.pt) [[config.yaml]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/resnet_152_E150/config.yaml) [[metrics.jb]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/resnet_152_E150/metrics.jb) |
19 | | EfficientNet | b2 | 9.3M | 79.5 | 81.5 | 79.5 | 81.6 | [[best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/efficientnet_b2_E150/best.pt) [[ema_best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/efficientnet_b2_E150/ema_best.pt) [[config.yaml]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/efficientnet_b2_E150/config.yaml) [[metrics.jb]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/efficientnet_b2_E150/metrics.jb) | [[best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/efficientnet_b2_E150/best.pt) [[ema_best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/efficientnet_b2_E150/ema_best.pt) [[config.yaml]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/efficientnet_b2_E150/config.yaml) [[metrics.jb]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/efficientnet_b2_E150/metrics.jb) |
20 | | EfficientNet | b3 | 12.4M | 80.9 | 82.4 | 80.8 | 82.5 | [[best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/efficientnet_b3_E150/best.pt) [[ema_best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/efficientnet_b3_E150/ema_best.pt) [[config.yaml]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/efficientnet_b3_E150/config.yaml) [[metrics.jb]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/efficientnet_b3_E150/metrics.jb) | [[best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/efficientnet_b3_E150/best.pt) [[ema_best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/efficientnet_b3_E150/ema_best.pt) [[config.yaml]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/efficientnet_b3_E150/config.yaml) [[metrics.jb]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/efficientnet_b3_E150/metrics.jb) |
21 | | EfficientNet | b4 | 19.7M | 82.7 | 83.6 | 82.7 | 83.7 | [[best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/efficientnet_b4_E150/best.pt) [[ema_best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/efficientnet_b4_E150/ema_best.pt) [[config.yaml]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/efficientnet_b4_E150/config.yaml) [[metrics.jb]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/efficientnet_b4_E150/metrics.jb) | [[best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/efficientnet_b4_E150/best.pt) [[ema_best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/efficientnet_b4_E150/ema_best.pt) [[config.yaml]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/efficientnet_b4_E150/config.yaml) [[metrics.jb]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/efficientnet_b4_E150/metrics.jb) |
22 | | ViT | tiny | 5.6M | 72.1 | 74.3 | 72.1 | 74.4 | [[best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/vit_tiny_E150/best.pt) [[ema_best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/vit_tiny_E150/ema_best.pt) [[config.yaml]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/vit_tiny_E150/config.yaml) [[metrics.jb]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/vit_tiny_E150/metrics.jb) | [[best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/vit_tiny_E150/best.pt) [[ema_best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/vit_tiny_E150/ema_best.pt) [[config.yaml]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/vit_tiny_E150/config.yaml) [[metrics.jb]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/vit_tiny_E150/metrics.jb) |
23 | | ViT | small | 21.9M | 78.4 | 79.8 | 78.7 | 79.9 | [[best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/vit_small_E150/best.pt) [[ema_best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/vit_small_E150/ema_best.pt) [[config.yaml]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/vit_small_E150/config.yaml) [[metrics.jb]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/vit_small_E150/metrics.jb) | [[best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/vit_small_E150/best.pt) [[ema_best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/vit_small_E150/ema_best.pt) [[config.yaml]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/vit_small_E150/config.yaml) [[metrics.jb]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/vit_small_E150/metrics.jb) |
24 | | ViT | base | 86.7M | 79.5 | 81.7 | 80.6 | 81.7 | [[best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/vit_base_E150/best.pt) [[ema_best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/vit_base_E150/ema_best.pt) [[config.yaml]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/vit_base_E150/config.yaml) [[metrics.jb]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/vit_base_E150/metrics.jb) | [[best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/vit_base_E150/best.pt) [[ema_best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/vit_base_E150/ema_best.pt) [[config.yaml]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/vit_base_E150/config.yaml) [[metrics.jb]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/vit_base_E150/metrics.jb) |
25 | | ViT-384 | base | 86.7M | 80.5 | 83.0 | 81.9 | 83.1 | [[best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/vit-384_base_E150/best.pt) [[ema_best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/vit-384_base_E150/ema_best.pt) [[config.yaml]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/vit-384_base_E150/config.yaml) [[metrics.jb]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/vit-384_base_E150/metrics.jb) | [[best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/vit-384_base_E150/best.pt) [[ema_best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/vit-384_base_E150/ema_best.pt) [[config.yaml]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/vit-384_base_E150/config.yaml) [[metrics.jb]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/vit-384_base_E150/metrics.jb) |
26 | | Swin | tiny | 28.3M | 80.5 | 82.1 | 80.3 | 81.9 | [[best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/swin_tiny_E150/best.pt) [[ema_best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/swin_tiny_E150/ema_best.pt) [[config.yaml]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/swin_tiny_E150/config.yaml) [[metrics.jb]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/swin_tiny_E150/metrics.jb) | [[best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/swin_tiny_E150/best.pt) [[ema_best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/swin_tiny_E150/ema_best.pt) [[config.yaml]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/swin_tiny_E150/config.yaml) [[metrics.jb]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/swin_tiny_E150/metrics.jb) |
27 | | Swin | small | 49.7M | 82.2 | 83.6 | 81.9 | 83.3 | [[best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/swin_small_E150/best.pt) [[ema_best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/swin_small_E150/ema_best.pt) [[config.yaml]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/swin_small_E150/config.yaml) [[metrics.jb]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/swin_small_E150/metrics.jb) | [[best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/swin_small_E150/best.pt) [[ema_best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/swin_small_E150/ema_best.pt) [[config.yaml]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/swin_small_E150/config.yaml) [[metrics.jb]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/swin_small_E150/metrics.jb) |
28 | | Swin | base | 87.8M | 82.7 | 83.9 | 82.2 | 83.7 | [[best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/swin_base_E150/best.pt) [[ema_best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/swin_base_E150/ema_best.pt) [[config.yaml]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/swin_base_E150/config.yaml) [[metrics.jb]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/swin_base_E150/metrics.jb) | [[best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/swin_base_E150/best.pt) [[ema_best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/swin_base_E150/ema_best.pt) [[config.yaml]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/swin_base_E150/config.yaml) [[metrics.jb]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/swin_base_E150/metrics.jb) |
29 | | Swin-384 | base | 87.8M | 82.6 | 83.2 | 82.4 | 83.0 | [[best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/swin-384_base_E150/best.pt) [[ema_best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/swin-384_base_E150/ema_best.pt) [[config.yaml]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/swin-384_base_E150/config.yaml) [[metrics.jb]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-cvnets/swin-384_base_E150/metrics.jb) | [[best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/swin-384_base_E150/best.pt) [[ema_best.pt]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/swin-384_base_E150/ema_best.pt) [[config.yaml]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/swin-384_base_E150/config.yaml) [[metrics.jb]](https://docs-assets.developer.apple.com/ml-research/models/dr/imagenet-plus-cvnets/swin-384_base_E150/metrics.jb) |
30 |
--------------------------------------------------------------------------------
/tests/test_reinforce.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023 Apple Inc. All rights reserved.
3 | #
4 |
5 | """Tests for dataset generation."""
6 |
7 | from typing import Any, Dict
8 | import pytest
9 | import torch
10 |
11 | import dr.transforms as T_dr
12 | from reinforce import transform_batch
13 |
14 | # from torchvision.transforms import InterpolationMode, functional as F
15 |
16 |
17 | def single_transform_config_test(
18 | config: Dict[str, Any],
19 | num_images: int,
20 | num_samples: int,
21 | height: int,
22 | width: int,
23 | crop_height: int,
24 | crop_width: int,
25 | compress: bool,
26 | single_image_test: bool,
27 | ) -> None:
28 | """Test applying and reapplying transformations to a batch of images."""
29 | ids = torch.arange(num_images)
30 | images_orig = torch.rand(size=(num_images, 3, height, width))
31 | images = torch.zeros(num_images, num_samples, 3, crop_height, crop_width)
32 | target = torch.randint(low=0, high=10, size=(num_images,))
33 | # Simulate data loader
34 | transforms = T_dr.compose_from_config(
35 | T_dr.before_collate_config(config["reinforce"]["image_augmentation"])
36 | )
37 |
38 | images = []
39 | coords = []
40 | for i in range(num_images):
41 | sample_all, params_all = T_dr.before_collate_apply(
42 | images_orig[i], transforms, num_samples
43 | )
44 | images += [sample_all]
45 | coords += [params_all]
46 | images = torch.utils.data.default_collate(images)
47 | coords = torch.utils.data.default_collate(coords)
48 |
49 | # Apply after collate transformations
50 | transforms = T_dr.compose_from_config(config["reinforce"]["image_augmentation"])
51 | images_aug, params_aug = transform_batch(
52 | ids, images, target, coords, transforms, config
53 | )
54 | assert images_aug.shape == (
55 | num_images,
56 | num_samples,
57 | 3,
58 | crop_height,
59 | crop_width,
60 | ), "Incorrect shape of transformed image."
61 | transforms = T_dr.compose_from_config(config["reinforce"]["image_augmentation"])
62 |
63 | # Single test
64 | # TODO: support transforms(img) and reapply(img, params) without img2
65 | if single_image_test:
66 | img, param = transforms(images_orig[0], images_orig[0])
67 | img2, param2 = transforms.reapply(images_orig[0], param)
68 | torch.testing.assert_close(
69 | actual=img,
70 | expected=img2,
71 | )
72 |
73 | if num_images > 1 and num_samples > 1:
74 | assert (
75 | len(set([str(q) for p in params_aug for q in p])) > 1
76 | ), "Parameters are not random."
77 |
78 | # Test reapply
79 | for i in range(num_images):
80 | for j in range(num_samples):
81 | if compress:
82 | params_aug[i][j] = transforms.decompress(params_aug[i][j])
83 | img2 = None
84 | if "mixup" in params_aug[i][j] or "cutmix" in params_aug[i][j]:
85 | img2 = images_orig[params_aug[i][j]["id2"]]
86 | out, _ = transforms.reapply(images_orig[i], params_aug[i][j], img2)
87 | torch.testing.assert_close(actual=out, expected=images_aug[i][j])
88 |
89 |
90 | @pytest.mark.parametrize("compress", [False, True])
91 | def test_random_resized_crop(compress: bool) -> None:
92 | """Test RRC with other transformations."""
93 | num_images = 2
94 | num_samples = 4
95 | height = 10
96 | width = 10
97 | crop_height = 5
98 | crop_width = 5
99 | single_image_test = True
100 | config = {
101 | "reinforce": {
102 | "num_samples": num_samples,
103 | "compress": compress,
104 | "image_augmentation": {
105 | "uint8": {"enable": True},
106 | "random_resized_crop": {
107 | "enable": True,
108 | "size": [crop_height, crop_width],
109 | },
110 | "random_horizontal_flip": {"enable": True, "p": 0.5},
111 | "rand_augment": {"enable": True, "p": 0.5},
112 | "to_tensor": {"enable": True},
113 | "normalize": {
114 | "enable": True,
115 | "mean": [0.485, 0.456, 0.406],
116 | "std": [0.229, 0.224, 0.225],
117 | },
118 | "random_erase": {"enable": True, "p": 0.25},
119 | },
120 | }
121 | }
122 | single_transform_config_test(
123 | config,
124 | num_images,
125 | num_samples,
126 | height,
127 | width,
128 | crop_height,
129 | crop_width,
130 | compress,
131 | single_image_test,
132 | )
133 |
134 |
135 | @pytest.mark.parametrize("compress", [False, True])
136 | def test_center_crop(compress: bool) -> None:
137 | """Test center-crop with other transformations."""
138 | num_images = 2
139 | num_samples = 4
140 | height = 10
141 | width = 10
142 | crop_height = 5
143 | crop_width = 5
144 | single_image_test = True
145 | config = {
146 | "reinforce": {
147 | "num_samples": num_samples,
148 | "compress": compress,
149 | "image_augmentation": {
150 | "uint8": {"enable": True},
151 | "center_crop": {"enable": True, "size": [crop_height, crop_width]},
152 | "random_horizontal_flip": {"enable": True, "p": 0.5},
153 | "rand_augment": {"enable": True, "p": 0.5},
154 | "to_tensor": {"enable": True},
155 | "normalize": {
156 | "enable": True,
157 | "mean": [0.485, 0.456, 0.406],
158 | "std": [0.229, 0.224, 0.225],
159 | },
160 | "random_erase": {"enable": True, "p": 0.25},
161 | },
162 | }
163 | }
164 | single_transform_config_test(
165 | config,
166 | num_images,
167 | num_samples,
168 | height,
169 | width,
170 | crop_height,
171 | crop_width,
172 | compress,
173 | single_image_test,
174 | )
175 |
176 |
177 | @pytest.mark.parametrize("compress", [False, True])
178 | def test_mixing(compress: bool) -> None:
179 | """Test MixUp/CutMix with other transformations."""
180 | num_images = 10
181 | num_samples = 20
182 | height = 256
183 | width = 256
184 | crop_height = 224
185 | crop_width = 224
186 | single_image_test = False
187 | config = {
188 | "reinforce": {
189 | "num_samples": num_samples,
190 | "compress": compress,
191 | "image_augmentation": {
192 | "uint8": {"enable": True},
193 | "random_resized_crop": {
194 | "enable": True,
195 | "size": [crop_height, crop_width],
196 | },
197 | "random_horizontal_flip": {"enable": True, "p": 0.5},
198 | "rand_augment": {"enable": True, "p": 0.5},
199 | "to_tensor": {"enable": True},
200 | "normalize": {
201 | "enable": True,
202 | "mean": [0.485, 0.456, 0.406],
203 | "std": [0.229, 0.224, 0.225],
204 | },
205 | "random_erase": {"enable": True, "p": 0.25},
206 | "mixup": {"enable": True, "alpha": 1.0, "p": 0.5},
207 | "cutmix": {"enable": True, "alpha": 1.0, "p": 0.5},
208 | },
209 | }
210 | }
211 | single_transform_config_test(
212 | config,
213 | num_images,
214 | num_samples,
215 | height,
216 | width,
217 | crop_height,
218 | crop_width,
219 | compress,
220 | single_image_test,
221 | )
222 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023 Apple Inc. All rights reserved.
3 | #
4 | """Modification of Pytorch ImageNet training code to handle additional datasets."""
5 | import argparse
6 | import os
7 | import random
8 | import shutil
9 | import warnings
10 | import yaml
11 | import copy
12 | from typing import List, Dict, Any
13 |
14 | import torch
15 | import torch.nn.parallel
16 | import torch.backends.cudnn as cudnn
17 | import torch.distributed as dist
18 | import torch.optim
19 | import torch.multiprocessing as mp
20 | import torch.utils.data
21 | import torch.utils.data.distributed
22 |
23 | from trainers import get_trainer
24 | from utils import CosineLR
25 | from data import (
26 | download_dataset,
27 | get_dataset,
28 | get_dataset_num_classes,
29 | get_dataset_size,
30 | )
31 |
32 | from dr.data import ReinforcedDataset
33 |
34 | import logging
35 |
36 | logging.basicConfig(format="%(asctime)s %(message)s", level=logging.INFO)
37 |
38 | best_acc1 = 0
39 |
40 |
41 | def parse_args() -> argparse.Namespace:
42 | """Parse command-line arguments."""
43 | parser = argparse.ArgumentParser(description="PyTorch Training")
44 | parser.add_argument(
45 | "--config", type=str, required=False, help="Path to a yaml config file."
46 | )
47 | args = parser.parse_args()
48 | return args
49 |
50 |
51 | def get_trainable_parameters(
52 | model: torch.nn.Module,
53 | weight_decay: float,
54 | no_decay_bn_filter_bias: bool,
55 | *args,
56 | **kwargs
57 | ) -> List[Dict[str, List[torch.nn.Parameter]]]:
58 | """Return trainable model parameters excluding biases and normalization layers.
59 |
60 | Args:
61 | model: The Torch model to be trained.
62 | weight_decay: The weight decay coefficient.
63 | no_decay_bn_filter_bias: If True exclude biases and normalization layer params.
64 |
65 | Returns:
66 | A list with two dictionaries for parameters with and without weight decay.
67 | """
68 | with_decay = []
69 | without_decay = []
70 | for _, param in model.named_parameters():
71 | if param.requires_grad and len(param.shape) == 1 and no_decay_bn_filter_bias:
72 | # biases and normalization layer parameters are of len 1
73 | without_decay.append(param)
74 | elif param.requires_grad:
75 | with_decay.append(param)
76 | param_list = [{"params": with_decay, "weight_decay": weight_decay}]
77 | if len(without_decay) > 0:
78 | param_list.append({"params": without_decay, "weight_decay": 0.0})
79 | return param_list
80 |
81 |
82 | def get_optimizer(
83 | model: torch.nn.Module, config: Dict[str, Any]
84 | ) -> torch.optim.Optimizer:
85 | """Initialize an optimizer with parameters of a model.
86 |
87 | Args:
88 | model: The model to be trained.
89 | config: A dictionary of optimizer hyperparameters and configurations. The
90 | configuration should at least have a `name`.
91 |
92 | Returns:
93 | A Torch optimizer.
94 | """
95 | optim_name = config["name"]
96 |
97 | params = get_trainable_parameters(
98 | model,
99 | weight_decay=config.get("weight_decay", 0.0),
100 | no_decay_bn_filter_bias=config.get("no_decay_bn_filter_bias", False),
101 | )
102 |
103 | if optim_name == "sgd":
104 | optimizer = torch.optim.SGD(
105 | params=params,
106 | lr=config["lr"],
107 | momentum=config["momentum"],
108 | nesterov=config.get("nesterov", True),
109 | )
110 | elif optim_name == "adamw":
111 | optimizer = torch.optim.AdamW(
112 | params=params,
113 | lr=config["lr"],
114 | betas=(config["beta1"], config["beta2"]),
115 | )
116 | else:
117 | raise NotImplementedError
118 | return optimizer
119 |
120 |
121 | def main(args) -> None:
122 | """Train a model with the configurations specified in given arguments."""
123 | if args.config is not None:
124 | # Read parameters from yaml config for local run
125 | yaml_path = args.config
126 | with open(yaml_path, "r") as file:
127 | config = yaml.safe_load(file).get("parameters")
128 |
129 | dataset = config.get("dataset", "imagenet")
130 | config["num_classes"] = get_dataset_num_classes(dataset)
131 |
132 | # Print args before training
133 | logging.info("Training args: {}".format(config))
134 |
135 | if config["seed"] is not None:
136 | random.seed(config["seed"])
137 | torch.manual_seed(config["seed"])
138 | cudnn.deterministic = True
139 | cudnn.benchmark = False
140 | warnings.warn(
141 | "You have chosen to seed training. "
142 | "This will turn on the CUDNN deterministic setting, "
143 | "which can slow down your training considerably! "
144 | "You may see unexpected behavior when restarting "
145 | "from checkpoints."
146 | )
147 |
148 | ngpus_per_node = torch.cuda.device_count()
149 | if config["gpu"] is not None:
150 | warnings.warn(
151 | "You have chosen a specific GPU. This will completely "
152 | "disable data parallelism."
153 | )
154 | ngpus_per_node = 1
155 |
156 | if config["dist_url"] == "env://" and config["world_size"] == -1:
157 | config["world_size"] = int(os.environ["WORLD_SIZE"])
158 |
159 | config["distributed"] = (
160 | config["world_size"] > 1 or config["multiprocessing_distributed"]
161 | )
162 |
163 | if config["download_data"]:
164 | config["data_path"] = download_dataset(config["data_path"], config["dataset"])
165 |
166 | if config["multiprocessing_distributed"]:
167 | # Since we have ngpus_per_node processes per node, the total world_size
168 | # needs to be adjusted accordingly
169 | config["world_size"] = ngpus_per_node * config["world_size"]
170 | # Use torch.multiprocessing.spawn to launch distributed processes: the
171 | # main_worker process function
172 | mp.spawn(
173 | main_worker,
174 | nprocs=ngpus_per_node,
175 | args=(ngpus_per_node, copy.deepcopy(config)),
176 | )
177 | else:
178 | # Simply call main_worker function
179 | main_worker(config["gpu"], ngpus_per_node, copy.deepcopy(config))
180 |
181 |
182 | def main_worker(gpu: int, ngpus_per_node: int, config: Dict[str, Any]) -> None:
183 | """Train a model with a single process. In distributed training, run on one GPU."""
184 | global best_acc1
185 | config["gpu"] = gpu
186 |
187 | if config["gpu"] is not None:
188 | logging.info("Use GPU: {} for training".format(config["gpu"]))
189 |
190 | if config["distributed"]:
191 | if config["dist_url"] == "env://" and config["rank"] == -1:
192 | config["rank"] = int(os.environ["RANK"])
193 | if config["multiprocessing_distributed"]:
194 | # For multiprocessing distributed training, rank needs to be the
195 | # global rank among all the processes
196 | config["rank"] = config["rank"] * ngpus_per_node + gpu
197 | dist.init_process_group(
198 | backend=config["dist_backend"],
199 | init_method=config["dist_url"],
200 | world_size=config["world_size"],
201 | rank=config["rank"],
202 | )
203 | # Disable logging on all workers except rank=0
204 | if config["multiprocessing_distributed"] and config["rank"] % ngpus_per_node != 0:
205 | logging.basicConfig(
206 | format="%(asctime)s %(message)s", level=logging.WARNING, force=True
207 | )
208 |
209 | # Initialize a trainer class that handles different models of training such as
210 | # standard training (ERM), Knowledge distillation, and Dataset Reinforcement
211 | trainer = get_trainer(config)
212 |
213 | # Create the model to train. Also create and load the teacher model for KD
214 | model = trainer.get_model()
215 |
216 | # Print number of model parameters
217 | logging.info(
218 | "Number of model parameters: {}".format(
219 | sum(p.numel() for p in model.parameters() if p.requires_grad)
220 | )
221 | )
222 |
223 | # Define loss function (criterion) and optimizer
224 | criterion = trainer.get_criterion()
225 | logging.info("Criterion: {}".format(criterion))
226 |
227 | # Define optimizer
228 | optimizer = get_optimizer(model, config["optim"])
229 | logging.info("Optimizer: {}".format(optimizer))
230 |
231 | dataset_size = get_dataset_size(config["dataset"])
232 | # Compute warmup and total iterations on each gpu
233 | warmup_length = (
234 | config["optim"]["warmup_length"]
235 | * dataset_size
236 | // config["batch_size"]
237 | // ngpus_per_node
238 | )
239 | total_steps = (
240 | config["epochs"] * dataset_size // config["batch_size"] // ngpus_per_node
241 | )
242 | lr_scheduler = CosineLR(
243 | optimizer,
244 | warmup_length=warmup_length,
245 | total_steps=total_steps,
246 | lr=config["optim"]["lr"],
247 | end_lr=config["optim"].get("end_lr", 0.0),
248 | )
249 |
250 | # Resume from a checkpoint
251 | if config["resume"]:
252 | if os.path.isfile(config["resume"]):
253 | logging.info("=> loading checkpoint '{}'".format(config["resume"]))
254 | if config["gpu"] is None:
255 | checkpoint = torch.load(config["resume"])
256 | else:
257 | # Map model to be loaded to specified single gpu.
258 | loc = "cuda:{}".format(config["gpu"])
259 | checkpoint = torch.load(config["resume"], map_location=loc)
260 | config["start_epoch"] = checkpoint["epoch"]
261 | best_acc1 = checkpoint["best_acc1"]
262 | if hasattr(model, "module"):
263 | model.module.load_state_dict(checkpoint["state_dict"])
264 | else:
265 | model.load_state_dict(checkpoint["state_dict"])
266 | optimizer.load_state_dict(checkpoint["optimizer"])
267 | lr_scheduler.load_state_dict(checkpoint["scheduler"])
268 | logging.info(
269 | "=> loaded checkpoint '{}' (epoch {})".format(
270 | config["resume"], checkpoint["epoch"]
271 | )
272 | )
273 | else:
274 | logging.info("=> no checkpoint found at '{}'".format(config["resume"]))
275 |
276 | # Data loading code
277 | logging.info("Instantiating dataset.")
278 | train_dataset, val_dataset = get_dataset(config)
279 | logging.info("Training Dataset: {}".format(train_dataset))
280 | logging.info("Validation Dataset: {}".format(val_dataset))
281 |
282 | if config["trainer"] == "DR":
283 | # Reinforce dataset
284 | train_dataset = ReinforcedDataset(
285 | train_dataset,
286 | rdata_path=config["reinforce"]["data_path"],
287 | config=config["reinforce"],
288 | num_classes=config["num_classes"],
289 | )
290 |
291 | if config["distributed"]:
292 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
293 | val_sampler = torch.utils.data.distributed.DistributedSampler(
294 | val_dataset, shuffle=False, drop_last=True
295 | )
296 | else:
297 | train_sampler = None
298 | val_sampler = None
299 |
300 | train_loader = torch.utils.data.DataLoader(
301 | train_dataset,
302 | batch_size=config["batch_size"],
303 | shuffle=(train_sampler is None),
304 | num_workers=config["workers"],
305 | pin_memory=config["pin_memory"],
306 | persistent_workers=config['persistent_workers'],
307 | sampler=train_sampler,
308 | )
309 |
310 | val_loader = torch.utils.data.DataLoader(
311 | val_dataset,
312 | batch_size=config["batch_size"],
313 | shuffle=False,
314 | num_workers=config["workers"],
315 | pin_memory=config["pin_memory"],
316 | persistent_workers=config['persistent_workers'],
317 | sampler=val_sampler,
318 | )
319 |
320 | if config["evaluate"]:
321 | # Evaluate a pretrained model without training
322 | trainer.validate(val_loader, model, criterion, config)
323 | return
324 |
325 | # Evaluate a pretrained teacher before training
326 | trainer.validate_pretrained(val_loader, model, criterion, config)
327 |
328 | # Start training
329 | for epoch in range(config["start_epoch"], config["epochs"]):
330 | if config["distributed"]:
331 | train_sampler.set_epoch(epoch)
332 |
333 | # Train for one epoch
334 | train_metrics = trainer.train(
335 | train_loader,
336 | model,
337 | criterion,
338 | optimizer,
339 | epoch,
340 | config,
341 | lr_scheduler,
342 | )
343 |
344 | # Evaluate on validation set
345 | val_metrics = trainer.validate(val_loader, model, criterion, config)
346 | val_acc1 = val_metrics["val_accuracy@top1"]
347 |
348 | # remember best acc@1 and save checkpoint
349 | is_best = val_acc1 > best_acc1
350 | best_acc1 = max(val_acc1, best_acc1)
351 |
352 | if not config["multiprocessing_distributed"] or (
353 | config["multiprocessing_distributed"]
354 | and config["rank"] % ngpus_per_node == 0
355 | ):
356 | metrics = dict(train_metrics)
357 | metrics.update(val_metrics)
358 | if isinstance(model, torch.nn.DataParallel) or isinstance(
359 | model, torch.nn.parallel.DistributedDataParallel
360 | ):
361 | model_state_dict = model.module.state_dict()
362 | else:
363 | model_state_dict = model.state_dict()
364 | is_save_epoch = (epoch + 1) % config.get("save_freq", 1) == 0
365 | is_save_epoch = is_save_epoch or ((epoch + 1) == config["epochs"])
366 | save_checkpoint(
367 | {
368 | "epoch": epoch + 1,
369 | "config": config,
370 | "state_dict": model_state_dict,
371 | "best_acc1": best_acc1,
372 | "optimizer": optimizer.state_dict(),
373 | "scheduler": lr_scheduler.state_dict(),
374 | },
375 | is_best,
376 | is_save_epoch,
377 | config["artifact_path"],
378 | )
379 |
380 |
381 | def save_checkpoint(state, is_best, is_save_epoch, artifact_path) -> None:
382 | """Save checkpoint and update the best model."""
383 | fname = os.path.join(artifact_path, "checkpoint.pth.tar")
384 | torch.save(state, fname)
385 |
386 | if is_save_epoch:
387 | checkpoint_fname = os.path.join(
388 | artifact_path, "checkpoint_{}.pth.tar".format(state["epoch"])
389 | )
390 | shutil.copyfile(fname, checkpoint_fname)
391 | if is_best:
392 | best_model_fname = os.path.join(artifact_path, "model_best.pth.tar")
393 | shutil.copyfile(fname, best_model_fname)
394 |
395 |
396 | if __name__ == "__main__":
397 | args = parse_args()
398 | main(args)
399 |
--------------------------------------------------------------------------------
/trainers.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023 Apple Inc. All rights reserved.
3 | #
4 | """Training methods for ERM, Knowledge Distillation, and Dataset Reinforcement."""
5 | from abc import ABC
6 | import time
7 | import logging
8 | from typing import Callable, Dict, Any, Type
9 |
10 | import torch
11 | from torch import Tensor
12 | import torch.nn as nn
13 | from torch.utils.data import DataLoader, Subset
14 | from torch.nn import functional as F
15 |
16 | from utils import AverageMeter, CosineLR, ProgressMeter, Summary, accuracy
17 | from models import move_to_device, load_model, create_model
18 | from transforms import MixingTransforms
19 |
20 |
21 | class Trainer(ABC):
22 | """Abstract class for various training methodologies."""
23 |
24 | def get_model(self) -> nn.Module:
25 | """Create and initialize the model to train using self.config."""
26 | raise NotImplementedError("Implement `get_model` to initialize a model.")
27 |
28 | def get_criterion(self) -> nn.Module:
29 | """Return the training criterion."""
30 | raise NotImplementedError("Implement `get_criterion`.")
31 |
32 | def train(
33 | self,
34 | train_loader: DataLoader,
35 | model: torch.nn.Module,
36 | criterion: torch.nn.Module,
37 | optimizer: torch.optim.Optimizer,
38 | epoch: int,
39 | config: Dict[str, Any],
40 | lr_scheduler: Type[CosineLR],
41 | ) -> Dict[str, Any]:
42 | """Train a model for a single epoch and return training metrics dictionary."""
43 | raise NotImplementedError("Implement `train` method.")
44 |
45 | def validate_pretrained(self, *args, **kwargs) -> None:
46 | """Validate pretrained teacher model."""
47 | pass
48 |
49 | def validate(self, *args, **kwargs) -> Dict[str, Any]:
50 | """Validate the model that is being trained and return a metrics dictionary."""
51 | return validate(*args, **kwargs)
52 |
53 |
54 | def get_trainer(config: Dict[str, Any]) -> Trainer:
55 | """Initialize a trainer given a configuration dictionary."""
56 | trainer_type = config["trainer"]
57 | if trainer_type == "ERM":
58 | return ERMTrainer(config)
59 | elif trainer_type == "KD":
60 | return KDTrainer(config)
61 | elif trainer_type == "DR":
62 | return ReinforcedTrainer(config)
63 | raise NotImplementedError("Trainer not implemented.")
64 |
65 |
66 | class ERMTrainer(Trainer):
67 | """Trainer class for Empirical Risk Minimization (ERM) with cross-entropy."""
68 |
69 | def __init__(self, config: Dict[str, Any]) -> None:
70 | """Initialize ERMTrainer."""
71 | self.config = config
72 | self.label_smoothing = config["loss"].get("label_smoothing", 0.0)
73 |
74 | def get_model(self) -> torch.nn.Module:
75 | """Create and initialize the model to train using self.config."""
76 | arch = self.config["arch"]
77 | model = create_model(arch, self.config)
78 | model = move_to_device(model, self.config)
79 | return model
80 |
81 | def get_criterion(self) -> torch.nn.Module:
82 | """Return the training criterion."""
83 | criterion = nn.CrossEntropyLoss(label_smoothing=self.label_smoothing).cuda(
84 | self.config["gpu"]
85 | )
86 | return criterion
87 |
88 | def train(
89 | self,
90 | train_loader: DataLoader,
91 | model: torch.nn.Module,
92 | criterion: torch.nn.Module,
93 | optimizer: torch.optim.Optimizer,
94 | epoch: int,
95 | config: Dict[str, Any],
96 | lr_scheduler: Type[CosineLR],
97 | ) -> Dict[str, Any]:
98 | """Train a model for a single epoch and return training metrics dictionary."""
99 | batch_time = AverageMeter("Time", ":6.3f")
100 | data_time = AverageMeter("Data", ":6.3f")
101 | losses = AverageMeter("Loss", ":.6f")
102 | top1 = AverageMeter("Acc@1", ":6.2f")
103 | top5 = AverageMeter("Acc@5", ":6.2f")
104 | lrs = AverageMeter("Lr", ":.4f")
105 | conf = AverageMeter("Confidence", ":.5f")
106 | progress = ProgressMeter(
107 | len(train_loader),
108 | [batch_time, data_time, losses, top1, top5, lrs, conf],
109 | prefix="Epoch: [{}]".format(epoch),
110 | )
111 |
112 | # switch to train mode
113 | model.train()
114 |
115 | mixing_transforms = MixingTransforms(
116 | config["image_augmentation"], config["num_classes"]
117 | )
118 |
119 | end = time.time()
120 | for i, (images, target) in enumerate(train_loader):
121 | # measure data loading time
122 | data_time.update(time.time() - end)
123 |
124 | if config["gpu"] is not None:
125 | images = images.cuda(config["gpu"], non_blocking=True)
126 | target = target.cuda(config["gpu"], non_blocking=True)
127 |
128 | # apply mixup / cutmix
129 | mix_images, mix_target = mixing_transforms(images, target)
130 |
131 | # compute output
132 | output = model(mix_images)
133 |
134 | # classification loss
135 | loss = criterion(output, mix_target)
136 |
137 | # measure accuracy and record loss
138 | acc1, acc5 = accuracy(output, target, topk=(1, 5))
139 | losses.update(loss.item(), images.size(0))
140 | top1.update(acc1, images.size(0))
141 | top5.update(acc5, images.size(0))
142 | lrs.update(lr_scheduler.get_last_lr()[0])
143 |
144 | # measure confidence
145 | prob = torch.nn.functional.softmax(output, dim=1)
146 | conf.update(prob.max(1).values.mean().item(), images.size(0))
147 |
148 | # compute gradient and do SGD step
149 | optimizer.zero_grad()
150 | loss.backward()
151 | optimizer.step()
152 |
153 | # measure elapsed time
154 | batch_time.update(time.time() - end)
155 | end = time.time()
156 |
157 | lr_scheduler.step()
158 |
159 | if i % config["print_freq"] == 0:
160 | progress.display(i)
161 |
162 | if config["distributed"]:
163 | top1.all_reduce()
164 | top5.all_reduce()
165 |
166 | metrics = {
167 | "train_accuracy@top1": top1.avg,
168 | "train_accuracy@top5": top5.avg,
169 | "train_loss": losses.avg,
170 | "lr": lrs.avg,
171 | "train_confidence": conf.avg,
172 | }
173 | return metrics
174 |
175 |
176 | class ReinforcedTrainer(ERMTrainer):
177 | """Trainer with a reinforced dataset. Same as ERM Trainer with KL loss."""
178 |
179 | def get_criterion(self) -> Callable[[Tensor, Tensor], Tensor]:
180 | """Return KL loss instead of cross-entropy."""
181 | return lambda output, target: F.kl_div(
182 | F.log_softmax(output, dim=1), target, reduction="batchmean"
183 | )
184 |
185 |
186 | class KDTrainer(ERMTrainer):
187 | """Trainer for Knowledge Distillation."""
188 |
189 | def __init__(self, config: Dict[str, Any]) -> None:
190 | """Initialize trainer and set hyperparameters of KD."""
191 | # Loss config
192 | self.lambda_kd = config["loss"].get("lambda_kd", 1.0)
193 | self.lambda_cls = config["loss"].get("lambda_cls", 0.0)
194 | self.temperature = config["loss"].get("temperature", 1.0)
195 | assert self.temperature > 0, "Softmax with temperature=0 is undefined."
196 | self.label_smoothing = config["loss"].get("label_smoothing", 0.0)
197 |
198 | self.config = config
199 | self.teacher_model = None
200 |
201 | def get_model(self) -> torch.nn.Module:
202 | """Create and initialize student and teacher models."""
203 | config = self.config
204 |
205 | # Instantiate student model for training.
206 | student_arch = config["student"]["arch"]
207 | model = create_model(student_arch, config["student"])
208 | model = move_to_device(model, self.config)
209 |
210 | # Instantiate teacher model
211 | teacher_model = load_model(config["gpu"], config["teacher"])
212 |
213 | if config["gpu"] is not None:
214 | torch.cuda.set_device(config["gpu"])
215 | teacher_model = teacher_model.cuda(config["gpu"])
216 | else:
217 | teacher_model.cuda()
218 | # Set teacher to eval mode
219 | teacher_model.eval()
220 | self.teacher_model = teacher_model
221 |
222 | return model
223 |
224 | def validate_pretrained(
225 | self,
226 | val_loader: DataLoader,
227 | model: torch.nn.Module,
228 | criterion: torch.nn.Module,
229 | config: Dict[str, Any],
230 | ) -> None:
231 | """Validate teacher accuracy before training."""
232 | teacher_model = self.teacher_model
233 | do_validate = config.get("teacher", {}).get("validate", True)
234 | if teacher_model is not None and do_validate:
235 | logging.info(
236 | "Validation loader resizes to standard 256x256 resolution"
237 | " which is necessarily the optimal resolution for the teacher."
238 | )
239 | val_metrics = validate(val_loader, teacher_model, criterion, config)
240 | logging.info(
241 | "Teacher accuracy@top1: {}, @top5: {}".format(
242 | val_metrics["val_accuracy@top1"], val_metrics["val_accuracy@top5"]
243 | )
244 | )
245 |
246 | def train(
247 | self,
248 | train_loader: DataLoader,
249 | model: torch.nn.Module,
250 | criterion: torch.nn.Module,
251 | optimizer: torch.optim.Optimizer,
252 | epoch: int,
253 | config: Dict[str, Any],
254 | lr_scheduler: Type[CosineLR],
255 | ) -> Dict[str, Any]:
256 | """Train a model for a single epoch and return training metrics dictionary."""
257 | batch_time = AverageMeter("Time", ":6.3f")
258 | data_time = AverageMeter("Data", ":6.3f")
259 | losses = AverageMeter("Loss", ":.6f")
260 | kd_losses = AverageMeter("KD Loss", ":.6f")
261 | overall_losses = AverageMeter("Loss", ":.6f")
262 | top1 = AverageMeter("Acc@1", ":6.2f")
263 | top5 = AverageMeter("Acc@5", ":6.2f")
264 | lrs = AverageMeter("Lr", ":.4f")
265 | progress = ProgressMeter(
266 | len(train_loader),
267 | [batch_time, data_time, losses, kd_losses, overall_losses, top1, top5, lrs],
268 | prefix="Epoch: [{}]".format(epoch),
269 | )
270 |
271 | # Switch to train mode
272 | model.train()
273 |
274 | mixing_transforms = MixingTransforms(
275 | config["image_augmentation"], config["num_classes"]
276 | )
277 |
278 | end = time.time()
279 | for i, (images, target) in enumerate(train_loader):
280 | # Measure data loading time
281 | data_time.update(time.time() - end)
282 |
283 | if config["gpu"] is not None:
284 | images = images.cuda(config["gpu"], non_blocking=True)
285 | target = target.cuda(config["gpu"], non_blocking=True)
286 |
287 | # Apply mixup / cutmix
288 | mix_images, mix_target = mixing_transforms(images, target)
289 |
290 | # Compute output for differing resolution. Support only 224 student
291 | mix_images_small = mix_images
292 | if mix_images.shape[-1] != 224:
293 | mix_images_small = F.interpolate(
294 | mix_images, size=(224, 224), mode="bilinear"
295 | )
296 | output = model(mix_images_small)
297 |
298 | # Classification loss
299 | loss = criterion(output, mix_target)
300 | losses.update(loss.item(), images.size(0))
301 |
302 | # Distillation loss
303 | # Get teacher's output for this input
304 | with torch.no_grad():
305 | teacher_probs = self.teacher_model(
306 | mix_images, return_prob=True, temperature=self.temperature
307 | ).detach()
308 | kd_loss = F.kl_div(
309 | F.log_softmax(output / self.temperature, dim=1),
310 | teacher_probs,
311 | reduction="batchmean",
312 | ) * (self.temperature**2)
313 | kd_losses.update(kd_loss.item(), images.size(0))
314 |
315 | # Overall loss is a combination of kd loss and classification loss
316 | loss = self.lambda_cls * loss + self.lambda_kd * kd_loss
317 |
318 | # measure accuracy and record loss
319 | acc1, acc5 = accuracy(output, target, topk=(1, 5))
320 | overall_losses.update(loss.item(), images.size(0))
321 | top1.update(acc1, images.size(0))
322 | top5.update(acc5, images.size(0))
323 | lrs.update(lr_scheduler.get_last_lr()[0])
324 |
325 | # compute gradient and do SGD step
326 | optimizer.zero_grad()
327 | loss.backward()
328 | optimizer.step()
329 |
330 | lr_scheduler.step()
331 |
332 | # measure elapsed time
333 | batch_time.update(time.time() - end)
334 | end = time.time()
335 |
336 | if i % config["print_freq"] == 0:
337 | progress.display(i)
338 |
339 | if config["distributed"]:
340 | top1.all_reduce()
341 | top5.all_reduce()
342 |
343 | metrics = {
344 | "train_accuracy@top1": top1.avg,
345 | "train_accuracy@top5": top5.avg,
346 | "train_loss_ce": losses.avg,
347 | "train_loss_kd": kd_losses.avg,
348 | "train_loss_total": overall_losses.avg,
349 | "lr": lrs.avg,
350 | }
351 | return metrics
352 |
353 |
354 | def validate(
355 | val_loader: DataLoader,
356 | model: torch.nn.Module,
357 | criterion: torch.nn.Module,
358 | config: Dict[str, Any],
359 | ) -> Dict[str, Any]:
360 | """Validate the model that is being trained and return a metrics dictionary."""
361 |
362 | def run_validate(loader: DataLoader, base_progress: int = 0) -> None:
363 | with torch.no_grad():
364 | end = time.time()
365 | for i, (images, target) in enumerate(val_loader):
366 | i = base_progress + i
367 | if config["gpu"] is not None:
368 | images = images.cuda(config["gpu"], non_blocking=True)
369 | target = target.cuda(config["gpu"], non_blocking=True)
370 |
371 | # compute output
372 | output = model(images)
373 | # for validation, compute standard CE loss without label smoothing
374 | loss = F.cross_entropy(output, target)
375 |
376 | # measure accuracy and record loss
377 | acc1, acc5 = accuracy(output, target, topk=(1, 5))
378 | losses.update(loss.item(), images.size(0))
379 | top1.update(acc1, images.size(0))
380 | top5.update(acc5, images.size(0))
381 |
382 | # measure confidence
383 | prob = torch.nn.functional.softmax(output, dim=1)
384 | conf.update(prob.max(1).values.mean().item(), images.size(0))
385 |
386 | # measure elapsed time
387 | batch_time.update(time.time() - end)
388 | end = time.time()
389 |
390 | if i % config["print_freq"] == 0:
391 | progress.display(i)
392 |
393 | batch_time = AverageMeter("Time", ":6.3f", Summary.NONE)
394 | losses = AverageMeter("Loss", ":.6f", Summary.NONE)
395 | top1 = AverageMeter("Acc@1", ":6.2f", Summary.AVERAGE)
396 | top5 = AverageMeter("Acc@5", ":6.2f", Summary.AVERAGE)
397 | conf = AverageMeter("Confidence", ":.5f", Summary.AVERAGE)
398 | progress = ProgressMeter(
399 | len(val_loader)
400 | + (
401 | config["distributed"]
402 | and (
403 | len(val_loader.sampler) * config["world_size"] < len(val_loader.dataset)
404 | )
405 | ),
406 | [batch_time, losses, top1, top5],
407 | prefix="Test: ",
408 | )
409 |
410 | # switch to evaluate mode
411 | model.eval()
412 |
413 | # run validation using all nodes in a distributed env and aggregate results
414 | run_validate(val_loader)
415 | if config["distributed"]:
416 | top1.all_reduce()
417 | top5.all_reduce()
418 |
419 | if config["distributed"] and (
420 | len(val_loader.sampler) * config["world_size"] < len(val_loader.dataset)
421 | ):
422 | aux_val_dataset = Subset(
423 | val_loader.dataset,
424 | range(
425 | len(val_loader.sampler) * config["world_size"], len(val_loader.dataset)
426 | ),
427 | )
428 | aux_val_loader = torch.utils.data.DataLoader(
429 | aux_val_dataset,
430 | batch_size=config["batch_size"],
431 | shuffle=False,
432 | num_workers=config["workers"],
433 | pin_memory=True,
434 | )
435 | run_validate(aux_val_loader, len(val_loader))
436 |
437 | progress.display_summary()
438 |
439 | metrics = {
440 | "val_loss": losses.avg,
441 | "val_accuracy@top1": top1.avg,
442 | "val_accuracy@top5": top5.avg,
443 | "val_confidence": conf.avg,
444 | }
445 | return metrics
446 |
--------------------------------------------------------------------------------
/transforms.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023 Apple Inc. All rights reserved.
3 | #
4 |
5 | """Simplified composition of PyTorch transformations from a configuration dictionary."""
6 |
7 | import math
8 | import random
9 | from typing import Any, Dict, Optional, OrderedDict, Tuple
10 | import numpy as np
11 |
12 | import timm
13 | from timm.data.transforms import str_to_interp_mode
14 | import torch
15 | from torch import Tensor
16 | import torchvision.transforms as T
17 | from torch.nn import functional as F
18 |
19 |
20 | INTERPOLATION_MODE_MAP = {
21 | "nearest": T.InterpolationMode.NEAREST,
22 | "bilinear": T.InterpolationMode.BILINEAR,
23 | "bicubic": T.InterpolationMode.BICUBIC,
24 | "cubic": T.InterpolationMode.BICUBIC,
25 | "box": T.InterpolationMode.BOX,
26 | "hamming": T.InterpolationMode.HAMMING,
27 | "lanczos": T.InterpolationMode.LANCZOS,
28 | }
29 |
30 |
31 | class AutoAugment(T.AutoAugment):
32 | """Extend PyTorch's AutoAugment to init from a policy and an interpolation name."""
33 |
34 | def __init__(
35 | self, policy: str = "imagenet", interpolation: str = "bilinear", *args, **kwargs
36 | ) -> None:
37 | """Init from an policy and interpolation name."""
38 | if "cifar" in policy.lower():
39 | policy = T.AutoAugmentPolicy.CIFAR10
40 | elif "svhn" in policy.lower():
41 | policy = T.AutoAugmentPolicy.SVHN
42 | else:
43 | policy = T.AutoAugmentPolicy.IMAGENET
44 | interpolation = INTERPOLATION_MODE_MAP[interpolation]
45 | super().__init__(*args, policy=policy, interpolation=interpolation, **kwargs)
46 |
47 |
48 | class RandAugment(T.RandAugment):
49 | """Extend PyTorch's RandAugment to init from an interpolation name."""
50 |
51 | def __init__(self, interpolation: str = "bilinear", *args, **kwargs) -> None:
52 | """Init from an interpolation name."""
53 | interpolation = INTERPOLATION_MODE_MAP[interpolation]
54 | super().__init__(*args, interpolation=interpolation, **kwargs)
55 |
56 |
57 | class TrivialAugmentWide(T.TrivialAugmentWide):
58 | """Extend PyTorch's TrivialAugmentWide to init from an interpolation name."""
59 |
60 | def __init__(self, interpolation: str = "bilinear", *args, **kwargs) -> None:
61 | """Init from an interpolation name."""
62 | interpolation = INTERPOLATION_MODE_MAP[interpolation]
63 | super().__init__(*args, interpolation=interpolation, **kwargs)
64 |
65 |
66 | # Transformations are composed according to the order in this dict, not the order in
67 | # yaml config
68 | TRANSFORMATION_TO_NAME = OrderedDict(
69 | [
70 | ("resize", T.Resize),
71 | ("center_crop", T.CenterCrop),
72 | ("random_crop", T.RandomCrop),
73 | ("random_resized_crop", T.RandomResizedCrop),
74 | ("random_horizontal_flip", T.RandomHorizontalFlip),
75 | ("rand_augment", RandAugment),
76 | ("auto_augment", AutoAugment),
77 | ("trivial_augment_wide", TrivialAugmentWide),
78 | ("to_tensor", T.ToTensor),
79 | ("random_erase", T.RandomErasing),
80 | ("normalize", T.Normalize),
81 | ]
82 | )
83 |
84 |
85 | def timm_resize_crop_norm(config: Dict[str, Any]) -> torch.nn.Module:
86 | """Set Resize/RandomCrop/Normalization parameters from configs of a Timm teacher."""
87 | teacher_name = config["timm_resize_crop_norm"]["name"]
88 | cfg = timm.models.get_pretrained_cfg(teacher_name).to_dict()
89 | if "test_input_size" in cfg:
90 | img_size = list(cfg["test_input_size"])[-1]
91 | else:
92 | img_size = list(cfg["input_size"])[-1]
93 | # Crop ratio and image size for optimal performance of a Timm model
94 | crop_pct = cfg["crop_pct"]
95 | scale_size = int(math.floor(img_size / crop_pct))
96 | interpolation = cfg["interpolation"]
97 | config["resize"] = {
98 | "size": scale_size,
99 | "interpolation": str_to_interp_mode(interpolation),
100 | }
101 | config["random_crop"] = {
102 | "size": img_size,
103 | "pad_if_needed": True,
104 | }
105 | config["normalize"] = {"mean": cfg["mean"], "std": cfg["std"]}
106 | return config
107 |
108 |
109 | def clean_config(config: Dict[str, Dict[str, Any]]) -> Dict[str, Dict[str, Any]]:
110 | """Return a clone of configs and remove unnecessary keys from configurations."""
111 | new_config = {}
112 | for k, v in config.items():
113 | vv = dict(v)
114 | if vv.pop("enable", True):
115 | new_config[k] = vv
116 | return new_config
117 |
118 |
119 | def compose_from_config(config_tr: Dict[str, Any]) -> torch.nn.Module:
120 | """Initialize transformations given the dataset name and configurations.
121 |
122 | Args:
123 | config_tr: A dictionary of transformation parameters.
124 |
125 | Returns a composition of transformations.
126 | """
127 | config_tr = clean_config(config_tr)
128 | if "timm_resize_crop_norm" in config_tr:
129 | config_tr = timm_resize_crop_norm(config_tr)
130 | transforms = []
131 | for t_name, t_class in TRANSFORMATION_TO_NAME.items():
132 | if t_name in config_tr:
133 | # TODO: warn for every key in config_tr that was not used
134 | transforms += [t_class(**config_tr[t_name])]
135 | return T.Compose(transforms)
136 |
137 |
138 | class MixUp(torch.nn.Module):
139 | r"""MixUp image transformation.
140 |
141 | For an input x the
142 | output is :math:`\lambda x + (1-\lambda) x_p` , where :math:`x_p` is a
143 | random permutation of `x` along the batch dimension, and lam is a random
144 | number between 0 and 1.
145 | See https://arxiv.org/abs/1710.09412 for more details.
146 | """
147 |
148 | def __init__(
149 | self, alpha: float = 1.0, p: float = 1.0, div_by: float = 1.0, *args, **kwargs
150 | ) -> None:
151 | """Initialize MixUp transformation.
152 |
153 | Args:
154 | alpha: A positive real number that determines the sampling
155 | distribution. Each mixed sample is a convex combination of two
156 | examples from the batch with mixing coefficient lambda.
157 | lambda is sampled from a symmetric Beta distribution with
158 | parameter alpha. When alpha=0 no mixing happens. Defaults to 1.0.
159 | p: Mixing is applied with probability `p`. Defaults to 1.0.
160 | div_by: Divide the lambda by a constant. Set to 2.0 to make sure mixing is
161 | biased towards the first input. Defaults to 1.0.
162 | """
163 | super().__init__(*args, **kwargs)
164 | assert alpha >= 0
165 | assert p >= 0 and p <= 1.0
166 | assert div_by >= 1.0
167 | self.alpha = alpha
168 | self.p = p
169 | self.div_by = div_by
170 |
171 | def get_params(self, alpha: float, div_by: float) -> float:
172 | """Return MixUp random parameters."""
173 | # Skip mixing by probability 1-self.p
174 | if alpha == 0 or torch.rand(1) > self.p:
175 | return None
176 |
177 | lam = np.random.beta(alpha, alpha) / div_by
178 | return lam
179 |
180 | def forward(
181 | self,
182 | x: Tensor,
183 | x2: Optional[Tensor] = None,
184 | y: Optional[Tensor] = None,
185 | y2: Optional[Tensor] = None,
186 | ) -> Tuple[Tensor, Tensor]:
187 | r"""Apply pixel-space mixing to a batch of examples.
188 |
189 | Args:
190 | x: A tensor with a batch of samples. Shape: [batch_size, ...].
191 | x2: A tensor with exactly one matching sample for any input in `x`. Shape:
192 | [batch_size, ...].
193 | y: A tensor of target labels. Shape: [batch_size, ...].
194 | y2: A tensor of target labels for paired samples. Shape: [batch_size, ...].
195 |
196 | Returns:
197 | Mixed x tensor, y labels, and dictionary of mixing parameter {'lam': lam}.
198 | """
199 | alpha = self.alpha
200 | # Randomly sample lambda if not provided
201 | params = self.get_params(alpha, self.div_by)
202 | if params is None:
203 | return x, y
204 | lam = params
205 |
206 | # Randomly sample second input from the same mini-batch if not provided
207 | if x2 is None:
208 | batch_size = int(x.size()[0])
209 | index = torch.randperm(batch_size, device=x.device)
210 | x2 = x[index, :]
211 | y2 = y[index, :] if y is not None else None
212 |
213 | # Mix inputs and labels
214 | mixed_x = lam * x + (1 - lam) * x2
215 | mixed_y = y
216 | if y is not None:
217 | mixed_y = lam * y + (1 - lam) * y2
218 |
219 | return mixed_x, mixed_y
220 |
221 |
222 | class CutMix(torch.nn.Module):
223 | r"""CutMix image transformation.
224 |
225 | Please see the full paper for more details:
226 | https://arxiv.org/pdf/1905.04899.pdf
227 | """
228 |
229 | def __init__(self, alpha: float = 1.0, p: float = 1.0, *args, **kwargs) -> None:
230 | """Initialize CutMix transformation.
231 |
232 | Args:
233 | alpha: The alpha parameter to the Beta for producing a mixing lambda.
234 | """
235 | super().__init__(*args, **kwargs)
236 | assert alpha >= 0
237 | assert p >= 0 and p <= 1.0
238 | self.alpha = alpha
239 | self.p = p
240 |
241 | @staticmethod
242 | def rand_bbox(size: torch.Size, lam: float) -> Tuple[int, int, int, int]:
243 | """Return a random bbox coordinates.
244 |
245 | Args:
246 | size: model input tensor shape in this format: (...,H,W)
247 | lam: lambda sampling parameter in CutMix method. See equation 1
248 | in the original paper: https://arxiv.org/pdf/1905.04899.pdf
249 |
250 | Returns:
251 | The output bbox format is a tuple: (x1, y1, x2, y2), where (x1,
252 | y1) and (x2,y2) are the coordinates of the top-left and bottom-right
253 | corners of the bbox in the pixel-space.
254 | """
255 | assert lam >= 0 and lam <= 1.0
256 | h = size[-2]
257 | w = size[-1]
258 | cut_rat = np.sqrt(1.0 - lam)
259 | cut_h = int(h * cut_rat)
260 | cut_w = int(w * cut_rat)
261 |
262 | # uniform
263 | cx = np.random.randint(h)
264 | cy = np.random.randint(w)
265 |
266 | bbx1 = np.clip(cx - cut_h // 2, 0, h)
267 | bby1 = np.clip(cy - cut_w // 2, 0, w)
268 | bbx2 = np.clip(cx + cut_h // 2, 0, h)
269 | bby2 = np.clip(cy + cut_w // 2, 0, w)
270 |
271 | return (bbx1, bby1, bbx2, bby2)
272 |
273 | def get_params(
274 | self, size: torch.Size, alpha: float
275 | ) -> Tuple[float, Tuple[int, int, int, int]]:
276 | """Return CutMix random parameters."""
277 | # Skip mixing by probability 1-self.p
278 | if alpha == 0 or torch.rand(1) > self.p:
279 | return None
280 |
281 | lam = np.random.beta(alpha, alpha)
282 | # Compute mask
283 | bbx1, bby1, bbx2, bby2 = self.rand_bbox(size, lam)
284 | return lam, (bbx1, bby1, bbx2, bby2)
285 |
286 | def forward(
287 | self,
288 | x: Tensor,
289 | x2: Optional[Tensor] = None,
290 | y: Optional[Tensor] = None,
291 | y2: Optional[Tensor] = None,
292 | ) -> Tuple[Tensor, Tensor]:
293 | """Mix images by replacing random patches from one to the other.
294 |
295 | Args:
296 | x: A tensor with a batch of samples. Shape: [batch_size, ...].
297 | x2: A tensor with exactly one matching sample for any input in `x`. Shape:
298 | [batch_size, ...].
299 | y: A tensor of target labels. Shape: [batch_size, ...].
300 | y2: A tensor of target labels for paired samples. Shape: [batch_size, ...].
301 | params: Dictionary of {'lam': lam_val} to reproduce a mixing.
302 |
303 | """
304 | alpha = self.alpha
305 |
306 | # Randomly sample lambda and bbox coordinates if not provided
307 | params = self.get_params(x.shape, alpha)
308 | if params is None:
309 | return x, y
310 | lam, (bbx1, bby1, bbx2, bby2) = params
311 |
312 | # Randomly sample second input from the same mini-batch if not provided
313 | if x2 is None:
314 | batch_size = int(x.size()[0])
315 | index = torch.randperm(batch_size, device=x.device)
316 | x2 = x[index, :]
317 | y2 = y[index, :] if y is not None else None
318 |
319 | # Mix inputs and labels
320 | mixed_x = x.detach().clone()
321 | mixed_x[:, bbx1:bbx2, bby1:bby2] = x2[:, bbx1:bbx2, bby1:bby2]
322 | mixed_y = y
323 | if y is not None:
324 | # Adjust lambda
325 | lam = 1.0 - ((bbx2 - bbx1) * (bby2 - bby1) / (x.size()[-1] * x.size()[-2]))
326 | mixed_y = lam * y + (1 - lam) * y2
327 |
328 | return mixed_x, mixed_y
329 |
330 |
331 | class MixingTransforms:
332 | """Randomly apply only one of MixUp or CutMix. Used for standard training."""
333 |
334 | def __init__(self, config_tr: Dict[str, Any], num_classes: int) -> None:
335 | """Initialize mixup and/or cutmix."""
336 | config_tr = clean_config(config_tr)
337 | self.mixing_transforms = []
338 | if "mixup" in config_tr:
339 | self.mixing_transforms += [MixUp(**config_tr["mixup"])]
340 | if "cutmix" in config_tr:
341 | self.mixing_transforms += [CutMix(**config_tr["cutmix"])]
342 | self.num_classes = num_classes
343 |
344 | def __call__(self, images: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]:
345 | """Apply only one of MixUp or CutMix."""
346 | if len(self.mixing_transforms) > 0:
347 | one_hot_label = F.one_hot(target, num_classes=self.num_classes)
348 | mix_f = random.choice(self.mixing_transforms)
349 | images, target = mix_f(x=images, y=one_hot_label)
350 | return images, target
351 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023 Apple Inc. All rights reserved.
3 | #
4 | """Utilities for training."""
5 | from enum import Enum
6 | from typing import Dict, Any, Iterable, List, Optional
7 |
8 | import torch
9 | from torch import Tensor
10 | import numpy as np
11 | import logging
12 |
13 | import torch.distributed as dist
14 |
15 |
16 | class Summary(Enum):
17 | """Meter value types."""
18 |
19 | NONE = 0
20 | AVERAGE = 1
21 | SUM = 2
22 | COUNT = 3
23 |
24 |
25 | class AverageMeter(object):
26 | """Computes and stores the average and current value."""
27 |
28 | def __init__(self, name, fmt=":f", summary_type=Summary.AVERAGE):
29 | self.name = name
30 | self.fmt = fmt
31 | self.summary_type = summary_type
32 | self.reset()
33 |
34 | def reset(self):
35 | self.val = 0
36 | self.avg = 0
37 | self.sum = 0
38 | self.count = 0
39 |
40 | def update(self, val, n=1):
41 | self.val = val
42 | self.sum += val * n
43 | self.count += n
44 | self.avg = self.sum / self.count
45 |
46 | def all_reduce(self):
47 | if torch.cuda.is_available():
48 | device = torch.device("cuda")
49 | elif torch.backends.mps.is_available():
50 | device = torch.device("mps")
51 | else:
52 | device = torch.device("cpu")
53 | total = torch.tensor([self.sum, self.count], dtype=torch.float32, device=device)
54 | dist.all_reduce(total, dist.ReduceOp.SUM, async_op=False)
55 | self.sum, self.count = total.tolist()
56 | self.avg = self.sum / self.count
57 |
58 | def __str__(self):
59 | fmtstr = "{name} {val" + self.fmt + "} ({avg" + self.fmt + "})"
60 | return fmtstr.format(**self.__dict__)
61 |
62 | def summary(self):
63 | fmtstr = ""
64 | if self.summary_type is Summary.NONE:
65 | fmtstr = ""
66 | elif self.summary_type is Summary.AVERAGE:
67 | fmtstr = "{name} {avg:.3f}"
68 | elif self.summary_type is Summary.SUM:
69 | fmtstr = "{name} {sum:.3f}"
70 | elif self.summary_type is Summary.COUNT:
71 | fmtstr = "{name} {count:.3f}"
72 | else:
73 | raise ValueError("invalid summary type %r" % self.summary_type)
74 |
75 | return fmtstr.format(**self.__dict__)
76 |
77 |
78 | class ProgressMeter(object):
79 | def __init__(self, num_batches, meters, prefix=""):
80 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
81 | self.meters = meters
82 | self.prefix = prefix
83 |
84 | def display(self, batch):
85 | entries = [self.prefix + self.batch_fmtstr.format(batch)]
86 | entries += [str(meter) for meter in self.meters]
87 | logging.info("\t".join(entries))
88 |
89 | def display_summary(self):
90 | entries = [" *"]
91 | entries += [meter.summary() for meter in self.meters]
92 | logging.info(" ".join(entries))
93 |
94 | def _get_batch_fmtstr(self, num_batches):
95 | num_digits = len(str(num_batches // 1))
96 | fmt = "{:" + str(num_digits) + "d}"
97 | return "[" + fmt + "/" + fmt.format(num_batches) + "]"
98 |
99 |
100 | def accuracy(
101 | output: Tensor, target: Tensor, topk: Optional[Iterable[int]] = (1,)
102 | ) -> List[float]:
103 | """Compute the accuracy over the k top predictions for the specified values of k."""
104 | with torch.no_grad():
105 | maxk = max(topk)
106 | batch_size = target.size(0)
107 |
108 | if len(target.shape) > 1 and target.shape[1] > 1:
109 | # soft labels
110 | _, target = target.max(dim=1)
111 |
112 | _, pred = output.topk(maxk, 1, True, True)
113 | pred = pred.t()
114 | correct = pred.eq(target.view(1, -1).expand_as(pred))
115 |
116 | res = []
117 | for k in topk:
118 | correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
119 | res.append(correct_k.mul_(100.0 / batch_size).item())
120 | return res
121 |
122 |
123 | def assign_learning_rate(optimizer: torch.optim.Optimizer, new_lr: float) -> None:
124 | """Update lr parameter of an optimizer.
125 |
126 | Args:
127 | optimizer: A torch optimizer.
128 | new_lr: updated value of learning rate.
129 | """
130 | for param_group in optimizer.param_groups:
131 | param_group["lr"] = new_lr
132 |
133 |
134 | def _warmup_lr(base_lr: float, warmup_length: int, n_iter: int) -> float:
135 | """Get updated lr after applying initial warmup.
136 |
137 | Args:
138 | base_lr: Nominal learning rate.
139 | warmup_length: Number of total iterations for initial warmup.
140 | n_iter: Current iteration number.
141 |
142 | Returns:
143 | Warmup-updated learning rate.
144 | """
145 | return base_lr * (n_iter + 1) / warmup_length
146 |
147 |
148 | class CosineLR:
149 | """LR adjustment callable with cosine schedule.
150 |
151 | Args:
152 | optimizer: A torch optimizer.
153 | warmup_length: Number of iterations for initial warmup.
154 | total_steps: Total number of iterations.
155 | lr: Nominal learning rate value.
156 |
157 | Returns:
158 | A callable to adjust learning rate per iteration.
159 | """
160 |
161 | def __init__(
162 | self,
163 | optimizer: torch.optim.Optimizer,
164 | warmup_length: int,
165 | total_steps: int,
166 | lr: float,
167 | end_lr: float = 0.0,
168 | **kwargs
169 | ) -> None:
170 | """Set parameters of cosine learning rate with warmup."""
171 | assert lr > end_lr, (
172 | "End LR should be less than the LR. Got:" " lr={} and last_lr={}"
173 | ).format(lr, end_lr)
174 | self.optimizer = optimizer
175 | self.warmup_length = warmup_length
176 | self.total_steps = total_steps
177 | self.lr = lr
178 | self.last_lr = 0
179 | self.end_lr = end_lr
180 | self.last_n_iter = 0
181 |
182 | def step(self) -> float:
183 | """Return updated learning rate for the next iteration."""
184 | self.last_n_iter += 1
185 | n_iter = self.last_n_iter
186 |
187 | if n_iter < self.warmup_length:
188 | new_lr = _warmup_lr(self.lr, self.warmup_length, n_iter)
189 | else:
190 | e = n_iter - self.warmup_length + 1
191 | es = self.total_steps - self.warmup_length
192 |
193 | new_lr = self.end_lr + 0.5 * (self.lr - self.end_lr) * (
194 | 1 + np.cos(np.pi * e / es)
195 | )
196 |
197 | assign_learning_rate(self.optimizer, new_lr)
198 |
199 | self.last_lr = new_lr
200 |
201 | def get_last_lr(self) -> List[float]:
202 | """Return the value of the last learning rate."""
203 | return [self.last_lr]
204 |
205 | def state_dict(self) -> Dict[str, Any]:
206 | """Return the state dictionary to recover optimization in training restart."""
207 | return {
208 | "warmup_length": self.warmup_length,
209 | "total_steps": self.total_steps,
210 | "lr": self.lr,
211 | "last_lr": self.last_lr,
212 | "last_n_iter": self.last_n_iter,
213 | }
214 |
215 | def load_state_dict(self, state: Dict[str, Any]) -> None:
216 | """Restore scheduler state."""
217 | self.warmup_length = state["warmup_length"]
218 | self.total_steps = state["total_steps"]
219 | self.lr = state["lr"]
220 | self.last_lr = state["last_lr"]
221 | self.last_n_iter = state["last_n_iter"]
222 |
--------------------------------------------------------------------------------