├── .gitattributes ├── LICENSE.md ├── README.md ├── docs └── ElasticTrainer_AE_Appendix.pdf ├── figures ├── Figure_test.pdf └── README.md ├── logs ├── README.md ├── test.txt └── test_log │ └── events.out.tfevents.1669669108.raspberrypi.15644.0.v2 ├── main.py ├── plot_bars_v1.py ├── plot_bars_v2.py ├── plot_curves.py ├── profile_extracted └── README.md ├── profiler.py ├── run_demo.sh ├── run_figure15.sh ├── run_figure15_ego.sh ├── run_figure15ad.sh ├── run_figure15ad_ego.sh ├── run_figure16.sh ├── run_figure16_ego.sh ├── run_figure17ac.sh ├── run_figure17ac_ego.sh ├── run_figure19.sh ├── run_figure19_ego.sh ├── saved_models └── README.md ├── selection_solver_DP.py ├── train.py ├── utils.py └── vit_utils.py /.gitattributes: -------------------------------------------------------------------------------- 1 | # Auto detect text files and perform LF normalization 2 | * text=auto 3 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Kai Huang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ElasticTrainer: Speeding Up On-Device Training with Runtime Elastic Tensor Selection (MobiSys'23) 2 | 3 | ## Introduction 4 | This is the official code repository for our MobiSys 2023 paper [ElasticTrainer: Speeding Up On-Device Training with Runtime Elastic Tensor Selection](https://dl.acm.org/doi/10.1145/3581791.3596852). ElasticTrainer can speed up on-device NN training by ***adaptively training the minimal set of important parameters on the fly*** :rocket:, with user-defined speedup and without noticeable accuracy loss. According to our paper, the code is intended to be run on embedded devices (e.g., Raspberry Pi and Nvidia Jetson TX2), but it is also applicable to stronger platforms such as workstations. 5 | 6 | :mag_right: Looking for the core of our implementation? We suggest you take a look at the following: 7 | * Tensor Timing Profiler -- [profiler.py](https://github.com/HelloKevin07/ElasticTrainer/blob/main/profiler.py) 8 | * Tensor Importance Evaluator -- [elastic_training](https://github.com/HelloKevin07/ElasticTrainer/blob/main/train.py#L407) in [train.py](https://github.com/HelloKevin07/ElasticTrainer/blob/main/train.py) 9 | * Tensor Selector by Dynamic Programming -- [selection_solver_DP.py](https://github.com/HelloKevin07/ElasticTrainer/blob/main/selection_solver_DP.py). 10 | 11 | :chart_with_upwards_trend: Want to reproduce our paper results? Please check instructions [here](#reproducing-paper-results). 12 | 13 | :bulb: If you plan to reproduce our results, we highly recommend you check experiment results **on tensorboard** since there has some bug report on our code of displaying figures. 14 | 15 | ## License 16 | 17 | Our source code is released under MIT License. 18 | 19 | ## Requirements 20 | * Python 3.7+ 21 | * tensorflow 2 22 | * tensorflow-datasets 23 | * tensorflow-addons 24 | * [tensorboard_plugin_profile](https://www.tensorflow.org/guide/profiler) 25 | * [vit-keras](https://github.com/faustomorales/vit-keras) 26 | * matplotlib 27 | * tqdm 28 | 29 | The software versions are platform dependent. In general, installing the most recent versions should work for typical workstations. For Nvidia Jetson TX2 and Raspberry Pi 4B, please check our provided OS images described [here](https://github.com/HelloKevin07/ElasticTrainer#reproducing-paper-results). 30 | 31 | ## General Usage 32 | Select NN models and datasets to run. Use `python main.py --help` and `python profiler.py --help` to see configurable parameters. The NN architectures and datasets should be downloaded automatically. We use [tensorflow-datasets](https://www.tensorflow.org/datasets/api_docs/python/tfds) API to download datasets from tensorflow's [dataset list](https://www.tensorflow.org/datasets/catalog/overview#all_datasets). If you encounter any errors (e.g., checksum error), please refer to [this instruction](https://www.tensorflow.org/datasets/overview#manual_download_if_download_fails) for manually downloading. 33 | 34 | Supported NN architectures: 35 | * ResNet50 -- `resnet50` 36 | * VGG16 -- `vgg16` 37 | * MobileNetV2 -- `mobilenetv2` 38 | * Vision Transformer (16x16 patch) -- `vit` 39 | 40 | Supported datasets: 41 | * [CUB-200 (200 classes)](https://www.vision.caltech.edu/datasets/cub_200_2011/) -- `caltech_birds2011` 42 | * [Oxford-IIIT Pet (37 classes)](https://www.robots.ox.ac.uk/~vgg/data/pets/) -- `oxford_iiit_pet` 43 | * [Stanford Dogs (120 classes)](http://vision.stanford.edu/aditya86/ImageNetDogs/) -- `stanford_dogs` 44 | 45 | ## Running an Example 46 | 47 | Below shows an example of training ResNet50 on CUB-200 dataset with our ElasticTrainer. First, profile the tensor timing on your dedicated device: 48 | ``` 49 | python profiler.py --model_name resnet50 \ 50 | --num_classes 200 51 | ``` 52 | Then start training your model on the device with speedup ratio of 0.5 (i.e., 2x faster in wall time): 53 | ``` 54 | python main.py --model_name resnet50 \ 55 | --dataset_name caltech_birds2011 \ 56 | --train_type elastic_training \ 57 | --rho 0.5 58 | ``` 59 | Please note that the wall training time should exclude validation time. 60 | 61 | ## FAQs 62 | **Q1: Why tensorflow, not pytorch?** 63 | 64 | We are aware that pytorch is a dominant NN library in AI research community. However, to the best of our knowledge, pytorch's profiler is incapable of presenting structured timing information of ops in backward pass, and some of the provided measurements are not even reliable. Check [link1](https://github.com/pytorch/kineto/issues/580), [link2](https://github.com/pytorch/kineto/pull/372), and [link3](https://github.com/pytorch/pytorch/issues/30711) for details. In comparison, tensorflow profiler provides accurate, well-structured, and human-readable timing information for us to parse and group, but tensorflow's profiler only works for tensorflow models and codes. If you insist on working with pytorch, I can only suggest you use FLOPs instead of wall-clock time as the tensor timing metric. That means you need to write your own code to derive FLOPs and it basically cannot reflect actual speed-up. 65 | 66 | **Q2: ElasticTrainer VS. Parameter-Efficient Fine-Tuning (PEFT) for recent Large Language Models (LLMs)?** 67 | 68 | * We have extended ElasticTrainer to [GreenTrainer](https://github.com/pittisl/GreenTrainer) which is designed to speedup finetuning LLMs. In our GreenTrainer paper, we made comparisons with existing PEFT methods. 69 | 70 | ~~If you are an NLP expert, you may know there are many existing PEFT works in NLP area, such as [prompt tuning](https://arxiv.org/abs/2104.08691), [prefix tuning](https://arxiv.org/abs/2101.00190), and [LoRA](https://arxiv.org/abs/2106.09685). These works focus on minimizing the number trainable parameters (usually to <1%) because they speculate that variance rather than bias is a dominant factor in model generalization.~~ 71 | 72 | ~~However, **solely minimizing the number of trainable parameters doesn't gurantee wall-time speedup**. For example, prompt tuning still requires error gradients to propagate through the entire network, which leads to very limited wall-time speedup. On the other hand, nobody can promise variance is always a dominant factor in model generalzation. Unless you want to use super super large pretrained LLMs (e.g., GPT-3) with stunning zero-shot adaptability, applying PEFT to most medium-sized pretrained models would kill a lot of representational power for **complex generative tasks** (e.g., text summarization and math Q&A) and lose much accuracy.~~ 73 | 74 | **Q3: How are you able to select a subset of parameters to train?** 75 | 76 | In tensorflow, `model.trainable_weights` gives you a list of all the trainable parameters. You can extract wanted ones into another list, say `var_list`. Then pass `var_list` to the optimizer, i.e., `optimizer.apply_gradients(zip(gradients, var_list))`. This process can be done at runtime but may cause frequent retracing in tensorflow. So you may need to manually free old graphs to avoid increasing memory usage, which is what we implemented originally. I later realized that maybe a better way to suppress retracing is to configure the [tf.function](https://www.tensorflow.org/api_docs/python/tf/function) decorator: 77 | 78 | ```python 79 | @tf.function( 80 | experimental_relax_shapes=True, 81 | experimental_follow_type_hints=True, 82 | ) 83 | def train_step(...) 84 | 85 | # alternatively 86 | @tf.function(reduce_retracing=True) 87 | def train_step(...) 88 | ``` 89 | 90 | **Q4: Why are some tensors' timings not counted in our Tensor Timing Profiler?** 91 | 92 | Because we cannot find related timings for these tensors from tensorflow's profiling results. That is, even for tensorflow profiler, it may fail to capture a few NN ops during profiling for no reason. We have no solution for that. One workaround can be using known op's timings to estimate missing op's timings based on their FLOPs relationships. 93 | 94 | **Q5: What's the meaning of `(rho - 1/3)*3/2` in `elastic_training` in `train.py`?** 95 | 96 | It converts training speedup to backward speedup based on the 2:1 FLOPs relationship between backward pass and forward pass. We did so to bypass profiling the forward time. Please note this is only an approximation, and we did this due to tight schedule when we rushing for this paper. To ensure precision, we highly recommend you do profile the forward time `T_fp` and backward time `T_bp`, and use `rho * (1 + T_fp/T_bp) - T_fp/T_bp` to for such conversion. 97 | 98 | **Q6: Why is `rho` multiplied by `disco` in `elastic_training` in `train.py`** 99 | 100 | `disco`, which is obtained [here](https://github.com/HelloKevin07/ElasticTrainer/blob/c9e53006f0ad64ca8392130b169952ff3c1cc57b/train.py#LL439C5-L439C72), is a heuristic factor that scales the `rho` a bit, to ensure the desired speedup can be achieved even if `t_dy` and `t_dw` lose much resolution after downscaling. The downside of `disco` is that sometimes it just becomes too small, and suppresses too much of the parameter selection. In that case, you can feel free to try removing this factor. 101 | 102 | ## Reproducing Paper Results 103 | Please download our artifacts on Zenodo [link1](https://doi.org/10.5281/zenodo.7812218) and [link2](https://doi.org/10.5281/zenodo.7812233), and follow the detailed instructions in our [artifact appendix](docs/ElasticTrainer_AE_Appendix.pdf). 104 | We provide experimental workflows that allow people to reproduce our main results in the paper. However, running all the experiments could take extremely long time (~800 hours), and thus we mark each experiment with its estimated execution time for users to choose based on their time budget. After you finish running each script, the figure will be automatically generated under `figures/`. For Nvidia Jetson TX2, we run experiments with its text-only interface, and to view the figures, you will need to switch back to the graphic interface. 105 | 106 | We first describe how you can prepare the environment that allows you to run our experiments, and then we list command lines to reproduce every figure in our main results. 107 | 108 | ### Preparing Nvidia Jetson TX2 109 | 1. According to our artifact appendix, flash the Jetson using our provided OS image. Insert SSD. 110 | 2. Login the system where both username and password are `nvidia`. 111 | 3. Run the following commands to finish preparation: 112 | ``` 113 | sudo su - 114 | cd ~/src/ElasticTrainer 115 | chmod +x *.sh 116 | ``` 117 | 118 | ### Preparing Raspberry Pi 4B 119 | 1. Flash the Raspberry Pi using our provided OS image. 120 | 2. Open a terminal and run the following commands to finish preparation: 121 | ``` 122 | cd ~/src/ElasticTrainer 123 | . ../kai_stl_code/venv/bin/activate 124 | chmod +x *.sh 125 | ``` 126 | 127 | ### Figure 15(a)(d) - A minimal reproduction of main results (~10 hours) 128 | On Nvidia Jetson TX2: 129 | ``` 130 | ./run_figure15ad.sh 131 | ``` 132 | ### Figure 15 from (a) to (f) (~33 hours) 133 | On Nvidia Jetson TX2: 134 | ``` 135 | ./run_figure15.sh 136 | ``` 137 | Alternatively, if you want to exclude baseline schemes, run the following (~6.5 hours): 138 | ``` 139 | ./run_figure15_ego.sh 140 | ``` 141 | ### Figure 16 from (a) to (d) (~221 hours) 142 | On Raspberry Pi 4B: 143 | ``` 144 | ./run_figure16.sh 145 | ``` 146 | Alternatively, if you want to exclude baseline schemes, run the following (~52 hours): 147 | ``` 148 | ./run_figure16_ego.sh 149 | ``` 150 | ### Figure 17 (a)(c) (~15+190 hours) 151 | Run the following command on both Nvidia Jetson TX2 (~15 hours) and Raspberry Pi 4B (~190 hours): 152 | ``` 153 | ./run_figure17ac.sh 154 | ``` 155 | Alternatively, if you want to exclude baseline schemes, run the following command on both Nvidia Jetson TX2 (~9 hours) and Raspberry Pi 4B (~85 hours): 156 | ``` 157 | ./run_figure17ac_ego.sh 158 | ``` 159 | ### Figure 19 from (a) to (d) (~20+310 hours) 160 | Run the following command on both Nvidia Jetson TX2 (~20 hours) and Raspberry Pi 4B (~310 hours): 161 | ``` 162 | ./run_figure19.sh 163 | ``` 164 | Alternatively, if you want to exclude baseline schemes, run the following command on both Nvidia Jetson TX2 (~3.5 hours) and Raspberry Pi 4B (~50 hours): 165 | ``` 166 | ./run_figure19_ego.sh 167 | ``` 168 | 169 | ### Checking Results 170 | All the experiment results should be generated under `figures/`. On Pi, directly click them to view. On Jetson, to check experiments results, you will need to switch to graphic mode: 171 | 172 | ``` 173 | sudo systemctl start graphical.target 174 | ``` 175 | In graphic mode, open a terminal, gain root privilege, and navigate to our code directory: 176 | ``` 177 | sudo su - 178 | cd ~/src/ElasticTrainer 179 | ``` 180 | All the figures are stored under `figures/`. Use `ls` command to check their file names. Use `evince` command to view the figures, for example, `evince xxx.pdf`. To go back to text-only mode, simply reboot the system. If you encounter any display issues, you can alternatively use tensorboard to view results. To enable tensorboard: 181 | ``` 182 | tensorboard --logdir logs 183 | ``` 184 | Open Chrome/Chromium browser and visit URL http://localhost:6006/. On the right sidebar, make sure you switch from "Step" to "Relative" on "Settings->General". 185 | 186 | ## Citation 187 | ``` 188 | @inproceedings{huang2023elastictrainer, 189 | title={ElasticTrainer: Speeding Up On-Device Training with Runtime Elastic Tensor Selection}, 190 | author={Huang, Kai and Yang, Boyuan and Gao, Wei}, 191 | booktitle={Proceedings of the 21st Annual International Conference on Mobile Systems, Applications and Services}, 192 | pages={56--69}, 193 | year={2023} 194 | } 195 | ``` 196 | -------------------------------------------------------------------------------- /docs/ElasticTrainer_AE_Appendix.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pittisl/ElasticTrainer/6f41470f0ff94bdd22895f123efc39dfec07323e/docs/ElasticTrainer_AE_Appendix.pdf -------------------------------------------------------------------------------- /figures/Figure_test.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pittisl/ElasticTrainer/6f41470f0ff94bdd22895f123efc39dfec07323e/figures/Figure_test.pdf -------------------------------------------------------------------------------- /figures/README.md: -------------------------------------------------------------------------------- 1 | This folder stores figures -------------------------------------------------------------------------------- /logs/README.md: -------------------------------------------------------------------------------- 1 | This folder stores profile and training logs. -------------------------------------------------------------------------------- /logs/test.txt: -------------------------------------------------------------------------------- 1 | 0.3 2 | 0.4 3 | -------------------------------------------------------------------------------- /logs/test_log/events.out.tfevents.1669669108.raspberrypi.15644.0.v2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pittisl/ElasticTrainer/6f41470f0ff94bdd22895f123efc39dfec07323e/logs/test_log/events.out.tfevents.1669669108.raspberrypi.15644.0.v2 -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from utils import (port_datasets, 2 | port_pretrained_models, 3 | RepeatTimer, record_once, 4 | sig_stop_handler, 5 | my_bool) 6 | from train import (full_training, 7 | traditional_tl_training, 8 | bn_plus_bias_training, 9 | elastic_training, 10 | elastic_training_weight_magnitude, 11 | elastic_training_grad_magnitude, 12 | prune_training) 13 | import argparse 14 | import signal 15 | 16 | # import logging 17 | # logging.getLogger('tensorflow').setLevel(logging.WARNING) 18 | 19 | parser = argparse.ArgumentParser(description='Training a NN model with selected schemes') 20 | parser.add_argument('--model_name', type=str, default='resnet50', help='valid model names are resnet50, vgg16, mobilenetv2, vit') 21 | parser.add_argument('--dataset_name', type=str, default='caltech_birds2011', help='valid dataset names are caltech_birds2011, stanford_dogs, oxford_iiit_pet') 22 | parser.add_argument('--train_type', type=str, default='elastic_training', help='valid training schemes are full_training, traditional_tl_training,\ 23 | bn_plus_bias_training, elastic_training') 24 | parser.add_argument('--input_size', type=int, default=224, help='input resolution, e.g., 224 stands for 224x224') 25 | parser.add_argument('--batch_size', type=int, default=4, help='batch size used to run during profiling') 26 | parser.add_argument('--num_classes', type=int, default=200, help='number of categories model can classify') 27 | parser.add_argument('--optimizer', type=str, default='sgd', help='valid optimizers are sgd and adam, adam is recommended for vit') 28 | parser.add_argument('--learning_rate', type=float, default=1e-4, help='learning rate') 29 | parser.add_argument('--weight_decay', type=float, default=5e-4, help='weight decay for sgd') 30 | parser.add_argument('--num_epochs', type=int, default=12, help='number of training epochs') 31 | parser.add_argument('--run_name', type=str, default='auto', help='whether to use auto-generated (auto) or user-defined run name') 32 | parser.add_argument('--save_model', type=my_bool, default=False, help='whether to save the trained model') 33 | parser.add_argument('--save_txt', type=my_bool, default=False, help='whether to save the final accuracy and wall time into txt') 34 | 35 | parser.add_argument('--interval', type=float, default=4, help='interval (in epoch) of tensor importance evaluation') 36 | parser.add_argument('--rho', type=float, default=0.533, help='speedup ratio') 37 | 38 | args = parser.parse_args() 39 | 40 | model_name = args.model_name 41 | dataset_name = args.dataset_name 42 | train_type = args.train_type 43 | input_size = args.input_size 44 | batch_size = args.batch_size 45 | num_classes = args.num_classes 46 | optimizer = args.optimizer 47 | learning_rate = args.learning_rate 48 | weight_decay = args.weight_decay 49 | num_epochs = args.num_epochs 50 | run_name = args.run_name 51 | save_model = args.save_model 52 | save_txt = args.save_txt 53 | interval = args.interval 54 | rho = args.rho 55 | 56 | disable_random_id = False 57 | 58 | if run_name == 'auto': 59 | run_name = model_name + '_' + dataset_name + '_' + train_type 60 | else: 61 | disable_random_id = True 62 | 63 | logdir = 'logs' 64 | timing_info = model_name + '_' + str(input_size) + '_' + str(num_classes) + '_' + str(batch_size) + '_' + 'profile' 65 | 66 | global timer 67 | timer = RepeatTimer(15, record_once) 68 | timer.start() 69 | 70 | signal.signal(signal.SIGINT, sig_stop_handler) 71 | signal.signal(signal.SIGTERM, sig_stop_handler) 72 | 73 | print('### Porting NN model...') 74 | 75 | model = port_pretrained_models( 76 | model_type=model_name, 77 | input_shape=(input_size, input_size, 3), 78 | num_classes=num_classes, 79 | ) 80 | 81 | print('### Porting dataset...') 82 | 83 | train_dataset, test_dataset = port_datasets( 84 | dataset_name=dataset_name, 85 | input_shape=(input_size, input_size, 3), 86 | batch_size=batch_size, 87 | ) 88 | 89 | print('### Start training...') 90 | 91 | if train_type == 'full_training': 92 | full_training( 93 | model, 94 | train_dataset, 95 | test_dataset, 96 | run_name, 97 | logdir, 98 | optim=optimizer, 99 | lr=learning_rate, 100 | weight_decay=weight_decay, 101 | epochs=num_epochs, 102 | disable_random_id=disable_random_id, 103 | save_model=save_model, 104 | save_txt=save_txt, 105 | ) 106 | elif train_type == 'traditional_tl_training': 107 | traditional_tl_training( 108 | model, 109 | train_dataset, 110 | test_dataset, 111 | run_name, 112 | logdir, 113 | optim=optimizer, 114 | lr=learning_rate, 115 | weight_decay=weight_decay, 116 | epochs=num_epochs, 117 | disable_random_id=disable_random_id, 118 | save_model=save_model, 119 | save_txt=save_txt, 120 | ) 121 | elif train_type == 'bn_plus_bias_training': 122 | bn_plus_bias_training( 123 | model, 124 | train_dataset, 125 | test_dataset, 126 | run_name, 127 | logdir, 128 | optim=optimizer, 129 | lr=learning_rate, 130 | weight_decay=weight_decay, 131 | epochs=num_epochs, 132 | disable_random_id=disable_random_id, 133 | save_model=save_model, 134 | save_txt=save_txt, 135 | ) 136 | elif train_type == 'elastic_training': 137 | elastic_training( 138 | model, 139 | model_name, 140 | train_dataset, 141 | test_dataset, 142 | run_name, 143 | logdir, 144 | timing_info, 145 | optim=optimizer, 146 | lr=learning_rate, 147 | weight_decay=weight_decay, 148 | epochs=num_epochs, 149 | interval=interval, 150 | rho=rho, 151 | disable_random_id=disable_random_id, 152 | save_model=save_model, 153 | save_txt=save_txt, 154 | ) 155 | 156 | elif train_type == 'elastic_training_weight_magnitude': 157 | elastic_training_weight_magnitude( 158 | model, 159 | model_name, 160 | train_dataset, 161 | test_dataset, 162 | run_name, 163 | logdir, 164 | timing_info, 165 | optim=optimizer, 166 | lr=learning_rate, 167 | weight_decay=weight_decay, 168 | epochs=num_epochs, 169 | interval=interval, 170 | rho=rho, 171 | disable_random_id=disable_random_id, 172 | save_model=save_model, 173 | save_txt=save_txt, 174 | ) 175 | 176 | elif train_type == 'elastic_training_grad_magnitude': 177 | elastic_training_grad_magnitude( 178 | model, 179 | model_name, 180 | train_dataset, 181 | test_dataset, 182 | run_name, 183 | logdir, 184 | timing_info, 185 | optim=optimizer, 186 | lr=learning_rate, 187 | weight_decay=weight_decay, 188 | epochs=num_epochs, 189 | interval=interval, 190 | rho=rho, 191 | disable_random_id=disable_random_id, 192 | save_model=save_model, 193 | save_txt=save_txt, 194 | ) 195 | 196 | elif train_type == 'prunetrain': 197 | prune_training( 198 | model, 199 | train_dataset, 200 | test_dataset, 201 | run_name, 202 | logdir, 203 | optim=optimizer, 204 | lr=learning_rate, 205 | weight_decay=weight_decay, 206 | epochs=num_epochs, 207 | disable_random_id=disable_random_id, 208 | save_model=save_model, 209 | save_txt=save_txt, 210 | ) 211 | 212 | else: 213 | raise NotImplementedError(f"Training scheme {train_type} has not been implemented yet") 214 | 215 | timer.cancel() 216 | -------------------------------------------------------------------------------- /plot_bars_v1.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib as mpl 3 | mpl.use('Agg') 4 | import matplotlib.pyplot as plt 5 | import argparse 6 | from utils import my_bool 7 | 8 | def plot_different_speedup_ratios( 9 | path_to_rho35, 10 | path_to_rho40, 11 | path_to_rho50, 12 | path_to_rho60, 13 | path_to_rho70, 14 | path_to_full_training, 15 | path_to_traditional_tl, 16 | path_to_bn_plus_bias, 17 | path_to_prunetrain, 18 | figure_id, 19 | figure_name, 20 | ): 21 | rho35 = np.loadtxt(path_to_rho35) # [time(h), accuracy(%)] 22 | rho40 = np.loadtxt(path_to_rho40) 23 | rho50 = np.loadtxt(path_to_rho50) 24 | rho60 = np.loadtxt(path_to_rho60) 25 | rho70 = np.loadtxt(path_to_rho70) 26 | ft = np.loadtxt(path_to_full_training) 27 | ttl = np.loadtxt(path_to_traditional_tl) 28 | bpb = np.loadtxt(path_to_bn_plus_bias) 29 | pt = np.loadtxt(path_to_prunetrain) 30 | 31 | fig = plt.figure(figure_id) 32 | 33 | x = ['rho=35%', 'rho=40%', 'rho=50%', 'rho=60%', 'rho=70%', 'Full training', \ 34 | 'Traditional TL', 'BN+Bias', 'PruneTrain'] 35 | 36 | y1 = np.array([rho35[1], rho40[1], rho50[1], rho60[1], rho70[1], ft[1], ttl[1], bpb[1], pt[1]]) 37 | y2 = np.array([rho35[0], rho40[0], rho50[0], rho60[0], rho70[0], ft[0], ttl[0], bpb[0], pt[0]]) 38 | pad = [0, 0, 0, 0, 0, 0, 0, 0, 0] 39 | 40 | def subcategorybar(X, vals, color, width=0.8): 41 | n = len(vals) 42 | _X = np.arange(len(X)) 43 | for i in range(n): 44 | plt.bar(_X - width/2. + i/float(n)*width, vals[i], 45 | width=width/float(n), align="edge", color=color) 46 | plt.xticks(_X, X) 47 | plt.xticks(rotation=45, ha='right') 48 | plt.xticks(fontsize=16) 49 | plt.yticks(fontsize=16) 50 | 51 | subcategorybar(x, [y1, pad], [0, 0.4470, 0.7410]) 52 | plt.ylabel('Accuracy (%)', fontdict={'family': 'Arial', 53 | 'color': [0, 0.4470, 0.7410], 54 | 'weight': 'bold', 55 | 'size': 16, 56 | }) 57 | 58 | fig.axes[1] = fig.axes[0].twinx() 59 | 60 | subcategorybar(x, [pad, y2], [0.8500, 0.3250, 0.0980]) 61 | plt.ylabel('Wall-clock time (h)', fontdict={'family': 'Arial', 62 | 'color': [0.8500, 0.3250, 0.0980], 63 | 'weight': 'bold', 64 | 'size': 16, 65 | }) 66 | 67 | fig.axes[0].tick_params(axis='y', colors=[0, 0.4470, 0.7410]) 68 | fig.axes[0].spines['left'].set_color([0, 0.4470, 0.7410]) 69 | fig.axes[1].spines['left'].set_color([0, 0.4470, 0.7410]) 70 | 71 | fig.axes[1].tick_params(axis='y', colors=[0.8500, 0.3250, 0.0980]) 72 | fig.axes[1].spines['right'].set_color([0.8500, 0.3250, 0.0980]) 73 | fig.axes[0].spines['right'].set_color([0.8500, 0.3250, 0.0980]) 74 | 75 | plt.tight_layout() 76 | # plt.show() 77 | plt.savefig(figure_name, format="pdf", bbox_inches="tight") 78 | 79 | 80 | def plot_different_speedup_ratios_ego( 81 | path_to_rho35, 82 | path_to_rho40, 83 | path_to_rho50, 84 | path_to_rho60, 85 | path_to_rho70, 86 | figure_id, 87 | figure_name, 88 | ): 89 | rho35 = np.loadtxt(path_to_rho35) # [time(h), accuracy(%)] 90 | rho40 = np.loadtxt(path_to_rho40) 91 | rho50 = np.loadtxt(path_to_rho50) 92 | rho60 = np.loadtxt(path_to_rho60) 93 | rho70 = np.loadtxt(path_to_rho70) 94 | 95 | fig = plt.figure(figure_id) 96 | 97 | x = ['rho=35%', 'rho=40%', 'rho=50%', 'rho=60%', 'rho=70%'] 98 | 99 | y1 = np.array([rho35[1], rho40[1], rho50[1], rho60[1], rho70[1]]) 100 | y2 = np.array([rho35[0], rho40[0], rho50[0], rho60[0], rho70[0]]) 101 | pad = [0, 0, 0, 0, 0] 102 | 103 | def subcategorybar(X, vals, color, width=0.8): 104 | n = len(vals) 105 | _X = np.arange(len(X)) 106 | for i in range(n): 107 | plt.bar(_X - width/2. + i/float(n)*width, vals[i], 108 | width=width/float(n), align="edge", color=color) 109 | plt.xticks(_X, X) 110 | plt.xticks(rotation=45, ha='right') 111 | plt.xticks(fontsize=16) 112 | plt.yticks(fontsize=16) 113 | 114 | subcategorybar(x, [y1, pad], [0, 0.4470, 0.7410]) 115 | plt.ylabel('Accuracy (%)', fontdict={'family': 'Arial', 116 | 'color': [0, 0.4470, 0.7410], 117 | 'weight': 'bold', 118 | 'size': 16, 119 | }) 120 | 121 | fig.axes[1] = fig.axes[0].twinx() 122 | 123 | subcategorybar(x, [pad, y2], [0.8500, 0.3250, 0.0980]) 124 | plt.ylabel('Wall-clock time (h)', fontdict={'family': 'Arial', 125 | 'color': [0.8500, 0.3250, 0.0980], 126 | 'weight': 'bold', 127 | 'size': 16, 128 | }) 129 | 130 | fig.axes[0].tick_params(axis='y', colors=[0, 0.4470, 0.7410]) 131 | fig.axes[0].spines['left'].set_color([0, 0.4470, 0.7410]) 132 | fig.axes[1].spines['left'].set_color([0, 0.4470, 0.7410]) 133 | 134 | fig.axes[1].tick_params(axis='y', colors=[0.8500, 0.3250, 0.0980]) 135 | fig.axes[1].spines['right'].set_color([0.8500, 0.3250, 0.0980]) 136 | fig.axes[0].spines['right'].set_color([0.8500, 0.3250, 0.0980]) 137 | 138 | plt.tight_layout() 139 | # plt.show() 140 | plt.savefig(figure_name, format="pdf", bbox_inches="tight") 141 | 142 | 143 | def main(): 144 | parser = argparse.ArgumentParser(description='Plot experiment results as bars') 145 | parser.add_argument('--path_to_rho35', type=str, default='TBD') 146 | parser.add_argument('--path_to_rho40', type=str, default='TBD') 147 | parser.add_argument('--path_to_rho50', type=str, default='TBD') 148 | parser.add_argument('--path_to_rho60', type=str, default='TBD') 149 | parser.add_argument('--path_to_rho70', type=str, default='TBD') 150 | parser.add_argument('--path_to_full_training', type=str, default='TBD') 151 | parser.add_argument('--path_to_traditional_tl', type=str, default='TBD') 152 | parser.add_argument('--path_to_bn_plus_bias', type=str, default='TBD') 153 | parser.add_argument('--path_to_prunetrain', type=str, default='TBD') 154 | parser.add_argument('--figure_id', type=int, default=1, help='figure id') 155 | parser.add_argument('--figure_name', type=str, default='TBD', help='figure name') 156 | parser.add_argument('--ego', type=my_bool, default=False, help='Whether to exclude baseline schemes') 157 | 158 | args = parser.parse_args() 159 | 160 | path_to_rho35 = args.path_to_rho35 161 | path_to_rho40 = args.path_to_rho40 162 | path_to_rho50 = args.path_to_rho50 163 | path_to_rho60 = args.path_to_rho60 164 | path_to_rho70 = args.path_to_rho70 165 | path_to_full_training = args.path_to_full_training 166 | path_to_traditional_tl = args.path_to_traditional_tl 167 | path_to_bn_plus_bias = args.path_to_bn_plus_bias 168 | path_to_prunetrain = args.path_to_prunetrain 169 | figure_id = args.figure_id 170 | figure_name = args.figure_name 171 | ego = args.ego 172 | 173 | if ego: 174 | plot_different_speedup_ratios_ego( 175 | 'logs/' + path_to_rho35 + '.txt', 176 | 'logs/' + path_to_rho40 + '.txt', 177 | 'logs/' + path_to_rho50 + '.txt', 178 | 'logs/' + path_to_rho60 + '.txt', 179 | 'logs/' + path_to_rho70 + '.txt', 180 | figure_id, 181 | 'figures/' + figure_name, 182 | ) 183 | else: 184 | plot_different_speedup_ratios( 185 | 'logs/' + path_to_rho35 + '.txt', 186 | 'logs/' + path_to_rho40 + '.txt', 187 | 'logs/' + path_to_rho50 + '.txt', 188 | 'logs/' + path_to_rho60 + '.txt', 189 | 'logs/' + path_to_rho70 + '.txt', 190 | 'logs/' + path_to_full_training + '.txt', 191 | 'logs/' + path_to_traditional_tl + '.txt', 192 | 'logs/' + path_to_bn_plus_bias + '.txt', 193 | 'logs/' + path_to_prunetrain + '.txt', 194 | figure_id, 195 | 'figures/' + figure_name, 196 | ) 197 | 198 | if __name__ == '__main__': 199 | main() 200 | 201 | # fig = plt.figure(1) 202 | 203 | # x = ['rho=35%', 'rho=40%', 'rho=50%', 'rho=60%', 'rho=70%', 'Full training', \ 204 | # 'Traditional TL', 'BN+Bias', 'PruneTrain'] 205 | 206 | # y1 = [10, 20, 30, 20, 20, 20, 20, 20, 20] 207 | # y2 = [40, 50, 20, 30, 30, 30, 30, 30, 30] 208 | # pad = [0, 0, 0, 0, 0, 0, 0, 0, 0] 209 | 210 | # def subcategorybar(X, vals, color, width=0.8): 211 | # n = len(vals) 212 | # _X = np.arange(len(X)) 213 | # for i in range(n): 214 | # plt.bar(_X - width/2. + i/float(n)*width, vals[i], 215 | # width=width/float(n), align="edge", color=color) 216 | # plt.xticks(_X, X) 217 | # plt.xticks(rotation=45, ha='right') 218 | # plt.xticks(fontsize=16) 219 | # plt.yticks(fontsize=16) 220 | 221 | # subcategorybar(x, [y1, pad], [0, 0.4470, 0.7410]) 222 | # plt.ylabel('Accuracy (%)', fontdict={'family': 'Arial', 223 | # 'color': [0, 0.4470, 0.7410], 224 | # 'weight': 'bold', 225 | # 'size': 16, 226 | # }) 227 | 228 | # fig.axes[1] = fig.axes[0].twinx() 229 | 230 | # subcategorybar(x, [pad, y2], [0.8500, 0.3250, 0.0980]) 231 | # plt.ylabel('Wall-clock time (h)', fontdict={'family': 'Arial', 232 | # 'color': [0.8500, 0.3250, 0.0980], 233 | # 'weight': 'bold', 234 | # 'size': 16, 235 | # }) 236 | 237 | # fig.axes[0].tick_params(axis='y', colors=[0, 0.4470, 0.7410]) 238 | # fig.axes[0].spines['left'].set_color([0, 0.4470, 0.7410]) 239 | # fig.axes[1].spines['left'].set_color([0, 0.4470, 0.7410]) 240 | 241 | # fig.axes[1].tick_params(axis='y', colors=[0.8500, 0.3250, 0.0980]) 242 | # fig.axes[1].spines['right'].set_color([0.8500, 0.3250, 0.0980]) 243 | # fig.axes[0].spines['right'].set_color([0.8500, 0.3250, 0.0980]) 244 | 245 | # plt.tight_layout() 246 | # plt.show() 247 | 248 | -------------------------------------------------------------------------------- /plot_bars_v2.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib as mpl 3 | mpl.use('Agg') 4 | import matplotlib.pyplot as plt 5 | import argparse 6 | from utils import my_bool 7 | 8 | 9 | def plot_different_models( 10 | path_to_elastic_trainer_resnet50, 11 | path_to_elastic_trainer_vgg16, 12 | path_to_elastic_trainer_mobilenetv2, 13 | path_to_full_training_resnet50, 14 | path_to_full_training_vgg16, 15 | path_to_full_training_mobilenetv2, 16 | path_to_traditional_tl_resnet50, 17 | path_to_traditional_tl_vgg16, 18 | path_to_traditional_tl_mobilenetv2, 19 | path_to_bn_plus_bias_resnet50, 20 | path_to_bn_plus_bias_vgg16, 21 | path_to_bn_plus_bias_mobilenetv2, 22 | figure_id, 23 | figure_name, 24 | ): 25 | et_r50 = np.loadtxt(path_to_elastic_trainer_resnet50) # [time(h), accuracy(%)] 26 | et_v16 = np.loadtxt(path_to_elastic_trainer_vgg16) 27 | et_mv2 = np.loadtxt(path_to_elastic_trainer_mobilenetv2) 28 | ft_r50 = np.loadtxt(path_to_full_training_resnet50) 29 | ft_v16 = np.loadtxt(path_to_full_training_vgg16) 30 | ft_mv2 = np.loadtxt(path_to_full_training_mobilenetv2) 31 | ttl_r50 = np.loadtxt(path_to_traditional_tl_resnet50) 32 | ttl_v16 = np.loadtxt(path_to_traditional_tl_vgg16) 33 | ttl_mv2 = np.loadtxt(path_to_traditional_tl_mobilenetv2) 34 | bpb_r50 = np.loadtxt(path_to_bn_plus_bias_resnet50) 35 | bpb_v16 = np.loadtxt(path_to_bn_plus_bias_vgg16) 36 | bpb_mv2 = np.loadtxt(path_to_bn_plus_bias_mobilenetv2) 37 | 38 | X = ['ResNet50','VGG16','MobileNetV2'] 39 | et = [et_r50[1], et_v16[1], et_mv2[1]] 40 | ft = [ft_r50[1], ft_v16[1], ft_mv2[1]] 41 | ttl = [ttl_r50[1], ttl_v16[1], ttl_mv2[1]] 42 | bpb = [bpb_r50[1], bpb_v16[1], bpb_mv2[1]] 43 | 44 | plt.figure(figure_id) 45 | 46 | def subcategorybar1(X, vals, width=0.8): 47 | n = len(vals) 48 | _X = np.arange(len(X)) 49 | for i in range(n): 50 | plt.bar(_X - width/2. + i/float(n)*width, vals[i], 51 | width=width/float(n), align="edge") 52 | plt.xticks(_X, X) 53 | plt.ylabel('Accuracy (%)', fontdict={'family': 'Arial', 54 | 'color': 'black', 55 | 'weight': 'bold', 56 | 'size': 16, 57 | }) 58 | plt.legend(['ElasticTrainer', 'Full training', 'Traditional TL', 'BN+Bias']) 59 | plt.xticks(fontsize=16) 60 | plt.yticks(fontsize=16) 61 | 62 | subcategorybar1(X, [et, ft, ttl, bpb]) 63 | # plt.show() 64 | plt.savefig(figure_name + '_accuracy.pdf', format="pdf", bbox_inches="tight") 65 | 66 | ######### 67 | 68 | et = [et_r50[0], et_v16[0], et_mv2[0]] 69 | ft = [ft_r50[0], ft_v16[0], ft_mv2[0]] 70 | ttl = [ttl_r50[0], ttl_v16[0], ttl_mv2[0]] 71 | bpb = [bpb_r50[0], bpb_v16[0], bpb_mv2[0]] 72 | 73 | plt.figure(figure_id) 74 | 75 | def subcategorybar2(X, vals, width=0.8): 76 | n = len(vals) 77 | _X = np.arange(len(X)) 78 | for i in range(n): 79 | plt.bar(_X - width/2. + i/float(n)*width, vals[i], 80 | width=width/float(n), align="edge") 81 | plt.xticks(_X, X) 82 | plt.ylabel('Wall-clock time (h)', fontdict={'family': 'Arial', 83 | 'color': 'black', 84 | 'weight': 'bold', 85 | 'size': 16, 86 | }) 87 | plt.legend(['ElasticTrainer', 'Full training', 'Traditional TL', 'BN+Bias']) 88 | plt.xticks(fontsize=16) 89 | plt.yticks(fontsize=16) 90 | 91 | subcategorybar2(X, [et, ft, ttl, bpb]) 92 | # plt.show() 93 | plt.savefig(figure_name + '_time.pdf', format="pdf", bbox_inches="tight") 94 | 95 | 96 | def plot_different_models_ego( 97 | path_to_elastic_trainer_resnet50, 98 | path_to_elastic_trainer_vgg16, 99 | path_to_elastic_trainer_mobilenetv2, 100 | figure_id, 101 | figure_name, 102 | ): 103 | et_r50 = np.loadtxt(path_to_elastic_trainer_resnet50) # [time(h), accuracy(%)] 104 | et_v16 = np.loadtxt(path_to_elastic_trainer_vgg16) 105 | et_mv2 = np.loadtxt(path_to_elastic_trainer_mobilenetv2) 106 | 107 | X = ['ResNet50','VGG16','MobileNetV2'] 108 | et = [et_r50[1], et_v16[1], et_mv2[1]] 109 | 110 | plt.figure(figure_id) 111 | 112 | def subcategorybar1(X, vals, width=0.8): 113 | n = len(vals) 114 | _X = np.arange(len(X)) 115 | for i in range(n): 116 | plt.bar(_X - width/2. + i/float(n)*width, vals[i], 117 | width=width/float(n), align="edge") 118 | plt.xticks(_X, X) 119 | plt.ylabel('Accuracy (%)', fontdict={'family': 'Arial', 120 | 'color': 'black', 121 | 'weight': 'bold', 122 | 'size': 16, 123 | }) 124 | # plt.legend(['ElasticTrainer', 'Full training', 'Traditional TL', 'BN+Bias']) 125 | plt.xticks(fontsize=16) 126 | plt.yticks(fontsize=16) 127 | 128 | subcategorybar1(X, [et]) 129 | # plt.show() 130 | plt.savefig(figure_name + '_accuracy.pdf', format="pdf", bbox_inches="tight") 131 | 132 | ########## 133 | et = [et_r50[0], et_v16[0], et_mv2[0]] 134 | 135 | plt.figure(figure_id + 1) 136 | 137 | def subcategorybar2(X, vals, width=0.8): 138 | n = len(vals) 139 | _X = np.arange(len(X)) 140 | for i in range(n): 141 | plt.bar(_X - width/2. + i/float(n)*width, vals[i], 142 | width=width/float(n), align="edge") 143 | plt.xticks(_X, X) 144 | plt.ylabel('Wall-clock time (h)', fontdict={'family': 'Arial', 145 | 'color': 'black', 146 | 'weight': 'bold', 147 | 'size': 16, 148 | }) 149 | plt.legend(['ElasticTrainer', 'Full training', 'Traditional TL', 'BN+Bias']) 150 | plt.xticks(fontsize=16) 151 | plt.yticks(fontsize=16) 152 | 153 | subcategorybar2(X, [et]) 154 | # plt.show() 155 | plt.savefig(figure_name + '_time.pdf', format="pdf", bbox_inches="tight") 156 | 157 | 158 | def main(): 159 | parser = argparse.ArgumentParser(description='Plot experiment results as bars') 160 | parser.add_argument('--path_to_elastic_trainer_resnet50', type=str, default='TBD') 161 | parser.add_argument('--path_to_elastic_trainer_vgg16', type=str, default='TBD') 162 | parser.add_argument('--path_to_elastic_trainer_mobilenetv2', type=str, default='TBD') 163 | 164 | parser.add_argument('--path_to_full_training_resnet50', type=str, default='TBD') 165 | parser.add_argument('--path_to_full_training_vgg16', type=str, default='TBD') 166 | parser.add_argument('--path_to_full_training_mobilenetv2', type=str, default='TBD') 167 | 168 | parser.add_argument('--path_to_traditional_tl_resnet50', type=str, default='TBD') 169 | parser.add_argument('--path_to_traditional_tl_vgg16', type=str, default='TBD') 170 | parser.add_argument('--path_to_traditional_tl_mobilenetv2', type=str, default='TBD') 171 | 172 | parser.add_argument('--path_to_bn_plus_bias_resnet50', type=str, default='TBD') 173 | parser.add_argument('--path_to_bn_plus_bias_vgg16', type=str, default='TBD') 174 | parser.add_argument('--path_to_bn_plus_bias_mobilenetv2', type=str, default='TBD') 175 | parser.add_argument('--figure_id', type=int, default=1, help='figure id') 176 | parser.add_argument('--figure_name', type=str, default='TBD', help='figure name') 177 | parser.add_argument('--ego', type=my_bool, default=False, help='Whether to exclude baseline schemes') 178 | 179 | args = parser.parse_args() 180 | 181 | path_to_elastic_trainer_resnet50 = args.path_to_elastic_trainer_resnet50 182 | path_to_elastic_trainer_vgg16 = args.path_to_elastic_trainer_vgg16 183 | path_to_elastic_trainer_mobilenetv2 = args.path_to_elastic_trainer_mobilenetv2 184 | 185 | path_to_full_training_resnet50 = args.path_to_full_training_resnet50 186 | path_to_full_training_vgg16 = args.path_to_full_training_vgg16 187 | path_to_full_training_mobilenetv2 = args.path_to_full_training_mobilenetv2 188 | 189 | path_to_traditional_tl_resnet50 = args.path_to_traditional_tl_resnet50 190 | path_to_traditional_tl_vgg16 = args.path_to_traditional_tl_vgg16 191 | path_to_traditional_tl_mobilenetv2 = args.path_to_traditional_tl_mobilenetv2 192 | 193 | path_to_bn_plus_bias_resnet50 = args.path_to_bn_plus_bias_resnet50 194 | path_to_bn_plus_bias_vgg16 = args.path_to_bn_plus_bias_vgg16 195 | path_to_bn_plus_bias_mobilenetv2 = args.path_to_bn_plus_bias_mobilenetv2 196 | 197 | figure_id = args.figure_id 198 | figure_name = args.figure_name 199 | ego = args.ego 200 | 201 | if ego: 202 | plot_different_models_ego( 203 | 'logs/' + path_to_elastic_trainer_resnet50 + '.txt', 204 | 'logs/' + path_to_elastic_trainer_vgg16 + '.txt', 205 | 'logs/' + path_to_elastic_trainer_mobilenetv2 + '.txt', 206 | figure_id, 207 | 'figures/' + figure_name, 208 | ) 209 | else: 210 | plot_different_models( 211 | 'logs/' + path_to_elastic_trainer_resnet50 + '.txt', 212 | 'logs/' + path_to_elastic_trainer_vgg16 + '.txt', 213 | 'logs/' + path_to_elastic_trainer_mobilenetv2 + '.txt', 214 | 'logs/' + path_to_full_training_resnet50 + '.txt', 215 | 'logs/' + path_to_full_training_vgg16 + '.txt', 216 | 'logs/' + path_to_full_training_mobilenetv2 + '.txt', 217 | 'logs/' + path_to_traditional_tl_resnet50 + '.txt', 218 | 'logs/' + path_to_traditional_tl_vgg16 + '.txt', 219 | 'logs/' + path_to_traditional_tl_mobilenetv2 + '.txt', 220 | 'logs/' + path_to_bn_plus_bias_resnet50 + '.txt', 221 | 'logs/' + path_to_bn_plus_bias_vgg16 + '.txt', 222 | 'logs/' + path_to_bn_plus_bias_mobilenetv2 + '.txt', 223 | figure_id, 224 | 'figures/' + figure_name, 225 | ) 226 | 227 | if __name__ == '__main__': 228 | main() -------------------------------------------------------------------------------- /plot_curves.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | from tensorboard.backend.event_processing.event_accumulator import EventAccumulator 4 | from tensorboard.backend.event_processing import tag_types 5 | import matplotlib as mpl 6 | mpl.use('Agg') 7 | import matplotlib.pyplot as plt 8 | import argparse 9 | from utils import my_bool 10 | 11 | 12 | def read_data_from_tfboard_logs(path, x_tag, y_tag): 13 | """load logged training metrics into numpy arrays 14 | 15 | Args: 16 | path (str): path to the training log 17 | x_tag (str): one from ['wall_time', 'step'] 18 | y_tag (str): one from ['train/accuracy', 'train/classification_loss', 19 | 'train/learnig_rate', 'test/classification_loss', 20 | 'test/accuracy'] 21 | Returns: 22 | x, y: numpy arrays 23 | """ 24 | size_guidance = { 25 | tag_types.TENSORS: 20, 26 | } 27 | event_acc = EventAccumulator(path, size_guidance=size_guidance) 28 | event_acc.Reload() 29 | event_list = event_acc.Tensors(y_tag) 30 | if x_tag == 'wall_time': 31 | x = [e.wall_time for e in event_list] 32 | x = np.array(x) 33 | x = x - x[0] 34 | x = (x + x[1]) / 3600 # convert to hours 35 | else: 36 | x = [e.step for e in event_list] 37 | y = np.array([tf.make_ndarray(e.tensor_proto).item() for e in event_list]) 38 | if 'accuracy' in y_tag: 39 | y *= 100 # convert to % 40 | return x, y 41 | 42 | 43 | def plot_single_curve( 44 | x_tag, 45 | y_tag, 46 | path_to_elastic_training, 47 | figure_id, 48 | figure_name, 49 | ): 50 | """plot ElasticTrainer results as a single curve excluding baselines 51 | 52 | Args: 53 | x_tag (str): one from ['wall_time', 'step'] 54 | y_tag (str): one from ['train/accuracy', 'train/classification_loss', 55 | 'train/learnig_rate', 'test/classification_loss', 56 | 'test/accuracy'] 57 | path_to_elastic_training (str): path to ElasticTrainer's log 58 | figure_id (str): id of plotted figure 59 | """ 60 | et_x, et_y = read_data_from_tfboard_logs(path_to_elastic_training, x_tag, y_tag) 61 | font = {'family': 'Arial', 62 | 'color': 'black', 63 | 'weight': 'bold', 64 | 'size': 16, 65 | } 66 | plt.figure(figure_id) 67 | plt.plot(et_x, et_y, "ks-", label="ElasticTrainer", linewidth=3) 68 | plt.xlabel('Wall-clock time (h)', fontdict=font) 69 | if 'accuracy' in y_tag: 70 | plt.ylabel('Accuracy (%)', fontdict=font) 71 | else: 72 | plt.ylabel('Loss', fontdict=font) 73 | plt.xticks(fontsize=16) 74 | plt.yticks(fontsize=16) 75 | plt.legend(fontsize=16) 76 | plt.grid() 77 | plt.tight_layout() 78 | plt.savefig(figure_name, format="pdf", bbox_inches="tight") 79 | # plt.show() 80 | 81 | 82 | def plot_multiple_curves( 83 | x_tag, 84 | y_tag, 85 | path_to_elastic_training, 86 | path_to_full_training, 87 | path_to_traditional_tl, 88 | path_to_bn_plus_bias, 89 | figure_id, 90 | figure_name, 91 | ): 92 | """plot training results as curves including baselines 93 | 94 | Args: 95 | x_tag (str): one from ['wall_time', 'step'] 96 | y_tag (str): one from ['train/accuracy', 'train/classification_loss', 97 | 'train/learnig_rate', 'test/classification_loss', 98 | 'test/accuracy'] 99 | path_to_elastic_training (str): path to ElasticTrainer's log 100 | path_to_full_training (str): path to Full Training's log 101 | path_to_traditional_tl (str): path to Traditional TL's log 102 | path_to_bn_plus_bias (str): path to BN+Bias's log 103 | figure_id (str): id of plotted figure 104 | """ 105 | et_x, et_y = read_data_from_tfboard_logs(path_to_elastic_training, x_tag, y_tag) 106 | ft_x, ft_y = read_data_from_tfboard_logs(path_to_full_training, x_tag, y_tag) 107 | ttl_x, ttl_y = read_data_from_tfboard_logs(path_to_traditional_tl, x_tag, y_tag) 108 | bpb_x, bpb_y = read_data_from_tfboard_logs(path_to_bn_plus_bias, x_tag, y_tag) 109 | 110 | font = {'family': 'Arial', 111 | 'color': 'black', 112 | 'weight': 'bold', 113 | 'size': 16, 114 | } 115 | 116 | plt.figure(figure_id) 117 | plt.plot(et_x, et_y, "ks-", label="ElasticTrainer", linewidth=3) 118 | plt.plot(ft_x, ft_y, "rs-", label="Full Training", linewidth=3) 119 | plt.plot(ttl_x, ttl_y, "bs-", label="Traditional TL", linewidth=3) 120 | plt.plot(bpb_x, bpb_y, "gs-", label="BN+Bias", linewidth=3) 121 | plt.xlabel('Wall-clock time (h)', fontdict=font) 122 | if 'accuracy' in y_tag: 123 | plt.ylabel('Accuracy (%)', fontdict=font) 124 | else: 125 | plt.ylabel('Loss', fontdict=font) 126 | plt.xticks(fontsize=16) 127 | plt.yticks(fontsize=16) 128 | plt.legend(fontsize=16) 129 | plt.grid() 130 | plt.tight_layout() 131 | plt.savefig(figure_name, format="pdf", bbox_inches="tight") 132 | # plt.show() 133 | 134 | def main(): 135 | parser = argparse.ArgumentParser(description='Plot experiment results as curves') 136 | parser.add_argument('--x_tag', type=str, default='wall_time', help="one from ['wall_time', 'step']") 137 | parser.add_argument('--y_tag', type=str, default='accuracy', help="['train/accuracy', 'train/classification_loss',\ 138 | 'train/learnig_rate', 'test/classification_loss', 'test/accuracy']") 139 | parser.add_argument('--single', type=my_bool, default=True, help='whether to exclude baseline schemes') 140 | parser.add_argument('--elastic_trainer_path', type=str, default='TBD', help='path to log of elastic_trainer') 141 | parser.add_argument('--full_training_path', type=str, default='TBD', help='path to log of full_training') 142 | parser.add_argument('--traditional_tl_path', type=str, default='TBD', help='path to log of elastic_trainer') 143 | parser.add_argument('--bn_plus_bias_path', type=str, default='TBD', help='path to log of bn_plus_bias') 144 | parser.add_argument('--figure_id', type=int, default=1, help='figure id') 145 | parser.add_argument('--figure_name', type=str, default='TBD', help='figure name') 146 | 147 | args = parser.parse_args() 148 | 149 | x_tag = args.x_tag 150 | y_tag = args.y_tag 151 | single = args.single 152 | elastic_trainer_path = args.elastic_trainer_path 153 | full_training_path = args.full_training_path 154 | traditional_tl_path = args.traditional_tl_path 155 | bn_plus_bias_path = args.bn_plus_bias_path 156 | figure_id = args.figure_id 157 | figure_name = args.figure_name 158 | 159 | if single: 160 | plot_single_curve( 161 | x_tag, 162 | y_tag, 163 | 'logs/' + elastic_trainer_path, 164 | figure_id, 165 | 'figures/' + figure_name, 166 | ) 167 | else: 168 | plot_multiple_curves( 169 | x_tag, 170 | y_tag, 171 | 'logs/' + elastic_trainer_path, 172 | 'logs/' + full_training_path, 173 | 'logs/' + traditional_tl_path, 174 | 'logs/' + bn_plus_bias_path, 175 | figure_id, 176 | 'figures/' + figure_name, 177 | ) 178 | 179 | if __name__ == '__main__': 180 | main() -------------------------------------------------------------------------------- /profile_extracted/README.md: -------------------------------------------------------------------------------- 1 | This folder stores tensor timing files. -------------------------------------------------------------------------------- /profiler.py: -------------------------------------------------------------------------------- 1 | from utils import port_pretrained_models 2 | import tensorflow as tf 3 | import numpy as np 4 | import matplotlib as mpl 5 | mpl.use('Agg') 6 | import matplotlib.pyplot as plt 7 | from tensorboard_plugin_profile.protobuf import tf_stats_pb2 #, kernel_stats_pb2 8 | from tensorboard_plugin_profile.convert.tf_stats_proto_to_gviz import generate_chart_table 9 | # from tensorboard_plugin_profile.convert.kernel_stats_proto_to_gviz import generate_kernel_reports_table 10 | from tqdm import tqdm 11 | import argparse 12 | import time 13 | import csv 14 | import os 15 | 16 | 17 | def profile_backpropagation( 18 | model, 19 | input_shape, 20 | batch_size, 21 | num_iterations, 22 | logdir): 23 | """ 24 | This function profiles NN ops in backward pass. 25 | 26 | Args: 27 | model (tf.keras.Model): NN model to profile 28 | input_shape (tuple): input shape of NN model 29 | batch_size (int): batch size for backward pass 30 | num_iterations (int): number of backward passes to run 31 | logdir (str): path to where profile is recorded 32 | """ 33 | 34 | if os.path.exists(logdir): 35 | print(f"Profile '{logdir}' already exists") 36 | return 37 | 38 | loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True) 39 | 40 | @tf.function 41 | def train_step(x, y): 42 | with tf.GradientTape() as tape: 43 | y_pred = model(x, training=True) 44 | loss = loss_fn(y, y_pred) 45 | gradients = tape.gradient(loss, model.trainable_weights) 46 | return gradients 47 | 48 | # dummy training data 49 | x = tf.random.normal((batch_size, input_shape[0], input_shape[1], input_shape[2])) 50 | y = tf.ones((batch_size,)) 51 | 52 | print("Warmup...") 53 | for k in tqdm(range(2)): 54 | train_step(x, y) 55 | 56 | t0 = time.time() 57 | 58 | print("Profiling the model...") 59 | tf.profiler.experimental.start(logdir) 60 | for k in range(num_iterations): 61 | with tf.profiler.experimental.Trace('train', step_num=k, _r=1): 62 | train_step(x, y) 63 | tf.profiler.experimental.stop() 64 | 65 | t1 = time.time() 66 | 67 | print("Finished profiling!") 68 | print("Elasped time (s):", t1 - t0) 69 | 70 | 71 | def convert_pb_to_csv(logdir, outdir): 72 | """ 73 | This function extracts timing-related info from the recorded profile. 74 | 75 | Args: 76 | logdir (str): path to where profile is recorded 77 | outdir (str): path to where the extracted timing info is stored 78 | """ 79 | if os.path.exists(outdir): 80 | print(f"Extracted timing info '{outdir}' already exists") 81 | return 82 | 83 | tf_profile_path = logdir + '/plugins/profile/' 84 | tf_stats_path = '' 85 | for root, subdirs, files in os.walk(logdir): 86 | if tf_profile_path in root: 87 | fn = '' 88 | for f in files: 89 | if 'tensorflow_stats.pb' in f: 90 | fn = f 91 | tf_stats_path = root + '/' + fn 92 | 93 | with tf.io.gfile.GFile(tf_stats_path, 'rb') as f: 94 | tf_stats_db = tf_stats_pb2.TfStatsDatabase() 95 | tf_stats_db.ParseFromString(f.read()) 96 | 97 | csv_table = generate_chart_table(tf_stats_db.with_idle, 98 | tf_stats_db.device_type).ToCsv() 99 | with open(outdir, 'w') as f: 100 | f.write(csv_table) 101 | 102 | 103 | def profile_parser( 104 | model, 105 | model_type, 106 | num_iterations, 107 | filedir, 108 | draw_figure=False): 109 | """ 110 | This function constructs tensor timings from the profiled op timings. 111 | 112 | Args: 113 | model (tf.keras.Model): NN model 114 | model_type (str): type of NN model 115 | num_iterations (int): number of iterations of backward passes in profiling 116 | filedir (str): where the timing-related info is stored 117 | draw_figure (bool, optional): whether to plot the tensor timings. Defaults to False. 118 | 119 | Returns: 120 | tensor timings t_dw and t_dy 121 | """ 122 | 123 | if model_type in ('resnet50', 'vgg16'): 124 | 125 | all_stats = [] 126 | with open(filedir) as csv_file: 127 | csv_reader = csv.reader(csv_file, delimiter=',') 128 | csv_reader.__next__() 129 | for r in csv_reader: 130 | all_stats.append(r) 131 | 132 | op_total_time = [] 133 | op_names = [] 134 | # extract gradient related ops 135 | for op_stat in all_stats: 136 | if 'gradient_tape' in op_stat[3]: 137 | op_total_time.append(float(op_stat[5])) 138 | op_names.append(op_stat[3]) 139 | 140 | base_layers = model.layers[4].layers 141 | custom_layers = model.layers[5:] 142 | model_layers = [*base_layers, *custom_layers] 143 | 144 | t_dw = [0.0 for k in range(len(model.trainable_weights))] 145 | t_dy = [0.0 for k in range(len(model.trainable_weights))] 146 | 147 | weight_count = 0 148 | 149 | for l in model_layers: 150 | if '_conv' in l.name: 151 | if l.use_bias: 152 | for op, t in zip(op_names, op_total_time): 153 | if l.name in op and 'Conv2DBackpropFilter' in op: 154 | t_dw[weight_count] = t 155 | elif l.name in op and 'Conv2DBackpropInput' in op: 156 | t_dy[weight_count] += t # include TransposeNCHWToNHWC 157 | elif l.name in op and 'BiasAddGrad' in op: 158 | t_dw[weight_count + 1] = t 159 | t_dy[weight_count + 1] = 0 160 | weight_count += 2 161 | else: 162 | for op, t in zip(op_names, op_total_time): 163 | if l.name in op and 'Conv2DBackpropFilter' in op: 164 | t_dw[weight_count] = t 165 | elif l.name in op and 'Conv2DBackpropInput' in op: 166 | t_dy[weight_count] += t # include TransposeNCHWToNHWC 167 | weight_count += 1 168 | 169 | elif '_bn' in l.name: 170 | for op, t in zip(op_names, op_total_time): 171 | if l.name in op and 'FusedBatchNormGrad' in op: 172 | # for gamma 173 | t_dw[weight_count] = 0 174 | t_dy[weight_count] = 0 175 | # for beta 176 | t_dw[weight_count + 1] = 0 177 | t_dy[weight_count + 1] = t 178 | weight_count += 2 179 | elif 'dense' in l.name: 180 | if l.use_bias: 181 | for op, t in zip(op_names, op_total_time): 182 | if l.name in op and 'MatMul/MatMul_1' in op: 183 | t_dw[weight_count] = t 184 | elif l.name in op and 'MatMul/MatMul' in op: 185 | t_dy[weight_count] = t 186 | elif l.name in op and 'BiasAddGrad' in op: 187 | t_dw[weight_count + 1] = t 188 | t_dy[weight_count + 1] = 0 189 | weight_count += 2 190 | else: 191 | for op, t in zip(op_names, op_total_time): 192 | if l.name in op and 'MatMul/MatMul_1' in op: 193 | t_dw[weight_count] = t 194 | elif l.name in op and 'MatMul/MatMul' in op: 195 | t_dy[weight_count] = t 196 | weight_count += 1 197 | 198 | else: 199 | # fuse backprop time of non-trainables to the previous trainable layer 200 | for op, t in zip(op_names, op_total_time): 201 | if l.name in op and weight_count > 0: 202 | t_dy[weight_count - 1] += t 203 | 204 | # the first layer never propagates input grads, just remove t_dy[0] 205 | # t_dy = t_dy[1:] # 1~N-1 206 | 207 | t_dw = np.array(t_dw) / num_iterations # (us) 208 | t_dy = np.array(t_dy) / num_iterations # (us) 209 | 210 | print(f'# model trainbles: {len(model.trainable_weights)}') 211 | print(f'# t_dw: {weight_count}, # t_dy: {weight_count}') 212 | 213 | if draw_figure: 214 | fig = plt.figure() 215 | plt.barh(np.arange(t_dw.shape[0]), t_dw, color ='navy') 216 | #plt.xticks(rotation=45) 217 | plt.xlabel('t_dw (us)', fontsize=20) 218 | plt.xticks(fontsize=20) 219 | plt.ylabel('Layer ID', fontsize=20) 220 | plt.yticks(fontsize=20) 221 | plt.show() 222 | 223 | fig = plt.figure() 224 | plt.barh(np.arange(t_dy.shape[0]), t_dy, color ='navy') 225 | #plt.xticks(rotation=45) 226 | plt.xlabel('t_dy (us)', fontsize=20) 227 | plt.xticks(fontsize=20) 228 | plt.ylabel('Layer ID', fontsize=20) 229 | plt.yticks(fontsize=20) 230 | plt.show() 231 | 232 | return tf.convert_to_tensor(t_dw/1000.0, tf.float32), tf.convert_to_tensor(t_dy/1000.0, tf.float32) # (ms) 233 | 234 | elif model_type == 'mobilenetv2': 235 | 236 | all_stats = [] 237 | with open(filedir) as csv_file: 238 | csv_reader = csv.reader(csv_file, delimiter=',') 239 | csv_reader.__next__() 240 | for r in csv_reader: 241 | all_stats.append(r) 242 | 243 | op_total_time = [] 244 | op_names = [] 245 | # extract gradient related ops 246 | for op_stat in all_stats: 247 | if 'gradient_tape' in op_stat[3]: 248 | op_total_time.append(float(op_stat[5])) 249 | op_names.append(op_stat[3]) 250 | 251 | base_layers = model.layers[4].layers 252 | custom_layers = model.layers[5:] 253 | model_layers = [*base_layers, *custom_layers] 254 | 255 | t_dw = [0.0 for k in range(len(model.trainable_weights))] 256 | t_dy = [0.0 for k in range(len(model.trainable_weights))] 257 | 258 | weight_count = 0 259 | 260 | for l in model_layers: 261 | # take care of the standard conv 262 | if ('Conv1' == l.name) or ('Conv_1' == l.name) or ((l.name).endswith('_project')) or ((l.name).endswith('_expand')): 263 | if l.use_bias: 264 | for op, t in zip(op_names, op_total_time): 265 | if l.name in op and 'Conv2DBackpropFilter' in op: 266 | t_dw[weight_count] = t 267 | elif l.name in op and 'Conv2DBackpropInput' in op: 268 | t_dy[weight_count] += t # include TransposeNCHWToNHWC 269 | elif l.name in op and 'BiasAddGrad' in op: 270 | t_dw[weight_count + 1] = t 271 | t_dy[weight_count + 1] = 0 272 | weight_count += 2 273 | else: 274 | for op, t in zip(op_names, op_total_time): 275 | if l.name in op and 'Conv2DBackpropFilter' in op: 276 | t_dw[weight_count] = t 277 | elif l.name in op and 'Conv2DBackpropInput' in op: 278 | t_dy[weight_count] += t # include TransposeNCHWToNHWC 279 | weight_count += 1 280 | # take care of the lightweight conv 281 | elif ((l.name).endswith('_depthwise')): 282 | if l.use_bias: 283 | for op, t in zip(op_names, op_total_time): 284 | if l.name in op and 'DepthwiseConv2dNativeBackpropFilter' in op: 285 | t_dw[weight_count] = t 286 | elif l.name in op and 'DepthwiseConv2dNativeBackpropInput' in op: 287 | t_dy[weight_count] += t # include TransposeNCHWToNHWC 288 | elif l.name in op and 'BiasAddGrad' in op: 289 | t_dw[weight_count + 1] = t 290 | t_dy[weight_count + 1] = 0 291 | weight_count += 2 292 | else: 293 | for op, t in zip(op_names, op_total_time): 294 | if l.name in op and 'DepthwiseConv2dNativeBackpropFilter' in op: 295 | t_dw[weight_count] = t 296 | elif l.name in op and 'DepthwiseConv2dNativeBackpropInput' in op: 297 | t_dy[weight_count] += t # include TransposeNCHWToNHWC 298 | weight_count += 1 299 | 300 | 301 | elif ('bn' in l.name) or ('BN' in l.name): 302 | for op, t in zip(op_names, op_total_time): 303 | if l.name in op and 'FusedBatchNormGrad' in op: 304 | # for gamma 305 | t_dw[weight_count] = 0 306 | t_dy[weight_count] = 0 307 | # for beta 308 | t_dw[weight_count + 1] = 0 309 | t_dy[weight_count + 1] = t 310 | weight_count += 2 311 | 312 | elif 'dense' in l.name: 313 | if l.use_bias: 314 | for op, t in zip(op_names, op_total_time): 315 | if l.name in op and 'MatMul/MatMul_1' in op: 316 | t_dw[weight_count] = t 317 | elif l.name in op and 'MatMul/MatMul' in op: 318 | t_dy[weight_count] = t 319 | elif l.name in op and 'BiasAddGrad' in op: 320 | t_dw[weight_count + 1] = t 321 | t_dy[weight_count + 1] = 0 322 | weight_count += 2 323 | else: 324 | for op, t in zip(op_names, op_total_time): 325 | if l.name in op and 'MatMul/MatMul_1' in op: 326 | t_dw[weight_count] = t 327 | elif l.name in op and 'MatMul/MatMul' in op: 328 | t_dy[weight_count] = t 329 | weight_count += 1 330 | 331 | else: 332 | # fuse backprop time of non-trainables to the previous trainable layer 333 | for op, t in zip(op_names, op_total_time): 334 | if l.name in op and weight_count > 0: 335 | t_dy[weight_count - 1] += t 336 | 337 | # the first layer never propagates input grads, just remove t_dy[0] 338 | # t_dy = t_dy[1:] # 1~N-1 339 | 340 | t_dw = np.array(t_dw) / num_iterations # (us) 341 | t_dy = np.array(t_dy) / num_iterations # (us) 342 | 343 | print(f'# model trainbles: {len(model.trainable_weights)}') 344 | print(f'# t_dw: {weight_count}, # t_dy: {weight_count}') 345 | 346 | if draw_figure: 347 | fig = plt.figure() 348 | plt.barh(np.arange(t_dw.shape[0]), t_dw, color ='navy') 349 | #plt.xticks(rotation=45) 350 | plt.xlabel('t_dw (us)', fontsize=20) 351 | plt.xticks(fontsize=20) 352 | plt.ylabel('Layer ID', fontsize=20) 353 | plt.yticks(fontsize=20) 354 | plt.show() 355 | 356 | fig = plt.figure() 357 | plt.barh(np.arange(t_dy.shape[0]), t_dy, color ='navy') 358 | #plt.xticks(rotation=45) 359 | plt.xlabel('t_dy (us)', fontsize=20) 360 | plt.xticks(fontsize=20) 361 | plt.ylabel('Layer ID', fontsize=20) 362 | plt.yticks(fontsize=20) 363 | plt.show() 364 | 365 | return tf.convert_to_tensor(t_dw/1000.0, tf.float32), tf.convert_to_tensor(t_dy/1000.0, tf.float32) # (ms) 366 | 367 | elif model_type == 'vit': 368 | all_stats = [] 369 | with open(filedir) as csv_file: 370 | csv_reader = csv.reader(csv_file, delimiter=',') 371 | csv_reader.__next__() 372 | for r in csv_reader: 373 | all_stats.append(r) 374 | 375 | op_total_time = [] 376 | op_names = [] 377 | # extract gradient related ops 378 | for op_stat in all_stats: 379 | if 'gradient_tape' in op_stat[3]: 380 | op_total_time.append(float(op_stat[5])) 381 | op_names.append(op_stat[3]) 382 | 383 | t_dw = [0.0 for k in range(len(model.trainable_weights))] 384 | t_dy = [0.0 for k in range(len(model.trainable_weights))] 385 | 386 | for weight_count, w in enumerate(model.trainable_weights): 387 | if 'embedding/kernel' in w.name: 388 | for op, t in zip(op_names, op_total_time): 389 | if 'embedding' in op and 'Conv2DBackpropFilter' in op: 390 | t_dw[weight_count] = t 391 | elif 'query/kernel' in w.name: 392 | s = w.name.split('/')[1] 393 | for op, t in zip(op_names, op_total_time): 394 | if s + '/MultiHeadDotProductAttention_1/query/Tensordot/MatMul/MatMul_1' in op: 395 | t_dw[weight_count] = t 396 | elif s + '/MultiHeadDotProductAttention_1/query/Tensordot/MatMul/MatMul' in op: 397 | t_dy[weight_count] = t 398 | elif 'key/kernel' in w.name: 399 | s = w.name.split('/')[1] 400 | for op, t in zip(op_names, op_total_time): 401 | if s + '/MultiHeadDotProductAttention_1/key/Tensordot/MatMul/MatMul_1' in op: 402 | t_dw[weight_count] = t 403 | elif s + '/MultiHeadDotProductAttention_1/key/Tensordot/MatMul/MatMul' in op: 404 | t_dy[weight_count] = t 405 | elif 'key/bias' in w.name: 406 | s = w.name.split('/')[1] 407 | for op, t in zip(op_names, op_total_time): 408 | if s + '/MultiHeadDotProductAttention_1/truediv/RealDiv' in op: 409 | t_dy[weight_count] += t # c = 1/sqrt(d) 410 | elif s + '/MultiHeadDotProductAttention_1/mul' in op: 411 | t_dy[weight_count] += t # *c 412 | elif s + '/MultiHeadDotProductAttention_1/MatMul/' in op: 413 | t_dy[weight_count] += t # query*key 414 | 415 | elif 'value/kernel' in w.name: 416 | s = w.name.split('/')[1] 417 | for op, t in zip(op_names, op_total_time): 418 | if s + '/MultiHeadDotProductAttention_1/value/Tensordot/MatMul/MatMul_1' in op: 419 | t_dw[weight_count] = t 420 | elif s + '/MultiHeadDotProductAttention_1/value/Tensordot/MatMul/MatMul' in op: 421 | t_dy[weight_count] = t 422 | 423 | elif 'value/bias' in w.name: 424 | s = w.name.split('/')[1] 425 | for op, t in zip(op_names, op_total_time): 426 | if s + '/MultiHeadDotProductAttention_1/MatMul_1/' in op: 427 | t_dy[weight_count] += t # att*value 428 | 429 | elif 'out/kernel' in w.name: 430 | s = w.name.split('/')[1] 431 | for op, t in zip(op_names, op_total_time): 432 | if s + '/MultiHeadDotProductAttention_1/out/Tensordot/MatMul/MatMul_1' in op: 433 | t_dw[weight_count] = t 434 | elif s + '/MultiHeadDotProductAttention_1/out/Tensordot/MatMul/MatMul' in op: 435 | t_dy[weight_count] = t 436 | 437 | elif 'Dense_0/kernel' in w.name: 438 | s = w.name.split('/')[1] 439 | for op, t in zip(op_names, op_total_time): 440 | if s + '/Dense_0/Tensordot/MatMul/MatMul_1' in op: 441 | t_dw[weight_count] = t 442 | elif s + '/Dense_0/Tensordot/MatMul/MatMul' in op: 443 | t_dy[weight_count] = t 444 | 445 | elif 'Dense_1/kernel' in w.name: 446 | s = w.name.split('/')[1] 447 | for op, t in zip(op_names, op_total_time): 448 | if s + '/Dense_1/Tensordot/MatMul/MatMul_1' in op: 449 | t_dw[weight_count] = t 450 | elif s + '/Dense_1/Tensordot/MatMul/MatMul' in op: 451 | t_dy[weight_count] = t 452 | 453 | t_dw = np.array(t_dw) / num_iterations # (us) 454 | t_dy = np.array(t_dy) / num_iterations # (us) 455 | 456 | print(f'# model trainbles: {len(model.trainable_weights)}') 457 | print(f'# t_dw: {weight_count+1}, # t_dy: {weight_count+1}') 458 | 459 | if draw_figure: 460 | fig = plt.figure() 461 | plt.barh(np.arange(t_dw.shape[0]), t_dw, color ='navy') 462 | #plt.xticks(rotation=45) 463 | plt.xlabel('t_dw (us)', fontsize=20) 464 | plt.xticks(fontsize=20) 465 | plt.ylabel('Layer ID', fontsize=20) 466 | plt.yticks(fontsize=20) 467 | plt.show() 468 | 469 | fig = plt.figure() 470 | plt.barh(np.arange(t_dy.shape[0]), t_dy, color ='navy') 471 | #plt.xticks(rotation=45) 472 | plt.xlabel('t_dy (us)', fontsize=20) 473 | plt.xticks(fontsize=20) 474 | plt.ylabel('Layer ID', fontsize=20) 475 | plt.yticks(fontsize=20) 476 | plt.show() 477 | 478 | return tf.convert_to_tensor(t_dw/1000.0, tf.float32), tf.convert_to_tensor(t_dy/1000.0, tf.float32) # (ms) 479 | 480 | else: 481 | raise NotImplementedError("This model has not been implemented yet") 482 | 483 | 484 | def main(): 485 | parser = argparse.ArgumentParser(description='Tensor timing profiling') 486 | parser.add_argument('--model_name', type=str, default='resnet50', help='valid model names are resnet50, vgg16, mobilenetv2, vit') 487 | parser.add_argument('--input_size', type=int, default=224, help='input resolution, e.g., 224 stands for 224x224') 488 | parser.add_argument('--batch_size', type=int, default=4, help='batch size used to run during profiling') 489 | parser.add_argument('--num_classes', type=int, default=200, help='number of categories model can classify') 490 | # parser.add_argument('--num_iterations', type=int, default=5, help='number of backward passes to run during profiling') 491 | 492 | args = parser.parse_args() 493 | 494 | model_name = args.model_name 495 | input_size = args.input_size 496 | batch_size = args.batch_size 497 | num_classes = args.num_classes 498 | 499 | model = port_pretrained_models( 500 | model_type=model_name, 501 | input_shape=(input_size, input_size, 3), 502 | num_classes=num_classes, 503 | ) 504 | 505 | run_name = model_name + '_' + str(input_size) + '_' + str(num_classes) + '_' + str(batch_size) + '_' + 'profile' 506 | 507 | profile_backpropagation( 508 | model, 509 | (input_size, input_size, 3), 510 | batch_size, 511 | 5, 512 | 'logs/' + run_name, 513 | ) 514 | 515 | convert_pb_to_csv('logs/' + run_name, 'profile_extracted/' + run_name) 516 | 517 | t_dw, t_dy = profile_parser( 518 | model, 519 | model_name, 520 | 5, 521 | 'profile_extracted/' + run_name, 522 | draw_figure=False, 523 | ) 524 | 525 | # for w, t1, t2 in zip(model.trainable_weights, t_dw, t_dy): 526 | # print(w.name, t1.numpy(), t2.numpy()) 527 | 528 | if __name__ == '__main__': 529 | main() 530 | -------------------------------------------------------------------------------- /run_demo.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # ResNet50, caltech_birds2011, ElasticTrainer 4 | python3 profiler.py --model_name resnet50 --num_classes 200 5 | 6 | python3 main.py --model_name resnet50 --dataset_name caltech_birds2011 --train_type elastic_training 7 | 8 | # ResNet50, oxford_iiit_pet, ElasticTrainer 9 | python3 profiler.py --model_name resnet50 --num_classes 37 10 | 11 | python3 main.py --model_name resnet50 --dataset_name oxford_iiit_pet --train_type elastic_training 12 | 13 | # ResNet50, stanford_dogs, ElasticTrainer 14 | python3 profiler.py --model_name resnet50 --num_classes 120 15 | 16 | python3 main.py --model_name resnet50 --dataset_name stanford_dogs --train_type elastic_training -------------------------------------------------------------------------------- /run_figure15.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # caltech_birds2011 4 | python3 profiler.py --model_name resnet50 \ 5 | --num_classes 200 6 | 7 | python3 main.py --model_name resnet50 \ 8 | --dataset_name caltech_birds2011 \ 9 | --train_type elastic_training \ 10 | --run_name CUB200_ElasticTrainer 11 | 12 | python3 main.py --model_name resnet50 \ 13 | --dataset_name caltech_birds2011 \ 14 | --train_type full_training \ 15 | --run_name CUB200_Full_training 16 | 17 | python3 main.py --model_name resnet50 \ 18 | --dataset_name caltech_birds2011 \ 19 | --train_type traditional_tl_training \ 20 | --run_name CUB200_Traditional_TL 21 | 22 | python3 main.py --model_name resnet50 \ 23 | --dataset_name caltech_birds2011 \ 24 | --train_type bn_plus_bias_training \ 25 | --run_name CUB200_BN+Bias 26 | 27 | python3 plot_curves.py --x_tag wall_time \ 28 | --y_tag test/accuracy \ 29 | --single False \ 30 | --elastic_trainer_path CUB200_ElasticTrainer \ 31 | --full_training_path CUB200_Full_training \ 32 | --traditional_tl_path CUB200_Traditional_TL \ 33 | --bn_plus_bias_path CUB200_BN+Bias \ 34 | --figure_id 1 \ 35 | --figure_name Figure_15_a.pdf 36 | 37 | python3 plot_curves.py --x_tag wall_time \ 38 | --y_tag test/classification_loss \ 39 | --single False \ 40 | --elastic_trainer_path CUB200_ElasticTrainer \ 41 | --full_training_path CUB200_Full_training \ 42 | --traditional_tl_path CUB200_Traditional_TL \ 43 | --bn_plus_bias_path CUB200_BN+Bias \ 44 | --figure_id 2 \ 45 | --figure_name Figure_15_d.pdf 46 | 47 | # oxford_iiit_pet 48 | python3 profiler.py --model_name resnet50 \ 49 | --num_classes 37 50 | 51 | python3 main.py --model_name resnet50 \ 52 | --dataset_name oxford_iiit_pet \ 53 | --train_type elastic_training \ 54 | --run_name PET37_ElasticTrainer 55 | 56 | python3 main.py --model_name resnet50 \ 57 | --dataset_name oxford_iiit_pet \ 58 | --train_type full_training \ 59 | --run_name PET37_Full_training 60 | 61 | python3 main.py --model_name resnet50 \ 62 | --dataset_name oxford_iiit_pet \ 63 | --train_type traditional_tl_training \ 64 | --run_name PET37_Traditional_TL 65 | 66 | python3 main.py --model_name resnet50 \ 67 | --dataset_name oxford_iiit_pet \ 68 | --train_type bn_plus_bias_training \ 69 | --run_name PET37_BN+Bias 70 | 71 | python3 plot_curves.py --x_tag wall_time \ 72 | --y_tag test/accuracy \ 73 | --single False \ 74 | --elastic_trainer_path PET37_ElasticTrainer \ 75 | --full_training_path PET37_Full_training \ 76 | --traditional_tl_path PET37_Traditional_TL \ 77 | --bn_plus_bias_path PET37_BN+Bias \ 78 | --figure_id 3 \ 79 | --figure_name Figure_15_b.pdf 80 | 81 | python3 plot_curves.py --x_tag wall_time \ 82 | --y_tag test/classification_loss \ 83 | --single False \ 84 | --elastic_trainer_path PET37_ElasticTrainer \ 85 | --full_training_path PET37_Full_training \ 86 | --traditional_tl_path PET37_Traditional_TL \ 87 | --bn_plus_bias_path PET37_BN+Bias \ 88 | --figure_id 4 \ 89 | --figure_name Figure_15_e.pdf 90 | 91 | # stanford_dogs 92 | python3 profiler.py --model_name resnet50 \ 93 | --num_classes 120 94 | 95 | python3 main.py --model_name resnet50 \ 96 | --dataset_name stanford_dogs \ 97 | --train_type elastic_training \ 98 | --run_name DOG120_ElasticTrainer 99 | 100 | python3 main.py --model_name resnet50 \ 101 | --dataset_name stanford_dogs \ 102 | --train_type full_training \ 103 | --run_name DOG120_Full_training 104 | 105 | python3 main.py --model_name resnet50 \ 106 | --dataset_name stanford_dogs \ 107 | --train_type traditional_tl_training \ 108 | --run_name DOG120_Traditional_TL 109 | 110 | python3 main.py --model_name resnet50 \ 111 | --dataset_name stanford_dogs \ 112 | --train_type bn_plus_bias_training \ 113 | --run_name DOG120_BN+Bias 114 | 115 | python3 plot_curves.py --x_tag wall_time \ 116 | --y_tag test/accuracy \ 117 | --single False \ 118 | --elastic_trainer_path DOG120_ElasticTrainer \ 119 | --full_training_path DOG120_Full_training \ 120 | --traditional_tl_path DOG120_Traditional_TL \ 121 | --bn_plus_bias_path DOG120_BN+Bias \ 122 | --figure_id 5 \ 123 | --figure_name Figure_15_c.pdf 124 | 125 | python3 plot_curves.py --x_tag wall_time \ 126 | --y_tag test/classification_loss \ 127 | --single False \ 128 | --elastic_trainer_path DOG120_ElasticTrainer \ 129 | --full_training_path DOG120_Full_training \ 130 | --traditional_tl_path DOG120_Traditional_TL \ 131 | --bn_plus_bias_path DOG120_BN+Bias \ 132 | --figure_id 6 \ 133 | --figure_name Figure_15_f.pdf -------------------------------------------------------------------------------- /run_figure15_ego.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # caltech_birds2011 4 | python3 profiler.py --model_name resnet50 \ 5 | --num_classes 200 6 | 7 | python3 main.py --model_name resnet50 \ 8 | --dataset_name caltech_birds2011 \ 9 | --train_type elastic_training \ 10 | --run_name CUB200_ElasticTrainer 11 | 12 | python3 plot_curves.py --x_tag wall_time \ 13 | --y_tag test/accuracy \ 14 | --single True \ 15 | --elastic_trainer_path CUB200_ElasticTrainer \ 16 | --figure_id 1 \ 17 | --figure_name Figure_15_a_ego.pdf 18 | 19 | python3 plot_curves.py --x_tag wall_time \ 20 | --y_tag test/classification_loss \ 21 | --single True \ 22 | --elastic_trainer_path CUB200_ElasticTrainer \ 23 | --figure_id 2 \ 24 | --figure_name Figure_15_d_ego.pdf 25 | 26 | # oxford_iiit_pet 27 | python3 profiler.py --model_name resnet50 \ 28 | --num_classes 37 29 | 30 | python3 main.py --model_name resnet50 \ 31 | --dataset_name oxford_iiit_pet \ 32 | --train_type elastic_training \ 33 | --run_name PET37_ElasticTrainer 34 | 35 | python3 plot_curves.py --x_tag wall_time \ 36 | --y_tag test/accuracy \ 37 | --single True \ 38 | --elastic_trainer_path PET37_ElasticTrainer \ 39 | --figure_id 3 \ 40 | --figure_name Figure_15_b_ego.pdf 41 | 42 | python3 plot_curves.py --x_tag wall_time \ 43 | --y_tag test/classification_loss \ 44 | --single True \ 45 | --elastic_trainer_path PET37_ElasticTrainer \ 46 | --figure_id 4 \ 47 | --figure_name Figure_15_e_ego.pdf 48 | 49 | # stanford_dogs 50 | python3 profiler.py --model_name resnet50 \ 51 | --num_classes 120 52 | 53 | python3 main.py --model_name resnet50 \ 54 | --dataset_name stanford_dogs \ 55 | --train_type elastic_training \ 56 | --run_name DOG120_ElasticTrainer 57 | 58 | python3 plot_curves.py --x_tag wall_time \ 59 | --y_tag test/accuracy \ 60 | --single True \ 61 | --elastic_trainer_path DOG120_ElasticTrainer \ 62 | --figure_id 5 \ 63 | --figure_name Figure_15_c_ego.pdf 64 | 65 | python3 plot_curves.py --x_tag wall_time \ 66 | --y_tag test/classification_loss \ 67 | --single True \ 68 | --elastic_trainer_path DOG120_ElasticTrainer \ 69 | --figure_id 6 \ 70 | --figure_name Figure_15_f_ego.pdf -------------------------------------------------------------------------------- /run_figure15ad.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # caltech_birds2011 4 | python3 profiler.py --model_name resnet50 \ 5 | --num_classes 200 6 | 7 | python3 main.py --model_name resnet50 \ 8 | --dataset_name caltech_birds2011 \ 9 | --train_type elastic_training \ 10 | --run_name CUB200_ElasticTrainer 11 | 12 | python3 main.py --model_name resnet50 \ 13 | --dataset_name caltech_birds2011 \ 14 | --train_type full_training \ 15 | --run_name CUB200_Full_training 16 | 17 | python3 main.py --model_name resnet50 \ 18 | --dataset_name caltech_birds2011 \ 19 | --train_type traditional_tl_training \ 20 | --run_name CUB200_Traditional_TL 21 | 22 | python3 main.py --model_name resnet50 \ 23 | --dataset_name caltech_birds2011 \ 24 | --train_type bn_plus_bias_training \ 25 | --run_name CUB200_BN+Bias 26 | 27 | python3 plot_curves.py --x_tag wall_time \ 28 | --y_tag test/accuracy \ 29 | --single False \ 30 | --elastic_trainer_path CUB200_ElasticTrainer \ 31 | --full_training_path CUB200_Full_training \ 32 | --traditional_tl_path CUB200_Traditional_TL \ 33 | --bn_plus_bias_path CUB200_BN+Bias \ 34 | --figure_id 1 \ 35 | --figure_name Figure_15_a.pdf 36 | 37 | python3 plot_curves.py --x_tag wall_time \ 38 | --y_tag test/classification_loss \ 39 | --single False \ 40 | --elastic_trainer_path CUB200_ElasticTrainer \ 41 | --full_training_path CUB200_Full_training \ 42 | --traditional_tl_path CUB200_Traditional_TL \ 43 | --bn_plus_bias_path CUB200_BN+Bias \ 44 | --figure_id 2 \ 45 | --figure_name Figure_15_d.pdf -------------------------------------------------------------------------------- /run_figure15ad_ego.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # caltech_birds2011 4 | python3 profiler.py --model_name resnet50 \ 5 | --num_classes 200 6 | 7 | python3 main.py --model_name resnet50 \ 8 | --dataset_name caltech_birds2011 \ 9 | --train_type elastic_training \ 10 | --run_name CUB200_ElasticTrainer 11 | 12 | python3 plot_curves.py --x_tag wall_time \ 13 | --y_tag test/accuracy \ 14 | --single True \ 15 | --elastic_trainer_path CUB200_ElasticTrainer \ 16 | --figure_id 1 \ 17 | --figure_name Figure_15_a_ego.pdf 18 | 19 | python3 plot_curves.py --x_tag wall_time \ 20 | --y_tag test/classification_loss \ 21 | --single True \ 22 | --elastic_trainer_path CUB200_ElasticTrainer \ 23 | --figure_id 2 \ 24 | --figure_name Figure_15_d_ego.pdf -------------------------------------------------------------------------------- /run_figure16.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # caltech_birds2011 4 | python3 profiler.py --model_name resnet50 \ 5 | --num_classes 200 6 | 7 | python3 main.py --model_name resnet50 \ 8 | --dataset_name caltech_birds2011 \ 9 | --train_type elastic_training \ 10 | --run_name CUB200_ElasticTrainer 11 | 12 | python3 main.py --model_name resnet50 \ 13 | --dataset_name caltech_birds2011 \ 14 | --train_type full_training \ 15 | --run_name CUB200_Full_training 16 | 17 | python3 main.py --model_name resnet50 \ 18 | --dataset_name caltech_birds2011 \ 19 | --train_type traditional_tl_training \ 20 | --run_name CUB200_Traditional_TL 21 | 22 | python3 main.py --model_name resnet50 \ 23 | --dataset_name caltech_birds2011 \ 24 | --train_type bn_plus_bias_training \ 25 | --run_name CUB200_BN+Bias 26 | 27 | python3 plot_curves.py --x_tag wall_time \ 28 | --y_tag test/accuracy \ 29 | --single False \ 30 | --elastic_trainer_path CUB200_ElasticTrainer \ 31 | --full_training_path CUB200_Full_training \ 32 | --traditional_tl_path CUB200_Traditional_TL \ 33 | --bn_plus_bias_path CUB200_BN+Bias \ 34 | --figure_id 1 \ 35 | --figure_name Figure_16_a.pdf 36 | 37 | python3 plot_curves.py --x_tag wall_time \ 38 | --y_tag test/classification_loss \ 39 | --single False \ 40 | --elastic_trainer_path CUB200_ElasticTrainer \ 41 | --full_training_path CUB200_Full_training \ 42 | --traditional_tl_path CUB200_Traditional_TL \ 43 | --bn_plus_bias_path CUB200_BN+Bias \ 44 | --figure_id 2 \ 45 | --figure_name Figure_16_c.pdf 46 | 47 | # oxford_iiit_pet 48 | python3 profiler.py --model_name resnet50 \ 49 | --num_classes 37 50 | 51 | python3 main.py --model_name resnet50 \ 52 | --dataset_name oxford_iiit_pet \ 53 | --train_type elastic_training \ 54 | --run_name PET37_ElasticTrainer 55 | 56 | python3 main.py --model_name resnet50 \ 57 | --dataset_name oxford_iiit_pet \ 58 | --train_type full_training \ 59 | --run_name PET37_Full_training 60 | 61 | python3 main.py --model_name resnet50 \ 62 | --dataset_name oxford_iiit_pet \ 63 | --train_type traditional_tl_training \ 64 | --run_name PET37_Traditional_TL 65 | 66 | python3 main.py --model_name resnet50 \ 67 | --dataset_name oxford_iiit_pet \ 68 | --train_type bn_plus_bias_training \ 69 | --run_name PET37_BN+Bias 70 | 71 | python3 plot_curves.py --x_tag wall_time \ 72 | --y_tag test/accuracy \ 73 | --single False \ 74 | --elastic_trainer_path PET37_ElasticTrainer \ 75 | --full_training_path PET37_Full_training \ 76 | --traditional_tl_path PET37_Traditional_TL \ 77 | --bn_plus_bias_path PET37_BN+Bias \ 78 | --figure_id 3 \ 79 | --figure_name Figure_16_b.pdf 80 | 81 | python3 plot_curves.py --x_tag wall_time \ 82 | --y_tag test/classification_loss \ 83 | --single False \ 84 | --elastic_trainer_path PET37_ElasticTrainer \ 85 | --full_training_path PET37_Full_training \ 86 | --traditional_tl_path PET37_Traditional_TL \ 87 | --bn_plus_bias_path PET37_BN+Bias \ 88 | --figure_id 4 \ 89 | --figure_name Figure_16_d.pdf -------------------------------------------------------------------------------- /run_figure16_ego.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # caltech_birds2011 4 | python3 profiler.py --model_name resnet50 \ 5 | --num_classes 200 6 | 7 | python3 main.py --model_name resnet50 \ 8 | --dataset_name caltech_birds2011 \ 9 | --train_type elastic_training \ 10 | --run_name CUB200_ElasticTrainer 11 | 12 | python3 plot_curves.py --x_tag wall_time \ 13 | --y_tag test/accuracy \ 14 | --single True \ 15 | --elastic_trainer_path CUB200_ElasticTrainer \ 16 | --figure_id 1 \ 17 | --figure_name Figure_16_a_ego.pdf 18 | 19 | python3 plot_curves.py --x_tag wall_time \ 20 | --y_tag test/classification_loss \ 21 | --single True \ 22 | --elastic_trainer_path CUB200_ElasticTrainer \ 23 | --figure_id 2 \ 24 | --figure_name Figure_16_c_ego.pdf 25 | 26 | # oxford_iiit_pet 27 | python3 profiler.py --model_name resnet50 \ 28 | --num_classes 37 29 | 30 | python3 main.py --model_name resnet50 \ 31 | --dataset_name oxford_iiit_pet \ 32 | --train_type elastic_training \ 33 | --run_name PET37_ElasticTrainer 34 | 35 | python3 plot_curves.py --x_tag wall_time \ 36 | --y_tag test/accuracy \ 37 | --single True \ 38 | --elastic_trainer_path PET37_ElasticTrainer \ 39 | --figure_id 3 \ 40 | --figure_name Figure_16_b_ego.pdf 41 | 42 | python3 plot_curves.py --x_tag wall_time \ 43 | --y_tag test/classification_loss \ 44 | --single True \ 45 | --elastic_trainer_path PET37_ElasticTrainer \ 46 | --figure_id 4 \ 47 | --figure_name Figure_16_d_ego.pdf -------------------------------------------------------------------------------- /run_figure17ac.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # caltech_birds2011 4 | python3 profiler.py --model_name resnet50 \ 5 | --num_classes 200 6 | 7 | python3 main.py --model_name resnet50 \ 8 | --dataset_name caltech_birds2011 \ 9 | --train_type elastic_training \ 10 | --run_name CUB200_ElasticTrainer_035 \ 11 | --rho 0.367 \ 12 | --save_txt True 13 | 14 | python3 main.py --model_name resnet50 \ 15 | --dataset_name caltech_birds2011 \ 16 | --train_type elastic_training \ 17 | --run_name CUB200_ElasticTrainer_040 \ 18 | --rho 0.4 \ 19 | --save_txt True 20 | 21 | python3 main.py --model_name resnet50 \ 22 | --dataset_name caltech_birds2011 \ 23 | --train_type elastic_training \ 24 | --run_name CUB200_ElasticTrainer_050 \ 25 | --rho 0.533 \ 26 | --save_txt True 27 | 28 | python3 main.py --model_name resnet50 \ 29 | --dataset_name caltech_birds2011 \ 30 | --train_type elastic_training \ 31 | --run_name CUB200_ElasticTrainer_060 \ 32 | --rho 0.6 \ 33 | --save_txt True 34 | 35 | python3 main.py --model_name resnet50 \ 36 | --dataset_name caltech_birds2011 \ 37 | --train_type elastic_training \ 38 | --run_name CUB200_ElasticTrainer_070 \ 39 | --rho 0.7 \ 40 | --save_txt True 41 | 42 | python3 main.py --model_name resnet50 \ 43 | --dataset_name caltech_birds2011 \ 44 | --train_type full_training \ 45 | --run_name CUB200_Full_training \ 46 | --save_txt True 47 | 48 | python3 main.py --model_name resnet50 \ 49 | --dataset_name caltech_birds2011 \ 50 | --train_type traditional_tl_training \ 51 | --run_name CUB200_Traditional_TL \ 52 | --save_txt True 53 | 54 | python3 main.py --model_name resnet50 \ 55 | --dataset_name caltech_birds2011 \ 56 | --train_type bn_plus_bias_training \ 57 | --run_name CUB200_BN+Bias \ 58 | --save_txt True 59 | 60 | python3 main.py --model_name resnet50 \ 61 | --dataset_name caltech_birds2011 \ 62 | --train_type prunetrain \ 63 | --run_name CUB200_PruneTrain \ 64 | --save_txt True 65 | 66 | python3 plot_bars_v1.py --path_to_rho35 CUB200_ElasticTrainer_035 \ 67 | --path_to_rho40 CUB200_ElasticTrainer_040 \ 68 | --path_to_rho50 CUB200_ElasticTrainer_050 \ 69 | --path_to_rho60 CUB200_ElasticTrainer_060 \ 70 | --path_to_rho70 CUB200_ElasticTrainer_070 \ 71 | --path_to_full_training CUB200_Full_training \ 72 | --path_to_traditional_tl CUB200_Traditional_TL \ 73 | --path_to_bn_plus_bias CUB200_BN+Bias \ 74 | --path_to_prunetrain CUB200_PruneTrain \ 75 | --figure_id 1 \ 76 | --figure_name Figure_17_ac.pdf \ 77 | --ego False -------------------------------------------------------------------------------- /run_figure17ac_ego.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # caltech_birds2011 4 | python3 profiler.py --model_name resnet50 \ 5 | --num_classes 200 6 | 7 | python3 main.py --model_name resnet50 \ 8 | --dataset_name caltech_birds2011 \ 9 | --train_type elastic_training \ 10 | --run_name CUB200_ElasticTrainer_035 \ 11 | --rho 0.367 \ 12 | --save_txt True 13 | 14 | python3 main.py --model_name resnet50 \ 15 | --dataset_name caltech_birds2011 \ 16 | --train_type elastic_training \ 17 | --run_name CUB200_ElasticTrainer_040 \ 18 | --rho 0.4 \ 19 | --save_txt True 20 | 21 | python3 main.py --model_name resnet50 \ 22 | --dataset_name caltech_birds2011 \ 23 | --train_type elastic_training \ 24 | --run_name CUB200_ElasticTrainer_050 \ 25 | --rho 0.533 \ 26 | --save_txt True 27 | 28 | python3 main.py --model_name resnet50 \ 29 | --dataset_name caltech_birds2011 \ 30 | --train_type elastic_training \ 31 | --run_name CUB200_ElasticTrainer_060 \ 32 | --rho 0.6 \ 33 | --save_txt True 34 | 35 | python3 main.py --model_name resnet50 \ 36 | --dataset_name caltech_birds2011 \ 37 | --train_type elastic_training \ 38 | --run_name CUB200_ElasticTrainer_070 \ 39 | --rho 0.7 \ 40 | --save_txt True 41 | 42 | python3 plot_bars_v1.py --path_to_rho35 CUB200_ElasticTrainer_035 \ 43 | --path_to_rho40 CUB200_ElasticTrainer_040 \ 44 | --path_to_rho50 CUB200_ElasticTrainer_050 \ 45 | --path_to_rho60 CUB200_ElasticTrainer_060 \ 46 | --path_to_rho70 CUB200_ElasticTrainer_070 \ 47 | --figure_id 1 \ 48 | --figure_name Figure_17_ac_ego.pdf \ 49 | --ego True -------------------------------------------------------------------------------- /run_figure19.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # caltech_birds2011 resnet50 4 | python3 profiler.py --model_name resnet50 \ 5 | --num_classes 200 6 | 7 | python3 main.py --model_name resnet50 \ 8 | --dataset_name caltech_birds2011 \ 9 | --train_type elastic_training \ 10 | --run_name resnet50_ElasticTrainer \ 11 | --rho 0.4 \ 12 | --save_txt True 13 | 14 | python3 main.py --model_name resnet50 \ 15 | --dataset_name caltech_birds2011 \ 16 | --train_type full_training \ 17 | --run_name resnet50_Full_training \ 18 | --save_txt True 19 | 20 | python3 main.py --model_name resnet50 \ 21 | --dataset_name caltech_birds2011 \ 22 | --train_type traditional_tl_training \ 23 | --run_name resnet50_Traditional_TL \ 24 | --save_txt True 25 | 26 | python3 main.py --model_name resnet50 \ 27 | --dataset_name caltech_birds2011 \ 28 | --train_type bn_plus_bias_training \ 29 | --run_name resnet50_BN+Bias \ 30 | --save_txt True 31 | 32 | # caltech_birds2011 vgg16 33 | python3 profiler.py --model_name vgg16 \ 34 | --num_classes 200 35 | 36 | python3 main.py --model_name vgg16 \ 37 | --dataset_name caltech_birds2011 \ 38 | --train_type elastic_training \ 39 | --run_name vgg16_ElasticTrainer \ 40 | --rho 0.4 \ 41 | --save_txt True 42 | 43 | python3 main.py --model_name vgg16 \ 44 | --dataset_name caltech_birds2011 \ 45 | --train_type full_training \ 46 | --run_name vgg16_Full_training \ 47 | --save_txt True 48 | 49 | python3 main.py --model_name vgg16 \ 50 | --dataset_name caltech_birds2011 \ 51 | --train_type traditional_tl_training \ 52 | --run_name vgg16_Traditional_TL \ 53 | --save_txt True 54 | 55 | python3 main.py --model_name vgg16 \ 56 | --dataset_name caltech_birds2011 \ 57 | --train_type bn_plus_bias_training \ 58 | --run_name vgg16_BN+Bias \ 59 | --save_txt True 60 | 61 | # caltech_birds2011 mobilenetv2 62 | python3 profiler.py --model_name mobilenetv2 \ 63 | --num_classes 200 64 | 65 | python3 main.py --model_name mobilenetv2 \ 66 | --dataset_name caltech_birds2011 \ 67 | --train_type elastic_training \ 68 | --run_name mobilenetv2_ElasticTrainer \ 69 | --rho 0.4 \ 70 | --save_txt True 71 | 72 | python3 main.py --model_name mobilenetv2 \ 73 | --dataset_name caltech_birds2011 \ 74 | --train_type full_training \ 75 | --run_name mobilenetv2_Full_training \ 76 | --num_epochs 24 \ 77 | --save_txt True 78 | 79 | python3 main.py --model_name mobilenetv2 \ 80 | --dataset_name caltech_birds2011 \ 81 | --train_type traditional_tl_training \ 82 | --run_name mobilenetv2_Traditional_TL \ 83 | --save_txt True 84 | 85 | python3 main.py --model_name mobilenetv2 \ 86 | --dataset_name caltech_birds2011 \ 87 | --train_type bn_plus_bias_training \ 88 | --run_name mobilenetv2_BN+Bias \ 89 | --save_txt True 90 | 91 | python3 plot_bars_v2.py --path_to_elastic_trainer_resnet50 resnet50_ElasticTrainer \ 92 | --path_to_elastic_trainer_vgg16 vgg16_ElasticTrainer \ 93 | --path_to_elastic_trainer_mobilenetv2 mobilenetv2_ElasticTrainer \ 94 | --path_to_full_training_resnet50 resnet50_Full_training \ 95 | --path_to_full_training_vgg16 vgg16_Full_training \ 96 | --path_to_full_training_mobilenetv2 mobilenetv2_Full_training \ 97 | --path_to_traditional_tl_resnet50 resnet50_Traditional_TL \ 98 | --path_to_traditional_tl_vgg16 vgg16_Traditional_TL \ 99 | --path_to_traditional_tl_mobilenetv2 mobilenetv2_Traditional_TL \ 100 | --path_to_bn_plus_bias_resnet50 resnet50_BN+Bias \ 101 | --path_to_bn_plus_bias_vgg16 vgg16_BN+Bias \ 102 | --path_to_bn_plus_bias_mobilenetv2 mobilenetv2_BN+Bias \ 103 | --figure_id 1 \ 104 | --figure_name Figure_19 \ 105 | --ego False 106 | -------------------------------------------------------------------------------- /run_figure19_ego.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # caltech_birds2011 resnet50 4 | python3 profiler.py --model_name resnet50 \ 5 | --num_classes 200 6 | 7 | python3 main.py --model_name resnet50 \ 8 | --dataset_name caltech_birds2011 \ 9 | --train_type elastic_training \ 10 | --run_name resnet50_ElasticTrainer \ 11 | --rho 0.4 \ 12 | --save_txt True 13 | 14 | # caltech_birds2011 vgg16 15 | python3 profiler.py --model_name vgg16 \ 16 | --num_classes 200 17 | 18 | python3 main.py --model_name vgg16 \ 19 | --dataset_name caltech_birds2011 \ 20 | --train_type elastic_training \ 21 | --run_name vgg16_ElasticTrainer \ 22 | --rho 0.4 \ 23 | --save_txt True 24 | 25 | # caltech_birds2011 mobilenetv2 26 | python3 profiler.py --model_name mobilenetv2 \ 27 | --num_classes 200 28 | 29 | python3 main.py --model_name mobilenetv2 \ 30 | --dataset_name caltech_birds2011 \ 31 | --train_type elastic_training \ 32 | --run_name mobilenetv2_ElasticTrainer \ 33 | --rho 0.4 \ 34 | --save_txt True 35 | 36 | python3 plot_bars_v2.py --path_to_elastic_trainer_resnet50 resnet50_ElasticTrainer \ 37 | --path_to_elastic_trainer_vgg16 vgg16_ElasticTrainer \ 38 | --path_to_elastic_trainer_mobilenetv2 mobilenetv2_ElasticTrainer \ 39 | --figure_id 1 \ 40 | --figure_name Figure_19_ego \ 41 | --ego True -------------------------------------------------------------------------------- /saved_models/README.md: -------------------------------------------------------------------------------- 1 | This folder stores trained models. -------------------------------------------------------------------------------- /selection_solver_DP.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import time 3 | 4 | 5 | def selection_DP(t_dy, t_dw, I, rho=0.3): 6 | """ 7 | Solving layer selection problem via dynamic programming 8 | 9 | Args: 10 | t_dy (np.array int16): downscaled t_dy [N,] 11 | t_dw (np.array int16): downscaled t_dw [N,] 12 | I (np.array float32): per-layer contribution to loss drop [N,] 13 | rho (float32): backprop timesaving ratio 14 | """ 15 | 16 | # Initialize the memo tables of subproblems 17 | N = t_dw.shape[0] # number of NN layers 18 | T = np.sum(t_dw + t_dy) # maximally possible BP time 19 | T_limit = int(rho * T) 20 | t_dy_cumsum = 0 21 | for k in range(N): 22 | t_dy_cumsum += t_dy[k] 23 | if t_dy_cumsum > T_limit: 24 | break 25 | N_limit = k 26 | # Infinite importance 27 | MINIMAL_IMPORTANCE = -99999.0 28 | # L[k, t] - maximum cumulative importance when: 29 | # 1. selectively training within last k layers, 30 | # 2. achieving BP time at most t 31 | L_memo = np.zeros(shape=(N_limit + 1, T_limit + 1), dtype=np.float32) 32 | L_memo[0, 0] = 0 33 | #L_memo[0, 1:] = MINIMAL_IMPORTANCE 34 | 35 | # M[k, t, :] - solution to subproblem L[k, t] 36 | M_memo = np.zeros(shape=(N_limit + 1, T_limit + 1, N), dtype=np.uint8) 37 | 38 | S_memo = np.zeros(shape=(N_limit + 1, T_limit + 1), dtype=np.uint8) 39 | S_memo[0, 0] = 1 40 | S_memo[1:, 0] = 1 41 | S_memo[0, 1:] = 1 42 | 43 | max_importance = MINIMAL_IMPORTANCE 44 | k_final, t_final = 0, 0 45 | # Solving all the subproblems recursively 46 | for k in range(1, N_limit + 1): 47 | for t in range(0, T_limit + 1): 48 | # Subproblem 1: 49 | # If layer k-1 is NOT selected 50 | # --> no BP time increase 51 | # --> no importance increase 52 | l_skip_curr_layer = L_memo[k - 1, t] 53 | 54 | # Subproblem 2: 55 | # If layer k-1 is selected 56 | # --> BP time increases dt = t_dw[k - 1] + sum(t_dy[k-2 : n]) 57 | opt_k = -1 58 | opt_t = -1 59 | l_max = l_skip_curr_layer 60 | t_p = t - t_dw[k - 1] 61 | # traverse from layer k-1 to the beginning 62 | for k_p in range(k - 1, -1, -1): 63 | t_p -= t_dy[k_p] 64 | if t_p >= 0 and S_memo[k_p, t_p] == 1: 65 | l_candidate = L_memo[k_p, t_p] + I[k - 1] 66 | if l_candidate > l_max: 67 | opt_k = k_p 68 | opt_t = t_p 69 | l_max = l_candidate 70 | 71 | # make sure valid solution found by checking integer variable 72 | if opt_k >= 0: 73 | L_memo[k, t] = l_max 74 | M_memo[k, t, :(k - 1)] = M_memo[opt_k, opt_t, :(k - 1)] 75 | M_memo[k, t, k - 1] = 1 76 | S_memo[k, t] = 1 77 | # no valid solution from backtrace or no larger than not selecting 78 | else: 79 | L_memo[k, t] = l_skip_curr_layer 80 | M_memo[k, t, :(k - 1)] = M_memo[k - 1, t, :(k - 1)] 81 | M_memo[k, t, k - 1] = 0 82 | S_memo[k, t] = 0 83 | 84 | if l_max > max_importance: 85 | max_importance = L_memo[k, t] 86 | k_final, t_final = k, t 87 | 88 | M_sol = M_memo[k_final, t_final, :] 89 | return max_importance, M_sol 90 | 91 | def downscale_t_dy_and_t_dw(t_dy, t_dw, Tq=1e3): 92 | T = np.sum(t_dw + t_dy) 93 | scale = Tq / T 94 | t_dy_q = np.floor(t_dy * scale).astype(np.int16) 95 | t_dw_q = np.floor(t_dw * scale).astype(np.int16) 96 | disco = 1.0 * np.sum(t_dy_q + t_dw_q) / Tq 97 | return t_dy_q, t_dw_q, disco 98 | 99 | def simple_test(): 100 | t_dy = np.array([0, 2, 1, 4, 0]) 101 | t_dw = np.array([5, 1, 7, 3, 1]) 102 | I = np.array([1., 3., 10., 5., 3.]) 103 | max_importance, M_sol = selection_DP(t_dy, t_dw, I, rho=0.6) 104 | print(max_importance) 105 | print(M_sol) 106 | 107 | def main(): 108 | I = np.loadtxt('importance.out') 109 | # print(I) 110 | t_dw = np.loadtxt('t_dw.out') 111 | t_dy = np.loadtxt('t_dy.out') 112 | t_dy_q, t_dw_q, disco = downscale_t_dy_and_t_dw(t_dy, t_dw, Tq=1e3) 113 | print('t_dy_q:', t_dy_q) 114 | print('t_dw_q:', t_dw_q) 115 | t_dy_q = np.flip(t_dy_q) 116 | t_dw_q = np.flip(t_dw_q) 117 | I = np.flip(I) 118 | t1 = time.time() 119 | max_importance, M_sol = selection_DP(t_dy_q, t_dw_q, I, rho=0.3*disco) 120 | t2 = time.time() 121 | print("t:", t2-t1) 122 | M_sol = np.flip(M_sol) 123 | print('max_I:', max_importance) 124 | print('m:', M_sol) 125 | print('%T_sel:', 100 * np.sum(np.maximum.accumulate(M_sol) * t_dy + M_sol * t_dw) / np.sum(t_dy + t_dw)) 126 | 127 | if __name__ == '__main__': 128 | main() -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from selection_solver_DP import selection_DP, downscale_t_dy_and_t_dw 2 | from profiler import profile_parser 3 | import tensorflow as tf 4 | import tensorflow_datasets as tfds 5 | import tensorflow_addons as tfa 6 | import numpy as np 7 | import time 8 | from tqdm import tqdm 9 | import os 10 | from utils import clear_cache_and_rec_usage 11 | 12 | 13 | def full_training( 14 | model, 15 | ds_train, 16 | ds_test, 17 | run_name, 18 | logdir, 19 | optim='sgd', 20 | lr=1e-4, 21 | weight_decay=5e-4, 22 | epochs=12, 23 | disable_random_id=False, 24 | save_model=False, 25 | save_txt=False, 26 | ): 27 | """All NN weights will be trained""" 28 | 29 | if optim == 'sgd': 30 | decay_steps = len(tfds.as_numpy(ds_train)) * epochs 31 | 32 | lr_schedule = tf.keras.experimental.CosineDecay(lr, decay_steps=decay_steps) 33 | wd_schedule = tf.keras.experimental.CosineDecay(lr * weight_decay, decay_steps=decay_steps) 34 | optimizer = tfa.optimizers.SGDW(learning_rate=lr_schedule, weight_decay=wd_schedule, momentum=0.9, nesterov=False) 35 | else: 36 | optimizer = tf.keras.optimizers.Adam(lr) 37 | 38 | loss_fn_cls = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True) 39 | 40 | if disable_random_id: 41 | runid = run_name 42 | else: 43 | runid = run_name + '_full_x' + str(np.random.randint(10000)) 44 | writer = tf.summary.create_file_writer(logdir + '/' + runid) 45 | accuracy = tf.metrics.SparseCategoricalAccuracy() 46 | cls_loss = tf.metrics.Mean() 47 | 48 | 49 | print(f"RUNID: {runid}") 50 | 51 | @tf.function 52 | def train_step(x, y): 53 | with tf.GradientTape() as tape: 54 | y_pred = model(x, training=True) 55 | loss = loss_fn_cls(y, y_pred) 56 | gradients = tape.gradient(loss, model.trainable_weights) 57 | optimizer.apply_gradients(zip(gradients, model.trainable_weights)) 58 | accuracy(y, y_pred) 59 | cls_loss(loss) 60 | 61 | @tf.function 62 | def test_step(x, y): 63 | y_pred = model(x, training=False) 64 | loss = loss_fn_cls(y, y_pred) 65 | accuracy(y, y_pred) 66 | cls_loss(loss) 67 | 68 | training_step = 0 69 | best_validation_acc = 0 70 | 71 | clear_cache_and_rec_usage() 72 | 73 | total_time_0 = 0 74 | total_time_1 = 0 75 | for epoch in range(epochs): 76 | 77 | t0 = time.time() 78 | for x, y in tqdm(ds_train, desc=f'epoch {epoch+1}/{epochs}', ascii=True): 79 | 80 | training_step += 1 81 | 82 | train_step(x, y) 83 | 84 | if training_step % 200 == 0: 85 | with writer.as_default(): 86 | c_loss, acc = cls_loss.result(), accuracy.result() 87 | tf.summary.scalar('train/accuracy', acc, training_step) 88 | tf.summary.scalar('train/classification_loss', c_loss, training_step) 89 | tf.summary.scalar('train/learnig_rate', optimizer._decayed_lr('float32'), training_step) 90 | cls_loss.reset_states() 91 | accuracy.reset_states() 92 | clear_cache_and_rec_usage() 93 | 94 | cls_loss.reset_states() 95 | accuracy.reset_states() 96 | 97 | t1 = time.time() 98 | print("per epoch time(s) excluding validation:", t1 - t0) 99 | total_time_0 += (t1 - t0) 100 | 101 | for x, y in ds_test: 102 | test_step(x, y) 103 | 104 | with writer.as_default(): 105 | tf.summary.scalar('test/classification_loss', cls_loss.result(), step=training_step) 106 | tf.summary.scalar('test/accuracy', accuracy.result(), step=training_step) 107 | 108 | if accuracy.result() > best_validation_acc: 109 | best_validation_acc = accuracy.result() 110 | if save_model: 111 | model.save_weights(os.path.join('saved_models', runid + '.tf')) 112 | print("=================================") 113 | print("acc: ", accuracy.result()) 114 | print("=================================") 115 | 116 | cls_loss.reset_states() 117 | accuracy.reset_states() 118 | 119 | t2 = time.time() 120 | print("per epoch time(s) including validation:", t2 - t0) 121 | total_time_1 += (t2 - t0) 122 | 123 | clear_cache_and_rec_usage() 124 | 125 | # print("total time excluding validation (s):", total_time_0) 126 | # print("total time including validation (s):", total_time_1) 127 | best_validation_acc = best_validation_acc.numpy() * 100 128 | total_time_0 /= 3600 129 | print('===============================================') 130 | print('Training Type: Full training') 131 | print(f"Accuracy (%): {best_validation_acc:.2f}") 132 | print(f"Time (h): {total_time_0:.2f}") 133 | print('===============================================') 134 | if save_txt: 135 | np.savetxt(logdir + '/' + runid + '.txt', np.array([total_time_0, best_validation_acc])) 136 | # sig_stop_handler(None, None) 137 | 138 | def bn_plus_bias_training( 139 | model, 140 | ds_train, 141 | ds_test, 142 | run_name, 143 | logdir, 144 | optim='sgd', 145 | lr=1e-4, 146 | weight_decay=5e-4, 147 | epochs=12, 148 | disable_random_id=False, 149 | save_model=False, 150 | save_txt=False, 151 | ): 152 | """Only train normalization, bias, and last layer weights""" 153 | if optim == 'sgd': 154 | decay_steps = len(tfds.as_numpy(ds_train)) * epochs 155 | 156 | lr_schedule = tf.keras.experimental.CosineDecay(lr, decay_steps=decay_steps) 157 | wd_schedule = tf.keras.experimental.CosineDecay(lr * weight_decay, decay_steps=decay_steps) 158 | optimizer = tfa.optimizers.SGDW(learning_rate=lr_schedule, weight_decay=wd_schedule, momentum=0.9, nesterov=False) 159 | else: 160 | optimizer = tf.keras.optimizers.Adam(lr) 161 | 162 | loss_fn_cls = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True) 163 | 164 | if disable_random_id: 165 | runid = run_name 166 | else: 167 | runid = run_name + '_bn+bias_x' + str(np.random.randint(10000)) 168 | 169 | writer = tf.summary.create_file_writer(logdir + '/' + runid) 170 | accuracy = tf.metrics.SparseCategoricalAccuracy() 171 | cls_loss = tf.metrics.Mean() 172 | 173 | bias_plus_bn_weights = list() 174 | # print('model.trainable_weights: {}'.format([x.name for x in model.trainable_weights])) 175 | for i in model.trainable_weights: 176 | tmp_name = i.name 177 | if 'gamma' in tmp_name or 'beta' in tmp_name or 'bias' in tmp_name or 'dense' in tmp_name or 'head' in tmp_name: 178 | bias_plus_bn_weights.append(i) 179 | # print('bias_plus_bn_weights: {}'.format([x.name for x in bias_plus_bn_weights])) 180 | model_trainable_weights = bias_plus_bn_weights 181 | 182 | print(f"RUNID: {runid}") 183 | 184 | @tf.function 185 | def train_step(x, y): 186 | with tf.GradientTape() as tape: 187 | y_pred = model(x, training=True) 188 | loss = loss_fn_cls(y, y_pred) 189 | gradients = tape.gradient(loss, model_trainable_weights) 190 | optimizer.apply_gradients(zip(gradients, model_trainable_weights)) 191 | accuracy(y, y_pred) 192 | cls_loss(loss) 193 | 194 | @tf.function 195 | def test_step(x, y): 196 | y_pred = model(x, training=False) 197 | loss = loss_fn_cls(y, y_pred) 198 | accuracy(y, y_pred) 199 | cls_loss(loss) 200 | 201 | training_step = 0 202 | best_validation_acc = 0 203 | 204 | clear_cache_and_rec_usage() 205 | 206 | total_time_0 = 0 207 | total_time_1 = 0 208 | for epoch in range(epochs): 209 | 210 | t0 = time.time() 211 | for x, y in tqdm(ds_train, desc=f'epoch {epoch+1}/{epochs}', ascii=True): 212 | 213 | training_step += 1 214 | 215 | 216 | train_step(x, y) 217 | 218 | if training_step % 200 == 0: 219 | with writer.as_default(): 220 | c_loss, acc = cls_loss.result(), accuracy.result() 221 | tf.summary.scalar('train/accuracy', acc, training_step) 222 | tf.summary.scalar('train/classification_loss', c_loss, training_step) 223 | tf.summary.scalar('train/learnig_rate', optimizer._decayed_lr('float32'), training_step) 224 | cls_loss.reset_states() 225 | accuracy.reset_states() 226 | clear_cache_and_rec_usage() 227 | 228 | cls_loss.reset_states() 229 | accuracy.reset_states() 230 | 231 | t1 = time.time() 232 | print("per epoch time(s) excluding validation:", t1 - t0) 233 | total_time_0 += (t1 - t0) 234 | 235 | for x, y in ds_test: 236 | test_step(x, y) 237 | 238 | with writer.as_default(): 239 | tf.summary.scalar('test/classification_loss', cls_loss.result(), step=training_step) 240 | tf.summary.scalar('test/accuracy', accuracy.result(), step=training_step) 241 | 242 | if accuracy.result() > best_validation_acc: 243 | best_validation_acc = accuracy.result() 244 | if save_model: 245 | model.save_weights(os.path.join('saved_models', runid + '.tf')) 246 | print("=================================") 247 | print("acc: ", accuracy.result()) 248 | print("=================================") 249 | 250 | cls_loss.reset_states() 251 | accuracy.reset_states() 252 | 253 | t2 = time.time() 254 | print("per epoch time(s) including validation:", t2 - t0) 255 | total_time_1 += (t2 - t0) 256 | 257 | clear_cache_and_rec_usage() 258 | 259 | # print("total time excluding validation (s):", total_time_0) 260 | # print("total time including validation (s):", total_time_1) 261 | best_validation_acc = best_validation_acc.numpy() * 100 262 | total_time_0 /= 3600 263 | print('===============================================') 264 | print('Training Type: BN+Bias') 265 | print(f"Accuracy (%): {best_validation_acc:.2f}") 266 | print(f"Time (h): {total_time_0:.2f}") 267 | print('===============================================') 268 | if save_txt: 269 | np.savetxt(logdir + '/' + runid + '.txt', np.array([total_time_0, best_validation_acc])) 270 | # sig_stop_handler(None, None) 271 | 272 | def traditional_tl_training( 273 | model, 274 | ds_train, 275 | ds_test, 276 | run_name, 277 | logdir, 278 | optim='sgd', 279 | lr=1e-4, 280 | weight_decay=5e-4, 281 | epochs=12, 282 | disable_random_id=False, 283 | save_model=False, 284 | save_txt=False, 285 | ): 286 | """Only train last layer weights""" 287 | 288 | if optim == 'sgd': 289 | decay_steps = len(tfds.as_numpy(ds_train)) * epochs 290 | 291 | lr_schedule = tf.keras.experimental.CosineDecay(lr, decay_steps=decay_steps) 292 | wd_schedule = tf.keras.experimental.CosineDecay(lr * weight_decay, decay_steps=decay_steps) 293 | optimizer = tfa.optimizers.SGDW(learning_rate=lr_schedule, weight_decay=wd_schedule, momentum=0.9, nesterov=False) 294 | else: 295 | optimizer = tf.keras.optimizers.Adam(lr) 296 | 297 | loss_fn_cls = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True) 298 | 299 | if disable_random_id: 300 | runid = run_name 301 | else: 302 | runid = run_name + '_ttl_x' + str(np.random.randint(10000)) 303 | 304 | writer = tf.summary.create_file_writer(logdir + '/' + runid) 305 | accuracy = tf.metrics.SparseCategoricalAccuracy() 306 | cls_loss = tf.metrics.Mean() 307 | 308 | dense_weights = list() 309 | # print('model.trainable_weights: {}'.format([x.name for x in model.trainable_weights])) 310 | for i in model.trainable_weights: 311 | tmp_name = i.name 312 | if 'dense' in tmp_name or 'head' in tmp_name: 313 | dense_weights.append(i) 314 | # print('bias_plus_bn_weights: {}'.format([x.name for x in dense_weights])) 315 | model_trainable_weights = dense_weights 316 | 317 | print(f"RUNID: {runid}") 318 | 319 | @tf.function 320 | def train_step(x, y): 321 | with tf.GradientTape() as tape: 322 | y_pred = model(x, training=True) 323 | loss = loss_fn_cls(y, y_pred) 324 | gradients = tape.gradient(loss, model_trainable_weights) 325 | optimizer.apply_gradients(zip(gradients, model_trainable_weights)) 326 | accuracy(y, y_pred) 327 | cls_loss(loss) 328 | 329 | @tf.function 330 | def test_step(x, y): 331 | y_pred = model(x, training=False) 332 | loss = loss_fn_cls(y, y_pred) 333 | accuracy(y, y_pred) 334 | cls_loss(loss) 335 | 336 | training_step = 0 337 | best_validation_acc = 0 338 | 339 | clear_cache_and_rec_usage() 340 | 341 | total_time_0 = 0 342 | total_time_1 = 0 343 | for epoch in range(epochs): 344 | 345 | t0 = time.time() 346 | for x, y in tqdm(ds_train, desc=f'epoch {epoch+1}/{epochs}', ascii=True): 347 | 348 | training_step += 1 349 | 350 | 351 | train_step(x, y) 352 | 353 | if training_step % 200 == 0: 354 | with writer.as_default(): 355 | c_loss, acc = cls_loss.result(), accuracy.result() 356 | tf.summary.scalar('train/accuracy', acc, training_step) 357 | tf.summary.scalar('train/classification_loss', c_loss, training_step) 358 | tf.summary.scalar('train/learnig_rate', optimizer._decayed_lr('float32'), training_step) 359 | cls_loss.reset_states() 360 | accuracy.reset_states() 361 | clear_cache_and_rec_usage() 362 | 363 | cls_loss.reset_states() 364 | accuracy.reset_states() 365 | 366 | t1 = time.time() 367 | print("per epoch time(s) excluding validation:", t1 - t0) 368 | total_time_0 += (t1 - t0) 369 | 370 | for x, y in ds_test: 371 | test_step(x, y) 372 | 373 | with writer.as_default(): 374 | tf.summary.scalar('test/classification_loss', cls_loss.result(), step=training_step) 375 | tf.summary.scalar('test/accuracy', accuracy.result(), step=training_step) 376 | 377 | if accuracy.result() > best_validation_acc: 378 | best_validation_acc = accuracy.result() 379 | if save_model: 380 | model.save_weights(os.path.join('saved_models', runid + '.tf')) 381 | print("=================================") 382 | print("acc: ", accuracy.result()) 383 | print("=================================") 384 | 385 | cls_loss.reset_states() 386 | accuracy.reset_states() 387 | 388 | t2 = time.time() 389 | print("per epoch time(s) including validation:", t2 - t0) 390 | total_time_1 += (t2 - t0) 391 | 392 | clear_cache_and_rec_usage() 393 | 394 | # print("total time excluding validation (s):", total_time_0) 395 | # print("total time including validation (s):", total_time_1) 396 | best_validation_acc = best_validation_acc.numpy() * 100 397 | total_time_0 /= 3600 398 | print('===============================================') 399 | print('Training Type: Traditional TL') 400 | print(f"Accuracy (%): {best_validation_acc:.2f}") 401 | print(f"Time (h): {total_time_0:.2f}") 402 | print('===============================================') 403 | if save_txt: 404 | np.savetxt(logdir + '/' + runid + '.txt', np.array([total_time_0, best_validation_acc])) 405 | # sig_stop_handler(None, None) 406 | 407 | def elastic_training( 408 | model, 409 | model_name, 410 | ds_train, 411 | ds_test, 412 | run_name, 413 | logdir, 414 | timing_info, 415 | optim='sgd', 416 | lr=1e-4, 417 | weight_decay=5e-4, 418 | epochs=12, 419 | interval=4, 420 | rho=0.533, 421 | disable_random_id=False, 422 | save_model=False, 423 | save_txt=False, 424 | ): 425 | """Train with ElasticTrainer""" 426 | 427 | def rho_for_backward_pass(rho): 428 | return (rho - 1/3)*3/2 429 | 430 | t_dw, t_dy = profile_parser( 431 | model, 432 | model_name, 433 | 5, 434 | 'profile_extracted/' + timing_info, 435 | draw_figure=False, 436 | ) 437 | #np.savetxt('t_dy.out', t_dy) 438 | #np.savetxt('t_dw.out', t_dw) 439 | t_dy_q, t_dw_q, disco = downscale_t_dy_and_t_dw(t_dy, t_dw, Tq=1e3) 440 | t_dy_q = np.flip(t_dy_q) 441 | t_dw_q = np.flip(t_dw_q) 442 | 443 | if optim == 'sgd': 444 | decay_steps = len(tfds.as_numpy(ds_train)) * epochs 445 | 446 | lr_schedule = tf.keras.experimental.CosineDecay(lr, decay_steps=decay_steps) 447 | wd_schedule = tf.keras.experimental.CosineDecay(lr * weight_decay, decay_steps=decay_steps) 448 | optimizer = tfa.optimizers.SGDW(learning_rate=lr_schedule, weight_decay=wd_schedule, momentum=0.9, nesterov=False) 449 | else: 450 | optimizer = tf.keras.optimizers.Adam(lr) 451 | 452 | loss_fn_cls = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True) 453 | 454 | if disable_random_id: 455 | runid = run_name 456 | else: 457 | runid = run_name + '_elastic_x' + str(np.random.randint(10000)) 458 | 459 | writer = tf.summary.create_file_writer(logdir + '/' + runid) 460 | accuracy = tf.metrics.SparseCategoricalAccuracy() 461 | cls_loss = tf.metrics.Mean() 462 | 463 | print(f"RUNID: {runid}") 464 | 465 | var_list = [] 466 | 467 | def train_step(x, y): 468 | with tf.GradientTape() as tape: 469 | y_pred = model(x, training=True) 470 | loss = loss_fn_cls(y, y_pred) 471 | gradients = tape.gradient(loss, var_list) 472 | optimizer.apply_gradients(zip(gradients, var_list)) 473 | accuracy(y, y_pred) 474 | cls_loss(loss) 475 | 476 | @tf.function 477 | def test_step(x, y): 478 | y_pred = model(x, training=False) 479 | loss = loss_fn_cls(y, y_pred) 480 | accuracy(y, y_pred) 481 | cls_loss(loss) 482 | 483 | 484 | @tf.function 485 | def compute_dw(x, y): 486 | with tf.GradientTape() as tape: 487 | y_pred_0 = model(x, training=True) 488 | loss_0 = loss_fn_cls(y, y_pred_0) 489 | grad_0 = tape.gradient(loss_0, model.trainable_weights) 490 | w_0 = [w.value() for w in model.trainable_weights] # record initial weight values 491 | optimizer.apply_gradients(zip(grad_0, model.trainable_weights)) 492 | w_1 = [w.value() for w in model.trainable_weights] # record weight values after applying optimizer 493 | dw_0 = [w_1_k - w_0_k for (w_0_k, w_1_k) in zip(w_0, w_1)] # compute weight changes 494 | with tf.GradientTape() as tape: 495 | y_pred_1 = model(x, training=True) 496 | loss_1 = loss_fn_cls(y, y_pred_1) 497 | grad_1 = tape.gradient(loss_1, model.trainable_weights) 498 | I = [tf.reduce_sum((grad_1_k * dw_0_k)) for (grad_1_k, dw_0_k) in zip(grad_1, dw_0)] 499 | I = tf.convert_to_tensor(I) 500 | I = I / tf.reduce_max(tf.abs(I)) 501 | # restore weights 502 | for k, w in enumerate(model.trainable_weights): 503 | w.assign(w_0[k]) 504 | return dw_0, I 505 | 506 | training_step = 0 507 | best_validation_acc = 0 508 | 509 | total_time_0 = 0 510 | total_time_1 = 0 511 | for epoch in range(epochs): 512 | 513 | t0 = time.time() 514 | if epoch % interval == 0: 515 | for x_probe, y_probe in ds_train.take(1): 516 | dw, I = compute_dw(x_probe, y_probe) 517 | I = -I.numpy() 518 | I = np.flip(I) 519 | #np.savetxt('importance.out', I) 520 | rho_b = rho_for_backward_pass(rho) 521 | max_importance, m = selection_DP(t_dy_q, t_dw_q, I, rho=rho_b*disco) 522 | m = np.flip(m) 523 | print("m:", m) 524 | print("max importance:", max_importance) 525 | print("%T_sel:", 100 * np.sum(np.maximum.accumulate(m) * t_dy + m * t_dw) / np.sum(t_dy + t_dw)) 526 | var_list = [] 527 | all_vars = model.trainable_weights 528 | for k, m_k in enumerate(m): 529 | if tf.equal(m_k, 1): 530 | var_list.append(all_vars[k]) 531 | train_step_cpl = tf.function(train_step) 532 | 533 | for x, y in tqdm(ds_train, desc=f'epoch {epoch+1}/{epochs}', ascii=True): 534 | 535 | training_step += 1 536 | 537 | train_step_cpl(x, y) 538 | 539 | if training_step % 200 == 0: 540 | with writer.as_default(): 541 | c_loss, acc = cls_loss.result(), accuracy.result() 542 | tf.summary.scalar('train/accuracy', acc, training_step) 543 | tf.summary.scalar('train/classification_loss', c_loss, training_step) 544 | tf.summary.scalar('train/learnig_rate', optimizer._decayed_lr('float32'), training_step) 545 | cls_loss.reset_states() 546 | accuracy.reset_states() 547 | clear_cache_and_rec_usage() 548 | 549 | 550 | cls_loss.reset_states() 551 | accuracy.reset_states() 552 | 553 | t1 = time.time() 554 | print("per epoch time(s) excluding validation:", t1 - t0) 555 | total_time_0 += (t1 - t0) 556 | 557 | for x, y in ds_test: 558 | test_step(x, y) 559 | 560 | with writer.as_default(): 561 | tf.summary.scalar('test/classification_loss', cls_loss.result(), step=training_step) 562 | tf.summary.scalar('test/accuracy', accuracy.result(), step=training_step) 563 | 564 | if accuracy.result() > best_validation_acc: 565 | best_validation_acc = accuracy.result() 566 | if save_model: 567 | model.save_weights(os.path.join('saved_models', runid + '.tf')) 568 | print("=================================") 569 | print("acc: ", accuracy.result()) 570 | print("=================================") 571 | 572 | cls_loss.reset_states() 573 | accuracy.reset_states() 574 | 575 | t2 = time.time() 576 | print("per epoch time(s) including validation:", t2 - t0) 577 | total_time_1 += (t2 - t0) 578 | 579 | # print("total time excluding validation (s):", total_time_0) 580 | # print("total time including validation (s):", total_time_1) 581 | best_validation_acc = best_validation_acc.numpy() * 100 582 | total_time_0 /= 3600 583 | print('===============================================') 584 | print('Training Type: ElasticTrainer') 585 | print(f"Accuracy (%): {best_validation_acc:.2f}") 586 | print(f"Time (h): {total_time_0:.2f}") 587 | print('===============================================') 588 | if save_txt: 589 | np.savetxt(logdir + '/' + runid + '.txt', np.array([total_time_0, best_validation_acc])) 590 | # sig_stop_handler(None, None) 591 | 592 | def elastic_training_weight_magnitude( 593 | model, 594 | model_name, 595 | ds_train, 596 | ds_test, 597 | run_name, 598 | logdir, 599 | timing_info, 600 | optim='sgd', 601 | lr=1e-4, 602 | weight_decay=5e-4, 603 | epochs=12, 604 | interval=4, 605 | rho=0.533, 606 | disable_random_id=False, 607 | save_model=False, 608 | save_txt=False, 609 | ): 610 | """Train with ElasticTrainer but use weight magnitude as importance metric""" 611 | 612 | def rho_for_backward_pass(rho): 613 | return (rho - 1/3)*3/2 614 | 615 | t_dw, t_dy = profile_parser( 616 | model, 617 | model_name, 618 | 5, 619 | 'profile_extracted/' + timing_info, 620 | draw_figure=False, 621 | ) 622 | #np.savetxt('t_dy.out', t_dy) 623 | #np.savetxt('t_dw.out', t_dw) 624 | t_dy_q, t_dw_q, disco = downscale_t_dy_and_t_dw(t_dy, t_dw, Tq=1e3) 625 | t_dy_q = np.flip(t_dy_q) 626 | t_dw_q = np.flip(t_dw_q) 627 | 628 | if optim == 'sgd': 629 | decay_steps = len(tfds.as_numpy(ds_train)) * epochs 630 | 631 | lr_schedule = tf.keras.experimental.CosineDecay(lr, decay_steps=decay_steps) 632 | wd_schedule = tf.keras.experimental.CosineDecay(lr * weight_decay, decay_steps=decay_steps) 633 | optimizer = tfa.optimizers.SGDW(learning_rate=lr_schedule, weight_decay=wd_schedule, momentum=0.9, nesterov=False) 634 | else: 635 | optimizer = tf.keras.optimizers.Adam(lr) 636 | 637 | loss_fn_cls = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True) 638 | 639 | if disable_random_id: 640 | runid = run_name 641 | else: 642 | runid = run_name + '_elastic_mag_x' + str(np.random.randint(10000)) 643 | 644 | writer = tf.summary.create_file_writer(logdir + '/' + runid) 645 | accuracy = tf.metrics.SparseCategoricalAccuracy() 646 | cls_loss = tf.metrics.Mean() 647 | 648 | print(f"RUNID: {runid}") 649 | 650 | var_list = [] 651 | 652 | def train_step(x, y): 653 | with tf.GradientTape() as tape: 654 | y_pred = model(x, training=True) 655 | loss = loss_fn_cls(y, y_pred) 656 | gradients = tape.gradient(loss, var_list) 657 | optimizer.apply_gradients(zip(gradients, var_list)) 658 | accuracy(y, y_pred) 659 | cls_loss(loss) 660 | 661 | @tf.function 662 | def test_step(x, y): 663 | y_pred = model(x, training=False) 664 | loss = loss_fn_cls(y, y_pred) 665 | accuracy(y, y_pred) 666 | cls_loss(loss) 667 | 668 | training_step = 0 669 | best_validation_acc = 0 670 | 671 | total_time_0 = 0 672 | total_time_1 = 0 673 | for epoch in range(epochs): 674 | 675 | t0 = time.time() 676 | if epoch % interval == 0: 677 | I = np.array([tf.reduce_sum(tf.abs(w.value())) for w in model.trainable_weights]) 678 | I = np.flip(I) 679 | #np.savetxt('importance.out', I) 680 | rho_b = rho_for_backward_pass(rho) 681 | max_importance, m = selection_DP(t_dy_q, t_dw_q, I, rho=rho_b*disco) 682 | m = np.flip(m) 683 | print("m:", m) 684 | print("max importance:", max_importance) 685 | print("%T_sel:", 100 * np.sum(np.maximum.accumulate(m) * t_dy + m * t_dw) / np.sum(t_dy + t_dw)) 686 | var_list = [] 687 | all_vars = model.trainable_weights 688 | for k, m_k in enumerate(m): 689 | if tf.equal(m_k, 1): 690 | var_list.append(all_vars[k]) 691 | train_step_cpl = tf.function(train_step) 692 | 693 | for x, y in tqdm(ds_train, desc=f'epoch {epoch+1}/{epochs}', ascii=True): 694 | 695 | training_step += 1 696 | 697 | train_step_cpl(x, y) 698 | 699 | if training_step % 200 == 0: 700 | with writer.as_default(): 701 | c_loss, acc = cls_loss.result(), accuracy.result() 702 | tf.summary.scalar('train/accuracy', acc, training_step) 703 | tf.summary.scalar('train/classification_loss', c_loss, training_step) 704 | tf.summary.scalar('train/learnig_rate', optimizer._decayed_lr('float32'), training_step) 705 | cls_loss.reset_states() 706 | accuracy.reset_states() 707 | clear_cache_and_rec_usage() 708 | 709 | 710 | cls_loss.reset_states() 711 | accuracy.reset_states() 712 | 713 | t1 = time.time() 714 | print("per epoch time(s) excluding validation:", t1 - t0) 715 | total_time_0 += (t1 - t0) 716 | 717 | for x, y in ds_test: 718 | test_step(x, y) 719 | 720 | with writer.as_default(): 721 | tf.summary.scalar('test/classification_loss', cls_loss.result(), step=training_step) 722 | tf.summary.scalar('test/accuracy', accuracy.result(), step=training_step) 723 | 724 | if accuracy.result() > best_validation_acc: 725 | best_validation_acc = accuracy.result() 726 | if save_model: 727 | model.save_weights(os.path.join('saved_models', runid + '.tf')) 728 | print("=================================") 729 | print("acc: ", accuracy.result()) 730 | print("=================================") 731 | 732 | cls_loss.reset_states() 733 | accuracy.reset_states() 734 | 735 | t2 = time.time() 736 | print("per epoch time(s) including validation:", t2 - t0) 737 | total_time_1 += (t2 - t0) 738 | 739 | # print("total time excluding validation (s):", total_time_0) 740 | # print("total time including validation (s):", total_time_1) 741 | best_validation_acc = best_validation_acc.numpy() * 100 742 | total_time_0 /= 3600 743 | print('===============================================') 744 | print('Training Type: ElasticTrainer (W)') 745 | print(f"Accuracy (%): {best_validation_acc:.2f}") 746 | print(f"Time (h): {total_time_0:.2f}") 747 | print('===============================================') 748 | if save_txt: 749 | np.savetxt(logdir + '/' + runid + '.txt', np.array([total_time_0, best_validation_acc])) 750 | # sig_stop_handler(None, None) 751 | 752 | def elastic_training_grad_magnitude( 753 | model, 754 | model_name, 755 | ds_train, 756 | ds_test, 757 | run_name, 758 | logdir, 759 | timing_info, 760 | optim='sgd', 761 | lr=1e-4, 762 | weight_decay=5e-4, 763 | epochs=12, 764 | interval=4, 765 | rho=0.4, 766 | disable_random_id=False, 767 | save_model=False, 768 | save_txt=False, 769 | ): 770 | """Train with ElasticTrainer but use gradient magnitude as importance metric""" 771 | 772 | def rho_for_backward_pass(rho): 773 | return (rho - 1/3)*3/2 774 | 775 | t_dw, t_dy = profile_parser( 776 | model, 777 | model_name, 778 | 5, 779 | 'profile_extracted/' + timing_info, 780 | draw_figure=False, 781 | ) 782 | #np.savetxt('t_dy.out', t_dy) 783 | #np.savetxt('t_dw.out', t_dw) 784 | t_dy_q, t_dw_q, disco = downscale_t_dy_and_t_dw(t_dy, t_dw, Tq=1e3) 785 | t_dy_q = np.flip(t_dy_q) 786 | t_dw_q = np.flip(t_dw_q) 787 | 788 | if optim == 'sgd': 789 | decay_steps = len(tfds.as_numpy(ds_train)) * epochs 790 | 791 | lr_schedule = tf.keras.experimental.CosineDecay(lr, decay_steps=decay_steps) 792 | wd_schedule = tf.keras.experimental.CosineDecay(lr * weight_decay, decay_steps=decay_steps) 793 | optimizer = tfa.optimizers.SGDW(learning_rate=lr_schedule, weight_decay=wd_schedule, momentum=0.9, nesterov=False) 794 | else: 795 | optimizer = tf.keras.optimizers.Adam(lr) 796 | 797 | loss_fn_cls = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True) 798 | 799 | if disable_random_id: 800 | runid = run_name 801 | else: 802 | runid = run_name + '_elastic_grad_x' + str(np.random.randint(10000)) 803 | 804 | writer = tf.summary.create_file_writer(logdir + '/' + runid) 805 | accuracy = tf.metrics.SparseCategoricalAccuracy() 806 | cls_loss = tf.metrics.Mean() 807 | 808 | print(f"RUNID: {runid}") 809 | 810 | var_list = [] 811 | 812 | def train_step(x, y): 813 | with tf.GradientTape() as tape: 814 | y_pred = model(x, training=True) 815 | loss = loss_fn_cls(y, y_pred) 816 | gradients = tape.gradient(loss, var_list) 817 | optimizer.apply_gradients(zip(gradients, var_list)) 818 | accuracy(y, y_pred) 819 | cls_loss(loss) 820 | 821 | @tf.function 822 | def test_step(x, y): 823 | y_pred = model(x, training=False) 824 | loss = loss_fn_cls(y, y_pred) 825 | accuracy(y, y_pred) 826 | cls_loss(loss) 827 | 828 | 829 | @tf.function 830 | def compute_dw(x, y): 831 | with tf.GradientTape() as tape: 832 | y_pred_0 = model(x, training=True) 833 | loss_0 = loss_fn_cls(y, y_pred_0) 834 | grad_0 = tape.gradient(loss_0, model.trainable_weights) 835 | I = [tf.reduce_sum(tf.abs(grad_0_k)) for grad_0_k in grad_0] 836 | I = tf.convert_to_tensor(I) 837 | return I 838 | 839 | training_step = 0 840 | best_validation_acc = 0 841 | 842 | total_time_0 = 0 843 | total_time_1 = 0 844 | for epoch in range(epochs): 845 | 846 | t0 = time.time() 847 | if epoch % interval == 0: 848 | for x_probe, y_probe in ds_train.take(1): 849 | I = compute_dw(x_probe, y_probe) 850 | I = I.numpy() 851 | I = np.flip(I) 852 | #np.savetxt('importance.out', I) 853 | rho_b = rho_for_backward_pass(rho) 854 | max_importance, m = selection_DP(t_dy_q, t_dw_q, I, rho=rho_b*disco) 855 | m = np.flip(m) 856 | print("m:", m) 857 | print("max importance:", max_importance) 858 | print("%T_sel:", 100 * np.sum(np.maximum.accumulate(m) * t_dy + m * t_dw) / np.sum(t_dy + t_dw)) 859 | var_list = [] 860 | all_vars = model.trainable_weights 861 | for k, m_k in enumerate(m): 862 | if tf.equal(m_k, 1): 863 | var_list.append(all_vars[k]) 864 | train_step_cpl = tf.function(train_step) 865 | 866 | for x, y in tqdm(ds_train, desc=f'epoch {epoch+1}/{epochs}', ascii=True): 867 | 868 | training_step += 1 869 | 870 | train_step_cpl(x, y) 871 | 872 | if training_step % 200 == 0: 873 | with writer.as_default(): 874 | c_loss, acc = cls_loss.result(), accuracy.result() 875 | tf.summary.scalar('train/accuracy', acc, training_step) 876 | tf.summary.scalar('train/classification_loss', c_loss, training_step) 877 | tf.summary.scalar('train/learnig_rate', optimizer._decayed_lr('float32'), training_step) 878 | cls_loss.reset_states() 879 | accuracy.reset_states() 880 | clear_cache_and_rec_usage() 881 | 882 | 883 | cls_loss.reset_states() 884 | accuracy.reset_states() 885 | 886 | t1 = time.time() 887 | print("per epoch time(s) excluding validation:", t1 - t0) 888 | total_time_0 += (t1 - t0) 889 | 890 | for x, y in ds_test: 891 | test_step(x, y) 892 | 893 | with writer.as_default(): 894 | tf.summary.scalar('test/classification_loss', cls_loss.result(), step=training_step) 895 | tf.summary.scalar('test/accuracy', accuracy.result(), step=training_step) 896 | 897 | if accuracy.result() > best_validation_acc: 898 | best_validation_acc = accuracy.result() 899 | if save_model: 900 | model.save_weights(os.path.join('saved_models', runid + '.tf')) 901 | print("=================================") 902 | print("acc: ", accuracy.result()) 903 | print("=================================") 904 | 905 | cls_loss.reset_states() 906 | accuracy.reset_states() 907 | 908 | t2 = time.time() 909 | print("per epoch time(s) including validation:", t2 - t0) 910 | total_time_1 += (t2 - t0) 911 | 912 | # print("total time excluding validation (s):", total_time_0) 913 | # print("total time including validation (s):", total_time_1) 914 | best_validation_acc = best_validation_acc.numpy() * 100 915 | total_time_0 /= 3600 916 | print('===============================================') 917 | print('Training Type: ElasticTrainer (G)') 918 | print(f"Accuracy (%): {best_validation_acc:.2f}") 919 | print(f"Time (h): {total_time_0:.2f}") 920 | print('===============================================') 921 | if save_txt: 922 | np.savetxt(logdir + '/' + runid + '.txt', np.array([total_time_0, best_validation_acc])) 923 | # sig_stop_handler(None, None) 924 | 925 | def prune_training( 926 | model, 927 | ds_train, 928 | ds_test, 929 | run_name, 930 | logdir, 931 | optim='sgd', 932 | lr=1e-4, 933 | weight_decay=5e-4, 934 | epochs=12, 935 | disable_random_id=False, 936 | save_model=False, 937 | save_txt=False, 938 | ): 939 | """All NN weights will be trained""" 940 | 941 | if optim == 'sgd': 942 | decay_steps = len(tfds.as_numpy(ds_train)) * epochs 943 | 944 | lr_schedule = tf.keras.experimental.CosineDecay(lr, decay_steps=decay_steps) 945 | wd_schedule = tf.keras.experimental.CosineDecay(lr * weight_decay, decay_steps=decay_steps) 946 | optimizer = tfa.optimizers.SGDW(learning_rate=lr_schedule, weight_decay=wd_schedule, momentum=0.9, nesterov=False) 947 | else: 948 | optimizer = tf.keras.optimizers.Adam(lr) 949 | 950 | loss_fn_cls = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True) 951 | 952 | if disable_random_id: 953 | runid = run_name 954 | else: 955 | runid = run_name + '_prunetrain_x' + str(np.random.randint(10000)) 956 | 957 | writer = tf.summary.create_file_writer(logdir + '/' + runid) 958 | accuracy = tf.metrics.SparseCategoricalAccuracy() 959 | cls_loss = tf.metrics.Mean() 960 | 961 | print(f"RUNID: {runid}") 962 | 963 | kernel_weights = [] 964 | for w in model.trainable_weights: 965 | if 'kernel' in w.name: 966 | kernel_weights.append(w) 967 | 968 | @tf.function 969 | def train_step(x, y): 970 | with tf.GradientTape() as tape: 971 | y_pred = model(x, training=True) 972 | L1_penalty = 1e-4 * tf.reduce_sum([tf.reduce_sum(tf.abs(w)) for w in kernel_weights]) 973 | loss = loss_fn_cls(y, y_pred) + L1_penalty 974 | gradients = tape.gradient(loss, model.trainable_weights) 975 | optimizer.apply_gradients(zip(gradients, model.trainable_weights)) 976 | accuracy(y, y_pred) 977 | cls_loss(loss) 978 | 979 | @tf.function 980 | def test_step(x, y): 981 | y_pred = model(x, training=False) 982 | loss = loss_fn_cls(y, y_pred) 983 | accuracy(y, y_pred) 984 | cls_loss(loss) 985 | 986 | training_step = 0 987 | best_validation_acc = 0 988 | 989 | clear_cache_and_rec_usage() 990 | 991 | total_time_0 = 0 992 | total_time_1 = 0 993 | for epoch in range(epochs): 994 | 995 | t0 = time.time() 996 | for x, y in tqdm(ds_train, desc=f'epoch {epoch+1}/{epochs}', ascii=True): 997 | 998 | training_step += 1 999 | 1000 | train_step(x, y) 1001 | 1002 | if training_step % 200 == 0: 1003 | with writer.as_default(): 1004 | c_loss, acc = cls_loss.result(), accuracy.result() 1005 | tf.summary.scalar('train/accuracy', acc, training_step) 1006 | tf.summary.scalar('train/classification_loss', c_loss, training_step) 1007 | tf.summary.scalar('train/learnig_rate', optimizer._decayed_lr('float32'), training_step) 1008 | cls_loss.reset_states() 1009 | accuracy.reset_states() 1010 | clear_cache_and_rec_usage() 1011 | 1012 | cls_loss.reset_states() 1013 | accuracy.reset_states() 1014 | 1015 | t1 = time.time() 1016 | print("per epoch time(s) excluding validation:", t1 - t0) 1017 | total_time_0 += (t1 - t0) 1018 | 1019 | for x, y in ds_test: 1020 | test_step(x, y) 1021 | 1022 | with writer.as_default(): 1023 | tf.summary.scalar('test/classification_loss', cls_loss.result(), step=training_step) 1024 | tf.summary.scalar('test/accuracy', accuracy.result(), step=training_step) 1025 | 1026 | if accuracy.result() > best_validation_acc: 1027 | best_validation_acc = accuracy.result() 1028 | if save_model: 1029 | model.save_weights(os.path.join('saved_models', runid + '.tf')) 1030 | print("=================================") 1031 | print("acc: ", accuracy.result()) 1032 | print("=================================") 1033 | 1034 | cls_loss.reset_states() 1035 | accuracy.reset_states() 1036 | 1037 | t2 = time.time() 1038 | print("per epoch time(s) including validation:", t2 - t0) 1039 | total_time_1 += (t2 - t0) 1040 | 1041 | clear_cache_and_rec_usage() 1042 | 1043 | # print("total time excluding validation (s):", total_time_0) 1044 | # print("total time including validation (s):", total_time_1) 1045 | best_validation_acc = best_validation_acc.numpy() * 100 1046 | total_time_0 /= 3600 1047 | print('===============================================') 1048 | print('Training Type: PruneTrain') 1049 | print(f"Accuracy (%): {best_validation_acc:.2f}") 1050 | print(f"Time (h): {total_time_0:.2f}") 1051 | print('===============================================') 1052 | if save_txt: 1053 | np.savetxt(logdir + '/' + runid + '.txt', np.array([total_time_0, best_validation_acc])) 1054 | # sig_stop_handler(None, None) 1055 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import tensorflow_datasets as tfds 3 | import tensorflow_addons as tfa 4 | from vit_keras import vit 5 | from vit_utils import vit_b16 6 | import resource 7 | import gc 8 | from subprocess import Popen, PIPE 9 | from threading import Timer 10 | import sys 11 | import os 12 | 13 | 14 | def my_bool(s): 15 | return s != 'False' 16 | 17 | class RepeatTimer(Timer): 18 | def run(self): 19 | while not self.finished.wait(self.interval): 20 | self.function(*self.args, **self.kwargs) 21 | 22 | def _clear_mem_cache(): 23 | # os.system('/bin/bash -c "sync ; echo 1 > /proc/sys/vm/drop_caches ; "') 24 | # os.system('/bin/bash -c "sync ; echo 2 > /proc/sys/vm/drop_caches ; "') 25 | # os.system('/bin/bash -c "sync ; echo 3 > /proc/sys/vm/drop_caches ; "') 26 | return 27 | 28 | def _print_mem_free(): 29 | process_free = Popen(["free"], stdout=PIPE) 30 | (output, err) = process_free.communicate() 31 | exit_code = process_free.wait() 32 | output_string = output.decode('UTF-8') 33 | # output_file.write(str(time.time())) 34 | # output_file.write(output_string) 35 | # output_file.write('\n') 36 | # output_file.flush() 37 | 38 | def clear_cache_and_rec_usage(): 39 | # NOOP 40 | return 41 | 42 | def record_once(): 43 | # _clear_mem_cache() 44 | gc.collect() 45 | # _print_mem_free() 46 | 47 | 48 | # timer = RepeatTimer(15, record_once) 49 | # timer.start() 50 | 51 | def sig_stop_handler(sig, frame): 52 | global timer 53 | # timer.cancel() 54 | # sys.exit(0) 55 | os.abort() 56 | 57 | # signal.signal(signal.SIGINT, sig_stop_handler) 58 | # signal.signal(signal.SIGTERM, sig_stop_handler) 59 | 60 | 61 | ## ENDOF: record mem info ################################################ 62 | 63 | def port_pretrained_models( 64 | model_type='resnet50', 65 | input_shape=(224, 224, 3), 66 | num_classes=1000, 67 | ): 68 | """ 69 | This function loads the NN model for training 70 | 71 | Args: 72 | model_type (str, optional): type of NN model. Defaults to 'resnet50'. 73 | input_shape (tuple, optional): NN input shape excluding batch dim. Defaults to (224, 224, 3). 74 | num_classes (int, optional): number of classes of the classification task. Defaults to 1000. 75 | 76 | Raises: 77 | NotImplementedError: The requested model is not implemented 78 | 79 | Returns: 80 | tf.keras.Model: The requested NN model 81 | """ 82 | 83 | if model_type == 'mobilenetv2': 84 | preprocess_input = tf.keras.applications.mobilenet_v2.preprocess_input 85 | base_model = tf.keras.applications.MobileNetV2(input_shape=input_shape, 86 | include_top=False, 87 | weights='imagenet') 88 | base_model.trainable = True 89 | data_augmentation = tf.keras.Sequential([ 90 | tf.keras.layers.RandomFlip('horizontal'), 91 | tf.keras.layers.RandomRotation(0.2), 92 | ]) 93 | global_average_layer = tf.keras.layers.GlobalAveragePooling2D() 94 | prediction_layer = tf.keras.layers.Dense(num_classes) 95 | inputs = tf.keras.Input(shape=input_shape) 96 | x = data_augmentation(inputs) 97 | x = preprocess_input(x) 98 | x = base_model(x, training=False) 99 | x = global_average_layer(x) 100 | x = tf.keras.layers.Dropout(0.2)(x) 101 | outputs = prediction_layer(x) 102 | model = tf.keras.Model(inputs, outputs) 103 | 104 | elif model_type == 'resnet50': 105 | preprocess_input = tf.keras.applications.resnet.preprocess_input 106 | base_model = tf.keras.applications.ResNet50(input_shape=input_shape, 107 | include_top=False, 108 | weights='imagenet') 109 | base_model.trainable = True 110 | data_augmentation = tf.keras.Sequential([ 111 | tf.keras.layers.RandomFlip('horizontal'), 112 | tf.keras.layers.RandomRotation(0.2), 113 | ]) 114 | global_average_layer = tf.keras.layers.GlobalAveragePooling2D() 115 | prediction_layer = tf.keras.layers.Dense(num_classes) 116 | inputs = tf.keras.Input(shape=input_shape) 117 | x = data_augmentation(inputs) 118 | x = preprocess_input(x) 119 | x = base_model(x, training=False) 120 | x = global_average_layer(x) 121 | x = tf.keras.layers.Dropout(0.2)(x) 122 | outputs = prediction_layer(x) 123 | model = tf.keras.Model(inputs, outputs) 124 | 125 | elif model_type == 'vgg16': 126 | preprocess_input = tf.keras.applications.vgg16.preprocess_input 127 | base_model = tf.keras.applications.VGG16(input_shape=input_shape, 128 | include_top=False, 129 | weights='imagenet') 130 | base_model.trainable = True 131 | data_augmentation = tf.keras.Sequential([ 132 | tf.keras.layers.RandomFlip('horizontal'), 133 | tf.keras.layers.RandomRotation(0.2), 134 | ]) 135 | global_average_layer = tf.keras.layers.GlobalAveragePooling2D() 136 | prediction_layer = tf.keras.layers.Dense(num_classes) 137 | inputs = tf.keras.Input(shape=input_shape) 138 | x = data_augmentation(inputs) 139 | x = preprocess_input(x) 140 | x = base_model(x, training=False) 141 | x = global_average_layer(x) 142 | x = tf.keras.layers.Dropout(0.2)(x) 143 | outputs = prediction_layer(x) 144 | model = tf.keras.Model(inputs, outputs) 145 | 146 | elif model_type == 'vit': 147 | # base_model = vit.vit_b16( 148 | # image_size=input_shape[0], 149 | # pretrained=True, 150 | # include_top=True, 151 | # pretrained_top=False, 152 | # weights='imagenet21k+imagenet2012', 153 | # classes=num_classes, 154 | # ) 155 | base_model = vit_b16( 156 | image_size=input_shape[0], 157 | pretrained=True, 158 | include_top=True, 159 | pretrained_top=False, 160 | weights='imagenet21k+imagenet2012', 161 | classes=num_classes, 162 | ) 163 | base_model.trainable = True 164 | # base_model.layers[4].layers[:-1] 165 | data_augmentation = tf.keras.Sequential([ 166 | tf.keras.layers.RandomFlip('horizontal'), 167 | tf.keras.layers.RandomRotation(0.2), 168 | ]) 169 | 170 | inputs = tf.keras.Input(shape=input_shape) 171 | x = data_augmentation(inputs) 172 | x = vit.preprocess_inputs(x) 173 | outputs = base_model(x, training=False) 174 | model = tf.keras.Model(inputs, outputs) 175 | 176 | else: 177 | raise NotImplementedError("This model has not been implemented yet") 178 | 179 | return model 180 | 181 | 182 | def port_datasets( 183 | dataset_name, 184 | input_shape, 185 | batch_size, 186 | ): 187 | """ 188 | This function loads the train and test splits of the requested dataset, and 189 | creates input pipelines for training. 190 | 191 | Args: 192 | dataset_name (str): name of the dataset 193 | input_shape (tuple): NN input shape excluding batch dim 194 | batch_size (int): batch size of training split, 195 | default batch size for testing split is batch_size*2 196 | 197 | Raises: 198 | NotImplementedError: The requested dataset is not implemented 199 | 200 | Returns: 201 | Train and test splits of the request dataset 202 | """ 203 | 204 | # maximize number limit of opened files 205 | low, high = resource.getrlimit(resource.RLIMIT_NOFILE) 206 | resource.setrlimit(resource.RLIMIT_NOFILE, (high, high)) 207 | 208 | def prep(x, y): 209 | x = tf.image.resize(x, [input_shape[0], input_shape[1]]) 210 | return x, y 211 | 212 | 213 | if dataset_name == 'caltech_birds2011': 214 | ds = tfds.load('caltech_birds2011', as_supervised=True) # 200 classes 215 | ds_train = ds['train'].map(prep, num_parallel_calls=tf.data.AUTOTUNE)\ 216 | .batch(batch_size)\ 217 | .prefetch(buffer_size=tf.data.AUTOTUNE) 218 | ds_test = ds['test'].map(prep, num_parallel_calls=tf.data.AUTOTUNE)\ 219 | .batch(batch_size*2)\ 220 | .prefetch(buffer_size=tf.data.AUTOTUNE) 221 | 222 | elif dataset_name == 'stanford_dogs': 223 | ds = tfds.load('stanford_dogs', as_supervised=True) # 120 classes 224 | ds_train = ds['train'].map(prep, num_parallel_calls=tf.data.AUTOTUNE)\ 225 | .batch(batch_size)\ 226 | .prefetch(buffer_size=tf.data.AUTOTUNE) 227 | ds_test = ds['test'].map(prep, num_parallel_calls=tf.data.AUTOTUNE)\ 228 | .batch(batch_size*2)\ 229 | .prefetch(buffer_size=tf.data.AUTOTUNE) 230 | 231 | elif dataset_name == 'oxford_iiit_pet': 232 | ds = tfds.load('oxford_iiit_pet', as_supervised=True) # 37 classes 233 | ds_train = ds['train'].map(prep, num_parallel_calls=tf.data.AUTOTUNE)\ 234 | .batch(batch_size)\ 235 | .prefetch(buffer_size=tf.data.AUTOTUNE) 236 | 237 | ds_test = ds['test'].map(prep, num_parallel_calls=tf.data.AUTOTUNE)\ 238 | .batch(batch_size*2)\ 239 | .prefetch(buffer_size=tf.data.AUTOTUNE) 240 | else: 241 | raise NotImplementedError("This dataset has not been implemented yet") 242 | 243 | return ds_train, ds_test 244 | -------------------------------------------------------------------------------- /vit_utils.py: -------------------------------------------------------------------------------- 1 | # This script is adapted from https://github.com/faustomorales/vit-keras 2 | # The implementation of TransformerBlock is refined so that the tensor ordering is consistent 3 | # with operation ordering. 4 | import tensorflow as tf 5 | import tensorflow_addons as tfa 6 | import typing 7 | import warnings 8 | import typing_extensions as tx 9 | import typing 10 | import warnings 11 | import numpy as np 12 | import scipy as sp 13 | 14 | 15 | # @tf.keras.utils.register_keras_serializable() 16 | class ClassToken(tf.keras.layers.Layer): 17 | """Append a class token to an input layer.""" 18 | 19 | def build(self, input_shape): 20 | cls_init = tf.zeros_initializer() 21 | self.hidden_size = input_shape[-1] 22 | self.cls = tf.Variable( 23 | name="cls", 24 | initial_value=cls_init(shape=(1, 1, self.hidden_size), dtype="float32"), 25 | trainable=True, 26 | ) 27 | 28 | def call(self, inputs): 29 | batch_size = tf.shape(inputs)[0] 30 | cls_broadcasted = tf.cast( 31 | tf.broadcast_to(self.cls, [batch_size, 1, self.hidden_size]), 32 | dtype=inputs.dtype, 33 | ) 34 | return tf.concat([cls_broadcasted, inputs], 1) 35 | 36 | def get_config(self): 37 | config = super().get_config() 38 | return config 39 | 40 | @classmethod 41 | def from_config(cls, config): 42 | return cls(**config) 43 | 44 | 45 | # @tf.keras.utils.register_keras_serializable() 46 | class AddPositionEmbs(tf.keras.layers.Layer): 47 | """Adds (optionally learned) positional embeddings to the inputs.""" 48 | 49 | def build(self, input_shape): 50 | assert ( 51 | len(input_shape) == 3 52 | ), f"Number of dimensions should be 3, got {len(input_shape)}" 53 | self.pe = tf.Variable( 54 | name="pos_embedding", 55 | initial_value=tf.random_normal_initializer(stddev=0.06)( 56 | shape=(1, input_shape[1], input_shape[2]) 57 | ), 58 | dtype="float32", 59 | trainable=True, 60 | ) 61 | 62 | def call(self, inputs): 63 | return inputs + tf.cast(self.pe, dtype=inputs.dtype) 64 | 65 | def get_config(self): 66 | config = super().get_config() 67 | return config 68 | 69 | @classmethod 70 | def from_config(cls, config): 71 | return cls(**config) 72 | 73 | 74 | # @tf.keras.utils.register_keras_serializable() 75 | class MultiHeadSelfAttention(tf.keras.layers.Layer): 76 | def __init__(self, *args, num_heads, **kwargs): 77 | super().__init__(*args, **kwargs) 78 | self.num_heads = num_heads 79 | 80 | def build(self, input_shape): 81 | hidden_size = input_shape[-1] 82 | num_heads = self.num_heads 83 | if hidden_size % num_heads != 0: 84 | raise ValueError( 85 | f"embedding dimension = {hidden_size} should be divisible by number of heads = {num_heads}" 86 | ) 87 | self.hidden_size = hidden_size 88 | self.projection_dim = hidden_size // num_heads 89 | self.query_dense = tf.keras.layers.Dense(hidden_size, name="query") 90 | self.key_dense = tf.keras.layers.Dense(hidden_size, name="key") 91 | self.value_dense = tf.keras.layers.Dense(hidden_size, name="value") 92 | self.combine_heads = tf.keras.layers.Dense(hidden_size, name="out") 93 | 94 | # pylint: disable=no-self-use 95 | def attention(self, query, key, value): 96 | score = tf.matmul(query, key, transpose_b=True) 97 | dim_key = tf.cast(tf.shape(key)[-1], score.dtype) 98 | scaled_score = score / tf.math.sqrt(dim_key) 99 | weights = tf.nn.softmax(scaled_score, axis=-1) 100 | output = tf.matmul(weights, value) 101 | return output, weights 102 | 103 | def separate_heads(self, x, batch_size): 104 | x = tf.reshape(x, (batch_size, -1, self.num_heads, self.projection_dim)) 105 | return tf.transpose(x, perm=[0, 2, 1, 3]) 106 | 107 | def call(self, inputs): 108 | batch_size = tf.shape(inputs)[0] 109 | query = self.query_dense(inputs) 110 | key = self.key_dense(inputs) 111 | value = self.value_dense(inputs) 112 | query = self.separate_heads(query, batch_size) 113 | key = self.separate_heads(key, batch_size) 114 | value = self.separate_heads(value, batch_size) 115 | 116 | attention, weights = self.attention(query, key, value) 117 | attention = tf.transpose(attention, perm=[0, 2, 1, 3]) 118 | concat_attention = tf.reshape(attention, (batch_size, -1, self.hidden_size)) 119 | output = self.combine_heads(concat_attention) 120 | return output, weights 121 | 122 | def get_config(self): 123 | config = super().get_config() 124 | config.update({"num_heads": self.num_heads}) 125 | return config 126 | 127 | @classmethod 128 | def from_config(cls, config): 129 | return cls(**config) 130 | 131 | 132 | # pylint: disable=too-many-instance-attributes 133 | # @tf.keras.utils.register_keras_serializable() 134 | class TransformerBlock(tf.keras.layers.Layer): 135 | """Implements a Transformer block. 136 | This implementation makes the order of trainables 137 | consistent with the execution order of operations. 138 | """ 139 | 140 | def __init__(self, *args, num_heads, mlp_dim, dropout, **kwargs): 141 | super().__init__(*args, **kwargs) 142 | self.num_heads = num_heads 143 | self.mlp_dim = mlp_dim 144 | self.dropout = dropout 145 | 146 | def build(self, input_shape): 147 | 148 | self.layernorm1 = tf.keras.layers.LayerNormalization( 149 | epsilon=1e-6, name="LayerNorm_0" 150 | ) 151 | 152 | self.att = MultiHeadSelfAttention( 153 | num_heads=self.num_heads, 154 | name="MultiHeadDotProductAttention_1", 155 | ) 156 | 157 | self.dropout_layer = tf.keras.layers.Dropout(self.dropout) 158 | 159 | self.layernorm2 = tf.keras.layers.LayerNormalization( 160 | epsilon=1e-6, name="LayerNorm_2" 161 | ) 162 | 163 | self.mlpblock = tf.keras.Sequential( 164 | [ 165 | tf.keras.layers.Dense( 166 | self.mlp_dim, 167 | activation="linear", 168 | name=f"{self.name}/Dense_0", 169 | ), 170 | tf.keras.layers.Lambda( 171 | lambda x: tf.keras.activations.gelu(x, approximate=False) 172 | ) 173 | if hasattr(tf.keras.activations, "gelu") 174 | else tf.keras.layers.Lambda( 175 | lambda x: tfa.activations.gelu(x, approximate=False) 176 | ), 177 | tf.keras.layers.Dropout(self.dropout), 178 | tf.keras.layers.Dense(input_shape[-1], name=f"{self.name}/Dense_1"), 179 | tf.keras.layers.Dropout(self.dropout), 180 | ], 181 | name="MlpBlock_3", 182 | ) 183 | 184 | def call(self, inputs, training): 185 | x = self.layernorm1(inputs) 186 | x, weights = self.att(x) 187 | x = self.dropout_layer(x, training=training) 188 | x = x + inputs 189 | y = self.layernorm2(x) 190 | y = self.mlpblock(y) 191 | return x + y, weights 192 | 193 | def get_config(self): 194 | config = super().get_config() 195 | config.update( 196 | { 197 | "num_heads": self.num_heads, 198 | "mlp_dim": self.mlp_dim, 199 | "dropout": self.dropout, 200 | } 201 | ) 202 | return config 203 | 204 | @classmethod 205 | def from_config(cls, config): 206 | return cls(**config) 207 | 208 | 209 | ###################################################################### 210 | 211 | def apply_embedding_weights(target_layer, source_weights, num_x_patches, num_y_patches): 212 | """Apply embedding weights to a target layer. 213 | 214 | Args: 215 | target_layer: The target layer to which weights will 216 | be applied. 217 | source_weights: The source weights, which will be 218 | resized as necessary. 219 | num_x_patches: Number of patches in width of image. 220 | num_y_patches: Number of patches in height of image. 221 | """ 222 | expected_shape = target_layer.weights[0].shape 223 | if expected_shape != source_weights.shape: 224 | token, grid = source_weights[0, :1], source_weights[0, 1:] 225 | sin = int(np.sqrt(grid.shape[0])) 226 | sout_x = num_x_patches 227 | sout_y = num_y_patches 228 | warnings.warn( 229 | "Resizing position embeddings from " f"{sin}, {sin} to {sout_x}, {sout_y}", 230 | UserWarning, 231 | ) 232 | zoom = (sout_y / sin, sout_x / sin, 1) 233 | grid = sp.ndimage.zoom(grid.reshape(sin, sin, -1), zoom, order=1).reshape( 234 | sout_x * sout_y, -1 235 | ) 236 | source_weights = np.concatenate([token, grid], axis=0)[np.newaxis] 237 | target_layer.set_weights([source_weights]) 238 | 239 | 240 | def load_weights_numpy( 241 | model, params_path, pretrained_top, num_x_patches, num_y_patches 242 | ): 243 | """Load weights saved using Flax as a numpy array. 244 | 245 | Args: 246 | model: A Keras model to load the weights into. 247 | params_path: Filepath to a numpy archive. 248 | pretrained_top: Whether to load the top layer weights. 249 | num_x_patches: Number of patches in width of image. 250 | num_y_patches: Number of patches in height of image. 251 | """ 252 | params_dict = np.load( 253 | params_path, allow_pickle=False 254 | ) # pylint: disable=unexpected-keyword-arg 255 | source_keys = list(params_dict.keys()) 256 | pre_logits = any(l.name == "pre_logits" for l in model.layers) 257 | source_keys_used = [] 258 | n_transformers = len( 259 | set( 260 | "/".join(k.split("/")[:2]) 261 | for k in source_keys 262 | if k.startswith("Transformer/encoderblock_") 263 | ) 264 | ) 265 | n_transformers_out = sum( 266 | l.name.startswith("Transformer/encoderblock_") for l in model.layers 267 | ) 268 | assert n_transformers == n_transformers_out, ( 269 | f"Wrong number of transformers (" 270 | f"{n_transformers_out} in model vs. {n_transformers} in weights)." 271 | ) 272 | 273 | matches = [] 274 | for tidx in range(n_transformers): 275 | encoder = model.get_layer(f"Transformer/encoderblock_{tidx}") 276 | source_prefix = f"Transformer/encoderblock_{tidx}" 277 | matches.extend( 278 | [ 279 | { 280 | "layer": layer, 281 | "keys": [ 282 | f"{source_prefix}/{norm}/{name}" for name in ["scale", "bias"] 283 | ], 284 | } 285 | for norm, layer in [ 286 | ("LayerNorm_0", encoder.layernorm1), 287 | ("LayerNorm_2", encoder.layernorm2), 288 | ] 289 | ] 290 | + [ 291 | { 292 | "layer": encoder.mlpblock.get_layer( 293 | f"{source_prefix}/Dense_{mlpdense}" 294 | ), 295 | "keys": [ 296 | f"{source_prefix}/MlpBlock_3/Dense_{mlpdense}/{name}" 297 | for name in ["kernel", "bias"] 298 | ], 299 | } 300 | for mlpdense in [0, 1] 301 | ] 302 | + [ 303 | { 304 | "layer": layer, 305 | "keys": [ 306 | f"{source_prefix}/MultiHeadDotProductAttention_1/{attvar}/{name}" 307 | for name in ["kernel", "bias"] 308 | ], 309 | "reshape": True, 310 | } 311 | for attvar, layer in [ 312 | ("query", encoder.att.query_dense), 313 | ("key", encoder.att.key_dense), 314 | ("value", encoder.att.value_dense), 315 | ("out", encoder.att.combine_heads), 316 | ] 317 | ] 318 | ) 319 | for layer_name in ["embedding", "head", "pre_logits"]: 320 | if layer_name == "head" and not pretrained_top: 321 | source_keys_used.extend(["head/kernel", "head/bias"]) 322 | continue 323 | if layer_name == "pre_logits" and not pre_logits: 324 | continue 325 | matches.append( 326 | { 327 | "layer": model.get_layer(layer_name), 328 | "keys": [f"{layer_name}/{name}" for name in ["kernel", "bias"]], 329 | } 330 | ) 331 | matches.append({"layer": model.get_layer("class_token"), "keys": ["cls"]}) 332 | matches.append( 333 | { 334 | "layer": model.get_layer("Transformer/encoder_norm"), 335 | "keys": [f"Transformer/encoder_norm/{name}" for name in ["scale", "bias"]], 336 | } 337 | ) 338 | apply_embedding_weights( 339 | target_layer=model.get_layer("Transformer/posembed_input"), 340 | source_weights=params_dict["Transformer/posembed_input/pos_embedding"], 341 | num_x_patches=num_x_patches, 342 | num_y_patches=num_y_patches, 343 | ) 344 | source_keys_used.append("Transformer/posembed_input/pos_embedding") 345 | for match in matches: 346 | source_keys_used.extend(match["keys"]) 347 | source_weights = [params_dict[k] for k in match["keys"]] 348 | if match.get("reshape", False): 349 | source_weights = [ 350 | source.reshape(expected.shape) 351 | for source, expected in zip( 352 | source_weights, match["layer"].get_weights() 353 | ) 354 | ] 355 | match["layer"].set_weights(source_weights) 356 | unused = set(source_keys).difference(source_keys_used) 357 | if unused: 358 | warnings.warn(f"Did not use the following weights: {unused}", UserWarning) 359 | target_keys_set = len(source_keys_used) 360 | target_keys_all = len(model.weights) 361 | if target_keys_set < target_keys_all: 362 | warnings.warn( 363 | f"Only set {target_keys_set} of {target_keys_all} weights.", UserWarning 364 | ) 365 | 366 | ###################################################################### 367 | 368 | ConfigDict = tx.TypedDict( 369 | "ConfigDict", 370 | { 371 | "dropout": float, 372 | "mlp_dim": int, 373 | "num_heads": int, 374 | "num_layers": int, 375 | "hidden_size": int, 376 | }, 377 | ) 378 | 379 | CONFIG_B: ConfigDict = { 380 | "dropout": 0.1, 381 | "mlp_dim": 3072, 382 | "num_heads": 12, 383 | "num_layers": 12, 384 | "hidden_size": 768, 385 | } 386 | 387 | CONFIG_L: ConfigDict = { 388 | "dropout": 0.1, 389 | "mlp_dim": 4096, 390 | "num_heads": 16, 391 | "num_layers": 24, 392 | "hidden_size": 1024, 393 | } 394 | 395 | BASE_URL = "https://github.com/faustomorales/vit-keras/releases/download/dl" 396 | WEIGHTS = {"imagenet21k": 21_843, "imagenet21k+imagenet2012": 1_000} 397 | SIZES = {"B_16", "B_32", "L_16", "L_32"} 398 | 399 | ImageSizeArg = typing.Union[typing.Tuple[int, int], int] 400 | 401 | 402 | def preprocess_inputs(X): 403 | """Preprocess images""" 404 | return tf.keras.applications.imagenet_utils.preprocess_input( 405 | X, data_format=None, mode="tf" 406 | ) 407 | 408 | 409 | def interpret_image_size(image_size_arg: ImageSizeArg) -> typing.Tuple[int, int]: 410 | """Process the image_size argument whether a tuple or int.""" 411 | if isinstance(image_size_arg, int): 412 | return (image_size_arg, image_size_arg) 413 | if ( 414 | isinstance(image_size_arg, tuple) 415 | and len(image_size_arg) == 2 416 | and all(map(lambda v: isinstance(v, int), image_size_arg)) 417 | ): 418 | return image_size_arg 419 | raise ValueError( 420 | f"The image_size argument must be a tuple of 2 integers or a single integer. Received: {image_size_arg}" 421 | ) 422 | 423 | 424 | def build_model( 425 | image_size: ImageSizeArg, 426 | patch_size: int, 427 | num_layers: int, 428 | hidden_size: int, 429 | num_heads: int, 430 | name: str, 431 | mlp_dim: int, 432 | classes: int, 433 | dropout=0.1, 434 | activation="linear", 435 | include_top=True, 436 | representation_size=None, 437 | ): 438 | """Build a ViT model. 439 | 440 | Args: 441 | image_size: The size of input images. 442 | patch_size: The size of each patch (must fit evenly in image_size) 443 | classes: optional number of classes to classify images 444 | into, only to be specified if `include_top` is True, and 445 | if no `weights` argument is specified. 446 | num_layers: The number of transformer layers to use. 447 | hidden_size: The number of filters to use 448 | num_heads: The number of transformer heads 449 | mlp_dim: The number of dimensions for the MLP output in the transformers. 450 | dropout_rate: fraction of the units to drop for dense layers. 451 | activation: The activation to use for the final layer. 452 | include_top: Whether to include the final classification layer. If not, 453 | the output will have dimensions (batch_size, hidden_size). 454 | representation_size: The size of the representation prior to the 455 | classification layer. If None, no Dense layer is inserted. 456 | """ 457 | image_size_tuple = interpret_image_size(image_size) 458 | assert (image_size_tuple[0] % patch_size == 0) and ( 459 | image_size_tuple[1] % patch_size == 0 460 | ), "image_size must be a multiple of patch_size" 461 | x = tf.keras.layers.Input(shape=(image_size_tuple[0], image_size_tuple[1], 3)) 462 | y = tf.keras.layers.Conv2D( 463 | filters=hidden_size, 464 | kernel_size=patch_size, 465 | strides=patch_size, 466 | padding="valid", 467 | name="embedding", 468 | )(x) 469 | y = tf.keras.layers.Reshape((y.shape[1] * y.shape[2], hidden_size))(y) 470 | y = ClassToken(name="class_token")(y) 471 | y = AddPositionEmbs(name="Transformer/posembed_input")(y) 472 | for n in range(num_layers): 473 | y, _ = TransformerBlock( 474 | num_heads=num_heads, 475 | mlp_dim=mlp_dim, 476 | dropout=dropout, 477 | name=f"Transformer/encoderblock_{n}", 478 | )(y) 479 | y = tf.keras.layers.LayerNormalization( 480 | epsilon=1e-6, name="Transformer/encoder_norm" 481 | )(y) 482 | y = tf.keras.layers.Lambda(lambda v: v[:, 0], name="ExtractToken")(y) 483 | if representation_size is not None: 484 | y = tf.keras.layers.Dense( 485 | representation_size, name="pre_logits", activation="tanh" 486 | )(y) 487 | if include_top: 488 | y = tf.keras.layers.Dense(classes, name="head", activation=activation)(y) 489 | return tf.keras.models.Model(inputs=x, outputs=y, name=name) 490 | 491 | 492 | def validate_pretrained_top( 493 | include_top: bool, pretrained: bool, classes: int, weights: str 494 | ): 495 | """Validate that the pretrained weight configuration makes sense.""" 496 | assert weights in WEIGHTS, f"Unexpected weights: {weights}." 497 | expected_classes = WEIGHTS[weights] 498 | if classes != expected_classes: 499 | warnings.warn( 500 | f"Can only use pretrained_top with {weights} if classes = {expected_classes}. Setting manually.", 501 | UserWarning, 502 | ) 503 | assert include_top, "Can only use pretrained_top with include_top." 504 | assert pretrained, "Can only use pretrained_top with pretrained." 505 | return expected_classes 506 | 507 | 508 | def load_pretrained( 509 | size: str, 510 | weights: str, 511 | pretrained_top: bool, 512 | model: tf.keras.models.Model, 513 | image_size: ImageSizeArg, 514 | patch_size: int, 515 | ): 516 | """Load model weights for a known configuration.""" 517 | image_size_tuple = interpret_image_size(image_size) 518 | fname = f"ViT-{size}_{weights}.npz" 519 | origin = f"{BASE_URL}/{fname}" 520 | local_filepath = tf.keras.utils.get_file(fname, origin, cache_subdir="weights") 521 | load_weights_numpy( 522 | model=model, 523 | params_path=local_filepath, 524 | pretrained_top=pretrained_top, 525 | num_x_patches=image_size_tuple[1] // patch_size, 526 | num_y_patches=image_size_tuple[0] // patch_size, 527 | ) 528 | 529 | 530 | def vit_b16( 531 | image_size: ImageSizeArg = (224, 224), 532 | classes=1000, 533 | activation="linear", 534 | include_top=True, 535 | pretrained=True, 536 | pretrained_top=True, 537 | weights="imagenet21k+imagenet2012", 538 | ): 539 | """Build ViT-B16. All arguments passed to build_model.""" 540 | if pretrained_top: 541 | classes = validate_pretrained_top( 542 | include_top=include_top, 543 | pretrained=pretrained, 544 | classes=classes, 545 | weights=weights, 546 | ) 547 | model = build_model( 548 | **CONFIG_B, 549 | name="vit-b16", 550 | patch_size=16, 551 | image_size=image_size, 552 | classes=classes, 553 | activation=activation, 554 | include_top=include_top, 555 | representation_size=768 if weights == "imagenet21k" else None, 556 | ) 557 | 558 | if pretrained: 559 | load_pretrained( 560 | size="B_16", 561 | weights=weights, 562 | model=model, 563 | pretrained_top=pretrained_top, 564 | image_size=image_size, 565 | patch_size=16, 566 | ) 567 | return model 568 | 569 | 570 | def vit_b32( 571 | image_size: ImageSizeArg = (224, 224), 572 | classes=1000, 573 | activation="linear", 574 | include_top=True, 575 | pretrained=True, 576 | pretrained_top=True, 577 | weights="imagenet21k+imagenet2012", 578 | ): 579 | """Build ViT-B32. All arguments passed to build_model.""" 580 | if pretrained_top: 581 | classes = validate_pretrained_top( 582 | include_top=include_top, 583 | pretrained=pretrained, 584 | classes=classes, 585 | weights=weights, 586 | ) 587 | model = build_model( 588 | **CONFIG_B, 589 | name="vit-b32", 590 | patch_size=32, 591 | image_size=image_size, 592 | classes=classes, 593 | activation=activation, 594 | include_top=include_top, 595 | representation_size=768 if weights == "imagenet21k" else None, 596 | ) 597 | if pretrained: 598 | load_pretrained( 599 | size="B_32", 600 | weights=weights, 601 | model=model, 602 | pretrained_top=pretrained_top, 603 | patch_size=32, 604 | image_size=image_size, 605 | ) 606 | return model 607 | 608 | 609 | def vit_l16( 610 | image_size: ImageSizeArg = (384, 384), 611 | classes=1000, 612 | activation="linear", 613 | include_top=True, 614 | pretrained=True, 615 | pretrained_top=True, 616 | weights="imagenet21k+imagenet2012", 617 | ): 618 | """Build ViT-L16. All arguments passed to build_model.""" 619 | if pretrained_top: 620 | classes = validate_pretrained_top( 621 | include_top=include_top, 622 | pretrained=pretrained, 623 | classes=classes, 624 | weights=weights, 625 | ) 626 | model = build_model( 627 | **CONFIG_L, 628 | patch_size=16, 629 | name="vit-l16", 630 | image_size=image_size, 631 | classes=classes, 632 | activation=activation, 633 | include_top=include_top, 634 | representation_size=1024 if weights == "imagenet21k" else None, 635 | ) 636 | if pretrained: 637 | load_pretrained( 638 | size="L_16", 639 | weights=weights, 640 | model=model, 641 | pretrained_top=pretrained_top, 642 | patch_size=16, 643 | image_size=image_size, 644 | ) 645 | return model 646 | 647 | 648 | def vit_l32( 649 | image_size: ImageSizeArg = (384, 384), 650 | classes=1000, 651 | activation="linear", 652 | include_top=True, 653 | pretrained=True, 654 | pretrained_top=True, 655 | weights="imagenet21k+imagenet2012", 656 | ): 657 | """Build ViT-L32. All arguments passed to build_model.""" 658 | if pretrained_top: 659 | classes = validate_pretrained_top( 660 | include_top=include_top, 661 | pretrained=pretrained, 662 | classes=classes, 663 | weights=weights, 664 | ) 665 | model = build_model( 666 | **CONFIG_L, 667 | patch_size=32, 668 | name="vit-l32", 669 | image_size=image_size, 670 | classes=classes, 671 | activation=activation, 672 | include_top=include_top, 673 | representation_size=1024 if weights == "imagenet21k" else None, 674 | ) 675 | if pretrained: 676 | load_pretrained( 677 | size="L_32", 678 | weights=weights, 679 | model=model, 680 | pretrained_top=pretrained_top, 681 | patch_size=32, 682 | image_size=image_size, 683 | ) 684 | return model 685 | 686 | --------------------------------------------------------------------------------