├── ACKNOWLEDGEMENTS
├── CODE_OF_CONDUCT.md
├── CONTRIBUTING.md
├── LICENSE
├── README.md
├── __init__.py
├── checkpoints
└── README.md
├── configs
├── cifar100_backbone_new.yaml
├── cifar100_backbone_old.yaml
├── cifar100_eval_new_new.yaml
├── cifar100_eval_old_new_fastfill.yaml
├── cifar100_eval_old_new_fct.yaml
├── cifar100_eval_old_old.yaml
├── cifar100_fastfill_transformation.yaml
├── cifar100_fct_transformation.yaml
├── imagenet_backbone_new.yaml
├── imagenet_backbone_old.yaml
├── imagenet_eval_new_new.yaml
├── imagenet_eval_old_new_fastfill.yaml
├── imagenet_eval_old_new_fct.yaml
├── imagenet_eval_old_old.yaml
├── imagenet_fastfill_transformation.yaml
└── imagenet_fct_transformation.yaml
├── dataset
├── __init__.py
├── data_transforms.py
└── sub_image_folder.py
├── demo.gif
├── eval.py
├── fct_logo.png
├── get_pretrained_models.sh
├── models
├── __init__.py
├── resnet.py
└── transformations.py
├── prepare_dataset.py
├── requirements.txt
├── train_backbone.py
├── train_transformation.py
├── trainers
├── __init__.py
├── backbone_trainer.py
└── transformation_trainer.py
└── utils
├── __init__.py
├── eval_utils.py
├── getters.py
├── logging_utils.py
├── net_utils.py
├── objectives.py
└── schedulers.py
/ACKNOWLEDGEMENTS:
--------------------------------------------------------------------------------
1 | Acknowledgements
2 | Portions of this Software may utilize the following copyrighted
3 | material, the use of which is hereby acknowledged.
4 |
5 | _____________________
6 |
7 | Shen, Yantao and Xiong, Yuanjun and Xia, Wei and Soatto, Stefano (OpenBCT)
8 | Copyright (c) 2020, Yantao Shen
9 | All rights reserved.
10 |
11 | Redistribution and use in source and binary forms, with or without
12 | modification, are permitted provided that the following conditions are met:
13 |
14 | 1. Redistributions of source code must retain the above copyright notice, this
15 | list of conditions and the following disclaimer.
16 |
17 | 2. Redistributions in binary form must reproduce the above copyright notice,
18 | this list of conditions and the following disclaimer in the documentation
19 | and/or other materials provided with the distribution.
20 |
21 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
22 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
23 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
24 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
25 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
26 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
27 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
28 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
29 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
30 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
--------------------------------------------------------------------------------
/CODE_OF_CONDUCT.md:
--------------------------------------------------------------------------------
1 |
2 | # Contributor Covenant Code of Conduct
3 |
4 | ## Our Pledge
5 |
6 | We as members, contributors, and leaders pledge to make participation in our
7 | community a harassment-free experience for everyone, regardless of age, body
8 | size, visible or invisible disability, ethnicity, sex characteristics, gender
9 | identity and expression, level of experience, education, socio-economic status,
10 | nationality, personal appearance, race, caste, color, religion, or sexual
11 | identity and orientation.
12 |
13 | We pledge to act and interact in ways that contribute to an open, welcoming,
14 | diverse, inclusive, and healthy community.
15 |
16 | ## Our Standards
17 |
18 | Examples of behavior that contributes to a positive environment for our
19 | community include:
20 |
21 | * Demonstrating empathy and kindness toward other people
22 | * Being respectful of differing opinions, viewpoints, and experiences
23 | * Giving and gracefully accepting constructive feedback
24 | * Accepting responsibility and apologizing to those affected by our mistakes,
25 | and learning from the experience
26 | * Focusing on what is best not just for us as individuals, but for the overall
27 | community
28 |
29 | Examples of unacceptable behavior include:
30 |
31 | * The use of sexualized language or imagery, and sexual attention or advances of
32 | any kind
33 | * Trolling, insulting or derogatory comments, and personal or political attacks
34 | * Public or private harassment
35 | * Publishing others' private information, such as a physical or email address,
36 | without their explicit permission
37 | * Other conduct which could reasonably be considered inappropriate in a
38 | professional setting
39 |
40 | ## Enforcement Responsibilities
41 |
42 | Community leaders are responsible for clarifying and enforcing our standards of
43 | acceptable behavior and will take appropriate and fair corrective action in
44 | response to any behavior that they deem inappropriate, threatening, offensive,
45 | or harmful.
46 |
47 | Community leaders have the right and responsibility to remove, edit, or reject
48 | comments, commits, code, wiki edits, issues, and other contributions that are
49 | not aligned to this Code of Conduct, and will communicate reasons for moderation
50 | decisions when appropriate.
51 |
52 | ## Scope
53 |
54 | This Code of Conduct applies within all community spaces, and also applies when
55 | an individual is officially representing the community in public spaces.
56 | Examples of representing our community include using an official e-mail address,
57 | posting via an official social media account, or acting as an appointed
58 | representative at an online or offline event.
59 |
60 | ## Enforcement
61 |
62 | Instances of abusive, harassing, or otherwise unacceptable behavior may be
63 | reported to the community leaders responsible for enforcement at
64 | [INSERT CONTACT METHOD].
65 | All complaints will be reviewed and investigated promptly and fairly.
66 |
67 | All community leaders are obligated to respect the privacy and security of the
68 | reporter of any incident.
69 |
70 | ## Enforcement Guidelines
71 |
72 | Community leaders will follow these Community Impact Guidelines in determining
73 | the consequences for any action they deem in violation of this Code of Conduct:
74 |
75 | ### 1. Correction
76 |
77 | **Community Impact**: Use of inappropriate language or other behavior deemed
78 | unprofessional or unwelcome in the community.
79 |
80 | **Consequence**: A private, written warning from community leaders, providing
81 | clarity around the nature of the violation and an explanation of why the
82 | behavior was inappropriate. A public apology may be requested.
83 |
84 | ### 2. Warning
85 |
86 | **Community Impact**: A violation through a single incident or series of
87 | actions.
88 |
89 | **Consequence**: A warning with consequences for continued behavior. No
90 | interaction with the people involved, including unsolicited interaction with
91 | those enforcing the Code of Conduct, for a specified period of time. This
92 | includes avoiding interactions in community spaces as well as external channels
93 | like social media. Violating these terms may lead to a temporary or permanent
94 | ban.
95 |
96 | ### 3. Temporary Ban
97 |
98 | **Community Impact**: A serious violation of community standards, including
99 | sustained inappropriate behavior.
100 |
101 | **Consequence**: A temporary ban from any sort of interaction or public
102 | communication with the community for a specified period of time. No public or
103 | private interaction with the people involved, including unsolicited interaction
104 | with those enforcing the Code of Conduct, is allowed during this period.
105 | Violating these terms may lead to a permanent ban.
106 |
107 | ### 4. Permanent Ban
108 |
109 | **Community Impact**: Demonstrating a pattern of violation of community
110 | standards, including sustained inappropriate behavior, harassment of an
111 | individual, or aggression toward or disparagement of classes of individuals.
112 |
113 | **Consequence**: A permanent ban from any sort of public interaction within the
114 | community.
115 |
116 | ## Attribution
117 |
118 | This Code of Conduct is adapted from the [Contributor Covenant][homepage],
119 | version 2.1, available at
120 | [https://www.contributor-covenant.org/version/2/1/code_of_conduct.html][v2.1].
121 |
122 | Community Impact Guidelines were inspired by
123 | [Mozilla's code of conduct enforcement ladder][Mozilla CoC].
124 |
125 | For answers to common questions about this code of conduct, see the FAQ at
126 | [https://www.contributor-covenant.org/faq][FAQ]. Translations are available at
127 | [https://www.contributor-covenant.org/translations][translations].
128 |
129 | [homepage]: https://www.contributor-covenant.org
130 | [v2.1]: https://www.contributor-covenant.org/version/2/1/code_of_conduct.html
131 | [Mozilla CoC]: https://github.com/mozilla/diversity
132 | [FAQ]: https://www.contributor-covenant.org/faq
133 | [translations]: https://www.contributor-covenant.org/translations
134 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # Contribution Guide
2 |
3 | Thanks for your interest in contributing. This project was released to accompany a research paper for purposes of reproducability, and beyond its publication there are limited plans for future development of the repository.
4 |
5 | While we welcome new pull requests and issues please note that our response may be limited. Forks and out-of-tree improvements are strongly encouraged.
6 |
7 | ## Before you get started
8 |
9 | By submitting a pull request, you represent that you have the right to license your contribution to Apple and the community, and agree by submitting the patch that your contributions are licensed under the [LICENSE](LICENSE).
10 |
11 | We ask that all community members read and observe our [Code of Conduct](CODE_OF_CONDUCT.md).
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Copyright (C) 2020 Apple Inc. All Rights Reserved.
2 |
3 | IMPORTANT: This Apple software is supplied to you by Apple
4 | Inc. ("Apple") in consideration of your agreement to the following
5 | terms, and your use, installation, modification or redistribution of
6 | this Apple software constitutes acceptance of these terms. If you do
7 | not agree with these terms, please do not use, install, modify or
8 | redistribute this Apple software.
9 |
10 | In consideration of your agreement to abide by the following terms, and
11 | subject to these terms, Apple grants you a personal, non-exclusive
12 | license, under Apple's copyrights in this original Apple software (the
13 | "Apple Software"), to use, reproduce, modify and redistribute the Apple
14 | Software, with or without modifications, in source and/or binary forms;
15 | provided that if you redistribute the Apple Software in its entirety and
16 | without modifications, you must retain this notice and the following
17 | text and disclaimers in all such redistributions of the Apple Software.
18 | Neither the name, trademarks, service marks or logos of Apple Inc. may
19 | be used to endorse or promote products derived from the Apple Software
20 | without specific prior written permission from Apple. Except as
21 | expressly stated in this notice, no other rights or licenses, express or
22 | implied, are granted by Apple herein, including but not limited to any
23 | patent rights that may be infringed by your derivative works or by other
24 | works in which the Apple Software may be incorporated.
25 |
26 | The Apple Software is provided by Apple on an "AS IS" basis. APPLE
27 | MAKES NO WARRANTIES, EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION
28 | THE IMPLIED WARRANTIES OF NON-INFRINGEMENT, MERCHANTABILITY AND FITNESS
29 | FOR A PARTICULAR PURPOSE, REGARDING THE APPLE SOFTWARE OR ITS USE AND
30 | OPERATION ALONE OR IN COMBINATION WITH YOUR PRODUCTS.
31 |
32 | IN NO EVENT SHALL APPLE BE LIABLE FOR ANY SPECIAL, INDIRECT, INCIDENTAL
33 | OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
34 | SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
35 | INTERRUPTION) ARISING IN ANY WAY OUT OF THE USE, REPRODUCTION,
36 | MODIFICATION AND/OR DISTRIBUTION OF THE APPLE SOFTWARE, HOWEVER CAUSED
37 | AND WHETHER UNDER THEORY OF CONTRACT, TORT (INCLUDING NEGLIGENCE),
38 | STRICT LIABILITY OR OTHERWISE, EVEN IF APPLE HAS BEEN ADVISED OF THE
39 | POSSIBILITY OF SUCH DAMAGE.
40 |
41 | -------------------------------------------------------------------------------
42 | SOFTWARE DISTRIBUTED WITH ML-FCT:
43 |
44 | The ML-FCT software includes a number of subcomponents with separate
45 | copyright notices and license terms - please see the file ACKNOWLEDGEMENTS.
46 | -------------------------------------------------------------------------------
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Compatibility for Machine Learning Model Update
2 | This repository contains PyTorch implementation of [Forward Compatible Training for Large-Scale Embedding Retrieval Systems (CVPR 2022)](https://arxiv.org/abs/2112.02805):
3 |
4 |
5 |
6 | and [FastFill: Efficient Compatible Model Update (ICLR 2023)](https://openreview.net/pdf?id=rnRiiHw8Vy):
7 |
8 |
9 |
10 | **The code is written to use Python 3.8 or above.**
11 |
12 | ## Requirements
13 |
14 | We suggest you first create a virtual environment and install dependencies in the virtual environment.
15 |
16 | ```bash
17 | # Go to repo
18 | cd
19 | # Create virtual environment ...
20 | python -m venv .venv
21 | # ... and activate it
22 | source .venv/bin/activate
23 | # Upgrade to the latest versions of pip and wheel
24 | pip install -U pip wheel
25 | pip install -r requirements.txt
26 | ```
27 |
28 | ## CIFAR-100 Experiments (quick start)
29 |
30 |
31 | We provide CIFAR-100 experiments, for fast exploration. The code will run and produce results of both FCT and Fastfill.
32 | Here are the sequence of commands for CIFAR-100 experiments (similar to ImageNet but faster cycles):
33 |
34 | ```bash
35 | # Get data: following command put data in data_store/cifar-100-python
36 | python prepare_dataset.py
37 |
38 | # Train old embedding model:
39 | # Note: config files assume training with 8 GPUs. Modify them according to your environment.
40 | python train_backbone.py --config configs/cifar100_backbone_old.yaml
41 |
42 | # Evaluate the old model (single GPU is OK):
43 | python eval.py --config configs/cifar100_eval_old_old.yaml
44 |
45 | # Train New embedding model:
46 | python train_backbone.py --config configs/cifar100_backbone_new.yaml
47 |
48 | # Evaluate the new model (single GPU is OK):
49 | python eval.py --config configs/cifar100_eval_new_new.yaml
50 |
51 | # Download pre-traianed models if training with side-information:
52 | source get_pretrained_models.sh
53 |
54 | # Train FCT transformation:
55 | # If training with side-info model, add its path to the config file below. You
56 | # can use the same side-info model as for ImageNet experiment here.
57 | python train_transformation.py --config configs/cifar100_fct_transformation.yaml
58 |
59 | # Evaluate transformed model vs new model (single GPU is OK):
60 | python eval.py --config configs/cifar100_eval_old_new_fct.yaml
61 |
62 | # Train FastFill transformation:
63 | python train_transformation.py --config configs/cifar100_fastfill_transformation.yaml
64 |
65 | # Evaluate transformed model vs new model (single GPU is OK):
66 | python eval.py --config configs/cifar100_eval_old_new_fastfill.yaml
67 | ```
68 |
69 | ### CIFAR-100 (FCT, without backfilling):
70 | * These results are *not* averaged over multiple runs.
71 |
72 | | Case | `Side-Info` | `CMC Top-1 (%)` | `CMC Top-5 (%)` | `mAP (%)` |
73 | |-------------------------------------------------------|:------------:|:------------------:|:---------------:|:---------:|
74 | | [old/old](./configs/cifar100_backbone_old.yaml) | N/A | 34.2 | 60.6 | 16.5 |
75 | | [new/new](./configs/cifar100_backbone_new.yaml) | N/A | 56.5 | 77.0 | 36.3 |
76 | | [FCT new/old](./configs/cifar100_transformation.yaml) | No | 47.2 | 72.6 | 25.8 |
77 | | [FCT new/old](./configs/cifar100_transformation.yaml) | Yes | 50.2 | 73.7 | 32.2 |
78 |
79 | ### CIFAR-100 (FastFill, with backfilling):
80 | * These results are *not* averaged over multiple runs.
81 | * AUC: Area Under the backfilling Curve. For old/old and new/new we report performance corresponding to
82 | no model update and full model update, respectively.
83 |
84 | | Case | `Side-Info` | `Backfilling` | `AUC CMC Top-1 (%)` | `AUC CMC Top-5 (%)` | `AUC mAP (%)` |
85 | |----------------------------------------------------------------|:-----------:|:-------------:|:-------------------:|:-------------------:|:-------------:|
86 | | [old/old](./configs/cifar100_backbone_old.yaml) | N/A | N/A | 34.2 | 60.6 | 16.5 |
87 | | [new/new](./configs/cifar100_backbone_new.yaml) | N/A | N/A | 56.5 | 77.0 | 36.3 |
88 | | [FCT new/old](./configs/cifar100_fct_transformation.yaml) | No | Random | 49.1 | 73.6 | 29.1 |
89 | | [FastFill new/old](./configs/cifar100_fct_transformation.yaml) | No | Uncertainty | 53.6 | 75.3 | 32.5 |
90 |
91 |
92 | ## ImageNet-1k Experiments
93 |
94 | Here are the sequence of commands for ImageNet experiments:
95 |
96 | ```bash
97 | # Get data: Prepare full ImageNet-1k dataset and provide its path in all config
98 | # files. The path should include training and validation directories.
99 |
100 | # Train old embedding model:
101 | # Note: config files assume training with 8 GPUs. Modify them according to your environment.
102 | python train_backbone.py --config configs/imagenet_backbone_old.yaml
103 |
104 | # Evaluate the old model:
105 | python eval.py --config configs/imagenet_eval_old_old.yaml
106 |
107 | # Train New embedding model:
108 | python train_backbone.py --config configs/imagenet_backbone_new.yaml
109 |
110 | # Evaluate the new model:
111 | python eval.py --config configs/imagenet_eval_new_new.yaml
112 |
113 | # Download pre-traianed models if training with side-information:
114 | source get_pretrained_models.sh
115 |
116 | # Train FCT transformation:
117 | # (If training with side-info model, add its path to the config file below.)
118 | python train_transformation.py --config configs/imagenet_fct_transformation.yaml
119 |
120 | # Evaluate transformed model vs new model:
121 | python eval.py --config configs/imagenet_eval_old_new_fct.yaml
122 |
123 | # Train FastFill transformation:
124 | python train_transformation.py --config configs/imagenet_fastfill_transformation.yaml
125 |
126 | # Evaluate transformed model vs new model:
127 | python eval.py --config configs/imagenet_eval_old_new_fastfill.yaml
128 | ```
129 |
130 | ### ImageNet-1k (FCT, without backfilling):
131 |
132 | | Case | `Side-Info` | `CMC Top-1 (%)` | `CMC Top-5 (%)` | `mAP (%)` |
133 | |-------------------------------------------------------|:-----------:|:---------------:|:---------------:|:-------:|
134 | | [old/old](./configs/imagenet_backbone_old.yaml) | N/A | 46.4 | 65.1 | 28.3 |
135 | | [new/new](./configs/imagenet_backbone_new.yaml) | N/A | 68.4 | 84.7 | 45.6 |
136 | | [FCT new/old](./configs/imagenet_transformation.yaml) | No | 61.8 | 80.5 | 39.9 |
137 | | [FCT new/old](./configs/imagenet_transformation.yaml) | Yes | 65.1 | 82.7 | 44.0 |
138 |
139 | ### ImageNet-1k (FastFill, with backfilling):
140 | * AUC: Area Under the backfilling Curve. For old/old and new/new we report performance corresponding to
141 | no model update and full model update, respectively.
142 |
143 | | Case | `Side-Info` | `Backfilling` | `AUC CMC Top-1 (%)` | `AUC CMC Top-5 (%)` | `AUC mAP (%)` |
144 | |---------------------------------------------------------------------|:-----------:|:-------------:|:-------------------:|:-------------------:|:------------:|
145 | | [old/old](./configs/imagenet_backbone_old.yaml) | N/A | N/A | 46.6 | 65.2 | 28.5 |
146 | | [new/new](./configs/imagenet_backbone_new.yaml) | N/A | N/A | 68.2 | 84.6 | 45.3 |
147 | | [FCT new/old](./configs/imagenet_fct_transformation.yaml) | No | Random | 62.8 | 81.1 | 40.5 |
148 | | [FastFill new/old](./configs/imagenet_fastfill_transformation.yaml) | No | Uncertainty | 66.5 | 83.6 | 44.8 |
149 | | [FCT new/old](./configs/imagenet_fct_transformation.yaml) | Yes | Random | 64.7 | 82.4 | 42.6 |
150 | | [FastFill new/old](./configs/imagenet_fastfill_transformation.yaml) | Yes | Uncertainty | 67.8 | 84.2 | 46.2 |
151 |
152 | ## Contact
153 |
154 | * **Hadi Pouransari**: mpouransari@apple.com
155 |
156 | ## Citation
157 |
158 | ```bibtex
159 | @article{ramanujan2022forward,
160 | title={Forward Compatible Training for Large-Scale Embedding Retrieval Systems},
161 | author={Ramanujan, Vivek and Vasu, Pavan Kumar Anasosalu and Farhadi, Ali and Tuzel, Oncel and Pouransari, Hadi},
162 | journal={Proceedings of the IEEE conference on computer vision and pattern recognition},
163 | year={2022}
164 | }
165 |
166 | @inproceedings{jaeckle2023fastfill,
167 | title={FastFill: Efficient Compatible Model Update},
168 | author={Jaeckle, Florian and Faghri, Fartash and Farhadi, Ali and Tuzel, Oncel and Pouransari, Hadi},
169 | booktitle={International Conference on Learning Representations}
170 | year={2023}
171 | }
172 | ```
173 |
--------------------------------------------------------------------------------
/__init__.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023 Apple Inc. All rights reserved.
3 | #
4 |
--------------------------------------------------------------------------------
/checkpoints/README.md:
--------------------------------------------------------------------------------
1 | # Pretrained models
2 | We provide pretrained side-information models trained with [SimCLR](https://arxiv.org/abs/2002.05709)
3 | using the same training setup as in [Stochastic Contrastive Learning](https://arxiv.org/abs/2110.00552):
4 |
5 | - [ResNet50-128-ImageNet250](https://devpubs.s3.amazonaws.com/ml-research/models/fct/imagenet_250_simclr.pt)
6 | - [ResNet50-128-ImageNet500](https://devpubs.s3.amazonaws.com/ml-research/models/fct/imagenet_500_simclr.pt)
7 | - [ResNet50-128-ImageNet1000](https://devpubs.s3.amazonaws.com/ml-research/models/fct/imagenet_1000_simclr.pt)
8 |
9 | The above models have [ResNet50](https://arxiv.org/abs/1512.03385) architecture with feature dimension of 128. The models are trained on the first 250, 500, and 1000 classes of ImageNet, respectively.
10 |
--------------------------------------------------------------------------------
/configs/cifar100_backbone_new.yaml:
--------------------------------------------------------------------------------
1 | arch_params:
2 | arch: ResNet50
3 | num_classes: 100 # This is the number of classes for architecture FC layer.
4 | embedding_dim: 128
5 | last_nonlin: True
6 |
7 | optimizer_params:
8 | algorithm: sgd
9 | lr: 1.024
10 | weight_decay: 0.000030517578125
11 | no_bn_decay: False
12 | momentum: 0.875
13 | nesterov: False
14 |
15 | dataset_params:
16 | name: cifar100
17 | data_root: data_store/cifar-100-python # This should contain training and validation dirs.
18 | num_classes: 100 # This is the number of classes to include for training.
19 | num_workers: 20
20 | batch_size: 1024
21 |
22 | lr_policy_params:
23 | algorithm: cosine_lr
24 | warmup_length: 5
25 | epochs: 100
26 | lr: 1.024
27 |
28 | epochs: 100
29 | label_smoothing: 0.1
30 | output_model_path: checkpoints/cifar100_new.pt
31 |
--------------------------------------------------------------------------------
/configs/cifar100_backbone_old.yaml:
--------------------------------------------------------------------------------
1 | arch_params:
2 | arch: ResNet50
3 | num_classes: 100 # This is the number of classes for architecture FC layer.
4 | embedding_dim: 128
5 | last_nonlin: True
6 |
7 | optimizer_params:
8 | algorithm: sgd
9 | lr: 1.024
10 | weight_decay: 0.000030517578125
11 | no_bn_decay: False
12 | momentum: 0.875
13 | nesterov: False
14 |
15 | dataset_params:
16 | name: cifar100
17 | data_root: data_store/cifar-100-python # This should contain training and validation dirs.
18 | num_classes: 50 # This is the number of classes to include for training.
19 | num_workers: 20
20 | batch_size: 1024
21 |
22 | lr_policy_params:
23 | algorithm: cosine_lr
24 | warmup_length: 5
25 | epochs: 100
26 | lr: 1.024
27 |
28 | epochs: 100
29 | label_smoothing: 0.1
30 | output_model_path: checkpoints/cifar100_old.pt
31 |
--------------------------------------------------------------------------------
/configs/cifar100_eval_new_new.yaml:
--------------------------------------------------------------------------------
1 | gallery_model_path: checkpoints/cifar100_new.pt
2 | query_model_path: checkpoints/cifar100_new.pt
3 |
4 | eval_params:
5 | distance_metric_name: l2
6 | verbose: True
7 | compute_map: True
8 |
9 | dataset_params: # Test split of the dataset will be used as both gallery and query sets.
10 | name: cifar100
11 | data_root: data_store/cifar-100-python
12 | num_workers: 20
13 | batch_size: 1024
--------------------------------------------------------------------------------
/configs/cifar100_eval_old_new_fastfill.yaml:
--------------------------------------------------------------------------------
1 | gallery_model_path: checkpoints/cifar100_old_fastfill_transformed.pt
2 | query_model_path: checkpoints/cifar100_new.pt
3 |
4 | eval_params:
5 | distance_metric_name: l2
6 | verbose: True
7 | compute_map: True
8 | backfilling: 20
9 | backfilling_list: [ 'sigma' ] # choices are 'random', 'distance', and 'sigma'
10 | backfilling_result_path: checkpoints/cifar100_fastfill_backfilling.npy
11 |
12 | dataset_params: # Test split of the dataset will be used as both gallery and query sets.
13 | name: cifar100
14 | data_root: data_store/cifar-100-python
15 | num_workers: 20
16 | batch_size: 1024
17 |
--------------------------------------------------------------------------------
/configs/cifar100_eval_old_new_fct.yaml:
--------------------------------------------------------------------------------
1 | gallery_model_path: checkpoints/cifar100_old_fct_transformed.pt
2 | query_model_path: checkpoints/cifar100_new.pt
3 |
4 | eval_params:
5 | distance_metric_name: l2
6 | verbose: True
7 | compute_map: True
8 | backfilling: 20 # put null here for no backfilling, only transformed(old) vs new.
9 | backfilling_list: [ 'random' ] # choices are 'random' and 'distance'
10 | backfilling_result_path: checkpoints/cifar100_fct_backfilling.npy
11 |
12 | dataset_params: # Test split of the dataset will be used as both gallery and query sets.
13 | name: cifar100
14 | data_root: data_store/cifar-100-python
15 | num_workers: 20
16 | batch_size: 1024
17 |
--------------------------------------------------------------------------------
/configs/cifar100_eval_old_old.yaml:
--------------------------------------------------------------------------------
1 | gallery_model_path: checkpoints/cifar100_old.pt
2 | query_model_path: checkpoints/cifar100_old.pt
3 |
4 | eval_params:
5 | distance_metric_name: l2
6 | verbose: True
7 | compute_map: True
8 |
9 | dataset_params: # Test split of the dataset will be used as both gallery and query sets.
10 | name: cifar100
11 | data_root: data_store/cifar-100-python
12 | num_workers: 20
13 | batch_size: 1024
14 |
--------------------------------------------------------------------------------
/configs/cifar100_fastfill_transformation.yaml:
--------------------------------------------------------------------------------
1 | old_model_path: checkpoints/cifar100_old.pt
2 | new_model_path: checkpoints/cifar100_new.pt
3 | #side_info_model_path: checkpoints/imagenet_1000_simclr.pt # Comment this line for no side-info experiment.
4 |
5 | arch_params:
6 | arch: MLP_BN_SIDE_PROJECTION_SIGMA
7 | old_embedding_dim: 128
8 | new_embedding_dim: 128
9 | side_info_dim: 128
10 | inner_dim: 2048
11 |
12 | optimizer_params:
13 | algorithm: adam
14 | lr: 0.0005
15 | weight_decay: 0.000030517578125
16 |
17 | dataset_params:
18 | name: cifar100
19 | data_root: data_store/cifar-100-python # This should contain training and validation dirs.
20 | num_classes: 100 # This is the number of classes to include for training.
21 | num_workers: 20
22 | batch_size: 1024
23 |
24 | lr_policy_params:
25 | algorithm: cosine_lr
26 | warmup_length: 5
27 | epochs: 80
28 | lr: 0.0005
29 |
30 | objective_params:
31 | similarity_loss:
32 | name: mse
33 | discriminative_loss:
34 | name: LabelSmoothing
35 | label_smoothing: 0.1
36 | mu_disc: 1
37 | uncertainty_loss:
38 | mu_uncertainty: 0.5
39 |
40 | epochs: 80
41 | switch_mode_to_eval: True
42 | output_transformation_path: checkpoints/cifar100_fastfill_transformation.pt
43 | output_transformed_old_model_path: checkpoints/cifar100_old_fastfill_transformed.pt
44 |
--------------------------------------------------------------------------------
/configs/cifar100_fct_transformation.yaml:
--------------------------------------------------------------------------------
1 | old_model_path: checkpoints/cifar100_old.pt
2 | new_model_path: checkpoints/cifar100_new.pt
3 | #side_info_model_path: checkpoints/imagenet_1000_simclr.pt # Comment this line for no side-info experiment.
4 |
5 | arch_params:
6 | arch: MLP_BN_SIDE_PROJECTION
7 | old_embedding_dim: 128
8 | new_embedding_dim: 128
9 | side_info_dim: 128
10 | inner_dim: 2048
11 |
12 | optimizer_params:
13 | algorithm: adam
14 | lr: 0.0005
15 | weight_decay: 0.000030517578125
16 |
17 | dataset_params:
18 | name: cifar100
19 | data_root: data_store/cifar-100-python # This should contain training and validation dirs.
20 | num_classes: 100 # This is the number of classes to include for training.
21 | num_workers: 20
22 | batch_size: 1024
23 |
24 | lr_policy_params:
25 | algorithm: cosine_lr
26 | warmup_length: 5
27 | epochs: 80
28 | lr: 0.0005
29 |
30 | objective_params:
31 | similarity_loss:
32 | name: mse
33 |
34 | epochs: 80
35 | switch_mode_to_eval: True
36 | output_transformation_path: checkpoints/cifar100_fct_transformation.pt
37 | output_transformed_old_model_path: checkpoints/cifar100_old_fct_transformed.pt
38 |
--------------------------------------------------------------------------------
/configs/imagenet_backbone_new.yaml:
--------------------------------------------------------------------------------
1 | arch_params:
2 | arch: ResNet50
3 | num_classes: 1000 # This is the number of classes for architecture FC layer.
4 | embedding_dim: 128
5 | last_nonlin: True
6 |
7 | optimizer_params:
8 | algorithm: sgd
9 | lr: 1.024
10 | weight_decay: 0.000030517578125
11 | no_bn_decay: False
12 | momentum: 0.875
13 | nesterov: False
14 |
15 | dataset_params:
16 | name: imagenet
17 | data_root: data_store/imagenet-1.0.2/data/raw # This should contain training and validation dirs.
18 | num_classes: 1000 # This is the number of classes to include for training.
19 | num_workers: 20
20 | batch_size: 1024
21 |
22 | lr_policy_params:
23 | algorithm: cosine_lr
24 | warmup_length: 5
25 | epochs: 100
26 | lr: 1.024
27 |
28 | epochs: 100
29 | label_smoothing: 0.1
30 | output_model_path: checkpoints/imagenet_new.pt
31 |
--------------------------------------------------------------------------------
/configs/imagenet_backbone_old.yaml:
--------------------------------------------------------------------------------
1 | arch_params:
2 | arch: ResNet50
3 | num_classes: 1000 # This is the number of classes for architecture FC layer.
4 | embedding_dim: 128
5 | last_nonlin: True
6 |
7 | optimizer_params:
8 | algorithm: sgd
9 | lr: 1.024
10 | weight_decay: 0.000030517578125
11 | no_bn_decay: False
12 | momentum: 0.875
13 | nesterov: False
14 |
15 | dataset_params:
16 | name: imagenet
17 | data_root: data_store/imagenet-1.0.2/data/raw # This should contain training and validation dirs.
18 | num_classes: 500 # This is the number of classes to include for training.
19 | num_workers: 20
20 | batch_size: 1024
21 |
22 | lr_policy_params:
23 | algorithm: cosine_lr
24 | warmup_length: 5
25 | epochs: 100
26 | lr: 1.024
27 |
28 | epochs: 100
29 | label_smoothing: 0.1
30 | output_model_path: checkpoints/imagenet_old.pt
31 |
--------------------------------------------------------------------------------
/configs/imagenet_eval_new_new.yaml:
--------------------------------------------------------------------------------
1 | gallery_model_path: checkpoints/imagenet_new.pt
2 | query_model_path: checkpoints/imagenet_new.pt
3 |
4 | eval_params:
5 | distance_metric_name: l2
6 | verbose: True
7 | compute_map: True
8 |
9 | dataset_params: # Test split of the dataset will be used as both gallery and query sets.
10 | name: imagenet
11 | data_root: data_store/imagenet-1.0.2/data/raw
12 | num_workers: 20
13 | batch_size: 1024
14 |
--------------------------------------------------------------------------------
/configs/imagenet_eval_old_new_fastfill.yaml:
--------------------------------------------------------------------------------
1 | gallery_model_path: checkpoints/imagenet_old_fastfill_transformed.pt
2 | query_model_path: checkpoints/imagenet_new.pt
3 |
4 | eval_params:
5 | distance_metric_name: l2
6 | verbose: True
7 | compute_map: True
8 | backfilling: 20
9 | backfilling_list: ['sigma'] # choices are 'random', 'distance', and 'sigma'
10 | backfilling_result_path: checkpoints/imagenet_fastfill_backfilling.npy
11 |
12 | dataset_params: # Test split of the dataset will be used as both gallery and query sets.
13 | name: imagenet
14 | data_root: data_store/imagenet-1.0.2/data/raw
15 | num_workers: 20
16 | batch_size: 1024
17 |
--------------------------------------------------------------------------------
/configs/imagenet_eval_old_new_fct.yaml:
--------------------------------------------------------------------------------
1 | gallery_model_path: checkpoints/imagenet_old_fct_transformed.pt
2 | query_model_path: checkpoints/imagenet_new.pt
3 |
4 | eval_params:
5 | distance_metric_name: l2
6 | verbose: True
7 | compute_map: True
8 | backfilling: 20 # put null here for no backfilling, only transformed(old) vs new.
9 | backfilling_list: ['random'] # choices are 'random' and 'distance'
10 | backfilling_result_path: checkpoints/imagenet_fct_backfilling.npy
11 |
12 | dataset_params: # Test split of the dataset will be used as both gallery and query sets.
13 | name: imagenet
14 | data_root: data_store/imagenet-1.0.2/data/raw
15 | num_workers: 20
16 | batch_size: 1024
17 |
--------------------------------------------------------------------------------
/configs/imagenet_eval_old_old.yaml:
--------------------------------------------------------------------------------
1 | gallery_model_path: checkpoints/imagenet_old.pt
2 | query_model_path: checkpoints/imagenet_old.pt
3 |
4 | eval_params:
5 | distance_metric_name: l2
6 | verbose: True
7 | compute_map: True
8 |
9 | dataset_params: # Test split of the dataset will be used as both gallery and query sets.
10 | name: imagenet
11 | data_root: data_store/imagenet-1.0.2/data/raw
12 | num_workers: 20
13 | batch_size: 1024
14 |
--------------------------------------------------------------------------------
/configs/imagenet_fastfill_transformation.yaml:
--------------------------------------------------------------------------------
1 | old_model_path: checkpoints/imagenet_old.pt
2 | new_model_path: checkpoints/imagenet_new.pt
3 | #side_info_model_path: checkpoints/imagenet_500_simclr.pt # Comment this line for no side-info experiment.
4 |
5 | arch_params:
6 | arch: MLP_BN_SIDE_PROJECTION_SIGMA
7 | old_embedding_dim: 128
8 | new_embedding_dim: 128
9 | side_info_dim: 128
10 | inner_dim: 2048
11 |
12 | optimizer_params:
13 | algorithm: adam
14 | lr: 0.0005
15 | weight_decay: 0.000030517578125
16 |
17 | dataset_params:
18 | name: imagenet
19 | data_root: data_store/imagenet-1.0.2/data/raw # This should contain training and validation dirs.
20 | num_classes: 1000 # This is the number of classes to include for training.
21 | num_workers: 20
22 | batch_size: 1024
23 |
24 | lr_policy_params:
25 | algorithm: cosine_lr
26 | warmup_length: 5
27 | epochs: 80
28 | lr: 0.0005
29 |
30 | objective_params:
31 | similarity_loss:
32 | name: mse
33 | discriminative_loss:
34 | name: LabelSmoothing
35 | label_smoothing: 0.1
36 | mu_disc: 1
37 | uncertainty_loss:
38 | mu_uncertainty: 0.5
39 |
40 | epochs: 80
41 | switch_mode_to_eval: True
42 | output_transformation_path: checkpoints/imagenet_fastfill_transformation.pt
43 | output_transformed_old_model_path: checkpoints/imagenet_old_fastfill_transformed.pt
44 |
--------------------------------------------------------------------------------
/configs/imagenet_fct_transformation.yaml:
--------------------------------------------------------------------------------
1 | old_model_path: checkpoints/imagenet_old.pt
2 | new_model_path: checkpoints/imagenet_new.pt
3 | #side_info_model_path: checkpoints/imagenet_500_simclr.pt # Comment this line for no side-info experiment.
4 |
5 | arch_params:
6 | arch: MLP_BN_SIDE_PROJECTION
7 | old_embedding_dim: 128
8 | new_embedding_dim: 128
9 | side_info_dim: 128
10 | inner_dim: 2048
11 |
12 | optimizer_params:
13 | algorithm: adam
14 | lr: 0.0005
15 | weight_decay: 0.000030517578125
16 |
17 | dataset_params:
18 | name: imagenet
19 | data_root: data_store/imagenet-1.0.2/data/raw # This should contain training and validation dirs.
20 | num_classes: 1000 # This is the number of classes to include for training.
21 | num_workers: 20
22 | batch_size: 1024
23 |
24 | lr_policy_params:
25 | algorithm: cosine_lr
26 | warmup_length: 5
27 | epochs: 80
28 | lr: 0.0005
29 |
30 | objective_params:
31 | similarity_loss:
32 | name: mse
33 |
34 | epochs: 80
35 | switch_mode_to_eval: True
36 | output_transformation_path: checkpoints/imagenet_fct_transformation.pt
37 | output_transformed_old_model_path: checkpoints/imagenet_old_fct_transformed.pt
38 |
--------------------------------------------------------------------------------
/dataset/__init__.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023 Apple Inc. All rights reserved.
3 | #
4 |
5 | from .sub_image_folder import SubImageFolder
6 |
--------------------------------------------------------------------------------
/dataset/data_transforms.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023 Apple Inc. All rights reserved.
3 | #
4 |
5 | from typing import Tuple
6 | from torchvision import transforms
7 |
8 |
9 | def imagenet_transforms() -> Tuple[transforms.Compose, transforms.Compose]:
10 | """Get training and validation transformations for Imagenet."""
11 | normalize = transforms.Normalize(
12 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
13 | )
14 | train_transforms = transforms.Compose(
15 | [
16 | transforms.RandomResizedCrop(224),
17 | transforms.RandomHorizontalFlip(),
18 | transforms.ToTensor(),
19 | normalize,
20 | ]
21 | )
22 |
23 | val_transforms = transforms.Compose(
24 | [
25 | transforms.Resize(256),
26 | transforms.CenterCrop(224),
27 | transforms.ToTensor(),
28 | normalize,
29 | ]
30 | )
31 | return train_transforms, val_transforms
32 |
33 |
34 | def cifar100_transforms() -> Tuple[transforms.Compose, transforms.Compose]:
35 | """Get training and validation transformations for Cifar100.
36 |
37 | Note that these are not optimal transformations (including normalization),
38 | yet provided for quick experimentation similar to Imagenet
39 | (and its corresponding side-information model).
40 | """
41 | normalize = transforms.Normalize(
42 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
43 | )
44 |
45 | train_transforms = transforms.Compose(
46 | [
47 | transforms.Resize(224),
48 | transforms.RandomHorizontalFlip(),
49 | transforms.ToTensor(),
50 | normalize,
51 | ]
52 | )
53 |
54 | val_transforms = transforms.Compose(
55 | [
56 | transforms.Resize(224),
57 | transforms.ToTensor(),
58 | normalize,
59 | ]
60 | )
61 | return train_transforms, val_transforms
62 |
63 |
64 | data_transforms_map = {"cifar100": cifar100_transforms, "imagenet": imagenet_transforms}
65 |
66 |
67 | def get_data_transforms(
68 | dataset_name: str,
69 | ) -> Tuple[transforms.Compose, transforms.Compose]:
70 | """Get training and validation transforms of a dataset.
71 |
72 | :param dataset_name: Name of the dataset (e.g., cifar100, imagenet)
73 | :return: Tuple of training and validation transformations.
74 | """
75 | return data_transforms_map.get(dataset_name)()
76 |
--------------------------------------------------------------------------------
/dataset/sub_image_folder.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023 Apple Inc. All rights reserved.
3 | #
4 |
5 | import os
6 |
7 | import torch
8 | from torchvision import datasets
9 |
10 | from .data_transforms import get_data_transforms
11 |
12 |
13 | class SubImageFolder:
14 | """Class to support training on subset of classes."""
15 |
16 | def __init__(
17 | self,
18 | name: str,
19 | data_root: str,
20 | num_workers: int,
21 | batch_size: int,
22 | num_classes=None,
23 | ) -> None:
24 | """Construct a SubImageFolder module.
25 |
26 | :param name: Name of the dataset (e.g., cifar100, imagenet).
27 | :param data_root: Path to a directory with training and validation
28 | subdirs of the dataset.
29 | :param num_workers: Number of workers for data loader.
30 | :param batch_size: Global batch size. Per GPU batch size = batch_size/num_gpus.
31 | :param num_classes: Number of classes to use for training. This should
32 | be smaller than or equal to the total number of classes in the dataset.
33 | Note that for evaluation we use all classes.
34 | """
35 | super(SubImageFolder, self).__init__()
36 |
37 | use_cuda = torch.cuda.is_available()
38 | kwargs = {"num_workers": num_workers, "pin_memory": True} if use_cuda else {}
39 |
40 | traindir = os.path.join(data_root, "training")
41 | valdir = os.path.join(data_root, "validation")
42 |
43 | train_transforms, val_transforms = get_data_transforms(name)
44 |
45 | self.train_dataset = datasets.ImageFolder(
46 | traindir,
47 | train_transforms,
48 | )
49 |
50 | # Filtering out some classes
51 | if num_classes is not None:
52 | self.train_dataset.imgs = [
53 | (path, cls_num)
54 | for path, cls_num in self.train_dataset.imgs
55 | if cls_num < num_classes
56 | ]
57 |
58 | self.train_dataset.samples = self.train_dataset.imgs
59 |
60 | self.train_sampler = None
61 |
62 | self.train_loader = torch.utils.data.DataLoader(
63 | self.train_dataset,
64 | batch_size=batch_size,
65 | sampler=self.train_sampler,
66 | shuffle=self.train_sampler is None,
67 | **kwargs
68 | )
69 |
70 | # Note: for evaluation we use all classes.
71 | self.val_loader = torch.utils.data.DataLoader(
72 | datasets.ImageFolder(
73 | valdir,
74 | val_transforms,
75 | ),
76 | batch_size=batch_size,
77 | shuffle=False,
78 | **kwargs
79 | )
80 |
--------------------------------------------------------------------------------
/demo.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/apple/ml-fct/1c658df902c70b2556ce9677bbca660c80fdf9e6/demo.gif
--------------------------------------------------------------------------------
/eval.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023 Apple Inc. All rights reserved.
3 | #
4 |
5 | from typing import Dict
6 | from argparse import ArgumentParser
7 |
8 | import yaml
9 | import torch
10 |
11 | from dataset import SubImageFolder
12 | from utils.eval_utils import cmc_evaluate
13 |
14 |
15 | def main(config: Dict) -> None:
16 | """Run evaluation.
17 |
18 | :param config: A dictionary with all configurations to run evaluation.
19 | """
20 | device = torch.device("cuda:0") if torch.cuda.is_available() else "cpu"
21 |
22 | # Load models:
23 | gallery_model = torch.jit.load(config.get("gallery_model_path"))
24 | query_model = torch.jit.load(config.get("query_model_path"))
25 |
26 | data = SubImageFolder(**config.get("dataset_params"))
27 |
28 | cmc_evaluate(
29 | gallery_model, query_model, data.val_loader, device, **config.get("eval_params")
30 | )
31 |
32 |
33 | if __name__ == "__main__":
34 | parser = ArgumentParser()
35 | parser.add_argument(
36 | "--config",
37 | type=str,
38 | required=True,
39 | help="Path to config file for this pipeline.",
40 | )
41 | args = parser.parse_args()
42 | with open(args.config) as f:
43 | read_config = yaml.safe_load(f)
44 | main(read_config)
45 |
--------------------------------------------------------------------------------
/fct_logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/apple/ml-fct/1c658df902c70b2556ce9677bbca660c80fdf9e6/fct_logo.png
--------------------------------------------------------------------------------
/get_pretrained_models.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | #
3 | # Copyright (C) 2023 Apple Inc. All rights reserved.
4 | #
5 |
6 | wget https://devpubs.s3.amazonaws.com/ml-research/models/fct/imagenet_250_simclr.pt -P checkpoints
7 | wget https://devpubs.s3.amazonaws.com/ml-research/models/fct/imagenet_500_simclr.pt -P checkpoints
8 | wget https://devpubs.s3.amazonaws.com/ml-research/models/fct/imagenet_1000_simclr.pt -P checkpoints
9 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023 Apple Inc. All rights reserved.
3 | #
4 |
5 | from .resnet import ResNet18, ResNet50, ResNet101, WideResNet50_2, WideResNet101_2
6 | from .transformations import MLP_BN_SIDE_PROJECTION, MLP_BN_SIDE_PROJECTION_SIGMA
7 |
--------------------------------------------------------------------------------
/models/resnet.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023 Apple Inc. All rights reserved.
3 | #
4 | from typing import Optional, List, Tuple
5 |
6 | import torch
7 | import torch.nn as nn
8 | import torch.nn.functional as F
9 |
10 |
11 | class BasicBlock(nn.Module):
12 | """Resnet basic block module."""
13 |
14 | expansion = 1
15 |
16 | def __init__(
17 | self,
18 | inplanes: int,
19 | planes: int,
20 | stride: int = 1,
21 | downsample: Optional[nn.Module] = None,
22 | base_width: int = 64,
23 | nonlin: bool = True,
24 | embedding_dim: Optional[int] = None,
25 | ) -> None:
26 | """Construct a BasicBlock module.
27 |
28 | :param inplanes: Number of input channels.
29 | :param planes: Number of output channels.
30 | :param stride: Stride size.
31 | :param downsample: Down-sampling for residual path.
32 | :param base_width: Base width of the block.
33 | :param nonlin: Whether to apply non-linearity before output.
34 | :param embedding_dim: Size of the output embedding dimension.
35 | """
36 | super(BasicBlock, self).__init__()
37 | if base_width / 64 > 1:
38 | raise ValueError("Base width >64 does not work for BasicBlock")
39 | if embedding_dim is not None:
40 | planes = embedding_dim
41 | self.conv1 = nn.Conv2d(
42 | inplanes, planes, kernel_size=3, padding=1, stride=stride, bias=False
43 | )
44 | self.bn1 = nn.BatchNorm2d(planes)
45 | self.relu = nn.ReLU(inplace=True)
46 | self.conv2 = nn.Conv2d(
47 | planes, planes, kernel_size=3, padding=1, stride=1, bias=False
48 | )
49 | self.bn2 = nn.BatchNorm2d(planes)
50 | self.downsample = downsample
51 | self.stride = stride
52 | self.nonlin = nonlin
53 |
54 | def forward(self, x: torch.Tensor) -> torch.Tensor:
55 | """Apply forward pass."""
56 | residual = x
57 |
58 | out = self.conv1(x)
59 | if self.bn1 is not None:
60 | out = self.bn1(out)
61 |
62 | out = self.relu(out)
63 |
64 | out = self.conv2(out)
65 |
66 | if self.bn2 is not None:
67 | out = self.bn2(out)
68 |
69 | if self.downsample is not None:
70 | residual = self.downsample(x)
71 |
72 | out += residual
73 | if self.nonlin:
74 | out = self.relu(out)
75 |
76 | return out
77 |
78 |
79 | class Bottleneck(nn.Module):
80 | """Resnet bottleneck block module."""
81 |
82 | expansion = 4
83 |
84 | def __init__(
85 | self,
86 | inplanes: int,
87 | planes: int,
88 | stride: int = 1,
89 | downsample: Optional[nn.Module] = None,
90 | base_width: int = 64,
91 | nonlin: bool = True,
92 | embedding_dim: Optional[int] = None,
93 | ) -> None:
94 | """Construct a Bottleneck module.
95 |
96 | :param inplanes: Number of input channels.
97 | :param planes: Number of output channels.
98 | :param stride: Stride size.
99 | :param downsample: Down-sampling for residual path.
100 | :param base_width: Base width of the block.
101 | :param nonlin: Whether to apply non-linearity before output.
102 | :param embedding_dim: Size of the output embedding dimension.
103 | """
104 | super(Bottleneck, self).__init__()
105 | width = int(planes * base_width / 64)
106 | if embedding_dim is not None:
107 | out_dim = embedding_dim
108 | else:
109 | out_dim = planes * self.expansion
110 | self.conv1 = nn.Conv2d(inplanes, width, kernel_size=1, stride=1, bias=False)
111 | self.bn1 = nn.BatchNorm2d(width)
112 | self.conv2 = nn.Conv2d(
113 | width, width, kernel_size=3, padding=1, stride=stride, bias=False
114 | )
115 | self.bn2 = nn.BatchNorm2d(width)
116 | self.conv3 = nn.Conv2d(width, out_dim, kernel_size=1, stride=1, bias=False)
117 | self.bn3 = nn.BatchNorm2d(out_dim)
118 | self.relu = nn.ReLU(inplace=True)
119 | self.downsample = downsample
120 | self.stride = stride
121 | self.nonlin = nonlin
122 |
123 | def forward(self, x: torch.Tensor) -> torch.Tensor:
124 | """Apply forward pass."""
125 | residual = x
126 |
127 | out = self.conv1(x)
128 | out = self.bn1(out)
129 | out = self.relu(out)
130 |
131 | out = self.conv2(out)
132 | out = self.bn2(out)
133 | out = self.relu(out)
134 |
135 | out = self.conv3(out)
136 | out = self.bn3(out)
137 |
138 | if self.downsample is not None:
139 | residual = self.downsample(x)
140 |
141 | out += residual
142 |
143 | if self.nonlin:
144 | out = self.relu(out)
145 |
146 | return out
147 |
148 |
149 | class ResNet(nn.Module):
150 | """Resnet module."""
151 |
152 | def __init__(
153 | self,
154 | block: nn.Module,
155 | layers: List[int],
156 | num_classes: int = 1000,
157 | base_width: int = 64,
158 | embedding_dim: Optional[int] = None,
159 | last_nonlin: bool = True,
160 | norm_feature: bool = False,
161 | ) -> None:
162 | """Construct a ResNet module.
163 |
164 | :param block: Block module to use in Resnet architecture.
165 | :param layers: List of number of blocks per layer.
166 | :param num_classes: Number of classes in the dataset. It is used to
167 | form linear classifier weights.
168 | :param base_width: Base width of the blocks.
169 | :param embedding_dim: Size of the output embedding dimension.
170 | :param last_nonlin: Whether to apply non-linearity before output.
171 | :param norm_feature: Whether to normalized output embeddings.
172 | """
173 | self.inplanes = 64
174 | super(ResNet, self).__init__()
175 |
176 | self.OUTPUT_SHAPE = [embedding_dim, 1, 1]
177 | self.is_normalized = norm_feature
178 | self.base_width = base_width
179 | if self.base_width // 64 > 1:
180 | print(f"==> Using {self.base_width // 64}x wide model")
181 |
182 | if embedding_dim is not None:
183 | print("Using given embedding dimension = {}".format(embedding_dim))
184 | self.embedding_dim = embedding_dim
185 | else:
186 | self.embedding_dim = 512 * block.expansion
187 |
188 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, padding=3, stride=2, bias=False)
189 | self.bn1 = nn.BatchNorm2d(64)
190 | self.relu = nn.ReLU(inplace=True)
191 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
192 | self.layer1 = self._make_layer(
193 | block, 64, layers[0], embedding_dim=64 * block.expansion
194 | )
195 | self.layer2 = self._make_layer(
196 | block,
197 | 128,
198 | layers[1],
199 | stride=2,
200 | embedding_dim=128 * block.expansion,
201 | )
202 | self.layer3 = self._make_layer(
203 | block,
204 | 256,
205 | layers[2],
206 | stride=2,
207 | embedding_dim=256 * block.expansion,
208 | )
209 | self.layer4 = self._make_layer(
210 | block,
211 | 512,
212 | layers[3],
213 | stride=2,
214 | nonlin=last_nonlin,
215 | embedding_dim=self.embedding_dim,
216 | )
217 | self.avgpool = nn.AdaptiveAvgPool2d(1)
218 |
219 | # self.fc = nn.Linear(512 * block.expansion, num_classes)
220 | self.fc = nn.Conv2d(
221 | self.embedding_dim, num_classes, kernel_size=1, stride=1, bias=False
222 | )
223 |
224 | def _make_layer(
225 | self,
226 | block: nn.Module,
227 | planes: int,
228 | blocks: int,
229 | embedding_dim: int,
230 | stride: int = 1,
231 | nonlin: bool = True,
232 | ):
233 | """Make a layer of resnet architecture.
234 |
235 | :param block: Block module to use in this layer.
236 | :param planes: Number of output channels.
237 | :param blocks: Number of blocks in this layer.
238 | :param embedding_dim: Size of the output embedding dimension.
239 | :param stride: Stride size.
240 | :param nonlin: Whether to apply non-linearity before output.
241 | :return:
242 | """
243 | downsample = None
244 | if stride != 1 or self.inplanes != planes * block.expansion:
245 | dconv = nn.Conv2d(
246 | self.inplanes,
247 | planes * block.expansion,
248 | kernel_size=1,
249 | stride=stride,
250 | bias=False,
251 | )
252 | dbn = nn.BatchNorm2d(planes * block.expansion)
253 | if dbn is not None:
254 | downsample = nn.Sequential(dconv, dbn)
255 | else:
256 | downsample = dconv
257 |
258 | last_downsample = None
259 |
260 | layers = []
261 | if blocks == 1: # If this layer has only one-block
262 | if stride != 1 or self.inplanes != embedding_dim:
263 | dconv = nn.Conv2d(
264 | self.inplanes,
265 | embedding_dim,
266 | kernel_size=1,
267 | stride=stride,
268 | bias=False,
269 | )
270 | dbn = nn.BatchNorm2d(embedding_dim)
271 | if dbn is not None:
272 | last_downsample = nn.Sequential(dconv, dbn)
273 | else:
274 | last_downsample = dconv
275 | layers.append(
276 | block(
277 | self.inplanes,
278 | planes,
279 | stride,
280 | last_downsample,
281 | base_width=self.base_width,
282 | nonlin=nonlin,
283 | embedding_dim=embedding_dim,
284 | )
285 | )
286 | return nn.Sequential(*layers)
287 | else:
288 | layers.append(
289 | block(
290 | self.inplanes,
291 | planes,
292 | stride,
293 | downsample,
294 | base_width=self.base_width,
295 | )
296 | )
297 | self.inplanes = planes * block.expansion
298 | for i in range(1, blocks - 1):
299 | layers.append(block(self.inplanes, planes, base_width=self.base_width))
300 |
301 | if self.inplanes != embedding_dim:
302 | dconv = nn.Conv2d(
303 | self.inplanes, embedding_dim, stride=1, kernel_size=1, bias=False
304 | )
305 | dbn = nn.BatchNorm2d(embedding_dim)
306 | if dbn is not None:
307 | last_downsample = nn.Sequential(dconv, dbn)
308 | else:
309 | last_downsample = dconv
310 | layers.append(
311 | block(
312 | self.inplanes,
313 | planes,
314 | downsample=last_downsample,
315 | base_width=self.base_width,
316 | nonlin=nonlin,
317 | embedding_dim=embedding_dim,
318 | )
319 | )
320 |
321 | return nn.Sequential(*layers)
322 |
323 | def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
324 | """Apply forward pass.
325 |
326 | :param x: input to the model with shape (N, C, H, W).
327 | :return: Tuple of (logits, embedding)
328 | """
329 | x = self.conv1(x)
330 |
331 | if self.bn1 is not None:
332 | x = self.bn1(x)
333 | x = self.relu(x)
334 | x = self.maxpool(x)
335 |
336 | x = self.layer1(x)
337 | x = self.layer2(x)
338 | x = self.layer3(x)
339 | x = self.layer4(x)
340 |
341 | feature = self.avgpool(x)
342 | if self.is_normalized:
343 | feature = F.normalize(feature)
344 |
345 | x = self.fc(feature)
346 | x = x.view(x.size(0), -1)
347 |
348 | return x, feature
349 |
350 |
351 | def ResNet18(
352 | num_classes: int, embedding_dim: int, last_nonlin: bool = True, **kwargs
353 | ) -> nn.Module:
354 | """Get a ResNet18 model.
355 |
356 | :param num_classes: Number of classes in the dataset.
357 | :param embedding_dim: Size of the output embedding dimension.
358 | :param last_nonlin: Whether to apply non-linearity before output.
359 | :return: ResNet18 Model.
360 | """
361 | return ResNet(
362 | BasicBlock,
363 | [2, 2, 2, 2],
364 | num_classes=num_classes,
365 | embedding_dim=embedding_dim,
366 | last_nonlin=last_nonlin,
367 | )
368 |
369 |
370 | def ResNet50(
371 | num_classes: int, embedding_dim: int, last_nonlin: bool = True, **kwargs
372 | ) -> nn.Module:
373 | """Get a ResNet50 model.
374 |
375 | :param num_classes: Number of classes in the dataset.
376 | :param embedding_dim: Size of the output embedding dimension.
377 | :param last_nonlin: Whether to apply non-linearity before output.
378 | :return: ResNet18 Model.
379 | """
380 | return ResNet(
381 | Bottleneck,
382 | [3, 4, 6, 3],
383 | num_classes=num_classes,
384 | embedding_dim=embedding_dim,
385 | last_nonlin=last_nonlin,
386 | )
387 |
388 |
389 | def ResNet101(
390 | num_classes: int, embedding_dim: int, last_nonlin: bool = True, **kwargs
391 | ) -> nn.Module:
392 | """Get a ResNet101 model.
393 |
394 | :param num_classes: Number of classes in the dataset.
395 | :param embedding_dim: Size of the output embedding dimension.
396 | :param last_nonlin: Whether to apply non-linearity before output.
397 | :return: ResNet18 Model.
398 | """
399 | return ResNet(
400 | Bottleneck,
401 | [3, 4, 23, 3],
402 | num_classes=num_classes,
403 | eembedding_dim=embedding_dim,
404 | last_nonlin=last_nonlin,
405 | )
406 |
407 |
408 | def WideResNet50_2(
409 | num_classes: int, embedding_dim: int, last_nonlin: bool = True, **kwargs
410 | ) -> nn.Module:
411 | """Get a WideResNet50 model.
412 |
413 | :param num_classes: Number of classes in the dataset.
414 | :param embedding_dim: Size of the output embedding dimension.
415 | :param last_nonlin: Whether to apply non-linearity before output.
416 | :return: ResNet18 Model.
417 | """
418 | return ResNet(
419 | Bottleneck,
420 | [3, 4, 6, 3],
421 | num_classes=num_classes,
422 | base_width=64 * 2,
423 | embedding_dim=embedding_dim,
424 | last_nonlin=last_nonlin,
425 | )
426 |
427 |
428 | def WideResNet101_2(
429 | num_classes: int, embedding_dim: int, last_nonlin: bool = True, **kwargs
430 | ) -> nn.Module:
431 | """Get a WideResNet101 model.
432 |
433 | :param num_classes: Number of classes in the dataset.
434 | :param embedding_dim: Size of the output embedding dimension.
435 | :param last_nonlin: Whether to apply non-linearity before output.
436 | :return: ResNet18 Model.
437 | """
438 | return ResNet(
439 | Bottleneck,
440 | [3, 4, 23, 3],
441 | num_classes=num_classes,
442 | base_width=64 * 2,
443 | embedding_dim=embedding_dim,
444 | last_nonlin=last_nonlin,
445 | )
446 |
--------------------------------------------------------------------------------
/models/transformations.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023 Apple Inc. All rights reserved.
3 | #
4 |
5 | from typing import Optional
6 |
7 | import torch
8 | import torch.nn as nn
9 |
10 |
11 | class ConvBlock(nn.Module):
12 | """Convenience convolution module."""
13 |
14 | def __init__(
15 | self,
16 | channels_in: int,
17 | channels_out: int,
18 | kernel_size: int = 1,
19 | stride: int = 1,
20 | normalizer: Optional[nn.Module] = nn.BatchNorm2d,
21 | activation: Optional[nn.Module] = nn.ReLU,
22 | ) -> None:
23 | """Construct a ConvBlock module.
24 |
25 | :param channels_in: Number of input channels.
26 | :param channels_out: Number of output channels.
27 | :param kernel_size: Size of the kernel.
28 | :param stride: Size of the convolution stride.
29 | :param normalizer: Optional normalization to use.
30 | :param activation: Optional activation module to use.
31 | """
32 | super().__init__()
33 |
34 | self.conv = nn.Conv2d(
35 | channels_in,
36 | channels_out,
37 | kernel_size=kernel_size,
38 | stride=stride,
39 | bias=normalizer is None,
40 | padding=kernel_size // 2,
41 | )
42 | if normalizer is not None:
43 | self.normalizer = normalizer(channels_out)
44 | else:
45 | self.normalizer = None
46 | if activation is not None:
47 | self.activation = activation()
48 | else:
49 | self.activation = None
50 |
51 | def forward(self, x: torch.Tensor) -> torch.Tensor:
52 | """Apply forward pass."""
53 | x = self.conv(x)
54 | if self.normalizer is not None:
55 | x = self.normalizer(x)
56 | if self.activation is not None:
57 | x = self.activation(x)
58 | return x
59 |
60 |
61 | class MLP_BN_SIDE_PROJECTION(nn.Module):
62 | """FCT transformation module."""
63 |
64 | def __init__(
65 | self,
66 | old_embedding_dim: int,
67 | new_embedding_dim: int,
68 | side_info_dim: int,
69 | inner_dim: int = 2048,
70 | **kwargs
71 | ) -> None:
72 | """Construct MLP_BN_SIDE_PROJECTION module.
73 |
74 | :param old_embedding_dim: Size of the old embeddings.
75 | :param new_embedding_dim: Size of the new embeddings.
76 | :param side_info_dim: Size of the side-information.
77 | :param inner_dim: Dimension of transformation MLP inner layer.
78 | """
79 | super().__init__()
80 |
81 | self.inner_dim = inner_dim
82 | self.p1 = nn.Sequential(
83 | ConvBlock(old_embedding_dim, 2 * old_embedding_dim),
84 | ConvBlock(2 * old_embedding_dim, 2 * new_embedding_dim),
85 | )
86 |
87 | self.p2 = nn.Sequential(
88 | ConvBlock(side_info_dim, 2 * side_info_dim),
89 | ConvBlock(2 * side_info_dim, 2 * new_embedding_dim),
90 | )
91 |
92 | self.mixer = nn.Sequential(
93 | ConvBlock(4 * new_embedding_dim, self.inner_dim),
94 | ConvBlock(self.inner_dim, self.inner_dim),
95 | ConvBlock(
96 | self.inner_dim, new_embedding_dim, normalizer=None, activation=None
97 | ),
98 | )
99 |
100 | def forward(
101 | self, old_feature: torch.Tensor, side_info: torch.Tensor
102 | ) -> torch.Tensor:
103 | """Apply forward pass.
104 |
105 | :param old_feature: Old embedding.
106 | :param side_info: Side-information.
107 | :return: Recycled old embedding compatible with new embeddings.
108 | """
109 | x1 = self.p1(old_feature)
110 | x2 = self.p2(side_info)
111 | return self.mixer(torch.cat([x1, x2], dim=1))
112 |
113 |
114 | class MLP_BN_SIDE_PROJECTION_SIGMA(nn.Module):
115 | """FastFill transformation module."""
116 |
117 | def __init__(
118 | self,
119 | old_embedding_dim: int,
120 | new_embedding_dim: int,
121 | side_info_dim: int,
122 | inner_dim: int = 2048,
123 | sigma_dim: int = 1,
124 | **kwargs
125 | ) -> None:
126 | """Construct MLP_BN_SIDE_PROJECTION module.
127 |
128 | :param old_embedding_dim: Size of the old embeddings.
129 | :param new_embedding_dim: Size of the new embeddings.
130 | :param side_info_dim: Size of the side-information.
131 | :param inner_dim: Dimension of transformation MLP inner layer.
132 | :param sigma_dim: Size of the uncertainty output choices=[1, new_embedding_dim].
133 | """
134 | super().__init__()
135 |
136 | assert sigma_dim in [1, new_embedding_dim]
137 | self.uncertainty_head = nn.Linear(new_embedding_dim, sigma_dim)
138 | self.inner_dim = inner_dim
139 | self.p1 = nn.Sequential(
140 | ConvBlock(old_embedding_dim, 2 * old_embedding_dim),
141 | ConvBlock(2 * old_embedding_dim, 2 * new_embedding_dim),
142 | )
143 |
144 | self.p2 = nn.Sequential(
145 | ConvBlock(side_info_dim, 2 * side_info_dim),
146 | ConvBlock(2 * side_info_dim, 2 * new_embedding_dim),
147 | )
148 |
149 | self.mixer = nn.Sequential(
150 | ConvBlock(4 * new_embedding_dim, self.inner_dim),
151 | ConvBlock(self.inner_dim, self.inner_dim),
152 | ConvBlock(
153 | self.inner_dim, new_embedding_dim, normalizer=None, activation=None
154 | ),
155 | )
156 |
157 | def forward(
158 | self, old_feature: torch.Tensor, side_info: torch.Tensor
159 | ) -> torch.Tensor:
160 | """Apply forward pass.
161 |
162 | :param old_feature: Old embedding.
163 | :param side_info: Side-information.
164 | :return: Recycled old embedding compatible with new embeddings concatenated with sigma.
165 | """
166 | x1 = self.p1(old_feature)
167 | x2 = self.p2(side_info)
168 |
169 | transformed_features = self.mixer(torch.cat([x1, x2], dim=1))
170 | sigma = self.uncertainty_head(transformed_features.squeeze())
171 | transformed_features = torch.cat(
172 | (transformed_features, sigma.unsqueeze(dim=2).unsqueeze(dim=3)), dim=1
173 | )
174 |
175 | return transformed_features
176 |
--------------------------------------------------------------------------------
/prepare_dataset.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023 Apple Inc. All rights reserved.
3 | #
4 |
5 | import os
6 | import pickle
7 |
8 | from PIL import Image
9 | from tqdm import tqdm
10 |
11 | from torchvision.datasets import CIFAR100
12 |
13 |
14 | def save_pngs(untar_dir: str, split: str) -> None:
15 | """Save loaded data as png files.
16 |
17 | :param untar_dir: Path to untared dataset.
18 | :param split: Split name (e.g., train, test)
19 | """
20 | split_map = {"train": "training", "test": "validation"}
21 | split_dir = os.path.join(untar_dir, split_map.get(split))
22 |
23 | os.makedirs(split_dir, exist_ok=True)
24 |
25 | for i in range(100):
26 | class_dir = os.path.join(split_dir, str(i))
27 | os.makedirs(class_dir, exist_ok=True)
28 |
29 | with open(os.path.join(untar_dir, split), "rb") as f:
30 | data_dict = pickle.load(f, encoding="latin1")
31 |
32 | data = data_dict.get("data") # numpy array
33 | # Reshape and cast
34 | data = data.reshape(data.shape[0], 3, 32, 32)
35 | data = data.transpose(0, 2, 3, 1).astype("uint8")
36 |
37 | labels = data_dict.get("fine_labels")
38 |
39 | for i, (datum, label) in tqdm(enumerate(zip(data, labels)), total=len(labels)):
40 | image = Image.fromarray(datum)
41 | image = image.convert("RGB")
42 | file_path = os.path.join(split_dir, str(label), "{}.png".format(i))
43 | image.save(file_path)
44 |
45 |
46 | def get_cifar100() -> None:
47 | """Get and reformat cifar100 dataset.
48 |
49 | See https://www.cs.toronto.edu/~kriz/cifar.html for dataset description.
50 | """
51 | data_store_dir = "data_store"
52 |
53 | if not os.path.exists(data_store_dir):
54 | os.makedirs(data_store_dir)
55 |
56 | dataset = CIFAR100(root=data_store_dir, download=True)
57 |
58 | # Load files and convert to PNG
59 | untar_dir = os.path.join(data_store_dir, dataset.base_folder)
60 | save_pngs(untar_dir, "test")
61 | save_pngs(untar_dir, "train")
62 |
63 |
64 | if __name__ == "__main__":
65 | get_cifar100()
66 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy
2 | Pillow
3 | torch
4 | torchvision
5 | tqdm
6 | scikit-learn
7 | PyYAML
8 |
--------------------------------------------------------------------------------
/train_backbone.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023 Apple Inc. All rights reserved.
3 | #
4 |
5 | from typing import Dict
6 | from argparse import ArgumentParser
7 |
8 | import yaml
9 | from PIL import ImageFile
10 |
11 | ImageFile.LOAD_TRUNCATED_IMAGES = True
12 |
13 | import torch
14 | import torch.nn as nn
15 |
16 | from trainers import BackboneTrainer
17 | from dataset import SubImageFolder
18 | from utils.net_utils import LabelSmoothing, backbone_to_torchscript
19 | from utils.schedulers import get_policy
20 | from utils.getters import get_model, get_optimizer
21 |
22 |
23 | def main(config: Dict) -> None:
24 | """Run training.
25 |
26 | :param config: A dictionary with all configurations to run training.
27 | :return:
28 | """
29 | model = get_model(config.get("arch_params"))
30 |
31 | device = torch.device("cuda:0") if torch.cuda.is_available() else "cpu"
32 | torch.backends.cudnn.benchmark = True
33 |
34 | if torch.cuda.is_available():
35 | model = torch.nn.DataParallel(model)
36 | model.to(device)
37 |
38 | trainer = BackboneTrainer()
39 | optimizer = get_optimizer(model, **config.get("optimizer_params"))
40 | data = SubImageFolder(**config.get("dataset_params"))
41 | lr_policy = get_policy(optimizer, **config.get("lr_policy_params"))
42 |
43 | if config.get("label_smoothing") is None:
44 | criterion = nn.CrossEntropyLoss()
45 | else:
46 | criterion = LabelSmoothing(smoothing=config.get("label_smoothing"))
47 |
48 | # Training loop
49 | for epoch in range(config.get("epochs")):
50 | lr_policy(epoch, iteration=None)
51 |
52 | train_acc1, train_acc5, train_loss = trainer.train(
53 | train_loader=data.train_loader,
54 | model=model,
55 | criterion=criterion,
56 | optimizer=optimizer,
57 | device=device,
58 | )
59 |
60 | print(
61 | "Train: epoch = {}, Loss = {}, Top 1 = {}, Top 5 = {}".format(
62 | epoch, train_loss, train_acc1, train_acc5
63 | )
64 | )
65 |
66 | test_acc1, test_acc5, test_loss = trainer.validate(
67 | val_loader=data.val_loader,
68 | model=model,
69 | criterion=criterion,
70 | device=device,
71 | )
72 |
73 | print(
74 | "Test: epoch = {}, Loss = {}, Top 1 = {}, Top 5 = {}".format(
75 | epoch, test_loss, test_acc1, test_acc5
76 | )
77 | )
78 |
79 | backbone_to_torchscript(model, config.get("output_model_path"))
80 |
81 |
82 | if __name__ == "__main__":
83 | parser = ArgumentParser()
84 | parser.add_argument(
85 | "--config",
86 | type=str,
87 | required=True,
88 | help="Path to config file for this pipeline.",
89 | )
90 | args = parser.parse_args()
91 | with open(args.config) as f:
92 | read_config = yaml.safe_load(f)
93 | main(read_config)
94 |
--------------------------------------------------------------------------------
/train_transformation.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023 Apple Inc. All rights reserved.
3 | #
4 |
5 | from typing import Dict
6 |
7 | import yaml
8 | from argparse import ArgumentParser
9 |
10 | import torch
11 |
12 | from PIL import ImageFile
13 |
14 | ImageFile.LOAD_TRUNCATED_IMAGES = True
15 |
16 | from trainers import TransformationTrainer
17 | from dataset import SubImageFolder
18 | from utils.net_utils import transformation_to_torchscripts
19 | from utils.schedulers import get_policy
20 | from utils.getters import get_model, get_optimizer, get_criteria
21 |
22 |
23 | def main(config: Dict) -> None:
24 | """Run training.
25 |
26 | :param config: A dictionary with all configurations to run training.
27 | :return:
28 | """
29 | device = torch.device("cuda:0") if torch.cuda.is_available() else "cpu"
30 | torch.backends.cudnn.benchmark = True
31 |
32 | model = get_model(config.get("arch_params"))
33 | old_model = torch.jit.load(config.get("old_model_path"))
34 | new_model = torch.jit.load(config.get("new_model_path"))
35 |
36 | if torch.cuda.is_available():
37 | model = torch.nn.DataParallel(model)
38 | old_model = torch.nn.DataParallel(old_model)
39 | new_model = torch.nn.DataParallel(new_model)
40 |
41 | model.to(device)
42 | old_model.to(device)
43 | new_model.to(device)
44 |
45 | if config.get("side_info_model_path") is not None:
46 | side_info_model = torch.jit.load(config.get("side_info_model_path"))
47 | if torch.cuda.is_available():
48 | side_info_model = torch.nn.DataParallel(side_info_model)
49 | side_info_model.to(device)
50 | else:
51 | side_info_model = old_model
52 |
53 | optimizer = get_optimizer(model, **config.get("optimizer_params"))
54 | data = SubImageFolder(**config.get("dataset_params"))
55 | lr_policy = get_policy(optimizer, **config.get("lr_policy_params"))
56 | mus, criteria = get_criteria(**config.get("objective_params", {}))
57 | trainer = TransformationTrainer(
58 | old_model, new_model, side_info_model, **mus, **criteria
59 | )
60 |
61 | for epoch in range(config.get("epochs")):
62 | lr_policy(epoch, iteration=None)
63 |
64 | if config.get("switch_mode_to_eval"):
65 | switch_mode_to_eval = epoch >= config.get("epochs") / 2
66 | else:
67 | switch_mode_to_eval = False
68 |
69 | train_loss = trainer.train(
70 | train_loader=data.train_loader,
71 | model=model,
72 | optimizer=optimizer,
73 | device=device,
74 | switch_mode_to_eval=switch_mode_to_eval,
75 | )
76 |
77 | print("Train: epoch = {}, Average Loss = {}".format(epoch, train_loss))
78 |
79 | # evaluate on validation set
80 | test_loss = trainer.validate(
81 | val_loader=data.val_loader,
82 | model=model,
83 | device=device,
84 | )
85 |
86 | print("Test: epoch = {}, Average Loss = {}".format(epoch, test_loss))
87 |
88 | transformation_to_torchscripts(
89 | old_model,
90 | side_info_model,
91 | model,
92 | config.get("output_transformation_path"),
93 | config.get("output_transformed_old_model_path"),
94 | )
95 |
96 |
97 | if __name__ == "__main__":
98 | parser = ArgumentParser()
99 | parser.add_argument(
100 | "--config",
101 | type=str,
102 | required=True,
103 | help="Path to config file for this pipeline.",
104 | )
105 | args = parser.parse_args()
106 | with open(args.config) as f:
107 | read_config = yaml.safe_load(f)
108 | main(read_config)
109 |
--------------------------------------------------------------------------------
/trainers/__init__.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023 Apple Inc. All rights reserved.
3 | #
4 |
5 | from .backbone_trainer import BackboneTrainer
6 | from .transformation_trainer import TransformationTrainer
7 |
--------------------------------------------------------------------------------
/trainers/backbone_trainer.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023 Apple Inc. All rights reserved.
3 | #
4 |
5 | from typing import Tuple, Callable
6 |
7 | import tqdm
8 | import torch
9 | import torch.nn as nn
10 |
11 | from utils.logging_utils import AverageMeter
12 | from utils.eval_utils import accuracy
13 |
14 |
15 | class BackboneTrainer:
16 | """Class to train and evaluate backbones."""
17 |
18 | def train(
19 | self,
20 | train_loader: torch.utils.data.DataLoader,
21 | model: nn.Module,
22 | criterion: Callable,
23 | optimizer: torch.optim.Optimizer,
24 | device: torch.device,
25 | ) -> Tuple[float, float, float]:
26 | """Run one epoch of training.
27 |
28 | :param train_loader: Data loader to train the model.
29 | :param model: Model to be trained.
30 | :param criterion: Loss criterion module.
31 | :param optimizer: A torch optimizer object.
32 | :param device: Device the model is on.
33 | :return: average of top-1, top-5, and loss on current epoch.
34 | """
35 | losses = AverageMeter("Loss", ":.3f")
36 | top1 = AverageMeter("Acc@1", ":6.2f")
37 | top5 = AverageMeter("Acc@5", ":6.2f")
38 |
39 | model.train()
40 |
41 | for i, (images, target) in tqdm.tqdm(
42 | enumerate(train_loader), ascii=True, total=len(train_loader)
43 | ):
44 | images = images.to(device, non_blocking=True)
45 | target = target.to(device, non_blocking=True)
46 |
47 | output, _ = model(images)
48 | loss = criterion(output, target)
49 |
50 | acc1, acc5 = accuracy(output, target, topk=(1, 5))
51 | losses.update(loss.item(), images.size(0))
52 | top1.update(acc1.item(), images.size(0))
53 | top5.update(acc5.item(), images.size(0))
54 |
55 | optimizer.zero_grad()
56 | loss.backward()
57 | optimizer.step()
58 |
59 | return top1.avg, top5.avg, losses.avg
60 |
61 | def validate(
62 | self,
63 | val_loader: torch.utils.data.DataLoader,
64 | model: nn.Module,
65 | criterion: Callable,
66 | device: torch.device,
67 | ) -> Tuple[float, float, float]:
68 | """Run validation.
69 |
70 | :param val_loader: Data loader to evaluate the model.
71 | :param model: Model to be evaluated.
72 | :param criterion: Loss criterion module.
73 | :param device: Device the model is on.
74 | :return: average of top-1, top-5, and loss on current epoch.
75 | """
76 | losses = AverageMeter("Loss", ":.3f")
77 | top1 = AverageMeter("Acc@1", ":6.2f")
78 | top5 = AverageMeter("Acc@5", ":6.2f")
79 |
80 | model.eval()
81 |
82 | with torch.no_grad():
83 | for i, (images, target) in tqdm.tqdm(
84 | enumerate(val_loader), ascii=True, total=len(val_loader)
85 | ):
86 | images = images.to(device, non_blocking=True)
87 | target = target.to(device, non_blocking=True)
88 |
89 | output, _ = model(images)
90 |
91 | loss = criterion(output, target)
92 |
93 | acc1, acc5 = accuracy(output, target, topk=(1, 5))
94 | losses.update(loss.item(), images.size(0))
95 | top1.update(acc1.item(), images.size(0))
96 | top5.update(acc5.item(), images.size(0))
97 |
98 | return top1.avg, top5.avg, losses.avg
99 |
--------------------------------------------------------------------------------
/trainers/transformation_trainer.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023 Apple Inc. All rights reserved.
3 | #
4 |
5 | from typing import Optional, Union
6 |
7 | import tqdm
8 | import torch
9 | import torch.nn as nn
10 |
11 | from utils.logging_utils import AverageMeter
12 |
13 |
14 | class TransformationTrainer:
15 | """Class to train and evaluate transformation models."""
16 |
17 | def __init__(
18 | self,
19 | old_model: Union[nn.Module, torch.jit.ScriptModule],
20 | new_model: Union[nn.Module, torch.jit.ScriptModule],
21 | side_info_model: Union[nn.Module, torch.jit.ScriptModule],
22 | mu_similarity: float = 1,
23 | mu_disc: Optional[float] = None,
24 | criterion_similarity: Optional[nn.Module] = None,
25 | criterion_disc: Optional[Union[torch.jit.ScriptModule, nn.Module]] = None,
26 | criterion_uncertainty: Optional[nn.Module] = None,
27 | **kwargs
28 | ) -> None:
29 | """Construct a TransformationTrainer module.
30 |
31 | :param old_model: A model that returns old embedding given x.
32 | :param new_model: A model that returns new embedding given x.
33 | :param side_info_model: A model that returns side-info given x.
34 | :param mu_similarity: hyperparameter for similarity loss
35 | :param mu_disc: hyperparameter for classification loss
36 | :param criterion_similarity: objective function computing the similarity between new and h(old) features.
37 | :param criterion_disc: objective function for the classification head for h(old).
38 | :param criterion_uncertainty: Uncertainty based Loss function.
39 | """
40 |
41 | self.old_model = old_model
42 | self.old_model.eval()
43 | self.new_model = new_model
44 | self.new_model.eval()
45 | self.side_info_model = side_info_model
46 | self.side_info_model.eval()
47 |
48 | self.mu_similarity = mu_similarity
49 | self.mu_disc = mu_disc
50 |
51 | self.criterion_similarity = criterion_similarity
52 | self.criterion_disc = criterion_disc
53 | self.criterion_uncertainty = criterion_uncertainty
54 |
55 | def compute_loss(
56 | self,
57 | new_feature: torch.Tensor,
58 | recycled_feature: torch.Tensor,
59 | target: torch.Tensor,
60 | sigma: torch.Tensor,
61 | ) -> torch.Tensor:
62 | """Compute total loss for a batch.
63 |
64 | :param new_feature: Tensor of features computed by new model.
65 | :param recycled_feature: Tensor of features computed by transformation model.
66 | :param target: Labels tensor.
67 | :param sigma: Tensor of sigmas computed by transformation.
68 | :return: Total loss tensor.
69 | """
70 | loss = 0
71 | if self.criterion_similarity is not None:
72 | similarity_loss = self.criterion_similarity(
73 | new_feature.squeeze(), recycled_feature.squeeze()
74 | )
75 | if len(similarity_loss.shape) > 1:
76 | similarity_loss = similarity_loss.mean(dim=1)
77 | loss += self.mu_similarity * similarity_loss
78 |
79 | if self.criterion_disc is not None:
80 | if isinstance(self.new_model, torch.nn.DataParallel):
81 | fc_layer = self.new_model.module.model.fc
82 | else:
83 | fc_layer = self.new_model.model.fc
84 | logits = fc_layer(recycled_feature[:, : new_feature.size()[1]])
85 | loss_disc = self.criterion_disc(logits.squeeze(), target)
86 | loss += self.mu_disc * loss_disc
87 |
88 | if self.criterion_uncertainty:
89 | loss = self.criterion_uncertainty(sigma, loss)
90 | return loss
91 |
92 | def train(
93 | self,
94 | train_loader: torch.utils.data.DataLoader,
95 | model: nn.Module,
96 | optimizer: torch.optim.Optimizer,
97 | device: torch.device,
98 | switch_mode_to_eval: bool,
99 | ) -> float:
100 | """Run one epoch of training.
101 |
102 | :param train_loader: Data loader to train the model.
103 | :param model: Model to be trained.
104 | :param optimizer: A torch optimizer object.
105 | :param device: Device the model is on.
106 | :param switch_mode_to_eval: If true model is train on eval mode!
107 | :return: Average loss on current epoch.
108 | """
109 | losses = AverageMeter("Loss", ":.3f")
110 |
111 | if switch_mode_to_eval:
112 | model.eval()
113 | else:
114 | model.train()
115 |
116 | for i, (images, target) in tqdm.tqdm(
117 | enumerate(train_loader), ascii=True, total=len(train_loader)
118 | ):
119 | images = images.to(device, non_blocking=True)
120 | target = target.to(device, non_blocking=True) # only needed by L_disc
121 |
122 | with torch.no_grad():
123 | old_feature = self.old_model(images)
124 | new_feature = self.new_model(images)
125 | side_info = self.side_info_model(images)
126 |
127 | recycled_feature = model(old_feature, side_info)
128 | sigma = recycled_feature[:, new_feature.size()[1] :]
129 | recycled_feature = recycled_feature[:, : new_feature.size()[1]]
130 |
131 | loss = self.compute_loss(new_feature, recycled_feature, target, sigma)
132 | losses.update(loss.item(), images.size(0))
133 |
134 | optimizer.zero_grad()
135 | loss.backward()
136 | optimizer.step()
137 |
138 | return losses.avg
139 |
140 | def validate(
141 | self,
142 | val_loader: torch.utils.data.DataLoader,
143 | model: nn.Module,
144 | device: torch.device,
145 | ) -> float:
146 | """Run validation.
147 |
148 | :param val_loader: Data loader to evaluate the model.
149 | :param model: Model to be evaluated.
150 | :param device: Device the model is on.
151 | :return: Average loss on current epoch.
152 | """
153 | losses = AverageMeter("Loss", ":.3f")
154 | model.eval()
155 |
156 | for i, (images, target) in tqdm.tqdm(
157 | enumerate(val_loader), ascii=True, total=len(val_loader)
158 | ):
159 | images = images.to(device, non_blocking=True)
160 | target = target.to(device, non_blocking=True) # only needed by L_disc
161 |
162 | with torch.no_grad():
163 | old_feature = self.old_model(images)
164 | new_feature = self.new_model(images)
165 | side_info = self.side_info_model(images)
166 |
167 | recycled_feature = model(old_feature, side_info)
168 | sigma = recycled_feature[:, new_feature.size()[1] :]
169 | recycled_feature = recycled_feature[:, : new_feature.size()[1]]
170 |
171 | loss = self.compute_loss(new_feature, recycled_feature, target, sigma)
172 | losses.update(loss.item(), images.size(0))
173 |
174 | return losses.avg
175 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023 Apple Inc. All rights reserved.
3 | #
4 |
--------------------------------------------------------------------------------
/utils/eval_utils.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023 Apple Inc. All rights reserved.
3 | #
4 |
5 | from typing import Union, Tuple, List, Optional, Callable
6 | import copy
7 |
8 | from sklearn.metrics import average_precision_score
9 | import numpy as np
10 | import torch
11 | import torch.nn as nn
12 | import torch.nn.functional as F
13 | import tqdm
14 |
15 |
16 | def accuracy(output, target, topk=(1,)):
17 | """Compute the accuracy over the k top predictions.
18 |
19 | From https://github.com/YantaoShen/openBCT/blob/main/main.py
20 | """
21 | with torch.no_grad():
22 | maxk = max(topk)
23 | batch_size = target.shape[0]
24 |
25 | _, pred = output.topk(maxk, 1, True, True)
26 | pred = pred.t()
27 | correct = pred.eq(target.view(1, -1).expand_as(pred))
28 |
29 | res = []
30 | for k in topk:
31 | correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
32 | res.append(correct_k.mul_(100.0 / batch_size))
33 | return res
34 |
35 |
36 | def mean_ap(
37 | distance_matrix: torch.Tensor,
38 | labels: torch.Tensor,
39 | ) -> float:
40 | """Get pair-wise cosine distances.
41 |
42 | :param distance_matrix: pairwise distance matrix between embeddings of gallery and query sets, shape = (n, n)
43 | :param labels: labels for the query data (assuming the same as gallery), shape = (n,)
44 |
45 | :return: mean average precision (float)
46 | """
47 | distance_matrix = distance_matrix
48 | m, n = distance_matrix.shape
49 | assert m == n
50 |
51 | # Sort and find correct matches
52 | distance_matrix, gallery_matched_indices = torch.sort(distance_matrix, dim=1)
53 | distance_matrix = distance_matrix.cpu().numpy()
54 | gallery_matched_indices = gallery_matched_indices.cpu().numpy()
55 |
56 | truth_mask = labels[gallery_matched_indices] == labels[:, None]
57 | truth_mask = truth_mask.cpu().numpy()
58 |
59 | # Compute average precision for each query
60 | average_precisions = list()
61 | for query_index in range(n):
62 |
63 | valid_sorted_match_indices = (
64 | gallery_matched_indices[query_index, :] != query_index
65 | )
66 | y_true = truth_mask[query_index, valid_sorted_match_indices]
67 | y_score = -distance_matrix[query_index][valid_sorted_match_indices]
68 | if not np.any(y_true):
69 | continue # if a query does not have any match, we exclude it from mAP calculation.
70 | average_precisions.append(average_precision_score(y_true, y_score))
71 | return np.mean(average_precisions)
72 |
73 |
74 | def cosine_distance_matrix(
75 | x: torch.Tensor, y: torch.Tensor, diag_only: bool = False
76 | ) -> torch.Tensor:
77 | """Get pair-wise cosine distances.
78 |
79 | :param x: A torch feature tensor with shape (n, d).
80 | :param y: A torch feature tensor with shape (n, d).
81 | :param diag_only: if True, only diagonal of distance matrix is computed and returned.
82 | :return: Distance tensor between features x and y with shape (n, n) if diag_only is False. Otherwise, elementwise
83 | distance tensor with shape (n,).
84 | """
85 | x_norm = F.normalize(x, p=1, dim=-1)
86 | y_norm = F.normalize(y, p=1, dim=-1)
87 | if diag_only:
88 | return 1.0 - torch.sum(x_norm * y_norm, dim=1)
89 | return 1.0 - x_norm @ y_norm.T
90 |
91 |
92 | def l2_distance_matrix(
93 | x: torch.Tensor, y: torch.Tensor, diag_only: bool = False
94 | ) -> torch.Tensor:
95 | """Get pair-wise l2 distances.
96 |
97 | :param x: A torch feature tensor with shape (n, d).
98 | :param y: A torch feature tensor with shape (n, d).
99 | :param diag_only: if True, only diagonal of distance matrix is computed and returned.
100 | :return: Distance tensor between features x and y with shape (n, n) if diag_only is False. Otherwise, elementwise
101 | distance tensor with shape (n,).
102 | """
103 | if diag_only:
104 | return torch.norm(x - y, dim=1, p=2)
105 | return torch.cdist(x, y, p=2)
106 |
107 |
108 | def cmc_optimized(
109 | distmat: torch.Tensor,
110 | query_ids: Optional[torch.Tensor] = None,
111 | topk: int = 5,
112 | ) -> Tuple[float, float]:
113 | """Compute Cumulative Matching Characteristics metric.
114 |
115 | :param distmat: pairwise distance matrix between embeddings of gallery and query sets
116 | :param query_ids: labels for the query data. We're assuming query_ids and gallery_ids are the same.
117 | :param topk: parameter for top k retrieval
118 | :return: CMC top-1 and top-5 floats, as well as per-query top-1 and top-5 values.
119 | """
120 | distmat = copy.deepcopy(distmat)
121 | query_ids = copy.deepcopy(query_ids)
122 |
123 | distmat.fill_diagonal_(float("inf"))
124 | distmat_new_old_sorted, indices = torch.sort(distmat)
125 | labels = query_ids.unsqueeze(dim=0).repeat(query_ids.shape[0], 1)
126 | sorted_labels = torch.gather(labels, 1, indices)
127 |
128 | top1_retrieval = sorted_labels[:, 0] == query_ids
129 | top5_retrieval = (
130 | (sorted_labels[:, :topk] == query_ids.unsqueeze(1)).sum(dim=1).clamp(max=1)
131 | )
132 |
133 | top1 = top1_retrieval.sum() / query_ids.shape[0]
134 | top5 = top5_retrieval.sum() / query_ids.shape[0]
135 |
136 | return float(top1), float(top5)
137 |
138 |
139 | def generate_feature_matrix(
140 | gallery_model: Union[nn.Module, torch.jit.ScriptModule],
141 | query_model: Union[nn.Module, torch.jit.ScriptModule],
142 | val_loader: torch.utils.data.DataLoader,
143 | device: torch.device,
144 | verbose: bool = False,
145 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
146 | """Generate Feature Matrix
147 | :param gallery_model: Model to compute gallery features.
148 | :param query_model: Model to compute query features.
149 | :param val_loader: Data loader to get gallery/query data.
150 | :param device: Device to use for computations.
151 | :param verbose: Whether to be verbose.
152 | :return: Three tensors gallery_features (n, d), query_features (n, d), labels (n,), where n is size of val dataset,
153 | and d is the embedding dimension.
154 | """
155 |
156 | gallery_model.eval()
157 | query_model.eval()
158 |
159 | gallery_model.to(device)
160 | query_model.to(device)
161 |
162 | gallery_features = []
163 | query_features = []
164 | labels = []
165 |
166 | iterator = tqdm.tqdm(val_loader) if verbose else val_loader
167 |
168 | with torch.no_grad():
169 | for data, label in iterator:
170 | data = data.to(device)
171 | label = label.to(device)
172 | gallery_feature = gallery_model(data)
173 | query_feature = query_model(data)
174 |
175 | gallery_features.append(gallery_feature.squeeze())
176 | query_features.append(query_feature.squeeze())
177 |
178 | labels.append(label)
179 |
180 | gallery_features = torch.cat(gallery_features)
181 | query_features = torch.cat(query_features)
182 | labels = torch.cat(labels)
183 |
184 | return gallery_features, query_features, labels
185 |
186 |
187 | def get_backfilling_orders(
188 | backfilling_list: List[str],
189 | query_features: torch.Tensor,
190 | gallery_features: torch.Tensor,
191 | distance_metric: Callable,
192 | sigma: Optional[torch.Tensor] = None,
193 | ) -> List[Tuple[torch.Tensor, str]]:
194 | """Compute backfilling ordering.
195 |
196 | :param backfilling_list: list of desired backfilling orders from ["random", "distance", "sigma']
197 | :param query_features: Tensor of query features with shape (n, d), where n is dataset size and d is embedding dim.
198 | :param gallery_features: Tensor of gallery features with shape (n, d).
199 | :param distance_metric: Callable to compute distance between features.
200 | :param sigma: Tensor of computed sigmas with shape (n,).
201 |
202 | :return: List of (ordering, ordering_name) tuples. ordering is a permutation of [0, 1, ..., n-1] determining the
203 | backfilling ordering, and ordering_name is the name of ordering. For example, if ordering=[3, 0, 2, 1] it means
204 | first backfill gallery data at index 3, followed by elements at indices 0, 2, and 1, respectively.
205 | """
206 | orderings_list = []
207 | n = query_features.shape[0]
208 | for ordering_name in backfilling_list:
209 | if ordering_name.lower() == "random":
210 | ordering = torch.randperm(n)
211 | elif ordering_name.lower() == "distance":
212 | distances = distance_metric(
213 | query_features.cpu(), gallery_features.cpu(), diag_only=True
214 | )
215 | ordering = torch.argsort(distances, descending=True)
216 | elif ordering_name.lower() == "sigma":
217 | assert sigma.numel() == n
218 | ordering = torch.argsort(sigma, dim=0, descending=True)
219 | else:
220 | print(f"{ordering_name} is not implemented for backfilling")
221 | continue
222 |
223 | # Sanity checks:
224 | assert torch.unique(ordering).shape == ordering.shape
225 | assert ordering.min() == 0
226 | assert ordering.max() == n - 1
227 | orderings_list.append((ordering, ordering_name))
228 |
229 | return orderings_list
230 |
231 |
232 | def cmc_evaluate(
233 | gallery_model: Union[nn.Module, torch.jit.ScriptModule],
234 | query_model: Union[nn.Module, torch.jit.ScriptModule],
235 | val_loader: torch.utils.data.DataLoader,
236 | device: torch.device,
237 | distance_metric_name: str,
238 | verbose: bool = False,
239 | compute_map: bool = False,
240 | backfilling: Optional[int] = None,
241 | backfilling_list: List[str] = ["random"],
242 | backfilling_result_path: Optional[str] = None,
243 | **kwargs,
244 | ) -> None:
245 | """Run CMC and mAP evaluations.
246 |
247 | :param gallery_model: Model to compute gallery features.
248 | :param query_model: Model to compute query features.
249 | :param val_loader: Data loader to get gallery/query data.
250 | :param device: Device to use for computations.
251 | :param distance_metric_name: Name of distance metric to use. Choose from ['l2', 'cosine'].
252 | :param verbose: Whether to be verbose.
253 | :param compute_map: Whether to compute mean average precision.
254 | :param backfilling: Number of intermediate backfilling steps. None means 0 intermediate backfilling. In this case
255 | only results for 0% backfilling will be computed.
256 | :param backfilling_list: list of desired backfilling orders from ["random", "distance", "sigma']. Default is
257 | ['random'].
258 | :param backfilling_result_path: path to save backfilling results.
259 | """
260 | distance_map = {"l2": l2_distance_matrix, "cosine": cosine_distance_matrix}
261 | distance_metric = distance_map.get(distance_metric_name.lower())
262 |
263 | print("Generating Feature Matrix")
264 | gallery_features, query_features, labels = generate_feature_matrix(
265 | gallery_model, query_model, val_loader, device, verbose
266 | )
267 |
268 | # (Possibly) Split gallery features and sigmas
269 | embedding_dim = query_features.shape[1]
270 | sigma = gallery_features[:, embedding_dim:].squeeze() # (n,)
271 | sigma = None if sigma.numel() == 0 else sigma
272 | gallery_features = gallery_features[:, :embedding_dim] # (n, d)
273 |
274 | n = query_features.shape[0] # dataset size
275 | backfilling = backfilling if backfilling is not None else -1
276 |
277 | orderings_list = get_backfilling_orders(
278 | backfilling_list=backfilling_list,
279 | query_features=query_features,
280 | gallery_features=gallery_features,
281 | distance_metric=distance_metric,
282 | sigma=sigma,
283 | )
284 |
285 | backfilling_results = {}
286 | for ordering, ordering_name in orderings_list:
287 | print(f"\nBackfilling evaluation with {ordering_name} ordering.")
288 | gallery_features_reordered = copy.deepcopy(gallery_features[ordering]).cpu()
289 | query_features_reordered = copy.deepcopy(query_features[ordering]).cpu()
290 | labels_reordered = copy.deepcopy(labels[ordering]).cpu()
291 |
292 | # Lists to store top1, top5, and ,mAP
293 | outputs = {"CMC-top1": [], "CMC-top5": [], "mAP": []}
294 |
295 | iterator = (
296 | tqdm.tqdm(range(backfilling + 2)) if verbose else range(backfilling + 2)
297 | )
298 |
299 | for i in iterator:
300 | if backfilling >= 0:
301 | cutoff_index = (i * n) // (backfilling + 1)
302 | else:
303 | cutoff_index = 0
304 | backfilling_mask = torch.zeros((n, 1), dtype=torch.bool)
305 | backfilling_mask[torch.arange(n) < cutoff_index] = True
306 | backfilled_gallery = torch.where(
307 | backfilling_mask, query_features_reordered, gallery_features_reordered
308 | )
309 |
310 | distmat = distance_metric(query_features_reordered, backfilled_gallery)
311 |
312 | top1, top5 = cmc_optimized(
313 | distmat=distmat,
314 | query_ids=labels_reordered,
315 | topk=5,
316 | )
317 |
318 | if compute_map:
319 | mean_ap_out = mean_ap(distance_matrix=distmat, labels=labels_reordered)
320 | else:
321 | mean_ap_out = None
322 |
323 | outputs["CMC-top1"].append(top1)
324 | outputs["CMC-top5"].append(top5)
325 | outputs["mAP"].append(mean_ap_out)
326 |
327 | backfilling_results[ordering_name] = outputs
328 |
329 | for metric_name, metric_data in outputs.items():
330 | print_partial_backfilling(metric_data, metric_name)
331 |
332 | if backfilling_result_path is not None:
333 | np.save(backfilling_result_path, backfilling_results)
334 |
335 |
336 | def print_partial_backfilling(data: List, metric_name: str) -> None:
337 | """Print partial backfilling results.
338 |
339 | :param data: list of floats for a given metric in [0,1] range.
340 | :param metric_name: name of the metric.
341 | """
342 | print(f"*** {metric_name} ***:")
343 | data_str_list = ["{:.2f}".format(100 * x) for x in data]
344 | print(" -> ".join(data_str_list))
345 | if len(data) > 1:
346 | print("AUC: {:.2f} %".format(100 * np.mean(data)))
347 | print("\n")
348 |
--------------------------------------------------------------------------------
/utils/getters.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023 Apple Inc. All rights reserved.
3 | #
4 |
5 | from typing import Dict, Optional, Tuple, Callable
6 |
7 | import torch
8 | import torch.nn as nn
9 |
10 | import models
11 | from utils.objectives import UncertaintyLoss, LabelSmoothing
12 |
13 |
14 | def get_model(arch_params: Dict, **kwargs) -> nn.Module:
15 | """Get a model given its configurations.
16 |
17 | :param arch_params: A dictionary containing all model parameters.
18 | :return: A torch model.
19 | """
20 | print("=> Creating model '{}'".format(arch_params.get("arch")))
21 | model = models.__dict__[arch_params.get("arch")](**arch_params)
22 | return model
23 |
24 |
25 | def get_optimizer(
26 | model: nn.Module,
27 | algorithm: str,
28 | lr: float,
29 | weight_decay: float,
30 | momentum: Optional[float] = None,
31 | no_bn_decay: bool = False,
32 | nesterov: bool = False,
33 | **kwargs
34 | ) -> torch.optim.Optimizer:
35 | """Get an optimizer given its configurations.
36 |
37 | :param model: A torch model (with parameters to be trained).
38 | :param algorithm: String defining what optimization algorithm to use.
39 | :param lr: Learning rate.
40 | :param weight_decay: Weight decay coefficient.
41 | :param momentum: Momentum value.
42 | :param no_bn_decay: Whether to avoid weight decay for Batch Norm params.
43 | :param nesterov: Whether to use Nesterov update.
44 | :return: A torch optimizer objet.
45 | """
46 | if algorithm == "sgd":
47 | parameters = list(model.named_parameters())
48 | bn_params = [v for n, v in parameters if ("bn" in n) and v.requires_grad]
49 | rest_params = [v for n, v in parameters if ("bn" not in n) and v.requires_grad]
50 | optimizer = torch.optim.SGD(
51 | [
52 | {
53 | "params": bn_params,
54 | "weight_decay": 0 if no_bn_decay else weight_decay,
55 | },
56 | {"params": rest_params, "weight_decay": weight_decay},
57 | ],
58 | lr,
59 | momentum=momentum,
60 | weight_decay=weight_decay,
61 | nesterov=nesterov,
62 | )
63 | elif algorithm == "adam":
64 | optimizer = torch.optim.Adam(
65 | filter(lambda p: p.requires_grad, model.parameters()),
66 | lr=lr,
67 | weight_decay=weight_decay,
68 | )
69 |
70 | return optimizer
71 |
72 |
73 | def get_criteria(similarity_loss: Optional[Dict] = None,
74 | discriminative_loss: Optional[Dict] = None,
75 | uncertainty_loss: Optional[Dict] = None,
76 | ) -> Tuple[Dict[str, float], Dict[str, Callable]]:
77 | """Get training criteria.
78 |
79 | :param similarity_loss: Optional dictionary to determine similarity loss. It should have a name key with values from
80 | ["mse", "l1", "cosine"]. "mu_similarity" key determines coefficient of the similarity loss, with default value of 1.
81 | :param discriminative_loss: Optional dictionary to determine discriminative loss. Name key can take value from
82 | ["LabelSmoothing", "CE"]. If "LabelSmoothing" is picked, the "label_smoothing" key can be used to determine the
83 | value of label smoothing (from [0,1] interval). "mu_disc" key determines coefficient of the discriminative loss.
84 | :param uncertainty_loss: Optional dictionary to determine whether to have uncertainty in loss. "mu_uncertainty" key
85 | determines coefficient of the regularization term in Bayesian uncertainty estimation formulation.
86 |
87 | :return: Two dictionaries. First one has coefficients: {"mu_similarity": mu_similarity, "mu_disc": mu_disc}
88 | second one has callable loss functions with "criterion_similarity", "criterion_disc", and "criterion_uncertainty"
89 | keys.
90 | """
91 | if uncertainty_loss is not None:
92 | criterion_uncertainty = UncertaintyLoss(**uncertainty_loss)
93 | reduction = "none"
94 | else:
95 | criterion_uncertainty = None
96 | reduction = "mean"
97 |
98 | if similarity_loss is not None:
99 | if similarity_loss.get("name") == "mse":
100 | criterion_similarity = nn.MSELoss(reduction=reduction)
101 | elif similarity_loss.get("name") == "l1":
102 | criterion_similarity = nn.L1Loss(reduction=reduction)
103 | elif similarity_loss.get("name") == "cosine":
104 |
105 | def similarity_cosine(x, y):
106 | return nn.CosineEmbeddingLoss(reduction=reduction)(
107 | x.squeeze(),
108 | y.squeeze(),
109 | target=torch.ones(x.shape[0], device=y.device),
110 | )
111 |
112 | criterion_similarity = similarity_cosine
113 | else:
114 | raise NotImplementedError("Similarity loss not implemented!")
115 | mu_similarity = similarity_loss.get("mu_similarity", 1.0)
116 | else:
117 | criterion_similarity = None
118 | mu_similarity = None
119 |
120 | if discriminative_loss is not None:
121 | if discriminative_loss.get("name") == "LabelSmoothing":
122 | criterion_disc = LabelSmoothing(
123 | smoothing=discriminative_loss.get("label_smoothing"),
124 | reduction=reduction,
125 | )
126 | elif discriminative_loss.get("name") == "CE":
127 | criterion_disc = nn.CrossEntropyLoss(reduction=reduction)
128 | else:
129 | raise NotImplementedError("discriminative loss not implemented")
130 | mu_disc = discriminative_loss.get("mu_disc")
131 | else:
132 | criterion_disc = None
133 | mu_disc = None
134 |
135 | mus_dict = {"mu_similarity": mu_similarity, "mu_disc": mu_disc}
136 | criteria_dict = {
137 | "criterion_similarity": criterion_similarity,
138 | "criterion_disc": criterion_disc,
139 | "criterion_uncertainty": criterion_uncertainty,
140 | }
141 | return mus_dict, criteria_dict
142 |
--------------------------------------------------------------------------------
/utils/logging_utils.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023 Apple Inc. All rights reserved.
3 | #
4 |
5 |
6 | class AverageMeter:
7 | """Computes and stores the average and current value."""
8 |
9 | def __init__(self, name: str, fmt: str = ":f") -> None:
10 | """Construct an AverageMeter module.
11 |
12 | :param name: Name of the metric to be tracked.
13 | :param fmt: Output format string.
14 | """
15 | self.name = name
16 | self.fmt = fmt
17 | self.reset()
18 |
19 | def reset(self):
20 | """Reset internal states."""
21 | self.val = 0
22 | self.avg = 0
23 | self.sum = 0
24 | self.count = 0
25 |
26 | def update(self, val: float, n: int = 1) -> None:
27 | """Update internal states given new values.
28 |
29 | :param val: New metric value.
30 | :param n: Step size for update.
31 | """
32 | self.val = val
33 | self.sum += val * n
34 | self.count += n
35 | self.avg = self.sum / self.count
36 |
37 | def __str__(self):
38 | """Get string name of the object."""
39 | fmtstr = "{name} {val" + self.fmt + "} ({avg" + self.fmt + "})"
40 | return fmtstr.format(**self.__dict__)
41 |
--------------------------------------------------------------------------------
/utils/net_utils.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023 Apple Inc. All rights reserved.
3 | #
4 |
5 | from typing import Union
6 |
7 | import torch
8 | import torch.nn as nn
9 |
10 |
11 | class LabelSmoothing(nn.Module):
12 | """NLL loss with label smoothing."""
13 |
14 | def __init__(self, smoothing: float = 0.0):
15 | """Construct LabelSmoothing module.
16 |
17 | :param smoothing: label smoothing factor
18 | """
19 | super(LabelSmoothing, self).__init__()
20 | self.confidence = 1.0 - smoothing
21 | self.smoothing = smoothing
22 |
23 | def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
24 | """Apply forward pass.
25 |
26 | :param x: Logits tensor.
27 | :param target: Ground truth target classes.
28 | :return: Loss tensor.
29 | """
30 | logprobs = torch.nn.functional.log_softmax(x, dim=-1)
31 |
32 | nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1))
33 | nll_loss = nll_loss.squeeze(1)
34 | smooth_loss = -logprobs.mean(dim=-1)
35 | loss = self.confidence * nll_loss + self.smoothing * smooth_loss
36 | return loss.mean()
37 |
38 |
39 | class FeatureExtractor(nn.Module):
40 | """A wrapper class to return only features (no logits)."""
41 |
42 | def __init__(self, model: Union[nn.Module, torch.jit.ScriptModule]) -> None:
43 | """Construct FeatureExtractor module.
44 |
45 | :param model: A model that outputs both logits and features.
46 | """
47 | super().__init__()
48 | self.model = model
49 |
50 | def forward(self, x: torch.Tensor) -> torch.Tensor:
51 | """Apply forward pass.
52 |
53 | :param x: Input data.
54 | :return: Feature tensor computed for x.
55 | """
56 | _, feature = self.model(x)
57 | return feature
58 |
59 |
60 | class TransformedOldModel(nn.Module):
61 | """A wrapper class to return transformed features."""
62 |
63 | def __init__(
64 | self,
65 | old_model: Union[nn.Module, torch.jit.ScriptModule],
66 | side_model: Union[nn.Module, torch.jit.ScriptModule],
67 | transformation: Union[nn.Module, torch.jit.ScriptModule],
68 | ) -> None:
69 | """Construct TransformedOldModel module.
70 |
71 | :param old_model: Old model.
72 | :param side_model: Side information model.
73 | :param transformation: Transformation model.
74 | """
75 | super().__init__()
76 | self.old_model = old_model
77 | self.transformation = transformation
78 | self.side_info_model = side_model
79 |
80 | def forward(self, x: torch.Tensor) -> torch.Tensor:
81 | """Apply forward pass.
82 |
83 | :param x: Input data
84 | :return: Transformed old feature.
85 | """
86 | old_feature = self.old_model(x)
87 | side_info = self.side_info_model(x)
88 | recycled_feature = self.transformation(old_feature, side_info)
89 | return recycled_feature
90 |
91 |
92 | def prepare_model_for_export(
93 | model: Union[nn.Module, torch.jit.ScriptModule]
94 | ) -> Union[nn.Module, torch.jit.ScriptModule]:
95 | """Prepare a model to be exported as torchscript."""
96 | if isinstance(model, torch.nn.DataParallel):
97 | model = model.module
98 | model.eval()
99 | model.cpu()
100 | return model
101 |
102 |
103 | def backbone_to_torchscript(
104 | model: Union[nn.Module, torch.jit.ScriptModule], output_model_path: str
105 | ) -> None:
106 | """Convert a backbone model to torchscript.
107 |
108 | :param model: A backbone model to be converted to torch script.
109 | :param output_model_path: Path to save torch script.
110 | """
111 | model = prepare_model_for_export(model)
112 | f = FeatureExtractor(model)
113 | model_script = torch.jit.script(f)
114 | torch.jit.save(model_script, output_model_path)
115 |
116 |
117 | def transformation_to_torchscripts(
118 | old_model: Union[nn.Module, torch.jit.ScriptModule],
119 | side_model: Union[nn.Module, torch.jit.ScriptModule],
120 | transformation: Union[nn.Module, torch.jit.ScriptModule],
121 | output_transformation_path: str,
122 | output_transformed_old_model_path: str,
123 | ) -> None:
124 | """Convert a transformation model to torchscript.
125 |
126 | :param old_model: Old model.
127 | :param side_model: Side information model.
128 | :param transformation: Transformation model.
129 | :param output_transformation_path: Path to store transformation torch
130 | script.
131 | :param output_transformed_old_model_path: Path to store combined old and
132 | transformation models' torch script.
133 | """
134 | transformation = prepare_model_for_export(transformation)
135 | old_model = prepare_model_for_export(old_model)
136 | side_model = prepare_model_for_export(side_model)
137 |
138 | model_script = torch.jit.script(transformation)
139 | torch.jit.save(model_script, output_transformation_path)
140 |
141 | f = TransformedOldModel(old_model, side_model, transformation)
142 | model_script = torch.jit.script(f)
143 | torch.jit.save(model_script, output_transformed_old_model_path)
144 |
--------------------------------------------------------------------------------
/utils/objectives.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023 Apple Inc. All rights reserved.
3 | #
4 |
5 | import torch
6 | from torch import nn
7 |
8 |
9 | class UncertaintyLoss(nn.Module):
10 | """Loss that takes a loss vector and a Sigma vector and applies uncertainty to it"""
11 |
12 | def __init__(
13 | self, mu_uncertainty: float = 0.5, sigma_dim: int = 1, **kwargs
14 | ) -> None:
15 | super(UncertaintyLoss, self).__init__()
16 | self.mu = mu_uncertainty
17 | self.sigma_dim = sigma_dim
18 |
19 | def forward(self, sigma: torch.Tensor, loss: torch.Tensor) -> torch.Tensor:
20 | """
21 | :param sigma: feature vector (N x sigma_dim) that includes the sigma vector
22 | :param loss: loss vector (N x 1) where N is the batch size
23 |
24 | :return: return either a Nx1 or a 1x1 dimensional loss vector
25 | """
26 | batch = sigma.size(0)
27 | reg = self.mu * sigma.view(batch, -1).mean(dim=1)
28 | loss_value = 0.5 * torch.exp(-sigma.squeeze()) * loss + reg
29 |
30 | return loss_value.mean()
31 |
32 |
33 | class LabelSmoothing(nn.Module):
34 | """NLL loss with label smoothing."""
35 |
36 | def __init__(self, smoothing: float = 0.0, reduction: str = "mean"):
37 | """Construct LabelSmoothing module.
38 |
39 | :param smoothing: label smoothing factor. Default is 0.0.
40 | :param reduction: reduction method to use from ["mean", "none"]. Default is "mean".
41 | """
42 | super(LabelSmoothing, self).__init__()
43 | self.confidence = 1.0 - smoothing
44 | self.smoothing = smoothing
45 | assert reduction in ["mean", "none"]
46 | self.reduction = reduction
47 |
48 | def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
49 | """Apply forward pass.
50 |
51 | :param x: Logits tensor.
52 | :param target: Ground truth target classes.
53 | :return: Loss tensor.
54 | """
55 | logprobs = torch.nn.functional.log_softmax(x, dim=-1)
56 |
57 | nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1))
58 | nll_loss = nll_loss.squeeze(1)
59 | smooth_loss = -logprobs.mean(dim=-1)
60 | loss = self.confidence * nll_loss + self.smoothing * smooth_loss
61 | if self.reduction == "mean":
62 | return loss.mean()
63 | else:
64 | return loss
65 |
--------------------------------------------------------------------------------
/utils/schedulers.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023 Apple Inc. All rights reserved.
3 | #
4 |
5 | from typing import Callable, Optional
6 |
7 | import torch
8 | import numpy as np
9 |
10 | __all__ = ["multistep_lr", "cosine_lr", "constant_lr", "get_policy"]
11 |
12 |
13 | def get_policy(optimizer: torch.optim.Optimizer, algorithm: str, **kwargs) -> Callable:
14 | """Get learning policy given its configurations.
15 |
16 | :param optimizer: A torch optimizer.
17 | :param algorithm: Name of the learning rate scheduling algorithm.
18 | :return: A callable to adjust learning rate for each epoch.
19 | """
20 | out_dict = {
21 | "constant_lr": constant_lr,
22 | "cosine_lr": cosine_lr,
23 | "multistep_lr": multistep_lr,
24 | }
25 | return out_dict[algorithm](optimizer, **kwargs)
26 |
27 |
28 | def assign_learning_rate(optimizer: torch.optim.Optimizer, new_lr: float) -> None:
29 | """Update lr parameter of an optimizer.
30 |
31 | :param optimizer: A torch optimizer.
32 | :param new_lr: updated value of learning rate.
33 | """
34 | for param_group in optimizer.param_groups:
35 | param_group["lr"] = new_lr
36 |
37 |
38 | def constant_lr(
39 | optimizer: torch.optim.Optimizer, warmup_length: int, lr: float, **kwargs
40 | ) -> Callable:
41 | """Get lr adjustment callable with constant schedule.
42 |
43 | :param optimizer: A torch optimizer.
44 | :param warmup_length: Number of epochs for initial warmup.
45 | :param lr: Nominal learning rate value.
46 | :return: A callable to adjust learning rate per epoch.
47 | """
48 |
49 | def _lr_adjuster(epoch: int, iteration: Optional[int]) -> float:
50 | """Get updated learning rate.
51 |
52 | :param epoch: Epoch number.
53 | :param iteration: Iteration number.
54 | :return: Updated learning rate value.
55 | """
56 | if epoch < warmup_length:
57 | new_lr = _warmup_lr(lr, warmup_length, epoch)
58 | else:
59 | new_lr = lr
60 |
61 | assign_learning_rate(optimizer, new_lr)
62 |
63 | return new_lr
64 |
65 | return _lr_adjuster
66 |
67 |
68 | def cosine_lr(
69 | optimizer: torch.optim.Optimizer,
70 | warmup_length: int,
71 | epochs: int,
72 | lr: float,
73 | **kwargs
74 | ) -> Callable:
75 | """Get lr adjustment callable with cosine schedule.
76 |
77 | :param optimizer: A torch optimizer.
78 | :param warmup_length: Number of epochs for initial warmup.
79 | :param epochs: Epoch number.
80 | :param lr: Nominal learning rate value.
81 | :return: A callable to adjust learning rate per epoch.
82 | """
83 |
84 | def _lr_adjuster(epoch: int, iteration: Optional[int]) -> float:
85 | """Get updated learning rate.
86 |
87 | :param epoch: Epoch number.
88 | :param iteration: Iteration number.
89 | :return: Updated learning rate value.
90 | """
91 | if epoch < warmup_length:
92 | new_lr = _warmup_lr(lr, warmup_length, epoch)
93 | else:
94 | e = epoch - warmup_length
95 | es = epochs - warmup_length
96 | new_lr = 0.5 * (1 + np.cos(np.pi * e / es)) * lr
97 |
98 | assign_learning_rate(optimizer, new_lr)
99 |
100 | return new_lr
101 |
102 | return _lr_adjuster
103 |
104 |
105 | def multistep_lr(
106 | optimizer: torch.optim.Optimizer,
107 | lr_gamma: float,
108 | lr_adjust: int,
109 | lr: float,
110 | **kwargs
111 | ) -> Callable:
112 | """Get lr adjustment callable with multi-step schedule.
113 |
114 | :param optimizer: A torch optimizer.
115 | :param lr_gamma: Learning rate decay factor.
116 | :param lr_adjust: Number of epochs to apply decay.
117 | :param lr: Nominal Learning rate.
118 | :return: A callable to adjust learning rate per epoch.
119 | """
120 |
121 | def _lr_adjuster(epoch: int, iteration: Optional[int]) -> float:
122 | """Get updated learning rate.
123 |
124 | :param epoch: Epoch number.
125 | :param iteration: Iteration number.
126 | :return: Updated learning rate value.
127 | """
128 | new_lr = lr * (lr_gamma ** (epoch // lr_adjust))
129 |
130 | assign_learning_rate(optimizer, new_lr)
131 |
132 | return new_lr
133 |
134 | return _lr_adjuster
135 |
136 |
137 | def _warmup_lr(base_lr: float, warmup_length: int, epoch: int) -> float:
138 | """Get updated lr after applying initial warmup.
139 |
140 | :param base_lr: Nominal learning rate.
141 | :param warmup_length: Number of epochs for initial warmup.
142 | :param epoch: Epoch number.
143 | :return: Warmup-updated learning rate.
144 | """
145 | return base_lr * (epoch + 1) / warmup_length
146 |
--------------------------------------------------------------------------------