├── ACKNOWLEDGEMENTS ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── __init__.py ├── checkpoints └── README.md ├── configs ├── cifar100_backbone_new.yaml ├── cifar100_backbone_old.yaml ├── cifar100_eval_new_new.yaml ├── cifar100_eval_old_new_fastfill.yaml ├── cifar100_eval_old_new_fct.yaml ├── cifar100_eval_old_old.yaml ├── cifar100_fastfill_transformation.yaml ├── cifar100_fct_transformation.yaml ├── imagenet_backbone_new.yaml ├── imagenet_backbone_old.yaml ├── imagenet_eval_new_new.yaml ├── imagenet_eval_old_new_fastfill.yaml ├── imagenet_eval_old_new_fct.yaml ├── imagenet_eval_old_old.yaml ├── imagenet_fastfill_transformation.yaml └── imagenet_fct_transformation.yaml ├── dataset ├── __init__.py ├── data_transforms.py └── sub_image_folder.py ├── demo.gif ├── eval.py ├── fct_logo.png ├── get_pretrained_models.sh ├── models ├── __init__.py ├── resnet.py └── transformations.py ├── prepare_dataset.py ├── requirements.txt ├── train_backbone.py ├── train_transformation.py ├── trainers ├── __init__.py ├── backbone_trainer.py └── transformation_trainer.py └── utils ├── __init__.py ├── eval_utils.py ├── getters.py ├── logging_utils.py ├── net_utils.py ├── objectives.py └── schedulers.py /ACKNOWLEDGEMENTS: -------------------------------------------------------------------------------- 1 | Acknowledgements 2 | Portions of this Software may utilize the following copyrighted 3 | material, the use of which is hereby acknowledged. 4 | 5 | _____________________ 6 | 7 | Shen, Yantao and Xiong, Yuanjun and Xia, Wei and Soatto, Stefano (OpenBCT) 8 | Copyright (c) 2020, Yantao Shen 9 | All rights reserved. 10 | 11 | Redistribution and use in source and binary forms, with or without 12 | modification, are permitted provided that the following conditions are met: 13 | 14 | 1. Redistributions of source code must retain the above copyright notice, this 15 | list of conditions and the following disclaimer. 16 | 17 | 2. Redistributions in binary form must reproduce the above copyright notice, 18 | this list of conditions and the following disclaimer in the documentation 19 | and/or other materials provided with the distribution. 20 | 21 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 22 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 23 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 24 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 25 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 26 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 27 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 28 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 29 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 30 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | 2 | # Contributor Covenant Code of Conduct 3 | 4 | ## Our Pledge 5 | 6 | We as members, contributors, and leaders pledge to make participation in our 7 | community a harassment-free experience for everyone, regardless of age, body 8 | size, visible or invisible disability, ethnicity, sex characteristics, gender 9 | identity and expression, level of experience, education, socio-economic status, 10 | nationality, personal appearance, race, caste, color, religion, or sexual 11 | identity and orientation. 12 | 13 | We pledge to act and interact in ways that contribute to an open, welcoming, 14 | diverse, inclusive, and healthy community. 15 | 16 | ## Our Standards 17 | 18 | Examples of behavior that contributes to a positive environment for our 19 | community include: 20 | 21 | * Demonstrating empathy and kindness toward other people 22 | * Being respectful of differing opinions, viewpoints, and experiences 23 | * Giving and gracefully accepting constructive feedback 24 | * Accepting responsibility and apologizing to those affected by our mistakes, 25 | and learning from the experience 26 | * Focusing on what is best not just for us as individuals, but for the overall 27 | community 28 | 29 | Examples of unacceptable behavior include: 30 | 31 | * The use of sexualized language or imagery, and sexual attention or advances of 32 | any kind 33 | * Trolling, insulting or derogatory comments, and personal or political attacks 34 | * Public or private harassment 35 | * Publishing others' private information, such as a physical or email address, 36 | without their explicit permission 37 | * Other conduct which could reasonably be considered inappropriate in a 38 | professional setting 39 | 40 | ## Enforcement Responsibilities 41 | 42 | Community leaders are responsible for clarifying and enforcing our standards of 43 | acceptable behavior and will take appropriate and fair corrective action in 44 | response to any behavior that they deem inappropriate, threatening, offensive, 45 | or harmful. 46 | 47 | Community leaders have the right and responsibility to remove, edit, or reject 48 | comments, commits, code, wiki edits, issues, and other contributions that are 49 | not aligned to this Code of Conduct, and will communicate reasons for moderation 50 | decisions when appropriate. 51 | 52 | ## Scope 53 | 54 | This Code of Conduct applies within all community spaces, and also applies when 55 | an individual is officially representing the community in public spaces. 56 | Examples of representing our community include using an official e-mail address, 57 | posting via an official social media account, or acting as an appointed 58 | representative at an online or offline event. 59 | 60 | ## Enforcement 61 | 62 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 63 | reported to the community leaders responsible for enforcement at 64 | [INSERT CONTACT METHOD]. 65 | All complaints will be reviewed and investigated promptly and fairly. 66 | 67 | All community leaders are obligated to respect the privacy and security of the 68 | reporter of any incident. 69 | 70 | ## Enforcement Guidelines 71 | 72 | Community leaders will follow these Community Impact Guidelines in determining 73 | the consequences for any action they deem in violation of this Code of Conduct: 74 | 75 | ### 1. Correction 76 | 77 | **Community Impact**: Use of inappropriate language or other behavior deemed 78 | unprofessional or unwelcome in the community. 79 | 80 | **Consequence**: A private, written warning from community leaders, providing 81 | clarity around the nature of the violation and an explanation of why the 82 | behavior was inappropriate. A public apology may be requested. 83 | 84 | ### 2. Warning 85 | 86 | **Community Impact**: A violation through a single incident or series of 87 | actions. 88 | 89 | **Consequence**: A warning with consequences for continued behavior. No 90 | interaction with the people involved, including unsolicited interaction with 91 | those enforcing the Code of Conduct, for a specified period of time. This 92 | includes avoiding interactions in community spaces as well as external channels 93 | like social media. Violating these terms may lead to a temporary or permanent 94 | ban. 95 | 96 | ### 3. Temporary Ban 97 | 98 | **Community Impact**: A serious violation of community standards, including 99 | sustained inappropriate behavior. 100 | 101 | **Consequence**: A temporary ban from any sort of interaction or public 102 | communication with the community for a specified period of time. No public or 103 | private interaction with the people involved, including unsolicited interaction 104 | with those enforcing the Code of Conduct, is allowed during this period. 105 | Violating these terms may lead to a permanent ban. 106 | 107 | ### 4. Permanent Ban 108 | 109 | **Community Impact**: Demonstrating a pattern of violation of community 110 | standards, including sustained inappropriate behavior, harassment of an 111 | individual, or aggression toward or disparagement of classes of individuals. 112 | 113 | **Consequence**: A permanent ban from any sort of public interaction within the 114 | community. 115 | 116 | ## Attribution 117 | 118 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], 119 | version 2.1, available at 120 | [https://www.contributor-covenant.org/version/2/1/code_of_conduct.html][v2.1]. 121 | 122 | Community Impact Guidelines were inspired by 123 | [Mozilla's code of conduct enforcement ladder][Mozilla CoC]. 124 | 125 | For answers to common questions about this code of conduct, see the FAQ at 126 | [https://www.contributor-covenant.org/faq][FAQ]. Translations are available at 127 | [https://www.contributor-covenant.org/translations][translations]. 128 | 129 | [homepage]: https://www.contributor-covenant.org 130 | [v2.1]: https://www.contributor-covenant.org/version/2/1/code_of_conduct.html 131 | [Mozilla CoC]: https://github.com/mozilla/diversity 132 | [FAQ]: https://www.contributor-covenant.org/faq 133 | [translations]: https://www.contributor-covenant.org/translations 134 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contribution Guide 2 | 3 | Thanks for your interest in contributing. This project was released to accompany a research paper for purposes of reproducability, and beyond its publication there are limited plans for future development of the repository. 4 | 5 | While we welcome new pull requests and issues please note that our response may be limited. Forks and out-of-tree improvements are strongly encouraged. 6 | 7 | ## Before you get started 8 | 9 | By submitting a pull request, you represent that you have the right to license your contribution to Apple and the community, and agree by submitting the patch that your contributions are licensed under the [LICENSE](LICENSE). 10 | 11 | We ask that all community members read and observe our [Code of Conduct](CODE_OF_CONDUCT.md). -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (C) 2020 Apple Inc. All Rights Reserved. 2 | 3 | IMPORTANT: This Apple software is supplied to you by Apple 4 | Inc. ("Apple") in consideration of your agreement to the following 5 | terms, and your use, installation, modification or redistribution of 6 | this Apple software constitutes acceptance of these terms. If you do 7 | not agree with these terms, please do not use, install, modify or 8 | redistribute this Apple software. 9 | 10 | In consideration of your agreement to abide by the following terms, and 11 | subject to these terms, Apple grants you a personal, non-exclusive 12 | license, under Apple's copyrights in this original Apple software (the 13 | "Apple Software"), to use, reproduce, modify and redistribute the Apple 14 | Software, with or without modifications, in source and/or binary forms; 15 | provided that if you redistribute the Apple Software in its entirety and 16 | without modifications, you must retain this notice and the following 17 | text and disclaimers in all such redistributions of the Apple Software. 18 | Neither the name, trademarks, service marks or logos of Apple Inc. may 19 | be used to endorse or promote products derived from the Apple Software 20 | without specific prior written permission from Apple. Except as 21 | expressly stated in this notice, no other rights or licenses, express or 22 | implied, are granted by Apple herein, including but not limited to any 23 | patent rights that may be infringed by your derivative works or by other 24 | works in which the Apple Software may be incorporated. 25 | 26 | The Apple Software is provided by Apple on an "AS IS" basis. APPLE 27 | MAKES NO WARRANTIES, EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION 28 | THE IMPLIED WARRANTIES OF NON-INFRINGEMENT, MERCHANTABILITY AND FITNESS 29 | FOR A PARTICULAR PURPOSE, REGARDING THE APPLE SOFTWARE OR ITS USE AND 30 | OPERATION ALONE OR IN COMBINATION WITH YOUR PRODUCTS. 31 | 32 | IN NO EVENT SHALL APPLE BE LIABLE FOR ANY SPECIAL, INDIRECT, INCIDENTAL 33 | OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 34 | SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 35 | INTERRUPTION) ARISING IN ANY WAY OUT OF THE USE, REPRODUCTION, 36 | MODIFICATION AND/OR DISTRIBUTION OF THE APPLE SOFTWARE, HOWEVER CAUSED 37 | AND WHETHER UNDER THEORY OF CONTRACT, TORT (INCLUDING NEGLIGENCE), 38 | STRICT LIABILITY OR OTHERWISE, EVEN IF APPLE HAS BEEN ADVISED OF THE 39 | POSSIBILITY OF SUCH DAMAGE. 40 | 41 | ------------------------------------------------------------------------------- 42 | SOFTWARE DISTRIBUTED WITH ML-FCT: 43 | 44 | The ML-FCT software includes a number of subcomponents with separate 45 | copyright notices and license terms - please see the file ACKNOWLEDGEMENTS. 46 | ------------------------------------------------------------------------------- -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Compatibility for Machine Learning Model Update 2 | This repository contains PyTorch implementation of [Forward Compatible Training for Large-Scale Embedding Retrieval Systems (CVPR 2022)](https://arxiv.org/abs/2112.02805): 3 | 4 | 5 | 6 | and [FastFill: Efficient Compatible Model Update (ICLR 2023)](https://openreview.net/pdf?id=rnRiiHw8Vy): 7 | 8 | 9 | 10 | **The code is written to use Python 3.8 or above.** 11 | 12 | ## Requirements 13 | 14 | We suggest you first create a virtual environment and install dependencies in the virtual environment. 15 | 16 | ```bash 17 | # Go to repo 18 | cd 19 | # Create virtual environment ... 20 | python -m venv .venv 21 | # ... and activate it 22 | source .venv/bin/activate 23 | # Upgrade to the latest versions of pip and wheel 24 | pip install -U pip wheel 25 | pip install -r requirements.txt 26 | ``` 27 | 28 | ## CIFAR-100 Experiments (quick start) 29 | 30 | 31 | We provide CIFAR-100 experiments, for fast exploration. The code will run and produce results of both FCT and Fastfill. 32 | Here are the sequence of commands for CIFAR-100 experiments (similar to ImageNet but faster cycles): 33 | 34 | ```bash 35 | # Get data: following command put data in data_store/cifar-100-python 36 | python prepare_dataset.py 37 | 38 | # Train old embedding model: 39 | # Note: config files assume training with 8 GPUs. Modify them according to your environment. 40 | python train_backbone.py --config configs/cifar100_backbone_old.yaml 41 | 42 | # Evaluate the old model (single GPU is OK): 43 | python eval.py --config configs/cifar100_eval_old_old.yaml 44 | 45 | # Train New embedding model: 46 | python train_backbone.py --config configs/cifar100_backbone_new.yaml 47 | 48 | # Evaluate the new model (single GPU is OK): 49 | python eval.py --config configs/cifar100_eval_new_new.yaml 50 | 51 | # Download pre-traianed models if training with side-information: 52 | source get_pretrained_models.sh 53 | 54 | # Train FCT transformation: 55 | # If training with side-info model, add its path to the config file below. You 56 | # can use the same side-info model as for ImageNet experiment here. 57 | python train_transformation.py --config configs/cifar100_fct_transformation.yaml 58 | 59 | # Evaluate transformed model vs new model (single GPU is OK): 60 | python eval.py --config configs/cifar100_eval_old_new_fct.yaml 61 | 62 | # Train FastFill transformation: 63 | python train_transformation.py --config configs/cifar100_fastfill_transformation.yaml 64 | 65 | # Evaluate transformed model vs new model (single GPU is OK): 66 | python eval.py --config configs/cifar100_eval_old_new_fastfill.yaml 67 | ``` 68 | 69 | ### CIFAR-100 (FCT, without backfilling): 70 | * These results are *not* averaged over multiple runs. 71 | 72 | | Case | `Side-Info` | `CMC Top-1 (%)` | `CMC Top-5 (%)` | `mAP (%)` | 73 | |-------------------------------------------------------|:------------:|:------------------:|:---------------:|:---------:| 74 | | [old/old](./configs/cifar100_backbone_old.yaml) | N/A | 34.2 | 60.6 | 16.5 | 75 | | [new/new](./configs/cifar100_backbone_new.yaml) | N/A | 56.5 | 77.0 | 36.3 | 76 | | [FCT new/old](./configs/cifar100_transformation.yaml) | No | 47.2 | 72.6 | 25.8 | 77 | | [FCT new/old](./configs/cifar100_transformation.yaml) | Yes | 50.2 | 73.7 | 32.2 | 78 | 79 | ### CIFAR-100 (FastFill, with backfilling): 80 | * These results are *not* averaged over multiple runs. 81 | * AUC: Area Under the backfilling Curve. For old/old and new/new we report performance corresponding to 82 | no model update and full model update, respectively. 83 | 84 | | Case | `Side-Info` | `Backfilling` | `AUC CMC Top-1 (%)` | `AUC CMC Top-5 (%)` | `AUC mAP (%)` | 85 | |----------------------------------------------------------------|:-----------:|:-------------:|:-------------------:|:-------------------:|:-------------:| 86 | | [old/old](./configs/cifar100_backbone_old.yaml) | N/A | N/A | 34.2 | 60.6 | 16.5 | 87 | | [new/new](./configs/cifar100_backbone_new.yaml) | N/A | N/A | 56.5 | 77.0 | 36.3 | 88 | | [FCT new/old](./configs/cifar100_fct_transformation.yaml) | No | Random | 49.1 | 73.6 | 29.1 | 89 | | [FastFill new/old](./configs/cifar100_fct_transformation.yaml) | No | Uncertainty | 53.6 | 75.3 | 32.5 | 90 | 91 | 92 | ## ImageNet-1k Experiments 93 | 94 | Here are the sequence of commands for ImageNet experiments: 95 | 96 | ```bash 97 | # Get data: Prepare full ImageNet-1k dataset and provide its path in all config 98 | # files. The path should include training and validation directories. 99 | 100 | # Train old embedding model: 101 | # Note: config files assume training with 8 GPUs. Modify them according to your environment. 102 | python train_backbone.py --config configs/imagenet_backbone_old.yaml 103 | 104 | # Evaluate the old model: 105 | python eval.py --config configs/imagenet_eval_old_old.yaml 106 | 107 | # Train New embedding model: 108 | python train_backbone.py --config configs/imagenet_backbone_new.yaml 109 | 110 | # Evaluate the new model: 111 | python eval.py --config configs/imagenet_eval_new_new.yaml 112 | 113 | # Download pre-traianed models if training with side-information: 114 | source get_pretrained_models.sh 115 | 116 | # Train FCT transformation: 117 | # (If training with side-info model, add its path to the config file below.) 118 | python train_transformation.py --config configs/imagenet_fct_transformation.yaml 119 | 120 | # Evaluate transformed model vs new model: 121 | python eval.py --config configs/imagenet_eval_old_new_fct.yaml 122 | 123 | # Train FastFill transformation: 124 | python train_transformation.py --config configs/imagenet_fastfill_transformation.yaml 125 | 126 | # Evaluate transformed model vs new model: 127 | python eval.py --config configs/imagenet_eval_old_new_fastfill.yaml 128 | ``` 129 | 130 | ### ImageNet-1k (FCT, without backfilling): 131 | 132 | | Case | `Side-Info` | `CMC Top-1 (%)` | `CMC Top-5 (%)` | `mAP (%)` | 133 | |-------------------------------------------------------|:-----------:|:---------------:|:---------------:|:-------:| 134 | | [old/old](./configs/imagenet_backbone_old.yaml) | N/A | 46.4 | 65.1 | 28.3 | 135 | | [new/new](./configs/imagenet_backbone_new.yaml) | N/A | 68.4 | 84.7 | 45.6 | 136 | | [FCT new/old](./configs/imagenet_transformation.yaml) | No | 61.8 | 80.5 | 39.9 | 137 | | [FCT new/old](./configs/imagenet_transformation.yaml) | Yes | 65.1 | 82.7 | 44.0 | 138 | 139 | ### ImageNet-1k (FastFill, with backfilling): 140 | * AUC: Area Under the backfilling Curve. For old/old and new/new we report performance corresponding to 141 | no model update and full model update, respectively. 142 | 143 | | Case | `Side-Info` | `Backfilling` | `AUC CMC Top-1 (%)` | `AUC CMC Top-5 (%)` | `AUC mAP (%)` | 144 | |---------------------------------------------------------------------|:-----------:|:-------------:|:-------------------:|:-------------------:|:------------:| 145 | | [old/old](./configs/imagenet_backbone_old.yaml) | N/A | N/A | 46.6 | 65.2 | 28.5 | 146 | | [new/new](./configs/imagenet_backbone_new.yaml) | N/A | N/A | 68.2 | 84.6 | 45.3 | 147 | | [FCT new/old](./configs/imagenet_fct_transformation.yaml) | No | Random | 62.8 | 81.1 | 40.5 | 148 | | [FastFill new/old](./configs/imagenet_fastfill_transformation.yaml) | No | Uncertainty | 66.5 | 83.6 | 44.8 | 149 | | [FCT new/old](./configs/imagenet_fct_transformation.yaml) | Yes | Random | 64.7 | 82.4 | 42.6 | 150 | | [FastFill new/old](./configs/imagenet_fastfill_transformation.yaml) | Yes | Uncertainty | 67.8 | 84.2 | 46.2 | 151 | 152 | ## Contact 153 | 154 | * **Hadi Pouransari**: mpouransari@apple.com 155 | 156 | ## Citation 157 | 158 | ```bibtex 159 | @article{ramanujan2022forward, 160 | title={Forward Compatible Training for Large-Scale Embedding Retrieval Systems}, 161 | author={Ramanujan, Vivek and Vasu, Pavan Kumar Anasosalu and Farhadi, Ali and Tuzel, Oncel and Pouransari, Hadi}, 162 | journal={Proceedings of the IEEE conference on computer vision and pattern recognition}, 163 | year={2022} 164 | } 165 | 166 | @inproceedings{jaeckle2023fastfill, 167 | title={FastFill: Efficient Compatible Model Update}, 168 | author={Jaeckle, Florian and Faghri, Fartash and Farhadi, Ali and Tuzel, Oncel and Pouransari, Hadi}, 169 | booktitle={International Conference on Learning Representations} 170 | year={2023} 171 | } 172 | ``` 173 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023 Apple Inc. All rights reserved. 3 | # 4 | -------------------------------------------------------------------------------- /checkpoints/README.md: -------------------------------------------------------------------------------- 1 | # Pretrained models 2 | We provide pretrained side-information models trained with [SimCLR](https://arxiv.org/abs/2002.05709) 3 | using the same training setup as in [Stochastic Contrastive Learning](https://arxiv.org/abs/2110.00552): 4 | 5 | - [ResNet50-128-ImageNet250](https://devpubs.s3.amazonaws.com/ml-research/models/fct/imagenet_250_simclr.pt) 6 | - [ResNet50-128-ImageNet500](https://devpubs.s3.amazonaws.com/ml-research/models/fct/imagenet_500_simclr.pt) 7 | - [ResNet50-128-ImageNet1000](https://devpubs.s3.amazonaws.com/ml-research/models/fct/imagenet_1000_simclr.pt) 8 | 9 | The above models have [ResNet50](https://arxiv.org/abs/1512.03385) architecture with feature dimension of 128. The models are trained on the first 250, 500, and 1000 classes of ImageNet, respectively. 10 | -------------------------------------------------------------------------------- /configs/cifar100_backbone_new.yaml: -------------------------------------------------------------------------------- 1 | arch_params: 2 | arch: ResNet50 3 | num_classes: 100 # This is the number of classes for architecture FC layer. 4 | embedding_dim: 128 5 | last_nonlin: True 6 | 7 | optimizer_params: 8 | algorithm: sgd 9 | lr: 1.024 10 | weight_decay: 0.000030517578125 11 | no_bn_decay: False 12 | momentum: 0.875 13 | nesterov: False 14 | 15 | dataset_params: 16 | name: cifar100 17 | data_root: data_store/cifar-100-python # This should contain training and validation dirs. 18 | num_classes: 100 # This is the number of classes to include for training. 19 | num_workers: 20 20 | batch_size: 1024 21 | 22 | lr_policy_params: 23 | algorithm: cosine_lr 24 | warmup_length: 5 25 | epochs: 100 26 | lr: 1.024 27 | 28 | epochs: 100 29 | label_smoothing: 0.1 30 | output_model_path: checkpoints/cifar100_new.pt 31 | -------------------------------------------------------------------------------- /configs/cifar100_backbone_old.yaml: -------------------------------------------------------------------------------- 1 | arch_params: 2 | arch: ResNet50 3 | num_classes: 100 # This is the number of classes for architecture FC layer. 4 | embedding_dim: 128 5 | last_nonlin: True 6 | 7 | optimizer_params: 8 | algorithm: sgd 9 | lr: 1.024 10 | weight_decay: 0.000030517578125 11 | no_bn_decay: False 12 | momentum: 0.875 13 | nesterov: False 14 | 15 | dataset_params: 16 | name: cifar100 17 | data_root: data_store/cifar-100-python # This should contain training and validation dirs. 18 | num_classes: 50 # This is the number of classes to include for training. 19 | num_workers: 20 20 | batch_size: 1024 21 | 22 | lr_policy_params: 23 | algorithm: cosine_lr 24 | warmup_length: 5 25 | epochs: 100 26 | lr: 1.024 27 | 28 | epochs: 100 29 | label_smoothing: 0.1 30 | output_model_path: checkpoints/cifar100_old.pt 31 | -------------------------------------------------------------------------------- /configs/cifar100_eval_new_new.yaml: -------------------------------------------------------------------------------- 1 | gallery_model_path: checkpoints/cifar100_new.pt 2 | query_model_path: checkpoints/cifar100_new.pt 3 | 4 | eval_params: 5 | distance_metric_name: l2 6 | verbose: True 7 | compute_map: True 8 | 9 | dataset_params: # Test split of the dataset will be used as both gallery and query sets. 10 | name: cifar100 11 | data_root: data_store/cifar-100-python 12 | num_workers: 20 13 | batch_size: 1024 -------------------------------------------------------------------------------- /configs/cifar100_eval_old_new_fastfill.yaml: -------------------------------------------------------------------------------- 1 | gallery_model_path: checkpoints/cifar100_old_fastfill_transformed.pt 2 | query_model_path: checkpoints/cifar100_new.pt 3 | 4 | eval_params: 5 | distance_metric_name: l2 6 | verbose: True 7 | compute_map: True 8 | backfilling: 20 9 | backfilling_list: [ 'sigma' ] # choices are 'random', 'distance', and 'sigma' 10 | backfilling_result_path: checkpoints/cifar100_fastfill_backfilling.npy 11 | 12 | dataset_params: # Test split of the dataset will be used as both gallery and query sets. 13 | name: cifar100 14 | data_root: data_store/cifar-100-python 15 | num_workers: 20 16 | batch_size: 1024 17 | -------------------------------------------------------------------------------- /configs/cifar100_eval_old_new_fct.yaml: -------------------------------------------------------------------------------- 1 | gallery_model_path: checkpoints/cifar100_old_fct_transformed.pt 2 | query_model_path: checkpoints/cifar100_new.pt 3 | 4 | eval_params: 5 | distance_metric_name: l2 6 | verbose: True 7 | compute_map: True 8 | backfilling: 20 # put null here for no backfilling, only transformed(old) vs new. 9 | backfilling_list: [ 'random' ] # choices are 'random' and 'distance' 10 | backfilling_result_path: checkpoints/cifar100_fct_backfilling.npy 11 | 12 | dataset_params: # Test split of the dataset will be used as both gallery and query sets. 13 | name: cifar100 14 | data_root: data_store/cifar-100-python 15 | num_workers: 20 16 | batch_size: 1024 17 | -------------------------------------------------------------------------------- /configs/cifar100_eval_old_old.yaml: -------------------------------------------------------------------------------- 1 | gallery_model_path: checkpoints/cifar100_old.pt 2 | query_model_path: checkpoints/cifar100_old.pt 3 | 4 | eval_params: 5 | distance_metric_name: l2 6 | verbose: True 7 | compute_map: True 8 | 9 | dataset_params: # Test split of the dataset will be used as both gallery and query sets. 10 | name: cifar100 11 | data_root: data_store/cifar-100-python 12 | num_workers: 20 13 | batch_size: 1024 14 | -------------------------------------------------------------------------------- /configs/cifar100_fastfill_transformation.yaml: -------------------------------------------------------------------------------- 1 | old_model_path: checkpoints/cifar100_old.pt 2 | new_model_path: checkpoints/cifar100_new.pt 3 | #side_info_model_path: checkpoints/imagenet_1000_simclr.pt # Comment this line for no side-info experiment. 4 | 5 | arch_params: 6 | arch: MLP_BN_SIDE_PROJECTION_SIGMA 7 | old_embedding_dim: 128 8 | new_embedding_dim: 128 9 | side_info_dim: 128 10 | inner_dim: 2048 11 | 12 | optimizer_params: 13 | algorithm: adam 14 | lr: 0.0005 15 | weight_decay: 0.000030517578125 16 | 17 | dataset_params: 18 | name: cifar100 19 | data_root: data_store/cifar-100-python # This should contain training and validation dirs. 20 | num_classes: 100 # This is the number of classes to include for training. 21 | num_workers: 20 22 | batch_size: 1024 23 | 24 | lr_policy_params: 25 | algorithm: cosine_lr 26 | warmup_length: 5 27 | epochs: 80 28 | lr: 0.0005 29 | 30 | objective_params: 31 | similarity_loss: 32 | name: mse 33 | discriminative_loss: 34 | name: LabelSmoothing 35 | label_smoothing: 0.1 36 | mu_disc: 1 37 | uncertainty_loss: 38 | mu_uncertainty: 0.5 39 | 40 | epochs: 80 41 | switch_mode_to_eval: True 42 | output_transformation_path: checkpoints/cifar100_fastfill_transformation.pt 43 | output_transformed_old_model_path: checkpoints/cifar100_old_fastfill_transformed.pt 44 | -------------------------------------------------------------------------------- /configs/cifar100_fct_transformation.yaml: -------------------------------------------------------------------------------- 1 | old_model_path: checkpoints/cifar100_old.pt 2 | new_model_path: checkpoints/cifar100_new.pt 3 | #side_info_model_path: checkpoints/imagenet_1000_simclr.pt # Comment this line for no side-info experiment. 4 | 5 | arch_params: 6 | arch: MLP_BN_SIDE_PROJECTION 7 | old_embedding_dim: 128 8 | new_embedding_dim: 128 9 | side_info_dim: 128 10 | inner_dim: 2048 11 | 12 | optimizer_params: 13 | algorithm: adam 14 | lr: 0.0005 15 | weight_decay: 0.000030517578125 16 | 17 | dataset_params: 18 | name: cifar100 19 | data_root: data_store/cifar-100-python # This should contain training and validation dirs. 20 | num_classes: 100 # This is the number of classes to include for training. 21 | num_workers: 20 22 | batch_size: 1024 23 | 24 | lr_policy_params: 25 | algorithm: cosine_lr 26 | warmup_length: 5 27 | epochs: 80 28 | lr: 0.0005 29 | 30 | objective_params: 31 | similarity_loss: 32 | name: mse 33 | 34 | epochs: 80 35 | switch_mode_to_eval: True 36 | output_transformation_path: checkpoints/cifar100_fct_transformation.pt 37 | output_transformed_old_model_path: checkpoints/cifar100_old_fct_transformed.pt 38 | -------------------------------------------------------------------------------- /configs/imagenet_backbone_new.yaml: -------------------------------------------------------------------------------- 1 | arch_params: 2 | arch: ResNet50 3 | num_classes: 1000 # This is the number of classes for architecture FC layer. 4 | embedding_dim: 128 5 | last_nonlin: True 6 | 7 | optimizer_params: 8 | algorithm: sgd 9 | lr: 1.024 10 | weight_decay: 0.000030517578125 11 | no_bn_decay: False 12 | momentum: 0.875 13 | nesterov: False 14 | 15 | dataset_params: 16 | name: imagenet 17 | data_root: data_store/imagenet-1.0.2/data/raw # This should contain training and validation dirs. 18 | num_classes: 1000 # This is the number of classes to include for training. 19 | num_workers: 20 20 | batch_size: 1024 21 | 22 | lr_policy_params: 23 | algorithm: cosine_lr 24 | warmup_length: 5 25 | epochs: 100 26 | lr: 1.024 27 | 28 | epochs: 100 29 | label_smoothing: 0.1 30 | output_model_path: checkpoints/imagenet_new.pt 31 | -------------------------------------------------------------------------------- /configs/imagenet_backbone_old.yaml: -------------------------------------------------------------------------------- 1 | arch_params: 2 | arch: ResNet50 3 | num_classes: 1000 # This is the number of classes for architecture FC layer. 4 | embedding_dim: 128 5 | last_nonlin: True 6 | 7 | optimizer_params: 8 | algorithm: sgd 9 | lr: 1.024 10 | weight_decay: 0.000030517578125 11 | no_bn_decay: False 12 | momentum: 0.875 13 | nesterov: False 14 | 15 | dataset_params: 16 | name: imagenet 17 | data_root: data_store/imagenet-1.0.2/data/raw # This should contain training and validation dirs. 18 | num_classes: 500 # This is the number of classes to include for training. 19 | num_workers: 20 20 | batch_size: 1024 21 | 22 | lr_policy_params: 23 | algorithm: cosine_lr 24 | warmup_length: 5 25 | epochs: 100 26 | lr: 1.024 27 | 28 | epochs: 100 29 | label_smoothing: 0.1 30 | output_model_path: checkpoints/imagenet_old.pt 31 | -------------------------------------------------------------------------------- /configs/imagenet_eval_new_new.yaml: -------------------------------------------------------------------------------- 1 | gallery_model_path: checkpoints/imagenet_new.pt 2 | query_model_path: checkpoints/imagenet_new.pt 3 | 4 | eval_params: 5 | distance_metric_name: l2 6 | verbose: True 7 | compute_map: True 8 | 9 | dataset_params: # Test split of the dataset will be used as both gallery and query sets. 10 | name: imagenet 11 | data_root: data_store/imagenet-1.0.2/data/raw 12 | num_workers: 20 13 | batch_size: 1024 14 | -------------------------------------------------------------------------------- /configs/imagenet_eval_old_new_fastfill.yaml: -------------------------------------------------------------------------------- 1 | gallery_model_path: checkpoints/imagenet_old_fastfill_transformed.pt 2 | query_model_path: checkpoints/imagenet_new.pt 3 | 4 | eval_params: 5 | distance_metric_name: l2 6 | verbose: True 7 | compute_map: True 8 | backfilling: 20 9 | backfilling_list: ['sigma'] # choices are 'random', 'distance', and 'sigma' 10 | backfilling_result_path: checkpoints/imagenet_fastfill_backfilling.npy 11 | 12 | dataset_params: # Test split of the dataset will be used as both gallery and query sets. 13 | name: imagenet 14 | data_root: data_store/imagenet-1.0.2/data/raw 15 | num_workers: 20 16 | batch_size: 1024 17 | -------------------------------------------------------------------------------- /configs/imagenet_eval_old_new_fct.yaml: -------------------------------------------------------------------------------- 1 | gallery_model_path: checkpoints/imagenet_old_fct_transformed.pt 2 | query_model_path: checkpoints/imagenet_new.pt 3 | 4 | eval_params: 5 | distance_metric_name: l2 6 | verbose: True 7 | compute_map: True 8 | backfilling: 20 # put null here for no backfilling, only transformed(old) vs new. 9 | backfilling_list: ['random'] # choices are 'random' and 'distance' 10 | backfilling_result_path: checkpoints/imagenet_fct_backfilling.npy 11 | 12 | dataset_params: # Test split of the dataset will be used as both gallery and query sets. 13 | name: imagenet 14 | data_root: data_store/imagenet-1.0.2/data/raw 15 | num_workers: 20 16 | batch_size: 1024 17 | -------------------------------------------------------------------------------- /configs/imagenet_eval_old_old.yaml: -------------------------------------------------------------------------------- 1 | gallery_model_path: checkpoints/imagenet_old.pt 2 | query_model_path: checkpoints/imagenet_old.pt 3 | 4 | eval_params: 5 | distance_metric_name: l2 6 | verbose: True 7 | compute_map: True 8 | 9 | dataset_params: # Test split of the dataset will be used as both gallery and query sets. 10 | name: imagenet 11 | data_root: data_store/imagenet-1.0.2/data/raw 12 | num_workers: 20 13 | batch_size: 1024 14 | -------------------------------------------------------------------------------- /configs/imagenet_fastfill_transformation.yaml: -------------------------------------------------------------------------------- 1 | old_model_path: checkpoints/imagenet_old.pt 2 | new_model_path: checkpoints/imagenet_new.pt 3 | #side_info_model_path: checkpoints/imagenet_500_simclr.pt # Comment this line for no side-info experiment. 4 | 5 | arch_params: 6 | arch: MLP_BN_SIDE_PROJECTION_SIGMA 7 | old_embedding_dim: 128 8 | new_embedding_dim: 128 9 | side_info_dim: 128 10 | inner_dim: 2048 11 | 12 | optimizer_params: 13 | algorithm: adam 14 | lr: 0.0005 15 | weight_decay: 0.000030517578125 16 | 17 | dataset_params: 18 | name: imagenet 19 | data_root: data_store/imagenet-1.0.2/data/raw # This should contain training and validation dirs. 20 | num_classes: 1000 # This is the number of classes to include for training. 21 | num_workers: 20 22 | batch_size: 1024 23 | 24 | lr_policy_params: 25 | algorithm: cosine_lr 26 | warmup_length: 5 27 | epochs: 80 28 | lr: 0.0005 29 | 30 | objective_params: 31 | similarity_loss: 32 | name: mse 33 | discriminative_loss: 34 | name: LabelSmoothing 35 | label_smoothing: 0.1 36 | mu_disc: 1 37 | uncertainty_loss: 38 | mu_uncertainty: 0.5 39 | 40 | epochs: 80 41 | switch_mode_to_eval: True 42 | output_transformation_path: checkpoints/imagenet_fastfill_transformation.pt 43 | output_transformed_old_model_path: checkpoints/imagenet_old_fastfill_transformed.pt 44 | -------------------------------------------------------------------------------- /configs/imagenet_fct_transformation.yaml: -------------------------------------------------------------------------------- 1 | old_model_path: checkpoints/imagenet_old.pt 2 | new_model_path: checkpoints/imagenet_new.pt 3 | #side_info_model_path: checkpoints/imagenet_500_simclr.pt # Comment this line for no side-info experiment. 4 | 5 | arch_params: 6 | arch: MLP_BN_SIDE_PROJECTION 7 | old_embedding_dim: 128 8 | new_embedding_dim: 128 9 | side_info_dim: 128 10 | inner_dim: 2048 11 | 12 | optimizer_params: 13 | algorithm: adam 14 | lr: 0.0005 15 | weight_decay: 0.000030517578125 16 | 17 | dataset_params: 18 | name: imagenet 19 | data_root: data_store/imagenet-1.0.2/data/raw # This should contain training and validation dirs. 20 | num_classes: 1000 # This is the number of classes to include for training. 21 | num_workers: 20 22 | batch_size: 1024 23 | 24 | lr_policy_params: 25 | algorithm: cosine_lr 26 | warmup_length: 5 27 | epochs: 80 28 | lr: 0.0005 29 | 30 | objective_params: 31 | similarity_loss: 32 | name: mse 33 | 34 | epochs: 80 35 | switch_mode_to_eval: True 36 | output_transformation_path: checkpoints/imagenet_fct_transformation.pt 37 | output_transformed_old_model_path: checkpoints/imagenet_old_fct_transformed.pt 38 | -------------------------------------------------------------------------------- /dataset/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023 Apple Inc. All rights reserved. 3 | # 4 | 5 | from .sub_image_folder import SubImageFolder 6 | -------------------------------------------------------------------------------- /dataset/data_transforms.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023 Apple Inc. All rights reserved. 3 | # 4 | 5 | from typing import Tuple 6 | from torchvision import transforms 7 | 8 | 9 | def imagenet_transforms() -> Tuple[transforms.Compose, transforms.Compose]: 10 | """Get training and validation transformations for Imagenet.""" 11 | normalize = transforms.Normalize( 12 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] 13 | ) 14 | train_transforms = transforms.Compose( 15 | [ 16 | transforms.RandomResizedCrop(224), 17 | transforms.RandomHorizontalFlip(), 18 | transforms.ToTensor(), 19 | normalize, 20 | ] 21 | ) 22 | 23 | val_transforms = transforms.Compose( 24 | [ 25 | transforms.Resize(256), 26 | transforms.CenterCrop(224), 27 | transforms.ToTensor(), 28 | normalize, 29 | ] 30 | ) 31 | return train_transforms, val_transforms 32 | 33 | 34 | def cifar100_transforms() -> Tuple[transforms.Compose, transforms.Compose]: 35 | """Get training and validation transformations for Cifar100. 36 | 37 | Note that these are not optimal transformations (including normalization), 38 | yet provided for quick experimentation similar to Imagenet 39 | (and its corresponding side-information model). 40 | """ 41 | normalize = transforms.Normalize( 42 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] 43 | ) 44 | 45 | train_transforms = transforms.Compose( 46 | [ 47 | transforms.Resize(224), 48 | transforms.RandomHorizontalFlip(), 49 | transforms.ToTensor(), 50 | normalize, 51 | ] 52 | ) 53 | 54 | val_transforms = transforms.Compose( 55 | [ 56 | transforms.Resize(224), 57 | transforms.ToTensor(), 58 | normalize, 59 | ] 60 | ) 61 | return train_transforms, val_transforms 62 | 63 | 64 | data_transforms_map = {"cifar100": cifar100_transforms, "imagenet": imagenet_transforms} 65 | 66 | 67 | def get_data_transforms( 68 | dataset_name: str, 69 | ) -> Tuple[transforms.Compose, transforms.Compose]: 70 | """Get training and validation transforms of a dataset. 71 | 72 | :param dataset_name: Name of the dataset (e.g., cifar100, imagenet) 73 | :return: Tuple of training and validation transformations. 74 | """ 75 | return data_transforms_map.get(dataset_name)() 76 | -------------------------------------------------------------------------------- /dataset/sub_image_folder.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023 Apple Inc. All rights reserved. 3 | # 4 | 5 | import os 6 | 7 | import torch 8 | from torchvision import datasets 9 | 10 | from .data_transforms import get_data_transforms 11 | 12 | 13 | class SubImageFolder: 14 | """Class to support training on subset of classes.""" 15 | 16 | def __init__( 17 | self, 18 | name: str, 19 | data_root: str, 20 | num_workers: int, 21 | batch_size: int, 22 | num_classes=None, 23 | ) -> None: 24 | """Construct a SubImageFolder module. 25 | 26 | :param name: Name of the dataset (e.g., cifar100, imagenet). 27 | :param data_root: Path to a directory with training and validation 28 | subdirs of the dataset. 29 | :param num_workers: Number of workers for data loader. 30 | :param batch_size: Global batch size. Per GPU batch size = batch_size/num_gpus. 31 | :param num_classes: Number of classes to use for training. This should 32 | be smaller than or equal to the total number of classes in the dataset. 33 | Note that for evaluation we use all classes. 34 | """ 35 | super(SubImageFolder, self).__init__() 36 | 37 | use_cuda = torch.cuda.is_available() 38 | kwargs = {"num_workers": num_workers, "pin_memory": True} if use_cuda else {} 39 | 40 | traindir = os.path.join(data_root, "training") 41 | valdir = os.path.join(data_root, "validation") 42 | 43 | train_transforms, val_transforms = get_data_transforms(name) 44 | 45 | self.train_dataset = datasets.ImageFolder( 46 | traindir, 47 | train_transforms, 48 | ) 49 | 50 | # Filtering out some classes 51 | if num_classes is not None: 52 | self.train_dataset.imgs = [ 53 | (path, cls_num) 54 | for path, cls_num in self.train_dataset.imgs 55 | if cls_num < num_classes 56 | ] 57 | 58 | self.train_dataset.samples = self.train_dataset.imgs 59 | 60 | self.train_sampler = None 61 | 62 | self.train_loader = torch.utils.data.DataLoader( 63 | self.train_dataset, 64 | batch_size=batch_size, 65 | sampler=self.train_sampler, 66 | shuffle=self.train_sampler is None, 67 | **kwargs 68 | ) 69 | 70 | # Note: for evaluation we use all classes. 71 | self.val_loader = torch.utils.data.DataLoader( 72 | datasets.ImageFolder( 73 | valdir, 74 | val_transforms, 75 | ), 76 | batch_size=batch_size, 77 | shuffle=False, 78 | **kwargs 79 | ) 80 | -------------------------------------------------------------------------------- /demo.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-fct/1c658df902c70b2556ce9677bbca660c80fdf9e6/demo.gif -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023 Apple Inc. All rights reserved. 3 | # 4 | 5 | from typing import Dict 6 | from argparse import ArgumentParser 7 | 8 | import yaml 9 | import torch 10 | 11 | from dataset import SubImageFolder 12 | from utils.eval_utils import cmc_evaluate 13 | 14 | 15 | def main(config: Dict) -> None: 16 | """Run evaluation. 17 | 18 | :param config: A dictionary with all configurations to run evaluation. 19 | """ 20 | device = torch.device("cuda:0") if torch.cuda.is_available() else "cpu" 21 | 22 | # Load models: 23 | gallery_model = torch.jit.load(config.get("gallery_model_path")) 24 | query_model = torch.jit.load(config.get("query_model_path")) 25 | 26 | data = SubImageFolder(**config.get("dataset_params")) 27 | 28 | cmc_evaluate( 29 | gallery_model, query_model, data.val_loader, device, **config.get("eval_params") 30 | ) 31 | 32 | 33 | if __name__ == "__main__": 34 | parser = ArgumentParser() 35 | parser.add_argument( 36 | "--config", 37 | type=str, 38 | required=True, 39 | help="Path to config file for this pipeline.", 40 | ) 41 | args = parser.parse_args() 42 | with open(args.config) as f: 43 | read_config = yaml.safe_load(f) 44 | main(read_config) 45 | -------------------------------------------------------------------------------- /fct_logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-fct/1c658df902c70b2556ce9677bbca660c80fdf9e6/fct_logo.png -------------------------------------------------------------------------------- /get_pretrained_models.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # 3 | # Copyright (C) 2023 Apple Inc. All rights reserved. 4 | # 5 | 6 | wget https://devpubs.s3.amazonaws.com/ml-research/models/fct/imagenet_250_simclr.pt -P checkpoints 7 | wget https://devpubs.s3.amazonaws.com/ml-research/models/fct/imagenet_500_simclr.pt -P checkpoints 8 | wget https://devpubs.s3.amazonaws.com/ml-research/models/fct/imagenet_1000_simclr.pt -P checkpoints 9 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023 Apple Inc. All rights reserved. 3 | # 4 | 5 | from .resnet import ResNet18, ResNet50, ResNet101, WideResNet50_2, WideResNet101_2 6 | from .transformations import MLP_BN_SIDE_PROJECTION, MLP_BN_SIDE_PROJECTION_SIGMA 7 | -------------------------------------------------------------------------------- /models/resnet.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023 Apple Inc. All rights reserved. 3 | # 4 | from typing import Optional, List, Tuple 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | 11 | class BasicBlock(nn.Module): 12 | """Resnet basic block module.""" 13 | 14 | expansion = 1 15 | 16 | def __init__( 17 | self, 18 | inplanes: int, 19 | planes: int, 20 | stride: int = 1, 21 | downsample: Optional[nn.Module] = None, 22 | base_width: int = 64, 23 | nonlin: bool = True, 24 | embedding_dim: Optional[int] = None, 25 | ) -> None: 26 | """Construct a BasicBlock module. 27 | 28 | :param inplanes: Number of input channels. 29 | :param planes: Number of output channels. 30 | :param stride: Stride size. 31 | :param downsample: Down-sampling for residual path. 32 | :param base_width: Base width of the block. 33 | :param nonlin: Whether to apply non-linearity before output. 34 | :param embedding_dim: Size of the output embedding dimension. 35 | """ 36 | super(BasicBlock, self).__init__() 37 | if base_width / 64 > 1: 38 | raise ValueError("Base width >64 does not work for BasicBlock") 39 | if embedding_dim is not None: 40 | planes = embedding_dim 41 | self.conv1 = nn.Conv2d( 42 | inplanes, planes, kernel_size=3, padding=1, stride=stride, bias=False 43 | ) 44 | self.bn1 = nn.BatchNorm2d(planes) 45 | self.relu = nn.ReLU(inplace=True) 46 | self.conv2 = nn.Conv2d( 47 | planes, planes, kernel_size=3, padding=1, stride=1, bias=False 48 | ) 49 | self.bn2 = nn.BatchNorm2d(planes) 50 | self.downsample = downsample 51 | self.stride = stride 52 | self.nonlin = nonlin 53 | 54 | def forward(self, x: torch.Tensor) -> torch.Tensor: 55 | """Apply forward pass.""" 56 | residual = x 57 | 58 | out = self.conv1(x) 59 | if self.bn1 is not None: 60 | out = self.bn1(out) 61 | 62 | out = self.relu(out) 63 | 64 | out = self.conv2(out) 65 | 66 | if self.bn2 is not None: 67 | out = self.bn2(out) 68 | 69 | if self.downsample is not None: 70 | residual = self.downsample(x) 71 | 72 | out += residual 73 | if self.nonlin: 74 | out = self.relu(out) 75 | 76 | return out 77 | 78 | 79 | class Bottleneck(nn.Module): 80 | """Resnet bottleneck block module.""" 81 | 82 | expansion = 4 83 | 84 | def __init__( 85 | self, 86 | inplanes: int, 87 | planes: int, 88 | stride: int = 1, 89 | downsample: Optional[nn.Module] = None, 90 | base_width: int = 64, 91 | nonlin: bool = True, 92 | embedding_dim: Optional[int] = None, 93 | ) -> None: 94 | """Construct a Bottleneck module. 95 | 96 | :param inplanes: Number of input channels. 97 | :param planes: Number of output channels. 98 | :param stride: Stride size. 99 | :param downsample: Down-sampling for residual path. 100 | :param base_width: Base width of the block. 101 | :param nonlin: Whether to apply non-linearity before output. 102 | :param embedding_dim: Size of the output embedding dimension. 103 | """ 104 | super(Bottleneck, self).__init__() 105 | width = int(planes * base_width / 64) 106 | if embedding_dim is not None: 107 | out_dim = embedding_dim 108 | else: 109 | out_dim = planes * self.expansion 110 | self.conv1 = nn.Conv2d(inplanes, width, kernel_size=1, stride=1, bias=False) 111 | self.bn1 = nn.BatchNorm2d(width) 112 | self.conv2 = nn.Conv2d( 113 | width, width, kernel_size=3, padding=1, stride=stride, bias=False 114 | ) 115 | self.bn2 = nn.BatchNorm2d(width) 116 | self.conv3 = nn.Conv2d(width, out_dim, kernel_size=1, stride=1, bias=False) 117 | self.bn3 = nn.BatchNorm2d(out_dim) 118 | self.relu = nn.ReLU(inplace=True) 119 | self.downsample = downsample 120 | self.stride = stride 121 | self.nonlin = nonlin 122 | 123 | def forward(self, x: torch.Tensor) -> torch.Tensor: 124 | """Apply forward pass.""" 125 | residual = x 126 | 127 | out = self.conv1(x) 128 | out = self.bn1(out) 129 | out = self.relu(out) 130 | 131 | out = self.conv2(out) 132 | out = self.bn2(out) 133 | out = self.relu(out) 134 | 135 | out = self.conv3(out) 136 | out = self.bn3(out) 137 | 138 | if self.downsample is not None: 139 | residual = self.downsample(x) 140 | 141 | out += residual 142 | 143 | if self.nonlin: 144 | out = self.relu(out) 145 | 146 | return out 147 | 148 | 149 | class ResNet(nn.Module): 150 | """Resnet module.""" 151 | 152 | def __init__( 153 | self, 154 | block: nn.Module, 155 | layers: List[int], 156 | num_classes: int = 1000, 157 | base_width: int = 64, 158 | embedding_dim: Optional[int] = None, 159 | last_nonlin: bool = True, 160 | norm_feature: bool = False, 161 | ) -> None: 162 | """Construct a ResNet module. 163 | 164 | :param block: Block module to use in Resnet architecture. 165 | :param layers: List of number of blocks per layer. 166 | :param num_classes: Number of classes in the dataset. It is used to 167 | form linear classifier weights. 168 | :param base_width: Base width of the blocks. 169 | :param embedding_dim: Size of the output embedding dimension. 170 | :param last_nonlin: Whether to apply non-linearity before output. 171 | :param norm_feature: Whether to normalized output embeddings. 172 | """ 173 | self.inplanes = 64 174 | super(ResNet, self).__init__() 175 | 176 | self.OUTPUT_SHAPE = [embedding_dim, 1, 1] 177 | self.is_normalized = norm_feature 178 | self.base_width = base_width 179 | if self.base_width // 64 > 1: 180 | print(f"==> Using {self.base_width // 64}x wide model") 181 | 182 | if embedding_dim is not None: 183 | print("Using given embedding dimension = {}".format(embedding_dim)) 184 | self.embedding_dim = embedding_dim 185 | else: 186 | self.embedding_dim = 512 * block.expansion 187 | 188 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, padding=3, stride=2, bias=False) 189 | self.bn1 = nn.BatchNorm2d(64) 190 | self.relu = nn.ReLU(inplace=True) 191 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 192 | self.layer1 = self._make_layer( 193 | block, 64, layers[0], embedding_dim=64 * block.expansion 194 | ) 195 | self.layer2 = self._make_layer( 196 | block, 197 | 128, 198 | layers[1], 199 | stride=2, 200 | embedding_dim=128 * block.expansion, 201 | ) 202 | self.layer3 = self._make_layer( 203 | block, 204 | 256, 205 | layers[2], 206 | stride=2, 207 | embedding_dim=256 * block.expansion, 208 | ) 209 | self.layer4 = self._make_layer( 210 | block, 211 | 512, 212 | layers[3], 213 | stride=2, 214 | nonlin=last_nonlin, 215 | embedding_dim=self.embedding_dim, 216 | ) 217 | self.avgpool = nn.AdaptiveAvgPool2d(1) 218 | 219 | # self.fc = nn.Linear(512 * block.expansion, num_classes) 220 | self.fc = nn.Conv2d( 221 | self.embedding_dim, num_classes, kernel_size=1, stride=1, bias=False 222 | ) 223 | 224 | def _make_layer( 225 | self, 226 | block: nn.Module, 227 | planes: int, 228 | blocks: int, 229 | embedding_dim: int, 230 | stride: int = 1, 231 | nonlin: bool = True, 232 | ): 233 | """Make a layer of resnet architecture. 234 | 235 | :param block: Block module to use in this layer. 236 | :param planes: Number of output channels. 237 | :param blocks: Number of blocks in this layer. 238 | :param embedding_dim: Size of the output embedding dimension. 239 | :param stride: Stride size. 240 | :param nonlin: Whether to apply non-linearity before output. 241 | :return: 242 | """ 243 | downsample = None 244 | if stride != 1 or self.inplanes != planes * block.expansion: 245 | dconv = nn.Conv2d( 246 | self.inplanes, 247 | planes * block.expansion, 248 | kernel_size=1, 249 | stride=stride, 250 | bias=False, 251 | ) 252 | dbn = nn.BatchNorm2d(planes * block.expansion) 253 | if dbn is not None: 254 | downsample = nn.Sequential(dconv, dbn) 255 | else: 256 | downsample = dconv 257 | 258 | last_downsample = None 259 | 260 | layers = [] 261 | if blocks == 1: # If this layer has only one-block 262 | if stride != 1 or self.inplanes != embedding_dim: 263 | dconv = nn.Conv2d( 264 | self.inplanes, 265 | embedding_dim, 266 | kernel_size=1, 267 | stride=stride, 268 | bias=False, 269 | ) 270 | dbn = nn.BatchNorm2d(embedding_dim) 271 | if dbn is not None: 272 | last_downsample = nn.Sequential(dconv, dbn) 273 | else: 274 | last_downsample = dconv 275 | layers.append( 276 | block( 277 | self.inplanes, 278 | planes, 279 | stride, 280 | last_downsample, 281 | base_width=self.base_width, 282 | nonlin=nonlin, 283 | embedding_dim=embedding_dim, 284 | ) 285 | ) 286 | return nn.Sequential(*layers) 287 | else: 288 | layers.append( 289 | block( 290 | self.inplanes, 291 | planes, 292 | stride, 293 | downsample, 294 | base_width=self.base_width, 295 | ) 296 | ) 297 | self.inplanes = planes * block.expansion 298 | for i in range(1, blocks - 1): 299 | layers.append(block(self.inplanes, planes, base_width=self.base_width)) 300 | 301 | if self.inplanes != embedding_dim: 302 | dconv = nn.Conv2d( 303 | self.inplanes, embedding_dim, stride=1, kernel_size=1, bias=False 304 | ) 305 | dbn = nn.BatchNorm2d(embedding_dim) 306 | if dbn is not None: 307 | last_downsample = nn.Sequential(dconv, dbn) 308 | else: 309 | last_downsample = dconv 310 | layers.append( 311 | block( 312 | self.inplanes, 313 | planes, 314 | downsample=last_downsample, 315 | base_width=self.base_width, 316 | nonlin=nonlin, 317 | embedding_dim=embedding_dim, 318 | ) 319 | ) 320 | 321 | return nn.Sequential(*layers) 322 | 323 | def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: 324 | """Apply forward pass. 325 | 326 | :param x: input to the model with shape (N, C, H, W). 327 | :return: Tuple of (logits, embedding) 328 | """ 329 | x = self.conv1(x) 330 | 331 | if self.bn1 is not None: 332 | x = self.bn1(x) 333 | x = self.relu(x) 334 | x = self.maxpool(x) 335 | 336 | x = self.layer1(x) 337 | x = self.layer2(x) 338 | x = self.layer3(x) 339 | x = self.layer4(x) 340 | 341 | feature = self.avgpool(x) 342 | if self.is_normalized: 343 | feature = F.normalize(feature) 344 | 345 | x = self.fc(feature) 346 | x = x.view(x.size(0), -1) 347 | 348 | return x, feature 349 | 350 | 351 | def ResNet18( 352 | num_classes: int, embedding_dim: int, last_nonlin: bool = True, **kwargs 353 | ) -> nn.Module: 354 | """Get a ResNet18 model. 355 | 356 | :param num_classes: Number of classes in the dataset. 357 | :param embedding_dim: Size of the output embedding dimension. 358 | :param last_nonlin: Whether to apply non-linearity before output. 359 | :return: ResNet18 Model. 360 | """ 361 | return ResNet( 362 | BasicBlock, 363 | [2, 2, 2, 2], 364 | num_classes=num_classes, 365 | embedding_dim=embedding_dim, 366 | last_nonlin=last_nonlin, 367 | ) 368 | 369 | 370 | def ResNet50( 371 | num_classes: int, embedding_dim: int, last_nonlin: bool = True, **kwargs 372 | ) -> nn.Module: 373 | """Get a ResNet50 model. 374 | 375 | :param num_classes: Number of classes in the dataset. 376 | :param embedding_dim: Size of the output embedding dimension. 377 | :param last_nonlin: Whether to apply non-linearity before output. 378 | :return: ResNet18 Model. 379 | """ 380 | return ResNet( 381 | Bottleneck, 382 | [3, 4, 6, 3], 383 | num_classes=num_classes, 384 | embedding_dim=embedding_dim, 385 | last_nonlin=last_nonlin, 386 | ) 387 | 388 | 389 | def ResNet101( 390 | num_classes: int, embedding_dim: int, last_nonlin: bool = True, **kwargs 391 | ) -> nn.Module: 392 | """Get a ResNet101 model. 393 | 394 | :param num_classes: Number of classes in the dataset. 395 | :param embedding_dim: Size of the output embedding dimension. 396 | :param last_nonlin: Whether to apply non-linearity before output. 397 | :return: ResNet18 Model. 398 | """ 399 | return ResNet( 400 | Bottleneck, 401 | [3, 4, 23, 3], 402 | num_classes=num_classes, 403 | eembedding_dim=embedding_dim, 404 | last_nonlin=last_nonlin, 405 | ) 406 | 407 | 408 | def WideResNet50_2( 409 | num_classes: int, embedding_dim: int, last_nonlin: bool = True, **kwargs 410 | ) -> nn.Module: 411 | """Get a WideResNet50 model. 412 | 413 | :param num_classes: Number of classes in the dataset. 414 | :param embedding_dim: Size of the output embedding dimension. 415 | :param last_nonlin: Whether to apply non-linearity before output. 416 | :return: ResNet18 Model. 417 | """ 418 | return ResNet( 419 | Bottleneck, 420 | [3, 4, 6, 3], 421 | num_classes=num_classes, 422 | base_width=64 * 2, 423 | embedding_dim=embedding_dim, 424 | last_nonlin=last_nonlin, 425 | ) 426 | 427 | 428 | def WideResNet101_2( 429 | num_classes: int, embedding_dim: int, last_nonlin: bool = True, **kwargs 430 | ) -> nn.Module: 431 | """Get a WideResNet101 model. 432 | 433 | :param num_classes: Number of classes in the dataset. 434 | :param embedding_dim: Size of the output embedding dimension. 435 | :param last_nonlin: Whether to apply non-linearity before output. 436 | :return: ResNet18 Model. 437 | """ 438 | return ResNet( 439 | Bottleneck, 440 | [3, 4, 23, 3], 441 | num_classes=num_classes, 442 | base_width=64 * 2, 443 | embedding_dim=embedding_dim, 444 | last_nonlin=last_nonlin, 445 | ) 446 | -------------------------------------------------------------------------------- /models/transformations.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023 Apple Inc. All rights reserved. 3 | # 4 | 5 | from typing import Optional 6 | 7 | import torch 8 | import torch.nn as nn 9 | 10 | 11 | class ConvBlock(nn.Module): 12 | """Convenience convolution module.""" 13 | 14 | def __init__( 15 | self, 16 | channels_in: int, 17 | channels_out: int, 18 | kernel_size: int = 1, 19 | stride: int = 1, 20 | normalizer: Optional[nn.Module] = nn.BatchNorm2d, 21 | activation: Optional[nn.Module] = nn.ReLU, 22 | ) -> None: 23 | """Construct a ConvBlock module. 24 | 25 | :param channels_in: Number of input channels. 26 | :param channels_out: Number of output channels. 27 | :param kernel_size: Size of the kernel. 28 | :param stride: Size of the convolution stride. 29 | :param normalizer: Optional normalization to use. 30 | :param activation: Optional activation module to use. 31 | """ 32 | super().__init__() 33 | 34 | self.conv = nn.Conv2d( 35 | channels_in, 36 | channels_out, 37 | kernel_size=kernel_size, 38 | stride=stride, 39 | bias=normalizer is None, 40 | padding=kernel_size // 2, 41 | ) 42 | if normalizer is not None: 43 | self.normalizer = normalizer(channels_out) 44 | else: 45 | self.normalizer = None 46 | if activation is not None: 47 | self.activation = activation() 48 | else: 49 | self.activation = None 50 | 51 | def forward(self, x: torch.Tensor) -> torch.Tensor: 52 | """Apply forward pass.""" 53 | x = self.conv(x) 54 | if self.normalizer is not None: 55 | x = self.normalizer(x) 56 | if self.activation is not None: 57 | x = self.activation(x) 58 | return x 59 | 60 | 61 | class MLP_BN_SIDE_PROJECTION(nn.Module): 62 | """FCT transformation module.""" 63 | 64 | def __init__( 65 | self, 66 | old_embedding_dim: int, 67 | new_embedding_dim: int, 68 | side_info_dim: int, 69 | inner_dim: int = 2048, 70 | **kwargs 71 | ) -> None: 72 | """Construct MLP_BN_SIDE_PROJECTION module. 73 | 74 | :param old_embedding_dim: Size of the old embeddings. 75 | :param new_embedding_dim: Size of the new embeddings. 76 | :param side_info_dim: Size of the side-information. 77 | :param inner_dim: Dimension of transformation MLP inner layer. 78 | """ 79 | super().__init__() 80 | 81 | self.inner_dim = inner_dim 82 | self.p1 = nn.Sequential( 83 | ConvBlock(old_embedding_dim, 2 * old_embedding_dim), 84 | ConvBlock(2 * old_embedding_dim, 2 * new_embedding_dim), 85 | ) 86 | 87 | self.p2 = nn.Sequential( 88 | ConvBlock(side_info_dim, 2 * side_info_dim), 89 | ConvBlock(2 * side_info_dim, 2 * new_embedding_dim), 90 | ) 91 | 92 | self.mixer = nn.Sequential( 93 | ConvBlock(4 * new_embedding_dim, self.inner_dim), 94 | ConvBlock(self.inner_dim, self.inner_dim), 95 | ConvBlock( 96 | self.inner_dim, new_embedding_dim, normalizer=None, activation=None 97 | ), 98 | ) 99 | 100 | def forward( 101 | self, old_feature: torch.Tensor, side_info: torch.Tensor 102 | ) -> torch.Tensor: 103 | """Apply forward pass. 104 | 105 | :param old_feature: Old embedding. 106 | :param side_info: Side-information. 107 | :return: Recycled old embedding compatible with new embeddings. 108 | """ 109 | x1 = self.p1(old_feature) 110 | x2 = self.p2(side_info) 111 | return self.mixer(torch.cat([x1, x2], dim=1)) 112 | 113 | 114 | class MLP_BN_SIDE_PROJECTION_SIGMA(nn.Module): 115 | """FastFill transformation module.""" 116 | 117 | def __init__( 118 | self, 119 | old_embedding_dim: int, 120 | new_embedding_dim: int, 121 | side_info_dim: int, 122 | inner_dim: int = 2048, 123 | sigma_dim: int = 1, 124 | **kwargs 125 | ) -> None: 126 | """Construct MLP_BN_SIDE_PROJECTION module. 127 | 128 | :param old_embedding_dim: Size of the old embeddings. 129 | :param new_embedding_dim: Size of the new embeddings. 130 | :param side_info_dim: Size of the side-information. 131 | :param inner_dim: Dimension of transformation MLP inner layer. 132 | :param sigma_dim: Size of the uncertainty output choices=[1, new_embedding_dim]. 133 | """ 134 | super().__init__() 135 | 136 | assert sigma_dim in [1, new_embedding_dim] 137 | self.uncertainty_head = nn.Linear(new_embedding_dim, sigma_dim) 138 | self.inner_dim = inner_dim 139 | self.p1 = nn.Sequential( 140 | ConvBlock(old_embedding_dim, 2 * old_embedding_dim), 141 | ConvBlock(2 * old_embedding_dim, 2 * new_embedding_dim), 142 | ) 143 | 144 | self.p2 = nn.Sequential( 145 | ConvBlock(side_info_dim, 2 * side_info_dim), 146 | ConvBlock(2 * side_info_dim, 2 * new_embedding_dim), 147 | ) 148 | 149 | self.mixer = nn.Sequential( 150 | ConvBlock(4 * new_embedding_dim, self.inner_dim), 151 | ConvBlock(self.inner_dim, self.inner_dim), 152 | ConvBlock( 153 | self.inner_dim, new_embedding_dim, normalizer=None, activation=None 154 | ), 155 | ) 156 | 157 | def forward( 158 | self, old_feature: torch.Tensor, side_info: torch.Tensor 159 | ) -> torch.Tensor: 160 | """Apply forward pass. 161 | 162 | :param old_feature: Old embedding. 163 | :param side_info: Side-information. 164 | :return: Recycled old embedding compatible with new embeddings concatenated with sigma. 165 | """ 166 | x1 = self.p1(old_feature) 167 | x2 = self.p2(side_info) 168 | 169 | transformed_features = self.mixer(torch.cat([x1, x2], dim=1)) 170 | sigma = self.uncertainty_head(transformed_features.squeeze()) 171 | transformed_features = torch.cat( 172 | (transformed_features, sigma.unsqueeze(dim=2).unsqueeze(dim=3)), dim=1 173 | ) 174 | 175 | return transformed_features 176 | -------------------------------------------------------------------------------- /prepare_dataset.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023 Apple Inc. All rights reserved. 3 | # 4 | 5 | import os 6 | import pickle 7 | 8 | from PIL import Image 9 | from tqdm import tqdm 10 | 11 | from torchvision.datasets import CIFAR100 12 | 13 | 14 | def save_pngs(untar_dir: str, split: str) -> None: 15 | """Save loaded data as png files. 16 | 17 | :param untar_dir: Path to untared dataset. 18 | :param split: Split name (e.g., train, test) 19 | """ 20 | split_map = {"train": "training", "test": "validation"} 21 | split_dir = os.path.join(untar_dir, split_map.get(split)) 22 | 23 | os.makedirs(split_dir, exist_ok=True) 24 | 25 | for i in range(100): 26 | class_dir = os.path.join(split_dir, str(i)) 27 | os.makedirs(class_dir, exist_ok=True) 28 | 29 | with open(os.path.join(untar_dir, split), "rb") as f: 30 | data_dict = pickle.load(f, encoding="latin1") 31 | 32 | data = data_dict.get("data") # numpy array 33 | # Reshape and cast 34 | data = data.reshape(data.shape[0], 3, 32, 32) 35 | data = data.transpose(0, 2, 3, 1).astype("uint8") 36 | 37 | labels = data_dict.get("fine_labels") 38 | 39 | for i, (datum, label) in tqdm(enumerate(zip(data, labels)), total=len(labels)): 40 | image = Image.fromarray(datum) 41 | image = image.convert("RGB") 42 | file_path = os.path.join(split_dir, str(label), "{}.png".format(i)) 43 | image.save(file_path) 44 | 45 | 46 | def get_cifar100() -> None: 47 | """Get and reformat cifar100 dataset. 48 | 49 | See https://www.cs.toronto.edu/~kriz/cifar.html for dataset description. 50 | """ 51 | data_store_dir = "data_store" 52 | 53 | if not os.path.exists(data_store_dir): 54 | os.makedirs(data_store_dir) 55 | 56 | dataset = CIFAR100(root=data_store_dir, download=True) 57 | 58 | # Load files and convert to PNG 59 | untar_dir = os.path.join(data_store_dir, dataset.base_folder) 60 | save_pngs(untar_dir, "test") 61 | save_pngs(untar_dir, "train") 62 | 63 | 64 | if __name__ == "__main__": 65 | get_cifar100() 66 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | Pillow 3 | torch 4 | torchvision 5 | tqdm 6 | scikit-learn 7 | PyYAML 8 | -------------------------------------------------------------------------------- /train_backbone.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023 Apple Inc. All rights reserved. 3 | # 4 | 5 | from typing import Dict 6 | from argparse import ArgumentParser 7 | 8 | import yaml 9 | from PIL import ImageFile 10 | 11 | ImageFile.LOAD_TRUNCATED_IMAGES = True 12 | 13 | import torch 14 | import torch.nn as nn 15 | 16 | from trainers import BackboneTrainer 17 | from dataset import SubImageFolder 18 | from utils.net_utils import LabelSmoothing, backbone_to_torchscript 19 | from utils.schedulers import get_policy 20 | from utils.getters import get_model, get_optimizer 21 | 22 | 23 | def main(config: Dict) -> None: 24 | """Run training. 25 | 26 | :param config: A dictionary with all configurations to run training. 27 | :return: 28 | """ 29 | model = get_model(config.get("arch_params")) 30 | 31 | device = torch.device("cuda:0") if torch.cuda.is_available() else "cpu" 32 | torch.backends.cudnn.benchmark = True 33 | 34 | if torch.cuda.is_available(): 35 | model = torch.nn.DataParallel(model) 36 | model.to(device) 37 | 38 | trainer = BackboneTrainer() 39 | optimizer = get_optimizer(model, **config.get("optimizer_params")) 40 | data = SubImageFolder(**config.get("dataset_params")) 41 | lr_policy = get_policy(optimizer, **config.get("lr_policy_params")) 42 | 43 | if config.get("label_smoothing") is None: 44 | criterion = nn.CrossEntropyLoss() 45 | else: 46 | criterion = LabelSmoothing(smoothing=config.get("label_smoothing")) 47 | 48 | # Training loop 49 | for epoch in range(config.get("epochs")): 50 | lr_policy(epoch, iteration=None) 51 | 52 | train_acc1, train_acc5, train_loss = trainer.train( 53 | train_loader=data.train_loader, 54 | model=model, 55 | criterion=criterion, 56 | optimizer=optimizer, 57 | device=device, 58 | ) 59 | 60 | print( 61 | "Train: epoch = {}, Loss = {}, Top 1 = {}, Top 5 = {}".format( 62 | epoch, train_loss, train_acc1, train_acc5 63 | ) 64 | ) 65 | 66 | test_acc1, test_acc5, test_loss = trainer.validate( 67 | val_loader=data.val_loader, 68 | model=model, 69 | criterion=criterion, 70 | device=device, 71 | ) 72 | 73 | print( 74 | "Test: epoch = {}, Loss = {}, Top 1 = {}, Top 5 = {}".format( 75 | epoch, test_loss, test_acc1, test_acc5 76 | ) 77 | ) 78 | 79 | backbone_to_torchscript(model, config.get("output_model_path")) 80 | 81 | 82 | if __name__ == "__main__": 83 | parser = ArgumentParser() 84 | parser.add_argument( 85 | "--config", 86 | type=str, 87 | required=True, 88 | help="Path to config file for this pipeline.", 89 | ) 90 | args = parser.parse_args() 91 | with open(args.config) as f: 92 | read_config = yaml.safe_load(f) 93 | main(read_config) 94 | -------------------------------------------------------------------------------- /train_transformation.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023 Apple Inc. All rights reserved. 3 | # 4 | 5 | from typing import Dict 6 | 7 | import yaml 8 | from argparse import ArgumentParser 9 | 10 | import torch 11 | 12 | from PIL import ImageFile 13 | 14 | ImageFile.LOAD_TRUNCATED_IMAGES = True 15 | 16 | from trainers import TransformationTrainer 17 | from dataset import SubImageFolder 18 | from utils.net_utils import transformation_to_torchscripts 19 | from utils.schedulers import get_policy 20 | from utils.getters import get_model, get_optimizer, get_criteria 21 | 22 | 23 | def main(config: Dict) -> None: 24 | """Run training. 25 | 26 | :param config: A dictionary with all configurations to run training. 27 | :return: 28 | """ 29 | device = torch.device("cuda:0") if torch.cuda.is_available() else "cpu" 30 | torch.backends.cudnn.benchmark = True 31 | 32 | model = get_model(config.get("arch_params")) 33 | old_model = torch.jit.load(config.get("old_model_path")) 34 | new_model = torch.jit.load(config.get("new_model_path")) 35 | 36 | if torch.cuda.is_available(): 37 | model = torch.nn.DataParallel(model) 38 | old_model = torch.nn.DataParallel(old_model) 39 | new_model = torch.nn.DataParallel(new_model) 40 | 41 | model.to(device) 42 | old_model.to(device) 43 | new_model.to(device) 44 | 45 | if config.get("side_info_model_path") is not None: 46 | side_info_model = torch.jit.load(config.get("side_info_model_path")) 47 | if torch.cuda.is_available(): 48 | side_info_model = torch.nn.DataParallel(side_info_model) 49 | side_info_model.to(device) 50 | else: 51 | side_info_model = old_model 52 | 53 | optimizer = get_optimizer(model, **config.get("optimizer_params")) 54 | data = SubImageFolder(**config.get("dataset_params")) 55 | lr_policy = get_policy(optimizer, **config.get("lr_policy_params")) 56 | mus, criteria = get_criteria(**config.get("objective_params", {})) 57 | trainer = TransformationTrainer( 58 | old_model, new_model, side_info_model, **mus, **criteria 59 | ) 60 | 61 | for epoch in range(config.get("epochs")): 62 | lr_policy(epoch, iteration=None) 63 | 64 | if config.get("switch_mode_to_eval"): 65 | switch_mode_to_eval = epoch >= config.get("epochs") / 2 66 | else: 67 | switch_mode_to_eval = False 68 | 69 | train_loss = trainer.train( 70 | train_loader=data.train_loader, 71 | model=model, 72 | optimizer=optimizer, 73 | device=device, 74 | switch_mode_to_eval=switch_mode_to_eval, 75 | ) 76 | 77 | print("Train: epoch = {}, Average Loss = {}".format(epoch, train_loss)) 78 | 79 | # evaluate on validation set 80 | test_loss = trainer.validate( 81 | val_loader=data.val_loader, 82 | model=model, 83 | device=device, 84 | ) 85 | 86 | print("Test: epoch = {}, Average Loss = {}".format(epoch, test_loss)) 87 | 88 | transformation_to_torchscripts( 89 | old_model, 90 | side_info_model, 91 | model, 92 | config.get("output_transformation_path"), 93 | config.get("output_transformed_old_model_path"), 94 | ) 95 | 96 | 97 | if __name__ == "__main__": 98 | parser = ArgumentParser() 99 | parser.add_argument( 100 | "--config", 101 | type=str, 102 | required=True, 103 | help="Path to config file for this pipeline.", 104 | ) 105 | args = parser.parse_args() 106 | with open(args.config) as f: 107 | read_config = yaml.safe_load(f) 108 | main(read_config) 109 | -------------------------------------------------------------------------------- /trainers/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023 Apple Inc. All rights reserved. 3 | # 4 | 5 | from .backbone_trainer import BackboneTrainer 6 | from .transformation_trainer import TransformationTrainer 7 | -------------------------------------------------------------------------------- /trainers/backbone_trainer.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023 Apple Inc. All rights reserved. 3 | # 4 | 5 | from typing import Tuple, Callable 6 | 7 | import tqdm 8 | import torch 9 | import torch.nn as nn 10 | 11 | from utils.logging_utils import AverageMeter 12 | from utils.eval_utils import accuracy 13 | 14 | 15 | class BackboneTrainer: 16 | """Class to train and evaluate backbones.""" 17 | 18 | def train( 19 | self, 20 | train_loader: torch.utils.data.DataLoader, 21 | model: nn.Module, 22 | criterion: Callable, 23 | optimizer: torch.optim.Optimizer, 24 | device: torch.device, 25 | ) -> Tuple[float, float, float]: 26 | """Run one epoch of training. 27 | 28 | :param train_loader: Data loader to train the model. 29 | :param model: Model to be trained. 30 | :param criterion: Loss criterion module. 31 | :param optimizer: A torch optimizer object. 32 | :param device: Device the model is on. 33 | :return: average of top-1, top-5, and loss on current epoch. 34 | """ 35 | losses = AverageMeter("Loss", ":.3f") 36 | top1 = AverageMeter("Acc@1", ":6.2f") 37 | top5 = AverageMeter("Acc@5", ":6.2f") 38 | 39 | model.train() 40 | 41 | for i, (images, target) in tqdm.tqdm( 42 | enumerate(train_loader), ascii=True, total=len(train_loader) 43 | ): 44 | images = images.to(device, non_blocking=True) 45 | target = target.to(device, non_blocking=True) 46 | 47 | output, _ = model(images) 48 | loss = criterion(output, target) 49 | 50 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 51 | losses.update(loss.item(), images.size(0)) 52 | top1.update(acc1.item(), images.size(0)) 53 | top5.update(acc5.item(), images.size(0)) 54 | 55 | optimizer.zero_grad() 56 | loss.backward() 57 | optimizer.step() 58 | 59 | return top1.avg, top5.avg, losses.avg 60 | 61 | def validate( 62 | self, 63 | val_loader: torch.utils.data.DataLoader, 64 | model: nn.Module, 65 | criterion: Callable, 66 | device: torch.device, 67 | ) -> Tuple[float, float, float]: 68 | """Run validation. 69 | 70 | :param val_loader: Data loader to evaluate the model. 71 | :param model: Model to be evaluated. 72 | :param criterion: Loss criterion module. 73 | :param device: Device the model is on. 74 | :return: average of top-1, top-5, and loss on current epoch. 75 | """ 76 | losses = AverageMeter("Loss", ":.3f") 77 | top1 = AverageMeter("Acc@1", ":6.2f") 78 | top5 = AverageMeter("Acc@5", ":6.2f") 79 | 80 | model.eval() 81 | 82 | with torch.no_grad(): 83 | for i, (images, target) in tqdm.tqdm( 84 | enumerate(val_loader), ascii=True, total=len(val_loader) 85 | ): 86 | images = images.to(device, non_blocking=True) 87 | target = target.to(device, non_blocking=True) 88 | 89 | output, _ = model(images) 90 | 91 | loss = criterion(output, target) 92 | 93 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 94 | losses.update(loss.item(), images.size(0)) 95 | top1.update(acc1.item(), images.size(0)) 96 | top5.update(acc5.item(), images.size(0)) 97 | 98 | return top1.avg, top5.avg, losses.avg 99 | -------------------------------------------------------------------------------- /trainers/transformation_trainer.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023 Apple Inc. All rights reserved. 3 | # 4 | 5 | from typing import Optional, Union 6 | 7 | import tqdm 8 | import torch 9 | import torch.nn as nn 10 | 11 | from utils.logging_utils import AverageMeter 12 | 13 | 14 | class TransformationTrainer: 15 | """Class to train and evaluate transformation models.""" 16 | 17 | def __init__( 18 | self, 19 | old_model: Union[nn.Module, torch.jit.ScriptModule], 20 | new_model: Union[nn.Module, torch.jit.ScriptModule], 21 | side_info_model: Union[nn.Module, torch.jit.ScriptModule], 22 | mu_similarity: float = 1, 23 | mu_disc: Optional[float] = None, 24 | criterion_similarity: Optional[nn.Module] = None, 25 | criterion_disc: Optional[Union[torch.jit.ScriptModule, nn.Module]] = None, 26 | criterion_uncertainty: Optional[nn.Module] = None, 27 | **kwargs 28 | ) -> None: 29 | """Construct a TransformationTrainer module. 30 | 31 | :param old_model: A model that returns old embedding given x. 32 | :param new_model: A model that returns new embedding given x. 33 | :param side_info_model: A model that returns side-info given x. 34 | :param mu_similarity: hyperparameter for similarity loss 35 | :param mu_disc: hyperparameter for classification loss 36 | :param criterion_similarity: objective function computing the similarity between new and h(old) features. 37 | :param criterion_disc: objective function for the classification head for h(old). 38 | :param criterion_uncertainty: Uncertainty based Loss function. 39 | """ 40 | 41 | self.old_model = old_model 42 | self.old_model.eval() 43 | self.new_model = new_model 44 | self.new_model.eval() 45 | self.side_info_model = side_info_model 46 | self.side_info_model.eval() 47 | 48 | self.mu_similarity = mu_similarity 49 | self.mu_disc = mu_disc 50 | 51 | self.criterion_similarity = criterion_similarity 52 | self.criterion_disc = criterion_disc 53 | self.criterion_uncertainty = criterion_uncertainty 54 | 55 | def compute_loss( 56 | self, 57 | new_feature: torch.Tensor, 58 | recycled_feature: torch.Tensor, 59 | target: torch.Tensor, 60 | sigma: torch.Tensor, 61 | ) -> torch.Tensor: 62 | """Compute total loss for a batch. 63 | 64 | :param new_feature: Tensor of features computed by new model. 65 | :param recycled_feature: Tensor of features computed by transformation model. 66 | :param target: Labels tensor. 67 | :param sigma: Tensor of sigmas computed by transformation. 68 | :return: Total loss tensor. 69 | """ 70 | loss = 0 71 | if self.criterion_similarity is not None: 72 | similarity_loss = self.criterion_similarity( 73 | new_feature.squeeze(), recycled_feature.squeeze() 74 | ) 75 | if len(similarity_loss.shape) > 1: 76 | similarity_loss = similarity_loss.mean(dim=1) 77 | loss += self.mu_similarity * similarity_loss 78 | 79 | if self.criterion_disc is not None: 80 | if isinstance(self.new_model, torch.nn.DataParallel): 81 | fc_layer = self.new_model.module.model.fc 82 | else: 83 | fc_layer = self.new_model.model.fc 84 | logits = fc_layer(recycled_feature[:, : new_feature.size()[1]]) 85 | loss_disc = self.criterion_disc(logits.squeeze(), target) 86 | loss += self.mu_disc * loss_disc 87 | 88 | if self.criterion_uncertainty: 89 | loss = self.criterion_uncertainty(sigma, loss) 90 | return loss 91 | 92 | def train( 93 | self, 94 | train_loader: torch.utils.data.DataLoader, 95 | model: nn.Module, 96 | optimizer: torch.optim.Optimizer, 97 | device: torch.device, 98 | switch_mode_to_eval: bool, 99 | ) -> float: 100 | """Run one epoch of training. 101 | 102 | :param train_loader: Data loader to train the model. 103 | :param model: Model to be trained. 104 | :param optimizer: A torch optimizer object. 105 | :param device: Device the model is on. 106 | :param switch_mode_to_eval: If true model is train on eval mode! 107 | :return: Average loss on current epoch. 108 | """ 109 | losses = AverageMeter("Loss", ":.3f") 110 | 111 | if switch_mode_to_eval: 112 | model.eval() 113 | else: 114 | model.train() 115 | 116 | for i, (images, target) in tqdm.tqdm( 117 | enumerate(train_loader), ascii=True, total=len(train_loader) 118 | ): 119 | images = images.to(device, non_blocking=True) 120 | target = target.to(device, non_blocking=True) # only needed by L_disc 121 | 122 | with torch.no_grad(): 123 | old_feature = self.old_model(images) 124 | new_feature = self.new_model(images) 125 | side_info = self.side_info_model(images) 126 | 127 | recycled_feature = model(old_feature, side_info) 128 | sigma = recycled_feature[:, new_feature.size()[1] :] 129 | recycled_feature = recycled_feature[:, : new_feature.size()[1]] 130 | 131 | loss = self.compute_loss(new_feature, recycled_feature, target, sigma) 132 | losses.update(loss.item(), images.size(0)) 133 | 134 | optimizer.zero_grad() 135 | loss.backward() 136 | optimizer.step() 137 | 138 | return losses.avg 139 | 140 | def validate( 141 | self, 142 | val_loader: torch.utils.data.DataLoader, 143 | model: nn.Module, 144 | device: torch.device, 145 | ) -> float: 146 | """Run validation. 147 | 148 | :param val_loader: Data loader to evaluate the model. 149 | :param model: Model to be evaluated. 150 | :param device: Device the model is on. 151 | :return: Average loss on current epoch. 152 | """ 153 | losses = AverageMeter("Loss", ":.3f") 154 | model.eval() 155 | 156 | for i, (images, target) in tqdm.tqdm( 157 | enumerate(val_loader), ascii=True, total=len(val_loader) 158 | ): 159 | images = images.to(device, non_blocking=True) 160 | target = target.to(device, non_blocking=True) # only needed by L_disc 161 | 162 | with torch.no_grad(): 163 | old_feature = self.old_model(images) 164 | new_feature = self.new_model(images) 165 | side_info = self.side_info_model(images) 166 | 167 | recycled_feature = model(old_feature, side_info) 168 | sigma = recycled_feature[:, new_feature.size()[1] :] 169 | recycled_feature = recycled_feature[:, : new_feature.size()[1]] 170 | 171 | loss = self.compute_loss(new_feature, recycled_feature, target, sigma) 172 | losses.update(loss.item(), images.size(0)) 173 | 174 | return losses.avg 175 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023 Apple Inc. All rights reserved. 3 | # 4 | -------------------------------------------------------------------------------- /utils/eval_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023 Apple Inc. All rights reserved. 3 | # 4 | 5 | from typing import Union, Tuple, List, Optional, Callable 6 | import copy 7 | 8 | from sklearn.metrics import average_precision_score 9 | import numpy as np 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | import tqdm 14 | 15 | 16 | def accuracy(output, target, topk=(1,)): 17 | """Compute the accuracy over the k top predictions. 18 | 19 | From https://github.com/YantaoShen/openBCT/blob/main/main.py 20 | """ 21 | with torch.no_grad(): 22 | maxk = max(topk) 23 | batch_size = target.shape[0] 24 | 25 | _, pred = output.topk(maxk, 1, True, True) 26 | pred = pred.t() 27 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 28 | 29 | res = [] 30 | for k in topk: 31 | correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) 32 | res.append(correct_k.mul_(100.0 / batch_size)) 33 | return res 34 | 35 | 36 | def mean_ap( 37 | distance_matrix: torch.Tensor, 38 | labels: torch.Tensor, 39 | ) -> float: 40 | """Get pair-wise cosine distances. 41 | 42 | :param distance_matrix: pairwise distance matrix between embeddings of gallery and query sets, shape = (n, n) 43 | :param labels: labels for the query data (assuming the same as gallery), shape = (n,) 44 | 45 | :return: mean average precision (float) 46 | """ 47 | distance_matrix = distance_matrix 48 | m, n = distance_matrix.shape 49 | assert m == n 50 | 51 | # Sort and find correct matches 52 | distance_matrix, gallery_matched_indices = torch.sort(distance_matrix, dim=1) 53 | distance_matrix = distance_matrix.cpu().numpy() 54 | gallery_matched_indices = gallery_matched_indices.cpu().numpy() 55 | 56 | truth_mask = labels[gallery_matched_indices] == labels[:, None] 57 | truth_mask = truth_mask.cpu().numpy() 58 | 59 | # Compute average precision for each query 60 | average_precisions = list() 61 | for query_index in range(n): 62 | 63 | valid_sorted_match_indices = ( 64 | gallery_matched_indices[query_index, :] != query_index 65 | ) 66 | y_true = truth_mask[query_index, valid_sorted_match_indices] 67 | y_score = -distance_matrix[query_index][valid_sorted_match_indices] 68 | if not np.any(y_true): 69 | continue # if a query does not have any match, we exclude it from mAP calculation. 70 | average_precisions.append(average_precision_score(y_true, y_score)) 71 | return np.mean(average_precisions) 72 | 73 | 74 | def cosine_distance_matrix( 75 | x: torch.Tensor, y: torch.Tensor, diag_only: bool = False 76 | ) -> torch.Tensor: 77 | """Get pair-wise cosine distances. 78 | 79 | :param x: A torch feature tensor with shape (n, d). 80 | :param y: A torch feature tensor with shape (n, d). 81 | :param diag_only: if True, only diagonal of distance matrix is computed and returned. 82 | :return: Distance tensor between features x and y with shape (n, n) if diag_only is False. Otherwise, elementwise 83 | distance tensor with shape (n,). 84 | """ 85 | x_norm = F.normalize(x, p=1, dim=-1) 86 | y_norm = F.normalize(y, p=1, dim=-1) 87 | if diag_only: 88 | return 1.0 - torch.sum(x_norm * y_norm, dim=1) 89 | return 1.0 - x_norm @ y_norm.T 90 | 91 | 92 | def l2_distance_matrix( 93 | x: torch.Tensor, y: torch.Tensor, diag_only: bool = False 94 | ) -> torch.Tensor: 95 | """Get pair-wise l2 distances. 96 | 97 | :param x: A torch feature tensor with shape (n, d). 98 | :param y: A torch feature tensor with shape (n, d). 99 | :param diag_only: if True, only diagonal of distance matrix is computed and returned. 100 | :return: Distance tensor between features x and y with shape (n, n) if diag_only is False. Otherwise, elementwise 101 | distance tensor with shape (n,). 102 | """ 103 | if diag_only: 104 | return torch.norm(x - y, dim=1, p=2) 105 | return torch.cdist(x, y, p=2) 106 | 107 | 108 | def cmc_optimized( 109 | distmat: torch.Tensor, 110 | query_ids: Optional[torch.Tensor] = None, 111 | topk: int = 5, 112 | ) -> Tuple[float, float]: 113 | """Compute Cumulative Matching Characteristics metric. 114 | 115 | :param distmat: pairwise distance matrix between embeddings of gallery and query sets 116 | :param query_ids: labels for the query data. We're assuming query_ids and gallery_ids are the same. 117 | :param topk: parameter for top k retrieval 118 | :return: CMC top-1 and top-5 floats, as well as per-query top-1 and top-5 values. 119 | """ 120 | distmat = copy.deepcopy(distmat) 121 | query_ids = copy.deepcopy(query_ids) 122 | 123 | distmat.fill_diagonal_(float("inf")) 124 | distmat_new_old_sorted, indices = torch.sort(distmat) 125 | labels = query_ids.unsqueeze(dim=0).repeat(query_ids.shape[0], 1) 126 | sorted_labels = torch.gather(labels, 1, indices) 127 | 128 | top1_retrieval = sorted_labels[:, 0] == query_ids 129 | top5_retrieval = ( 130 | (sorted_labels[:, :topk] == query_ids.unsqueeze(1)).sum(dim=1).clamp(max=1) 131 | ) 132 | 133 | top1 = top1_retrieval.sum() / query_ids.shape[0] 134 | top5 = top5_retrieval.sum() / query_ids.shape[0] 135 | 136 | return float(top1), float(top5) 137 | 138 | 139 | def generate_feature_matrix( 140 | gallery_model: Union[nn.Module, torch.jit.ScriptModule], 141 | query_model: Union[nn.Module, torch.jit.ScriptModule], 142 | val_loader: torch.utils.data.DataLoader, 143 | device: torch.device, 144 | verbose: bool = False, 145 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 146 | """Generate Feature Matrix 147 | :param gallery_model: Model to compute gallery features. 148 | :param query_model: Model to compute query features. 149 | :param val_loader: Data loader to get gallery/query data. 150 | :param device: Device to use for computations. 151 | :param verbose: Whether to be verbose. 152 | :return: Three tensors gallery_features (n, d), query_features (n, d), labels (n,), where n is size of val dataset, 153 | and d is the embedding dimension. 154 | """ 155 | 156 | gallery_model.eval() 157 | query_model.eval() 158 | 159 | gallery_model.to(device) 160 | query_model.to(device) 161 | 162 | gallery_features = [] 163 | query_features = [] 164 | labels = [] 165 | 166 | iterator = tqdm.tqdm(val_loader) if verbose else val_loader 167 | 168 | with torch.no_grad(): 169 | for data, label in iterator: 170 | data = data.to(device) 171 | label = label.to(device) 172 | gallery_feature = gallery_model(data) 173 | query_feature = query_model(data) 174 | 175 | gallery_features.append(gallery_feature.squeeze()) 176 | query_features.append(query_feature.squeeze()) 177 | 178 | labels.append(label) 179 | 180 | gallery_features = torch.cat(gallery_features) 181 | query_features = torch.cat(query_features) 182 | labels = torch.cat(labels) 183 | 184 | return gallery_features, query_features, labels 185 | 186 | 187 | def get_backfilling_orders( 188 | backfilling_list: List[str], 189 | query_features: torch.Tensor, 190 | gallery_features: torch.Tensor, 191 | distance_metric: Callable, 192 | sigma: Optional[torch.Tensor] = None, 193 | ) -> List[Tuple[torch.Tensor, str]]: 194 | """Compute backfilling ordering. 195 | 196 | :param backfilling_list: list of desired backfilling orders from ["random", "distance", "sigma'] 197 | :param query_features: Tensor of query features with shape (n, d), where n is dataset size and d is embedding dim. 198 | :param gallery_features: Tensor of gallery features with shape (n, d). 199 | :param distance_metric: Callable to compute distance between features. 200 | :param sigma: Tensor of computed sigmas with shape (n,). 201 | 202 | :return: List of (ordering, ordering_name) tuples. ordering is a permutation of [0, 1, ..., n-1] determining the 203 | backfilling ordering, and ordering_name is the name of ordering. For example, if ordering=[3, 0, 2, 1] it means 204 | first backfill gallery data at index 3, followed by elements at indices 0, 2, and 1, respectively. 205 | """ 206 | orderings_list = [] 207 | n = query_features.shape[0] 208 | for ordering_name in backfilling_list: 209 | if ordering_name.lower() == "random": 210 | ordering = torch.randperm(n) 211 | elif ordering_name.lower() == "distance": 212 | distances = distance_metric( 213 | query_features.cpu(), gallery_features.cpu(), diag_only=True 214 | ) 215 | ordering = torch.argsort(distances, descending=True) 216 | elif ordering_name.lower() == "sigma": 217 | assert sigma.numel() == n 218 | ordering = torch.argsort(sigma, dim=0, descending=True) 219 | else: 220 | print(f"{ordering_name} is not implemented for backfilling") 221 | continue 222 | 223 | # Sanity checks: 224 | assert torch.unique(ordering).shape == ordering.shape 225 | assert ordering.min() == 0 226 | assert ordering.max() == n - 1 227 | orderings_list.append((ordering, ordering_name)) 228 | 229 | return orderings_list 230 | 231 | 232 | def cmc_evaluate( 233 | gallery_model: Union[nn.Module, torch.jit.ScriptModule], 234 | query_model: Union[nn.Module, torch.jit.ScriptModule], 235 | val_loader: torch.utils.data.DataLoader, 236 | device: torch.device, 237 | distance_metric_name: str, 238 | verbose: bool = False, 239 | compute_map: bool = False, 240 | backfilling: Optional[int] = None, 241 | backfilling_list: List[str] = ["random"], 242 | backfilling_result_path: Optional[str] = None, 243 | **kwargs, 244 | ) -> None: 245 | """Run CMC and mAP evaluations. 246 | 247 | :param gallery_model: Model to compute gallery features. 248 | :param query_model: Model to compute query features. 249 | :param val_loader: Data loader to get gallery/query data. 250 | :param device: Device to use for computations. 251 | :param distance_metric_name: Name of distance metric to use. Choose from ['l2', 'cosine']. 252 | :param verbose: Whether to be verbose. 253 | :param compute_map: Whether to compute mean average precision. 254 | :param backfilling: Number of intermediate backfilling steps. None means 0 intermediate backfilling. In this case 255 | only results for 0% backfilling will be computed. 256 | :param backfilling_list: list of desired backfilling orders from ["random", "distance", "sigma']. Default is 257 | ['random']. 258 | :param backfilling_result_path: path to save backfilling results. 259 | """ 260 | distance_map = {"l2": l2_distance_matrix, "cosine": cosine_distance_matrix} 261 | distance_metric = distance_map.get(distance_metric_name.lower()) 262 | 263 | print("Generating Feature Matrix") 264 | gallery_features, query_features, labels = generate_feature_matrix( 265 | gallery_model, query_model, val_loader, device, verbose 266 | ) 267 | 268 | # (Possibly) Split gallery features and sigmas 269 | embedding_dim = query_features.shape[1] 270 | sigma = gallery_features[:, embedding_dim:].squeeze() # (n,) 271 | sigma = None if sigma.numel() == 0 else sigma 272 | gallery_features = gallery_features[:, :embedding_dim] # (n, d) 273 | 274 | n = query_features.shape[0] # dataset size 275 | backfilling = backfilling if backfilling is not None else -1 276 | 277 | orderings_list = get_backfilling_orders( 278 | backfilling_list=backfilling_list, 279 | query_features=query_features, 280 | gallery_features=gallery_features, 281 | distance_metric=distance_metric, 282 | sigma=sigma, 283 | ) 284 | 285 | backfilling_results = {} 286 | for ordering, ordering_name in orderings_list: 287 | print(f"\nBackfilling evaluation with {ordering_name} ordering.") 288 | gallery_features_reordered = copy.deepcopy(gallery_features[ordering]).cpu() 289 | query_features_reordered = copy.deepcopy(query_features[ordering]).cpu() 290 | labels_reordered = copy.deepcopy(labels[ordering]).cpu() 291 | 292 | # Lists to store top1, top5, and ,mAP 293 | outputs = {"CMC-top1": [], "CMC-top5": [], "mAP": []} 294 | 295 | iterator = ( 296 | tqdm.tqdm(range(backfilling + 2)) if verbose else range(backfilling + 2) 297 | ) 298 | 299 | for i in iterator: 300 | if backfilling >= 0: 301 | cutoff_index = (i * n) // (backfilling + 1) 302 | else: 303 | cutoff_index = 0 304 | backfilling_mask = torch.zeros((n, 1), dtype=torch.bool) 305 | backfilling_mask[torch.arange(n) < cutoff_index] = True 306 | backfilled_gallery = torch.where( 307 | backfilling_mask, query_features_reordered, gallery_features_reordered 308 | ) 309 | 310 | distmat = distance_metric(query_features_reordered, backfilled_gallery) 311 | 312 | top1, top5 = cmc_optimized( 313 | distmat=distmat, 314 | query_ids=labels_reordered, 315 | topk=5, 316 | ) 317 | 318 | if compute_map: 319 | mean_ap_out = mean_ap(distance_matrix=distmat, labels=labels_reordered) 320 | else: 321 | mean_ap_out = None 322 | 323 | outputs["CMC-top1"].append(top1) 324 | outputs["CMC-top5"].append(top5) 325 | outputs["mAP"].append(mean_ap_out) 326 | 327 | backfilling_results[ordering_name] = outputs 328 | 329 | for metric_name, metric_data in outputs.items(): 330 | print_partial_backfilling(metric_data, metric_name) 331 | 332 | if backfilling_result_path is not None: 333 | np.save(backfilling_result_path, backfilling_results) 334 | 335 | 336 | def print_partial_backfilling(data: List, metric_name: str) -> None: 337 | """Print partial backfilling results. 338 | 339 | :param data: list of floats for a given metric in [0,1] range. 340 | :param metric_name: name of the metric. 341 | """ 342 | print(f"*** {metric_name} ***:") 343 | data_str_list = ["{:.2f}".format(100 * x) for x in data] 344 | print(" -> ".join(data_str_list)) 345 | if len(data) > 1: 346 | print("AUC: {:.2f} %".format(100 * np.mean(data))) 347 | print("\n") 348 | -------------------------------------------------------------------------------- /utils/getters.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023 Apple Inc. All rights reserved. 3 | # 4 | 5 | from typing import Dict, Optional, Tuple, Callable 6 | 7 | import torch 8 | import torch.nn as nn 9 | 10 | import models 11 | from utils.objectives import UncertaintyLoss, LabelSmoothing 12 | 13 | 14 | def get_model(arch_params: Dict, **kwargs) -> nn.Module: 15 | """Get a model given its configurations. 16 | 17 | :param arch_params: A dictionary containing all model parameters. 18 | :return: A torch model. 19 | """ 20 | print("=> Creating model '{}'".format(arch_params.get("arch"))) 21 | model = models.__dict__[arch_params.get("arch")](**arch_params) 22 | return model 23 | 24 | 25 | def get_optimizer( 26 | model: nn.Module, 27 | algorithm: str, 28 | lr: float, 29 | weight_decay: float, 30 | momentum: Optional[float] = None, 31 | no_bn_decay: bool = False, 32 | nesterov: bool = False, 33 | **kwargs 34 | ) -> torch.optim.Optimizer: 35 | """Get an optimizer given its configurations. 36 | 37 | :param model: A torch model (with parameters to be trained). 38 | :param algorithm: String defining what optimization algorithm to use. 39 | :param lr: Learning rate. 40 | :param weight_decay: Weight decay coefficient. 41 | :param momentum: Momentum value. 42 | :param no_bn_decay: Whether to avoid weight decay for Batch Norm params. 43 | :param nesterov: Whether to use Nesterov update. 44 | :return: A torch optimizer objet. 45 | """ 46 | if algorithm == "sgd": 47 | parameters = list(model.named_parameters()) 48 | bn_params = [v for n, v in parameters if ("bn" in n) and v.requires_grad] 49 | rest_params = [v for n, v in parameters if ("bn" not in n) and v.requires_grad] 50 | optimizer = torch.optim.SGD( 51 | [ 52 | { 53 | "params": bn_params, 54 | "weight_decay": 0 if no_bn_decay else weight_decay, 55 | }, 56 | {"params": rest_params, "weight_decay": weight_decay}, 57 | ], 58 | lr, 59 | momentum=momentum, 60 | weight_decay=weight_decay, 61 | nesterov=nesterov, 62 | ) 63 | elif algorithm == "adam": 64 | optimizer = torch.optim.Adam( 65 | filter(lambda p: p.requires_grad, model.parameters()), 66 | lr=lr, 67 | weight_decay=weight_decay, 68 | ) 69 | 70 | return optimizer 71 | 72 | 73 | def get_criteria(similarity_loss: Optional[Dict] = None, 74 | discriminative_loss: Optional[Dict] = None, 75 | uncertainty_loss: Optional[Dict] = None, 76 | ) -> Tuple[Dict[str, float], Dict[str, Callable]]: 77 | """Get training criteria. 78 | 79 | :param similarity_loss: Optional dictionary to determine similarity loss. It should have a name key with values from 80 | ["mse", "l1", "cosine"]. "mu_similarity" key determines coefficient of the similarity loss, with default value of 1. 81 | :param discriminative_loss: Optional dictionary to determine discriminative loss. Name key can take value from 82 | ["LabelSmoothing", "CE"]. If "LabelSmoothing" is picked, the "label_smoothing" key can be used to determine the 83 | value of label smoothing (from [0,1] interval). "mu_disc" key determines coefficient of the discriminative loss. 84 | :param uncertainty_loss: Optional dictionary to determine whether to have uncertainty in loss. "mu_uncertainty" key 85 | determines coefficient of the regularization term in Bayesian uncertainty estimation formulation. 86 | 87 | :return: Two dictionaries. First one has coefficients: {"mu_similarity": mu_similarity, "mu_disc": mu_disc} 88 | second one has callable loss functions with "criterion_similarity", "criterion_disc", and "criterion_uncertainty" 89 | keys. 90 | """ 91 | if uncertainty_loss is not None: 92 | criterion_uncertainty = UncertaintyLoss(**uncertainty_loss) 93 | reduction = "none" 94 | else: 95 | criterion_uncertainty = None 96 | reduction = "mean" 97 | 98 | if similarity_loss is not None: 99 | if similarity_loss.get("name") == "mse": 100 | criterion_similarity = nn.MSELoss(reduction=reduction) 101 | elif similarity_loss.get("name") == "l1": 102 | criterion_similarity = nn.L1Loss(reduction=reduction) 103 | elif similarity_loss.get("name") == "cosine": 104 | 105 | def similarity_cosine(x, y): 106 | return nn.CosineEmbeddingLoss(reduction=reduction)( 107 | x.squeeze(), 108 | y.squeeze(), 109 | target=torch.ones(x.shape[0], device=y.device), 110 | ) 111 | 112 | criterion_similarity = similarity_cosine 113 | else: 114 | raise NotImplementedError("Similarity loss not implemented!") 115 | mu_similarity = similarity_loss.get("mu_similarity", 1.0) 116 | else: 117 | criterion_similarity = None 118 | mu_similarity = None 119 | 120 | if discriminative_loss is not None: 121 | if discriminative_loss.get("name") == "LabelSmoothing": 122 | criterion_disc = LabelSmoothing( 123 | smoothing=discriminative_loss.get("label_smoothing"), 124 | reduction=reduction, 125 | ) 126 | elif discriminative_loss.get("name") == "CE": 127 | criterion_disc = nn.CrossEntropyLoss(reduction=reduction) 128 | else: 129 | raise NotImplementedError("discriminative loss not implemented") 130 | mu_disc = discriminative_loss.get("mu_disc") 131 | else: 132 | criterion_disc = None 133 | mu_disc = None 134 | 135 | mus_dict = {"mu_similarity": mu_similarity, "mu_disc": mu_disc} 136 | criteria_dict = { 137 | "criterion_similarity": criterion_similarity, 138 | "criterion_disc": criterion_disc, 139 | "criterion_uncertainty": criterion_uncertainty, 140 | } 141 | return mus_dict, criteria_dict 142 | -------------------------------------------------------------------------------- /utils/logging_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023 Apple Inc. All rights reserved. 3 | # 4 | 5 | 6 | class AverageMeter: 7 | """Computes and stores the average and current value.""" 8 | 9 | def __init__(self, name: str, fmt: str = ":f") -> None: 10 | """Construct an AverageMeter module. 11 | 12 | :param name: Name of the metric to be tracked. 13 | :param fmt: Output format string. 14 | """ 15 | self.name = name 16 | self.fmt = fmt 17 | self.reset() 18 | 19 | def reset(self): 20 | """Reset internal states.""" 21 | self.val = 0 22 | self.avg = 0 23 | self.sum = 0 24 | self.count = 0 25 | 26 | def update(self, val: float, n: int = 1) -> None: 27 | """Update internal states given new values. 28 | 29 | :param val: New metric value. 30 | :param n: Step size for update. 31 | """ 32 | self.val = val 33 | self.sum += val * n 34 | self.count += n 35 | self.avg = self.sum / self.count 36 | 37 | def __str__(self): 38 | """Get string name of the object.""" 39 | fmtstr = "{name} {val" + self.fmt + "} ({avg" + self.fmt + "})" 40 | return fmtstr.format(**self.__dict__) 41 | -------------------------------------------------------------------------------- /utils/net_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023 Apple Inc. All rights reserved. 3 | # 4 | 5 | from typing import Union 6 | 7 | import torch 8 | import torch.nn as nn 9 | 10 | 11 | class LabelSmoothing(nn.Module): 12 | """NLL loss with label smoothing.""" 13 | 14 | def __init__(self, smoothing: float = 0.0): 15 | """Construct LabelSmoothing module. 16 | 17 | :param smoothing: label smoothing factor 18 | """ 19 | super(LabelSmoothing, self).__init__() 20 | self.confidence = 1.0 - smoothing 21 | self.smoothing = smoothing 22 | 23 | def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor: 24 | """Apply forward pass. 25 | 26 | :param x: Logits tensor. 27 | :param target: Ground truth target classes. 28 | :return: Loss tensor. 29 | """ 30 | logprobs = torch.nn.functional.log_softmax(x, dim=-1) 31 | 32 | nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1)) 33 | nll_loss = nll_loss.squeeze(1) 34 | smooth_loss = -logprobs.mean(dim=-1) 35 | loss = self.confidence * nll_loss + self.smoothing * smooth_loss 36 | return loss.mean() 37 | 38 | 39 | class FeatureExtractor(nn.Module): 40 | """A wrapper class to return only features (no logits).""" 41 | 42 | def __init__(self, model: Union[nn.Module, torch.jit.ScriptModule]) -> None: 43 | """Construct FeatureExtractor module. 44 | 45 | :param model: A model that outputs both logits and features. 46 | """ 47 | super().__init__() 48 | self.model = model 49 | 50 | def forward(self, x: torch.Tensor) -> torch.Tensor: 51 | """Apply forward pass. 52 | 53 | :param x: Input data. 54 | :return: Feature tensor computed for x. 55 | """ 56 | _, feature = self.model(x) 57 | return feature 58 | 59 | 60 | class TransformedOldModel(nn.Module): 61 | """A wrapper class to return transformed features.""" 62 | 63 | def __init__( 64 | self, 65 | old_model: Union[nn.Module, torch.jit.ScriptModule], 66 | side_model: Union[nn.Module, torch.jit.ScriptModule], 67 | transformation: Union[nn.Module, torch.jit.ScriptModule], 68 | ) -> None: 69 | """Construct TransformedOldModel module. 70 | 71 | :param old_model: Old model. 72 | :param side_model: Side information model. 73 | :param transformation: Transformation model. 74 | """ 75 | super().__init__() 76 | self.old_model = old_model 77 | self.transformation = transformation 78 | self.side_info_model = side_model 79 | 80 | def forward(self, x: torch.Tensor) -> torch.Tensor: 81 | """Apply forward pass. 82 | 83 | :param x: Input data 84 | :return: Transformed old feature. 85 | """ 86 | old_feature = self.old_model(x) 87 | side_info = self.side_info_model(x) 88 | recycled_feature = self.transformation(old_feature, side_info) 89 | return recycled_feature 90 | 91 | 92 | def prepare_model_for_export( 93 | model: Union[nn.Module, torch.jit.ScriptModule] 94 | ) -> Union[nn.Module, torch.jit.ScriptModule]: 95 | """Prepare a model to be exported as torchscript.""" 96 | if isinstance(model, torch.nn.DataParallel): 97 | model = model.module 98 | model.eval() 99 | model.cpu() 100 | return model 101 | 102 | 103 | def backbone_to_torchscript( 104 | model: Union[nn.Module, torch.jit.ScriptModule], output_model_path: str 105 | ) -> None: 106 | """Convert a backbone model to torchscript. 107 | 108 | :param model: A backbone model to be converted to torch script. 109 | :param output_model_path: Path to save torch script. 110 | """ 111 | model = prepare_model_for_export(model) 112 | f = FeatureExtractor(model) 113 | model_script = torch.jit.script(f) 114 | torch.jit.save(model_script, output_model_path) 115 | 116 | 117 | def transformation_to_torchscripts( 118 | old_model: Union[nn.Module, torch.jit.ScriptModule], 119 | side_model: Union[nn.Module, torch.jit.ScriptModule], 120 | transformation: Union[nn.Module, torch.jit.ScriptModule], 121 | output_transformation_path: str, 122 | output_transformed_old_model_path: str, 123 | ) -> None: 124 | """Convert a transformation model to torchscript. 125 | 126 | :param old_model: Old model. 127 | :param side_model: Side information model. 128 | :param transformation: Transformation model. 129 | :param output_transformation_path: Path to store transformation torch 130 | script. 131 | :param output_transformed_old_model_path: Path to store combined old and 132 | transformation models' torch script. 133 | """ 134 | transformation = prepare_model_for_export(transformation) 135 | old_model = prepare_model_for_export(old_model) 136 | side_model = prepare_model_for_export(side_model) 137 | 138 | model_script = torch.jit.script(transformation) 139 | torch.jit.save(model_script, output_transformation_path) 140 | 141 | f = TransformedOldModel(old_model, side_model, transformation) 142 | model_script = torch.jit.script(f) 143 | torch.jit.save(model_script, output_transformed_old_model_path) 144 | -------------------------------------------------------------------------------- /utils/objectives.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023 Apple Inc. All rights reserved. 3 | # 4 | 5 | import torch 6 | from torch import nn 7 | 8 | 9 | class UncertaintyLoss(nn.Module): 10 | """Loss that takes a loss vector and a Sigma vector and applies uncertainty to it""" 11 | 12 | def __init__( 13 | self, mu_uncertainty: float = 0.5, sigma_dim: int = 1, **kwargs 14 | ) -> None: 15 | super(UncertaintyLoss, self).__init__() 16 | self.mu = mu_uncertainty 17 | self.sigma_dim = sigma_dim 18 | 19 | def forward(self, sigma: torch.Tensor, loss: torch.Tensor) -> torch.Tensor: 20 | """ 21 | :param sigma: feature vector (N x sigma_dim) that includes the sigma vector 22 | :param loss: loss vector (N x 1) where N is the batch size 23 | 24 | :return: return either a Nx1 or a 1x1 dimensional loss vector 25 | """ 26 | batch = sigma.size(0) 27 | reg = self.mu * sigma.view(batch, -1).mean(dim=1) 28 | loss_value = 0.5 * torch.exp(-sigma.squeeze()) * loss + reg 29 | 30 | return loss_value.mean() 31 | 32 | 33 | class LabelSmoothing(nn.Module): 34 | """NLL loss with label smoothing.""" 35 | 36 | def __init__(self, smoothing: float = 0.0, reduction: str = "mean"): 37 | """Construct LabelSmoothing module. 38 | 39 | :param smoothing: label smoothing factor. Default is 0.0. 40 | :param reduction: reduction method to use from ["mean", "none"]. Default is "mean". 41 | """ 42 | super(LabelSmoothing, self).__init__() 43 | self.confidence = 1.0 - smoothing 44 | self.smoothing = smoothing 45 | assert reduction in ["mean", "none"] 46 | self.reduction = reduction 47 | 48 | def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor: 49 | """Apply forward pass. 50 | 51 | :param x: Logits tensor. 52 | :param target: Ground truth target classes. 53 | :return: Loss tensor. 54 | """ 55 | logprobs = torch.nn.functional.log_softmax(x, dim=-1) 56 | 57 | nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1)) 58 | nll_loss = nll_loss.squeeze(1) 59 | smooth_loss = -logprobs.mean(dim=-1) 60 | loss = self.confidence * nll_loss + self.smoothing * smooth_loss 61 | if self.reduction == "mean": 62 | return loss.mean() 63 | else: 64 | return loss 65 | -------------------------------------------------------------------------------- /utils/schedulers.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023 Apple Inc. All rights reserved. 3 | # 4 | 5 | from typing import Callable, Optional 6 | 7 | import torch 8 | import numpy as np 9 | 10 | __all__ = ["multistep_lr", "cosine_lr", "constant_lr", "get_policy"] 11 | 12 | 13 | def get_policy(optimizer: torch.optim.Optimizer, algorithm: str, **kwargs) -> Callable: 14 | """Get learning policy given its configurations. 15 | 16 | :param optimizer: A torch optimizer. 17 | :param algorithm: Name of the learning rate scheduling algorithm. 18 | :return: A callable to adjust learning rate for each epoch. 19 | """ 20 | out_dict = { 21 | "constant_lr": constant_lr, 22 | "cosine_lr": cosine_lr, 23 | "multistep_lr": multistep_lr, 24 | } 25 | return out_dict[algorithm](optimizer, **kwargs) 26 | 27 | 28 | def assign_learning_rate(optimizer: torch.optim.Optimizer, new_lr: float) -> None: 29 | """Update lr parameter of an optimizer. 30 | 31 | :param optimizer: A torch optimizer. 32 | :param new_lr: updated value of learning rate. 33 | """ 34 | for param_group in optimizer.param_groups: 35 | param_group["lr"] = new_lr 36 | 37 | 38 | def constant_lr( 39 | optimizer: torch.optim.Optimizer, warmup_length: int, lr: float, **kwargs 40 | ) -> Callable: 41 | """Get lr adjustment callable with constant schedule. 42 | 43 | :param optimizer: A torch optimizer. 44 | :param warmup_length: Number of epochs for initial warmup. 45 | :param lr: Nominal learning rate value. 46 | :return: A callable to adjust learning rate per epoch. 47 | """ 48 | 49 | def _lr_adjuster(epoch: int, iteration: Optional[int]) -> float: 50 | """Get updated learning rate. 51 | 52 | :param epoch: Epoch number. 53 | :param iteration: Iteration number. 54 | :return: Updated learning rate value. 55 | """ 56 | if epoch < warmup_length: 57 | new_lr = _warmup_lr(lr, warmup_length, epoch) 58 | else: 59 | new_lr = lr 60 | 61 | assign_learning_rate(optimizer, new_lr) 62 | 63 | return new_lr 64 | 65 | return _lr_adjuster 66 | 67 | 68 | def cosine_lr( 69 | optimizer: torch.optim.Optimizer, 70 | warmup_length: int, 71 | epochs: int, 72 | lr: float, 73 | **kwargs 74 | ) -> Callable: 75 | """Get lr adjustment callable with cosine schedule. 76 | 77 | :param optimizer: A torch optimizer. 78 | :param warmup_length: Number of epochs for initial warmup. 79 | :param epochs: Epoch number. 80 | :param lr: Nominal learning rate value. 81 | :return: A callable to adjust learning rate per epoch. 82 | """ 83 | 84 | def _lr_adjuster(epoch: int, iteration: Optional[int]) -> float: 85 | """Get updated learning rate. 86 | 87 | :param epoch: Epoch number. 88 | :param iteration: Iteration number. 89 | :return: Updated learning rate value. 90 | """ 91 | if epoch < warmup_length: 92 | new_lr = _warmup_lr(lr, warmup_length, epoch) 93 | else: 94 | e = epoch - warmup_length 95 | es = epochs - warmup_length 96 | new_lr = 0.5 * (1 + np.cos(np.pi * e / es)) * lr 97 | 98 | assign_learning_rate(optimizer, new_lr) 99 | 100 | return new_lr 101 | 102 | return _lr_adjuster 103 | 104 | 105 | def multistep_lr( 106 | optimizer: torch.optim.Optimizer, 107 | lr_gamma: float, 108 | lr_adjust: int, 109 | lr: float, 110 | **kwargs 111 | ) -> Callable: 112 | """Get lr adjustment callable with multi-step schedule. 113 | 114 | :param optimizer: A torch optimizer. 115 | :param lr_gamma: Learning rate decay factor. 116 | :param lr_adjust: Number of epochs to apply decay. 117 | :param lr: Nominal Learning rate. 118 | :return: A callable to adjust learning rate per epoch. 119 | """ 120 | 121 | def _lr_adjuster(epoch: int, iteration: Optional[int]) -> float: 122 | """Get updated learning rate. 123 | 124 | :param epoch: Epoch number. 125 | :param iteration: Iteration number. 126 | :return: Updated learning rate value. 127 | """ 128 | new_lr = lr * (lr_gamma ** (epoch // lr_adjust)) 129 | 130 | assign_learning_rate(optimizer, new_lr) 131 | 132 | return new_lr 133 | 134 | return _lr_adjuster 135 | 136 | 137 | def _warmup_lr(base_lr: float, warmup_length: int, epoch: int) -> float: 138 | """Get updated lr after applying initial warmup. 139 | 140 | :param base_lr: Nominal learning rate. 141 | :param warmup_length: Number of epochs for initial warmup. 142 | :param epoch: Epoch number. 143 | :return: Warmup-updated learning rate. 144 | """ 145 | return base_lr * (epoch + 1) / warmup_length 146 | --------------------------------------------------------------------------------