├── imgs └── gradmax.png ├── growneuron ├── cifar │ ├── configs │ │ ├── README.md │ │ ├── __init__.py │ │ ├── baseline_big.py │ │ ├── baseline_big_vgg.py │ │ ├── baseline_small_vgg.py │ │ ├── grow_all_at_once.py │ │ ├── grow_round_robin.py │ │ ├── grow_all_at_once_vgg.py │ │ └── baseline_small.py │ ├── __init__.py │ ├── data.py │ ├── vgg.py │ ├── wide_resnet.py │ └── main.py ├── imagenet │ ├── __init__.py │ ├── configs │ │ ├── __init__.py │ │ ├── baseline_big.py │ │ ├── grow_all_at_once.py │ │ └── baseline_small.py │ ├── data.py │ ├── mb_v1.py │ ├── data_util.py │ └── main.py ├── __init__.py ├── updaters_test.py ├── layers_test.py ├── updaters.py ├── layers.py └── growers.py ├── CONTRIBUTING.md ├── run.sh ├── setup.py ├── README.md └── LICENSE /imgs/gradmax.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-research/growneuron/HEAD/imgs/gradmax.png -------------------------------------------------------------------------------- /growneuron/cifar/configs/README.md: -------------------------------------------------------------------------------- 1 | Here we explain main configs for WRN-28-1: 2 | 3 | - `baseline_small` does small 25% block-width dense training. 4 | - `baseline_big` does regular width dense training. 5 | - `grow` starts with 25% block-width dense model and grows the model in to the 6 | regular width 10% at a time. 7 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to Contribute 2 | 3 | We'd love to accept your patches and contributions to this project. 4 | 5 | - If you want to contribute to the library please check `Issues` tab and feel 6 | free to take on any problem/issue you find interesting. 7 | - If your `issue` is not reported yet, please create a new one. It is 8 | important to discuss the problem/request before implementing the solution. 9 | 10 | ## Code reviews 11 | 12 | All submissions, including submissions by project members, require review. We 13 | use GitHub pull requests for this purpose. Consult 14 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more 15 | information on using pull requests. 16 | -------------------------------------------------------------------------------- /growneuron/cifar/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 GradMax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """This folder involves the code for cifar experiments.""" 17 | -------------------------------------------------------------------------------- /growneuron/cifar/configs/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 GradMax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """This folder involves configs for cifar experiments.""" 17 | -------------------------------------------------------------------------------- /growneuron/imagenet/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 GradMax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """This folder involves the code for imagenet experiments.""" 17 | -------------------------------------------------------------------------------- /growneuron/imagenet/configs/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 GradMax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """This folder involves configs for imagenet experiments.""" 17 | -------------------------------------------------------------------------------- /growneuron/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 GradMax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """This repo involves the code for gradient maximizing growth.""" 17 | name = 'growneuron' 18 | -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | # Copyright 2022 GradMax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | #!/bin/bash 16 | set -e 17 | set -x 18 | 19 | pip install . 20 | python -m growneuron.layers_test 21 | python -m growneuron.updaters_test 22 | -------------------------------------------------------------------------------- /growneuron/cifar/configs/baseline_big.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 GradMax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Default config for big baseline training.""" 17 | from growneuron.cifar.configs import baseline_small 18 | 19 | 20 | def get_config(): 21 | """Builds and returns config.""" 22 | config = baseline_small.get_config() 23 | config.model.block_width_multiplier = 1. 24 | return config 25 | -------------------------------------------------------------------------------- /growneuron/imagenet/configs/baseline_big.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 GradMax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Default config for big baseline training.""" 17 | from growneuron.imagenet.configs import baseline_small 18 | 19 | 20 | def get_config(): 21 | """Builds and returns config.""" 22 | config = baseline_small.get_config() 23 | config.model.width_multiplier = 1. 24 | return config 25 | -------------------------------------------------------------------------------- /growneuron/cifar/configs/baseline_big_vgg.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 GradMax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Default config for big baseline training.""" 17 | from growneuron.cifar.configs import baseline_small_vgg 18 | 19 | 20 | def get_config(): 21 | """Builds and returns config.""" 22 | config = baseline_small_vgg.get_config() 23 | config.model.width_multiplier = 1 24 | return config 25 | -------------------------------------------------------------------------------- /growneuron/cifar/configs/baseline_small_vgg.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 GradMax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Default configs for baseline cifar-10 training.""" 17 | from growneuron.cifar.configs import baseline_small 18 | import ml_collections 19 | 20 | 21 | def get_config(): 22 | """Builds and returns config.""" 23 | config = baseline_small.get_config() 24 | 25 | config.architecture = 'vgg' 26 | config.model = ml_collections.ConfigDict() 27 | config.model.depth = 11 28 | config.model.normalization_type = 'none' 29 | config.model.width_multiplier = 0.25 30 | config.optimizer.base_learning_rate = 0.05 31 | 32 | return config 33 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | """growneuron library setup.""" 2 | 3 | import pathlib 4 | from setuptools import find_packages 5 | from setuptools import setup 6 | 7 | here = pathlib.Path(__file__).parent.resolve() 8 | 9 | long_description = (here / 'README.md').read_text(encoding='utf-8') 10 | setup( 11 | name='growneuron', 12 | version='0.1', 13 | description='Gradmax, gradient maximizing neural network growth.', 14 | long_description=long_description, 15 | long_description_content_type='text/markdown', 16 | url='https://github.com/google-research/growneuron', 17 | author='Google LLC', 18 | license='Apache 2.0', 19 | packages=find_packages(), 20 | package_data={}, 21 | scripts=['growneuron/cifar/main.py', 'growneuron/imagenet/main.py'], 22 | classifiers=[ 23 | 'Development Status :: 4 - Beta', 24 | 'Intended Audience :: Developers', 25 | 'Intended Audience :: Science/Research', 26 | 'License :: OSI Approved :: Apache Software License', 27 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 28 | ], 29 | keywords=('neural networks tensorflow machine learning growing growth' 30 | 'gradmax google convolutional during training'), 31 | install_requires=[ 32 | 'absl-py', 33 | 'numpy', 34 | 'ml-collections', 35 | 'tensorflow==2.7', 36 | 'scipy==1.7.3', 37 | 'tfds-nightly', 38 | ('uncertainty_baselines @ git+https://github.com/google/' 39 | 'uncertainty-baselines.git#egg=uncertainty_baselines'), 40 | ], 41 | ) 42 | -------------------------------------------------------------------------------- /growneuron/imagenet/configs/grow_all_at_once.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 GradMax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Default config for random growing.""" 17 | from growneuron.imagenet.configs import baseline_small 18 | 19 | 20 | def get_config(): 21 | """Builds and returns config.""" 22 | config = baseline_small.get_config() 23 | config.updater_type = 'all_at_once' 24 | config.grow_type = 'add_random' 25 | config.grow_epsilon = 0. 26 | config.is_outgoing_zero = False 27 | config.grow_scale_method = 'mean_norm' 28 | config.model.normalization_type = 'none' 29 | config.updater.carry_optimizer = True 30 | 31 | config.updater.update_frequency = 4000 32 | config.grow_frequency_multiplier = 1. 33 | config.updater.start_iteration = 10000 34 | # 1 cyle is 12 growth steps. 35 | config.updater.n_growth_steps = 12 36 | # Use one of the following 37 | # config.updater.n_grow = 2 38 | config.updater.n_grow_fraction = 0.25 39 | config.updater.scale = 0.5 40 | 41 | return config 42 | -------------------------------------------------------------------------------- /growneuron/cifar/configs/grow_all_at_once.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 GradMax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Default config for random growing.""" 17 | from growneuron.cifar.configs import baseline_small 18 | 19 | 20 | def get_config(): 21 | """Builds and returns config.""" 22 | config = baseline_small.get_config() 23 | config.updater_type = 'all_at_once' 24 | config.grow_type = 'add_random' 25 | config.grow_batch_size = 128 26 | config.grow_epsilon = 0. 27 | config.is_outgoing_zero = False 28 | config.grow_scale_method = 'mean_norm' 29 | config.model.normalization_type = 'none' 30 | config.updater.carry_optimizer = True 31 | 32 | # We are aiming 12*2500=30000 steps growth period. 33 | config.updater.update_frequency = 2500 34 | config.updater.start_iteration = 10000 35 | # 1 cyle is 12 growth steps. 36 | config.updater.n_growth_steps = 12 # 12 * 12 cycle 37 | # Use one of the following 38 | # config.updater.n_grow = 2 39 | config.updater.n_grow_fraction = 0.25 40 | config.updater.scale = 0.5 41 | 42 | return config 43 | -------------------------------------------------------------------------------- /growneuron/cifar/configs/grow_round_robin.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 GradMax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Default config for random growing.""" 17 | from growneuron.cifar.configs import baseline_small 18 | 19 | 20 | def get_config(): 21 | """Builds and returns config.""" 22 | config = baseline_small.get_config() 23 | config.updater_type = 'round_robin' 24 | config.grow_type = 'add_random' 25 | config.grow_batch_size = 128 26 | config.grow_epsilon = 0. 27 | config.is_outgoing_zero = False 28 | config.grow_scale_method = 'mean_norm' 29 | config.model.normalization_type = 'none' 30 | config.updater.carry_optimizer = True 31 | 32 | # We are aiming 144*200=28800 steps growth period 33 | config.updater.update_frequency = 150 34 | config.updater.start_iteration = 10000 35 | config.scale_epochs = True 36 | # 1 cyle is 12 growth steps. 37 | config.updater.n_growth_steps = 144 # 12 * 12 cycle 38 | # Use one of the following 39 | # config.updater.n_grow = 2 40 | config.updater.n_grow_fraction = 0.25 41 | config.updater.scale = 0.5 42 | 43 | return config 44 | -------------------------------------------------------------------------------- /growneuron/cifar/configs/grow_all_at_once_vgg.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 GradMax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Default config for random growing.""" 17 | from growneuron.cifar.configs import baseline_small_vgg 18 | 19 | 20 | def get_config(): 21 | """Builds and returns config.""" 22 | config = baseline_small_vgg.get_config() 23 | config.updater_type = 'all_at_once' 24 | config.grow_type = 'add_random' 25 | config.grow_batch_size = 128 26 | config.grow_epsilon = 0. 27 | config.is_outgoing_zero = False 28 | config.grow_scale_method = 'fixed' 29 | config.model.normalization_type = 'none' 30 | config.updater.carry_optimizer = True 31 | 32 | # We are aiming 12*2500=30000 steps growth period. 33 | config.updater.update_frequency = 2500 34 | config.updater.start_iteration = 10000 35 | config.scale_epochs = False 36 | # 1 cyle is 12 growth steps. 37 | config.updater.n_growth_steps = 12 # 12 cycle 38 | # Use one of the following 39 | # config.updater.n_grow = 2 40 | config.updater.n_grow_fraction = 0.25 41 | config.updater.scale = 0.5 42 | 43 | return config 44 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # GradMax: Growing Neural Networks using Gradient Information 2 | 3 | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google-research/growneuron/blob/main/Student_Teacher.ipynb) 4 | 5 | Code for reproducing our results in the GradMax paper [[arxiv.org/abs/2201.05125](https://arxiv.org/abs/2201.05125)]. 6 | 7 | GradMax 8 | 9 | 10 | ## Setup 11 | First clone this repo. 12 | ```bash 13 | git clone https://github.com/google-research/growneuron.git 14 | cd growneuron 15 | ``` 16 | 17 | Following script installs the necessary libraries and runs few tests. 18 | ```bash 19 | bash run.sh 20 | ``` 21 | 22 | Following will download the data and run the baseline experiment. If the 23 | data is already downloaded use the `--data_dir` flag to pass the path. 24 | ```bash 25 | python growneuron/cifar/main.py --output_dir=/tmp/cifar --download_data 26 | ``` 27 | 28 | ## Running GradMax 29 | Following command would start a training with WRN-28-0.25x and grow it into 30 | WRN-28-1. Growth is done every 2500 step starting from 31 | iteration 10000 at all convolutional layers at once. 32 | ```bash 33 | rm -rf /tmp/cifar 34 | python growneuron/cifar/main.py --output_dir=/tmp/cifar \ 35 | --config=growneuron/cifar/configs/grow_all_at_once.py \ 36 | --config.grow_type=add_gradmax 37 | ``` 38 | 39 | ## Other Experiments 40 | - Baselines for WRN-28 and VGG11 can be ran using the corresponding configs in 41 | `growneuron/cifar/configs/`. 42 | - Set `--config.grow_type` argument to `add_gradmax`, `add_firefly`, `add_gradmax_opt` or 43 | `add_random` to grow using different strategies. 44 | - Use `--config.is_outgoing_zero` to run experiments where outgoing weights 45 | are set to zero. 46 | - Use `--config.model.normalization_type=batchnorm` to run experiments with 47 | batch normalization layers. 48 | 49 | 50 | 51 | ## Disclaimer 52 | This is not an officially supported Google product. 53 | -------------------------------------------------------------------------------- /growneuron/imagenet/configs/baseline_small.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 GradMax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Default configs for baseline cifar-10 training.""" 17 | import ml_collections 18 | 19 | 20 | def get_config(): 21 | """Builds and returns config.""" 22 | config = ml_collections.ConfigDict() 23 | 24 | config.optimizer = ml_collections.ConfigDict() 25 | # Base learning rate when total batch size is 128. It is scaled by the ratio 26 | # of the total batch size to 128. 27 | config.optimizer.base_learning_rate = 0.05 28 | # One of 'step', 'cosine' 29 | config.optimizer.decay_type = 'step' 30 | config.optimizer.nesterov = False 31 | # Amount to decay learning rate. 32 | config.optimizer.lr_decay_ratio = 0.1 33 | # Epochs to decay learning rate by. 34 | config.optimizer.lr_decay_epochs = [0.3, 0.6, 0.8] 35 | # Number of epochs for a linear warmup to the initial learning rate. Use 0 to' 36 | # do no warmup. 37 | config.optimizer.lr_warmup_epochs = 5 38 | # Optimizer momentum. 39 | config.optimizer.momentum = 0.9 40 | # Following is empty for the baselines and used by the growing algorithms. 41 | config.updater = ml_collections.ConfigDict() 42 | config.updater.carry_optimizer = False 43 | config.is_outgoing_zero = False 44 | config.scale_epochs = False 45 | 46 | config.model = ml_collections.ConfigDict() 47 | # L2 regularization coefficient. 48 | config.model.l2_coef = 1e-4 49 | config.model.width_multiplier = 0.25 50 | config.model.normalization_type = 'batchnorm' 51 | 52 | # Number of epochs between saving checkpoints. Use -1 for no checkpoints. 53 | config.checkpoint_interval = 25 54 | config.dataset = 'imagenet2012' 55 | # WBatch size per TPU core/GPU. The number of new datapoints gathered per 56 | # batch is this number divided by ensemble_size (we tile the batch by that # 57 | # of times). 58 | config.per_core_batch_size = 64 59 | config.num_cores = 1 60 | config.seed = 8 61 | config.train_epochs = 90 62 | config.log_freq = 200 63 | 64 | return config 65 | -------------------------------------------------------------------------------- /growneuron/cifar/configs/baseline_small.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 GradMax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Default configs for baseline cifar-10 training.""" 17 | import ml_collections 18 | 19 | 20 | def get_config(): 21 | """Builds and returns config.""" 22 | config = ml_collections.ConfigDict() 23 | 24 | config.optimizer = ml_collections.ConfigDict() 25 | # Base learning rate when total batch size is 128. It is scaled by the ratio 26 | # of the total batch size to 128. 27 | config.optimizer.base_learning_rate = 0.1 28 | # One of 'step', 'cosine' 29 | config.optimizer.decay_type = 'cosine' 30 | config.optimizer.nesterov = True 31 | # Amount to decay learning rate. 32 | config.optimizer.lr_decay_ratio = 0.2 33 | # Epochs to decay learning rate by. 34 | config.optimizer.lr_decay_epochs = [0.3, 0.6, 0.8] 35 | # Number of epochs for a linear warmup to the initial learning rate. Use 0 to' 36 | # do no warmup. 37 | config.optimizer.lr_warmup_epochs = 1 38 | # Optimizer momentum. 39 | config.optimizer.momentum = 0.9 40 | # Following is empty for the baselines and used by the growing algorithms. 41 | config.updater = ml_collections.ConfigDict() 42 | config.updater.carry_optimizer = False 43 | config.is_outgoing_zero = False 44 | config.scale_epochs = False 45 | 46 | config.model = ml_collections.ConfigDict() 47 | # L2 regularization coefficient. 48 | config.model.l2_coef = 2e-4 49 | config.model.depth = 28 50 | config.model.width_multiplier = 1 51 | config.model.normalization_type = 'none' 52 | config.model.block_width_multiplier = 0.25 53 | 54 | # Number of epochs between saving checkpoints. Use -1 for no checkpoints. 55 | config.checkpoint_interval = 25 56 | # One of ['cifar10', 'cifar100'] 57 | config.dataset = 'cifar10' 58 | # Whether to cache the dataset. 59 | config.cache_dataset = True 60 | # WBatch size per TPU core/GPU. The number of new datapoints gathered per 61 | # batch is this number divided by ensemble_size (we tile the batch by that # 62 | # of times). 63 | config.per_core_batch_size = 128 64 | config.num_cores = 1 65 | config.seed = 8 66 | config.train_epochs = 200 67 | config.log_freq = 200 68 | 69 | return config 70 | -------------------------------------------------------------------------------- /growneuron/cifar/data.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 GradMax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Data pipeline. 17 | 18 | Forked from simclr/tf2 codebase. 19 | """ 20 | from typing import Optional 21 | from absl import logging 22 | 23 | import tensorflow.compat.v2 as tf 24 | import tensorflow_datasets as tfds 25 | 26 | 27 | def build_input_fn( 28 | builder, 29 | global_batch_size, 30 | topology, 31 | is_training, 32 | cache_dataset = True): 33 | """Build input function. 34 | 35 | Args: 36 | builder: TFDS builder for specified dataset. 37 | global_batch_size: Global batch size. 38 | topology: An instance of `tf.tpu.experimental.Topology` or None. 39 | is_training: Whether to build in training mode. 40 | cache_dataset: bool, whether to cache the dataset. 41 | 42 | Returns: 43 | A function that accepts a dict of params and returns a tuple of images and 44 | features, to be used as the input_fn in TPUEstimator. 45 | """ 46 | 47 | def _input_fn(input_context): 48 | """Inner input function.""" 49 | batch_size = input_context.get_per_replica_batch_size(global_batch_size) 50 | logging.info('Global batch size: %d', global_batch_size) 51 | logging.info('Per-replica batch size: %d', batch_size) 52 | 53 | def map_fn(image, label): 54 | """Produces multiple transformations of the same batch.""" 55 | if is_training: 56 | image_shape = tf.shape(image) 57 | # Expand the image by 2 pixels, then crop back down to 32x32. 58 | image = tf.image.resize_with_crop_or_pad( 59 | image, image_shape[0] + 4, image_shape[1] + 4) 60 | image = tf.image.random_crop(image, (image_shape[0], image_shape[0], 3)) 61 | image = tf.image.random_flip_left_right(image) 62 | image = tf.image.convert_image_dtype(image, tf.float32) 63 | return image, label 64 | 65 | dataset = builder.as_dataset( 66 | split='train' if is_training else 'test', 67 | shuffle_files=is_training, 68 | as_supervised=True) 69 | logging.info('num_input_pipelines: %d', input_context.num_input_pipelines) 70 | # The dataset is always sharded by number of hosts. 71 | # num_input_pipelines is the number of hosts rather than number of cores. 72 | if input_context.num_input_pipelines > 1: 73 | dataset = dataset.shard(input_context.num_input_pipelines, 74 | input_context.input_pipeline_id) 75 | if cache_dataset: 76 | dataset = dataset.cache() 77 | if is_training: 78 | dataset = dataset.shuffle(50000) 79 | dataset = dataset.repeat(-1) 80 | dataset = dataset.map( 81 | map_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE) 82 | dataset = dataset.batch(batch_size, drop_remainder=is_training) 83 | prefetch_buffer_size = 2 * topology.num_tpus_per_task if topology else 2 84 | dataset = dataset.prefetch(prefetch_buffer_size) 85 | return dataset 86 | 87 | return _input_fn 88 | 89 | 90 | -------------------------------------------------------------------------------- /growneuron/imagenet/data.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 GradMax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Data pipeline. 17 | 18 | Forked from simclr/tf2 codebase. 19 | """ 20 | import functools 21 | from typing import Optional 22 | from absl import logging 23 | from growneuron.imagenet import data_util 24 | 25 | import tensorflow.compat.v2 as tf 26 | import tensorflow_datasets as tfds 27 | 28 | 29 | def build_input_fn( 30 | builder, 31 | global_batch_size, 32 | topology, 33 | is_training, 34 | image_size = 224): 35 | """Build input function. 36 | 37 | Args: 38 | builder: TFDS builder for specified dataset. 39 | global_batch_size: Global batch size. 40 | topology: An instance of `tf.tpu.experimental.Topology` or None. 41 | is_training: Whether to build in training mode. 42 | image_size: Size of the output images. 43 | 44 | Returns: 45 | A function that accepts a dict of params and returns a tuple of images and 46 | features, to be used as the input_fn in TPUEstimator. 47 | """ 48 | 49 | def _input_fn(input_context): 50 | """Inner input function.""" 51 | batch_size = input_context.get_per_replica_batch_size(global_batch_size) 52 | logging.info('Global batch size: %d', global_batch_size) 53 | logging.info('Per-replica batch size: %d', batch_size) 54 | 55 | preprocess_fn = get_preprocess_fn(is_training, image_size) 56 | def map_fn(image, label): 57 | """Produces multiple transformations of the same batch.""" 58 | image = preprocess_fn(image) 59 | return image, label 60 | 61 | dataset = builder.as_dataset( 62 | split='train' if is_training else 'validation', 63 | shuffle_files=is_training, 64 | as_supervised=True) 65 | logging.info('num_input_pipelines: %d', input_context.num_input_pipelines) 66 | # The dataset is always sharded by number of hosts. 67 | # num_input_pipelines is the number of hosts rather than number of cores. 68 | if input_context.num_input_pipelines > 1: 69 | dataset = dataset.shard(input_context.num_input_pipelines, 70 | input_context.input_pipeline_id) 71 | if is_training: 72 | buffer_multiplier = 50 if image_size <= 32 else 10 73 | dataset = dataset.shuffle(batch_size * buffer_multiplier) 74 | dataset = dataset.repeat(-1) 75 | dataset = dataset.map( 76 | map_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE) 77 | dataset = dataset.batch(batch_size, drop_remainder=is_training) 78 | prefetch_buffer_size = 2 * topology.num_tpus_per_task if topology else 2 79 | dataset = dataset.prefetch(prefetch_buffer_size) 80 | return dataset 81 | 82 | return _input_fn 83 | 84 | 85 | def get_preprocess_fn(is_training, image_size=224): 86 | """Get function that accepts an image and returns a preprocessed image.""" 87 | # Disable test cropping for small images (e.g. CIFAR) 88 | if image_size <= 32: 89 | test_crop = False 90 | else: 91 | test_crop = True 92 | return functools.partial( 93 | data_util.preprocess_image, 94 | image_size=image_size, 95 | is_training=is_training, 96 | test_crop=test_crop) 97 | -------------------------------------------------------------------------------- /growneuron/updaters_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 GradMax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for growneuron.updaters.""" 17 | 18 | import itertools 19 | import absl.testing.parameterized as parameterized 20 | from growneuron import growers 21 | from growneuron import updaters 22 | import tensorflow as tf 23 | 24 | 25 | class RoundRobinScheduleTest(parameterized.TestCase, tf.test.TestCase): 26 | 27 | def setUp(self): 28 | super().setUp() 29 | self.all_layers = [] 30 | for _ in range(5): 31 | layer = tf.keras.layers.Dense(2) 32 | layer.build((None, 8)) 33 | self.all_layers.append(layer) 34 | 35 | def get_random_grower(self): 36 | return growers.AddRandom() 37 | 38 | def get_grow_layers(self, n=1): 39 | return list(itertools.permutations(self.all_layers, 2))[:n] 40 | 41 | @parameterized.named_parameters( 42 | ('every_100_start0_3times', 100, 0, 3, 43 | [(0, True), (4, False), (100, True), (2300, True), (100, False)]), 44 | ('every_1_start0_21times', 1, 0, 21, 45 | [(0, True), (100, True), (2300, True), (3, True)]), 46 | ('every_1_start2_5times', 1, 2, 4, 47 | [(1, False), (2, True), (100, True), (1, False), (3, True)]), 48 | ('every_100_start50_1times', 100, 50, 1, 49 | [(20, False), (75, False), (200, True), (234, False), (300, False)]), 50 | ('none_start', 5, None, 25, 51 | [(2, False), (5, True), (7, False), (45, True)]), 52 | ) 53 | def test_update_iter(self, update_frequency, start_iteration, n_growth_steps, 54 | iterations): 55 | network_grower = self.get_random_grower() 56 | grow_layer_tuples = self.get_grow_layers() 57 | 58 | updater = updaters.RoundRobin( 59 | network_grower, 60 | grow_layer_tuples, 61 | update_frequency=update_frequency, 62 | n_grow=1, 63 | start_iteration=start_iteration, 64 | n_growth_steps=n_growth_steps) 65 | for iteration, bool_val in iterations: 66 | print(f'COUNT:{updater._growth_counter}, {iteration}, {bool_val}') 67 | self.assertEqual(updater.is_update_iteration(iteration), bool_val) 68 | if bool_val: 69 | updater._next_grow_layer_tuple(None) 70 | 71 | @parameterized.named_parameters(('n4', 4), ('n1', 1)) 72 | def test_next_grow_layers(self, n_tuples): 73 | network_grower = self.get_random_grower() 74 | grow_layer_tuples = self.get_grow_layers(n=n_tuples) 75 | updater = updaters.RoundRobin( 76 | network_grower, grow_layer_tuples, update_frequency=2) 77 | for tup in itertools.chain(grow_layer_tuples, grow_layer_tuples): 78 | self.assertEqual(tup, updater._next_grow_layer_tuple(None)) 79 | 80 | @parameterized.named_parameters(('1d', (2,), (3,)), 81 | ('2d_out', (2, 4), (2, 5)), 82 | ('2d_in', (2, 4), (4, 4)), 83 | # Deptwise kernel 84 | ('3d_in', (4, 4, 5), (4, 4, 8)), 85 | ('4d_in', (4, 4, 2, 3), (4, 4, 4, 3)), 86 | ('4d_out', (4, 4, 2, 3), (4, 4, 2, 5)) 87 | ) 88 | def test_pad_zeros_to(self, old_shape, new_shape): 89 | tensor = tf.random.uniform(old_shape) 90 | new_tensor = updaters.pad_zeros_to(tensor, new_shape) 91 | old_slice = tuple(slice(None, x) for x in tensor.shape) 92 | self.assertAllEqual(new_tensor[old_slice], tensor) 93 | 94 | @parameterized.named_parameters( 95 | ('dense_outgrown', (3, 4), (3, 5), lambda a, i: a[:, i]), 96 | ('dense_ingrown', (3, 4), (4, 4), lambda a, i: a[i, :]), 97 | ('conv2d_ingrown', (2, 2, 3, 4), (2, 2, 4, 4), 98 | lambda a, i: a[:, :, i, :]), 99 | ('conv2d_outgrown', (2, 2, 3, 4), (2, 2, 3, 5), lambda a, i: a[Ellipsis, i]), 100 | ('conv_dw', (2, 2, 3), (2, 2, 4), lambda a, i: a[:, :, i]) 101 | ) 102 | def test_copy_adam_slots(self, old_shape, new_shape, slice_fn): 103 | grow_layer_tuples = self.get_grow_layers() 104 | network_grower = self.get_random_grower() 105 | updater = updaters.RoundRobin( 106 | network_grower, grow_layer_tuples, update_frequency=2) 107 | old_var = tf.Variable(tf.ones(old_shape)) 108 | new_var = tf.Variable(tf.ones(new_shape)) 109 | optimizer = tf.keras.optimizers.Adam() 110 | optimizer._create_slots([old_var, new_var]) 111 | random_slot_vals = tf.random.uniform(old_shape) 112 | for s_name in optimizer.get_slot_names(): 113 | self.assertAllEqual(optimizer.get_slot(new_var, s_name), 114 | tf.zeros(new_shape)) 115 | optimizer.get_slot(old_var, s_name).assign(random_slot_vals) 116 | 117 | updater.copy_optimizer_slots(optimizer, [old_var], [new_var]) 118 | for s_name in optimizer.get_slot_names(): 119 | # Check new_values still have zeros. 120 | new_values_slice = slice_fn(optimizer.get_slot(new_var, s_name), -1) 121 | self.assertAllEqual(new_values_slice, tf.zeros_like(new_values_slice)) 122 | # Check old variables have their random values set correctly. 123 | old_values_slice = slice_fn(optimizer.get_slot(old_var, s_name), 0) 124 | new_values_slice = slice_fn(optimizer.get_slot(new_var, s_name), 0) 125 | self.assertAllEqual(new_values_slice, old_values_slice) 126 | 127 | 128 | if __name__ == '__main__': 129 | tf.test.main() 130 | -------------------------------------------------------------------------------- /growneuron/cifar/vgg.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 GradMax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """VGG Network.""" 17 | 18 | import functools 19 | from typing import Any, Dict 20 | from growneuron.cifar import wide_resnet 21 | import growneuron.layers as glayers 22 | import tensorflow as tf 23 | 24 | NormalizationType = wide_resnet.NormalizationType 25 | 26 | BatchNormalization = functools.partial( 27 | tf.keras.layers.BatchNormalization, 28 | epsilon=1e-5, # using epsilon and momentum defaults from Torch 29 | momentum=0.9) 30 | 31 | LayerNormalization = functools.partial( 32 | tf.keras.layers.LayerNormalization, 33 | epsilon=1e-5) # using epsilon and momentum defaults from Torch 34 | 35 | 36 | def Conv2D(filters, seed=None, **kwargs): 37 | """Conv2D layer that is deterministically initialized.""" 38 | default_kwargs = { 39 | "kernel_size": 3, 40 | "padding": "same", 41 | "use_bias": False, 42 | # Note that we need to use the class constructor for the initializer to 43 | # get deterministic initialization. 44 | "kernel_initializer": tf.keras.initializers.HeNormal(seed=seed), 45 | } 46 | # Override defaults with the passed kwargs. 47 | default_kwargs.update(kwargs) 48 | return tf.keras.layers.Conv2D(filters, **default_kwargs) 49 | 50 | 51 | class VGG(tf.keras.Model): 52 | """Builds a VGG CNN without the FC layers at the end. 53 | 54 | We don't add the FC layers to stay in sync with the implementation in the 55 | "Firefly Neural Architecture Descent" paper. 56 | 57 | Attributes: 58 | depth: Use 11 for VGG11, 16 for VGG16, etc., 59 | width_multiplier: The number of filters in the first layer 60 | ("1" corresponds to 64 filters). 61 | num_classes: Number of output classes. 62 | normalization_type: NormalizationType, of the normalization used inside 63 | blocks. 64 | l2: L2 regularization coefficient. 65 | seed: random seed used for initialization. 66 | """ 67 | 68 | def __init__(self, 69 | depth, 70 | width_multiplier, 71 | num_classes, 72 | normalization_type, 73 | l2, 74 | seed = 42): 75 | super().__init__(name=F"VGG-{depth}-{width_multiplier}") 76 | l2_reg = tf.keras.regularizers.l2 77 | 78 | rng_seed = [seed, seed + 1] 79 | assert depth == 11, "Only supporting VGG11 right now" 80 | 81 | # VGG consists of blocks of convs separated by downsampling. 82 | # Within each block, each conv has of base_width * multiplier filters. 83 | # This dict maps VGG-xx to a list of blocks. 84 | architecture = { 85 | 11: [[1], [2], [4, 4], [8, 8], [8, 8]], 86 | 14: [[1, 1], [2, 2], [4, 4], [8, 8], [8, 8]], 87 | 16: [[1, 1], [2, 2], [4, 4, 4], [8, 8, 8], [8, 8, 8]], 88 | 19: [[1, 1], [2, 2], [4, 4, 4, 4], [8, 8, 8, 8], [8, 8, 8, 8]] 89 | } 90 | 91 | blocklist = architecture[depth] 92 | base_width = int(64 * width_multiplier) 93 | 94 | downsample = False 95 | self.layer_list = [] 96 | for block in blocklist: 97 | for multiplier in block: 98 | rng_seed, seed = tf.random.experimental.stateless_split(rng_seed) 99 | self.layer_list.append(glayers.GrowLayer(Conv2D( 100 | base_width*multiplier, strides=1 if not downsample else 2, 101 | seed=seed[0], 102 | kernel_regularizer=tf.keras.regularizers.l2(l2)))) 103 | downsample = False 104 | self.layer_list.append(tf.keras.layers.Activation( 105 | glayers.get_activation_fn("relu1")),) 106 | if normalization_type == NormalizationType.batchnorm: 107 | self.layer_list.append(glayers.GrowLayer( 108 | BatchNormalization( 109 | beta_regularizer=tf.keras.regularizers.l2(l2), 110 | gamma_regularizer=tf.keras.regularizers.l2(l2)))) 111 | elif normalization_type == NormalizationType.layernorm: 112 | self.layer_list.append(glayers.GrowLayer( 113 | LayerNormalization( 114 | beta_regularizer=tf.keras.regularizers.l2(l2), 115 | gamma_regularizer=tf.keras.regularizers.l2(l2)))) 116 | elif normalization_type == NormalizationType.none: 117 | pass 118 | else: 119 | raise ValueError 120 | downsample = True 121 | self.layer_list.append( 122 | glayers.GrowLayer( 123 | Conv2D(num_classes, strides=2, kernel_regularizer=l2_reg(l2)))) 124 | self.layer_list.append(tf.keras.layers.Flatten()) 125 | 126 | def call(self, x): 127 | for layer in self.layer_list: 128 | x = layer(x) 129 | return x 130 | 131 | def get_grow_layer_tuples(self): 132 | """Gets all groups of layers that need to grow together.""" 133 | 134 | grow_layers = [ 135 | i for i, l in enumerate(self.layer_list) 136 | if (isinstance(l, glayers.GrowLayer) and 137 | isinstance(l.layer, tf.keras.layers.Conv2D)) 138 | ] 139 | 140 | grow_layer_tuples = [] 141 | for i, j in zip(grow_layers[:-1], grow_layers[1:]): 142 | # Grow tuples should be in order. 143 | grow_layer_tuples.append(self.layer_list[i:(j+1)]) 144 | return grow_layer_tuples 145 | 146 | 147 | def create_model( 148 | depth = 1, 149 | width_multiplier = 1, 150 | num_classes = 10, 151 | l2_coef = 0.0, 152 | normalization_type = "batchnorm", 153 | **unused_kwargs): 154 | """Creates model.""" 155 | normalization_type = NormalizationType[normalization_type] 156 | model = VGG( 157 | depth=depth, 158 | width_multiplier=width_multiplier, 159 | num_classes=num_classes, 160 | normalization_type=normalization_type, 161 | l2=l2_coef) 162 | return model 163 | -------------------------------------------------------------------------------- /growneuron/imagenet/mb_v1.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 GradMax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """VGG Network.""" 17 | 18 | import functools 19 | from typing import Any, Dict 20 | from growneuron.cifar import wide_resnet 21 | import growneuron.layers as glayers 22 | import tensorflow as tf 23 | 24 | NormalizationType = wide_resnet.NormalizationType 25 | 26 | BatchNormalization = functools.partial( 27 | tf.keras.layers.BatchNormalization, 28 | epsilon=1e-5, # using epsilon and momentum defaults from Torch 29 | momentum=0.9) 30 | 31 | LayerNormalization = functools.partial( 32 | tf.keras.layers.LayerNormalization, 33 | epsilon=1e-5) # using epsilon and momentum defaults from Torch 34 | 35 | 36 | def check_grow_layer(layer): 37 | return (isinstance(layer, glayers.GrowLayer) and 38 | isinstance(layer.layer, 39 | (tf.keras.layers.Dense, tf.keras.layers.Conv2D)) and 40 | not isinstance(layer.layer, tf.keras.layers.DepthwiseConv2D)) 41 | 42 | 43 | def Conv2D(filters, seed=None, **kwargs): 44 | """Conv2D layer that is deterministically initialized.""" 45 | default_kwargs = { 46 | 'kernel_size': 1, 47 | 'padding': 'same', 48 | 'strides': 1, 49 | 'use_bias': False, 50 | # Note that we need to use the class constructor for the initializer to 51 | # get deterministic initialization. 52 | 'kernel_initializer': tf.keras.initializers.HeNormal(seed=seed), 53 | } 54 | # Override defaults with the passed kwargs. 55 | default_kwargs.update(kwargs) 56 | return tf.keras.layers.Conv2D(filters, **default_kwargs) 57 | 58 | 59 | def DepthwiseConv2D(seed=None, **kwargs): 60 | """DepthwiseConv2D layer that is deterministically initialized.""" 61 | default_kwargs = { 62 | 'kernel_size': 3, 63 | 'padding': 'same', 64 | 'strides': 1, 65 | 'use_bias': False, 66 | # Note that we need to use the class constructor for the initializer to 67 | # get deterministic initialization. 68 | 'kernel_initializer': tf.keras.initializers.HeNormal(seed=seed), 69 | } 70 | # Override defaults with the passed kwargs. 71 | default_kwargs.update(kwargs) 72 | return tf.keras.layers.DepthwiseConv2D(**default_kwargs) 73 | 74 | 75 | class MobilenetV1(tf.keras.Model): 76 | """Builds a MobileNet-v1. 77 | 78 | Attributes: 79 | width_multiplier: The number of filters in the first layer 80 | ("1" corresponds to 64 filters). 81 | num_classes: Number of output classes. 82 | normalization_type: NormalizationType, of the normalization used inside 83 | blocks. 84 | l2: L2 regularization coefficient. 85 | seed: random seed used for initialization. 86 | """ 87 | 88 | def __init__(self, 89 | width_multiplier, 90 | num_classes, 91 | normalization_type, 92 | l2, 93 | seed = 42): 94 | super().__init__(name=F'MBv1-{width_multiplier}') 95 | l2_reg = tf.keras.regularizers.l2 96 | 97 | rng_seed = [seed, seed + 1] 98 | rng_seed, seed = tf.random.experimental.stateless_split(rng_seed) 99 | self.layer_list = [ 100 | glayers.GrowLayer( 101 | Conv2D(32 * width_multiplier, 102 | strides=2, 103 | kernel_size=3, 104 | seed=seed[0], 105 | kernel_regularizer=l2_reg(l2))), 106 | glayers.GrowLayer(BatchNormalization()), 107 | tf.keras.layers.Activation(glayers.get_activation_fn('relu1')) 108 | ] 109 | 110 | # MobileNet consists of blocks of convs. 111 | # Within each block, each conv has of base_width * multiplier filters. 112 | blocklist = [[1], [2, 2], [4, 4], [8, 8, 8, 8, 8, 8], [16, 16]] 113 | base_width = int(64 * width_multiplier) 114 | downsample = False 115 | for i, block in enumerate(blocklist): 116 | for j, multiplier in enumerate(block): 117 | rng_seed, seed = tf.random.experimental.stateless_split(rng_seed) 118 | self.layer_list.append(glayers.GrowLayer(DepthwiseConv2D( 119 | seed=seed[0], 120 | kernel_regularizer=tf.keras.regularizers.l2(l2)))) 121 | self.layer_list.append(tf.keras.layers.Activation( 122 | glayers.get_activation_fn('relu1'))) 123 | if normalization_type == NormalizationType.batchnorm: 124 | self.layer_list.append(glayers.GrowLayer(BatchNormalization())) 125 | elif normalization_type == NormalizationType.layernorm: 126 | self.layer_list.append(glayers.GrowLayer(LayerNormalization())) 127 | elif normalization_type == NormalizationType.none: 128 | pass 129 | else: 130 | raise ValueError 131 | rng_seed, seed = tf.random.experimental.stateless_split(rng_seed) 132 | # We are doing strides at conv not at dw, as it's better for 133 | # decomposition. 134 | n_channels = base_width * multiplier 135 | if (i+1) == len(blocklist) and (j+1) == len(block): 136 | # We don't scale the last layer since we are not growing it. 137 | n_channels = 64 * multiplier 138 | self.layer_list.append(glayers.GrowLayer(Conv2D( 139 | n_channels, 140 | seed=seed[0], 141 | strides=1 if not downsample else 2, 142 | kernel_regularizer=tf.keras.regularizers.l2(l2)))) 143 | downsample = False 144 | self.layer_list.append(tf.keras.layers.Activation( 145 | glayers.get_activation_fn('relu1')),) 146 | self.layer_list.append(glayers.GrowLayer(BatchNormalization())) 147 | 148 | downsample = True 149 | # TODO make global pooling+dense a conv-layer, so that we can grow. 150 | self.layer_list.append( 151 | tf.keras.layers.GlobalAveragePooling2D()) 152 | rng_seed, seed = tf.random.experimental.stateless_split(rng_seed) 153 | self.layer_list.append( 154 | tf.keras.layers.Dense( 155 | num_classes, 156 | kernel_initializer=tf.keras.initializers.HeNormal( 157 | seed=seed[0]), 158 | kernel_regularizer=l2_reg(l2) 159 | ) 160 | ) 161 | 162 | def call(self, x): 163 | for layer in self.layer_list: 164 | x = layer(x) 165 | return x 166 | 167 | def get_grow_layer_tuples(self): 168 | """Gets all groups of layers that need to grow together.""" 169 | grow_layers = [i for i, l in enumerate(self.layer_list) 170 | if check_grow_layer(l)] 171 | 172 | grow_layer_tuples = [] 173 | for i, j in zip(grow_layers[:-1], grow_layers[1:]): 174 | # Grow tuples should be in order. 175 | grow_layer_tuples.append(self.layer_list[i:(j+1)]) 176 | return grow_layer_tuples 177 | 178 | 179 | def create_model( 180 | width_multiplier = 1, 181 | num_classes = 1000, 182 | l2_coef = 0.0, 183 | normalization_type = 'batchnorm', 184 | **unused_kwargs): 185 | """Creates model.""" 186 | normalization_type = NormalizationType[normalization_type] 187 | model = MobilenetV1( 188 | width_multiplier=width_multiplier, 189 | num_classes=num_classes, 190 | normalization_type=normalization_type, 191 | l2=l2_coef) 192 | return model 193 | -------------------------------------------------------------------------------- /growneuron/cifar/wide_resnet.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 GradMax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Wide Residual Network.""" 17 | 18 | import enum 19 | import functools 20 | from typing import Any, Dict 21 | import growneuron.layers as glayers 22 | import tensorflow as tf 23 | 24 | BatchNormalization = functools.partial( 25 | tf.keras.layers.BatchNormalization, 26 | epsilon=1e-5, # using epsilon and momentum defaults from Torch 27 | momentum=0.9) 28 | 29 | LayerNormalization = functools.partial( 30 | tf.keras.layers.LayerNormalization, 31 | epsilon=1e-5) # using epsilon and momentum defaults from Torch 32 | 33 | 34 | @enum.unique 35 | class NormalizationType(enum.Enum): 36 | """Direction along the z-axis.""" 37 | layernorm = 'layernorm' 38 | batchnorm = ' batchnorm' 39 | none = 'none' 40 | 41 | 42 | def Conv2D(filters, seed=None, **kwargs): 43 | """Conv2D layer that is deterministically initialized.""" 44 | default_kwargs = { 45 | 'kernel_size': 3, 46 | 'padding': 'same', 47 | 'use_bias': False, 48 | # Note that we need to use the class constructor for the initializer to 49 | # get deterministic initialization. 50 | 'kernel_initializer': tf.keras.initializers.HeNormal(seed=seed), 51 | } 52 | # Override defaults with the passed kwargs. 53 | default_kwargs.update(kwargs) 54 | return tf.keras.layers.Conv2D(filters, **default_kwargs) 55 | 56 | 57 | def basic_block( 58 | filters, 59 | block_width, 60 | normalization_type, 61 | strides, 62 | l2, 63 | seed): 64 | """Basic residual block of two 3x3 convs. 65 | 66 | Args: 67 | filters: Number of filters for Conv2D. 68 | block_width: Multiplies the first filter. 69 | normalization_type: NormalizationType 70 | strides: Stride dimensions for Conv2D. 71 | l2: L2 regularization coefficient. 72 | seed: random seed used for initialization. 73 | 74 | Returns: 75 | block_layers: list of sequential layers for the main branch. 76 | skip_layer: tf.keras.Conv2D or None. 77 | """ 78 | seeds = tf.random.experimental.stateless_split([seed, seed + 1], 3)[:, 0] 79 | 80 | block_layers = [ 81 | BatchNormalization(beta_regularizer=tf.keras.regularizers.l2(l2), 82 | gamma_regularizer=tf.keras.regularizers.l2(l2)), 83 | tf.keras.layers.Activation('relu'), 84 | glayers.GrowLayer( 85 | Conv2D(int(filters*block_width), strides=strides, seed=seeds[0], 86 | kernel_regularizer=tf.keras.regularizers.l2(l2))) 87 | ] 88 | # Maybe add normalization in between the layers. 89 | if normalization_type == NormalizationType.batchnorm: 90 | block_layers.append(glayers.GrowLayer( 91 | BatchNormalization(beta_regularizer=tf.keras.regularizers.l2(l2), 92 | gamma_regularizer=tf.keras.regularizers.l2(l2)))) 93 | elif normalization_type == NormalizationType.layernorm: 94 | block_layers.append(glayers.GrowLayer( 95 | LayerNormalization(beta_regularizer=tf.keras.regularizers.l2(l2), 96 | gamma_regularizer=tf.keras.regularizers.l2(l2)))) 97 | elif normalization_type == NormalizationType.none: 98 | pass 99 | else: 100 | raise ValueError 101 | 102 | block_layers += [ 103 | # This is to ensure gradient is 1 at 0 for relu. 104 | tf.keras.layers.Activation(glayers.get_activation_fn('relu1')), 105 | glayers.GrowLayer( 106 | Conv2D(filters, strides=1, seed=seeds[1], 107 | kernel_regularizer=tf.keras.regularizers.l2(l2))) 108 | ] 109 | 110 | if strides > 1: 111 | skip_layer = Conv2D(filters, kernel_size=1, strides=strides, seed=seeds[2], 112 | kernel_regularizer=tf.keras.regularizers.l2(l2)) 113 | else: 114 | skip_layer = None 115 | return (block_layers, skip_layer) 116 | 117 | 118 | class WideResnet(tf.keras.Model): 119 | """Builds Wide ResNet. 120 | 121 | Following Zagoruyko and Komodakis (2016), it accepts a width multiplier on the 122 | number of filters. Using three groups of residual blocks, the network maps 123 | spatial features of size 32x32 -> 16x16 -> 8x8. 124 | 125 | Attributes: 126 | depth: Total number of convolutional layers. "n" in WRN-n-k. It differs from 127 | He et al. (2015)'s notation which uses the maximum depth of the network 128 | counting non-conv layers like dense. 129 | width_multiplier: Integer to multiply the number of typical filters by. "k" 130 | in WRN-n-k. 131 | block_width_multiplier: Multiplies the filters in the first conv for each 132 | block. 133 | normalization_type: NormalizationType, of the normalization used inside 134 | blocks. 135 | num_classes: Number of output classes. 136 | l2: L2 regularization coefficient. 137 | seed: random seed used for initialization. 138 | 139 | """ 140 | 141 | def __init__( 142 | self, 143 | depth, 144 | width_multiplier, 145 | block_width_multiplier, 146 | normalization_type, 147 | num_classes, 148 | l2, 149 | seed = 42 150 | ): 151 | super().__init__(name='wide_resnet-{}-{}'.format(depth, width_multiplier)) 152 | l2_reg = tf.keras.regularizers.l2 153 | 154 | seeds = tf.random.experimental.stateless_split([seed, seed + 1], 5)[:, 0] 155 | if (depth - 4) % 6 != 0: 156 | raise ValueError('depth should be 6n+4 (e.g., 16, 22, 28, 40).') 157 | num_blocks = (depth - 4) // 6 158 | 159 | self.conv_stem = Conv2D(16, 160 | strides=1, 161 | seed=seeds[0], 162 | kernel_regularizer=l2_reg(l2)) 163 | self.group_seq = [] 164 | for i, (filters, strides, seed) in enumerate( 165 | zip([16, 32, 64], [1, 2, 2], seeds[1:4])): 166 | block_seq = [] 167 | group_seeds = tf.random.experimental.stateless_split( 168 | [seed, seed + 1], num_blocks)[:, 0] 169 | for j, group_seed in enumerate(group_seeds): 170 | block_strides = strides if j == 0 else 1 171 | block_seq.append( 172 | basic_block(filters=filters*width_multiplier, 173 | block_width=block_width_multiplier, 174 | normalization_type=normalization_type, 175 | strides=block_strides, l2=l2, seed=group_seed) 176 | ) 177 | self.group_seq.append(block_seq) 178 | 179 | self.final_layers = [ 180 | BatchNormalization(beta_regularizer=l2_reg(l2), 181 | gamma_regularizer=l2_reg(l2)), 182 | tf.keras.layers.Activation('relu'), 183 | tf.keras.layers.AveragePooling2D(pool_size=8), 184 | tf.keras.layers.Flatten(), 185 | tf.keras.layers.Dense( 186 | num_classes, 187 | kernel_initializer=tf.keras.initializers.HeNormal(seed=seeds[4]), 188 | kernel_regularizer=l2_reg(l2), 189 | bias_regularizer=l2_reg(l2)) 190 | ] 191 | 192 | def call(self, inputs): 193 | x = self.conv_stem(inputs) 194 | for block_seq in self.group_seq: 195 | for block_layers, skip_layer in block_seq: 196 | y = x 197 | # Main branch. 198 | for layer in block_layers: 199 | y = layer(y) 200 | # Skip branch 201 | if skip_layer: 202 | x = skip_layer(x) 203 | x = x + y 204 | for layer in self.final_layers: 205 | x = layer(x) 206 | return x 207 | 208 | 209 | def create_model( 210 | depth = 22, 211 | width_multiplier = 1, 212 | block_width_multiplier = 1., 213 | normalization_type = 'batchnorm', 214 | num_classes = 10, 215 | l2_coef = 0.0, 216 | **unused_kwargs): 217 | """Creates model.""" 218 | normalization_type = NormalizationType[normalization_type] 219 | model = WideResnet(depth=depth, 220 | width_multiplier=width_multiplier, 221 | block_width_multiplier=block_width_multiplier, 222 | num_classes=num_classes, 223 | normalization_type=normalization_type, 224 | l2=l2_coef) 225 | return model 226 | -------------------------------------------------------------------------------- /growneuron/layers_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 GradMax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for growneuron.layers.""" 17 | import absl.testing.parameterized as parameterized 18 | import growneuron.layers as glayers 19 | import tensorflow as tf 20 | 21 | 22 | class LayerTest(parameterized.TestCase, tf.test.TestCase): 23 | 24 | @parameterized.named_parameters( 25 | ('dense', tf.keras.layers.Dense(3), (3, 4)), 26 | ('batchnorm', tf.keras.layers.BatchNormalization(), (2, 4)), 27 | ('conv2d', tf.keras.layers.Conv2D(3, 3), (3, 5, 5, 4)) 28 | ) 29 | def test_consistency(self, layer, input_shape): 30 | wrapped_layer = glayers.GrowLayer(layer) 31 | x = tf.random.uniform(input_shape) 32 | original_out = layer(x) 33 | new_out = wrapped_layer(x) 34 | self.assertAllEqual(original_out, new_out) 35 | 36 | @parameterized.named_parameters( 37 | ('dense', tf.keras.layers.Dense(3), (3, 4), 1), 38 | ('dense_5neuron', tf.keras.layers.Dense(3), (3, 4), 5), 39 | ('conv2d', tf.keras.layers.Conv2D(3, 3), (3, 5, 5, 4), 1), 40 | ('conv2d_5neuron', tf.keras.layers.Conv2D(3, 3), (3, 5, 5, 4), 5), 41 | ) 42 | def test_add_neurons_incoming_zeros(self, layer, input_shape, n_new): 43 | wrapped_layer = glayers.GrowLayer(layer) 44 | x = tf.random.uniform(input_shape) 45 | original_out = wrapped_layer(x) 46 | old_output_shape = original_out.get_shape() 47 | n_neurons_old = old_output_shape[-1] 48 | wrapped_layer.add_neurons(n_new, new_weights='zeros', is_outgoing=False) 49 | new_out = wrapped_layer(x) 50 | # Check the output has the expected shape 51 | new_shape = old_output_shape[:-1] + [n_neurons_old+n_new] 52 | self.assertAllEqual(new_shape, new_out.get_shape()) 53 | # Check the old neurons create same output 54 | self.assertAllClose(original_out, new_out[Ellipsis, :n_neurons_old]) 55 | # Check the new neurons create zero output 56 | self.assertEqual(0, tf.math.count_nonzero(new_out[Ellipsis, n_neurons_old:])) 57 | new_weights, new_biases = wrapped_layer.get_weights() 58 | # Check the new weights are zero 59 | added_weights = new_weights[Ellipsis, n_neurons_old:] 60 | self.assertAllEqual(added_weights, tf.zeros_like(added_weights)) 61 | # Check the new biases are zero 62 | added_biases = new_biases[n_neurons_old:] 63 | self.assertAllEqual(added_biases, tf.zeros_like(added_biases)) 64 | 65 | @parameterized.named_parameters( 66 | ('dense', tf.keras.layers.Dense(3), (3, 4), 1), 67 | ('dense_5neuron', tf.keras.layers.Dense(3), (3, 4), 5), 68 | ('conv2d', tf.keras.layers.Conv2D(3, 3), (3, 5, 5, 4), 1), 69 | ('conv2d_5neuron', tf.keras.layers.Conv2D(3, 3), (3, 5, 5, 4), 5), 70 | ) 71 | def test_add_neurons_outgoing_zeros(self, layer, input_shape, n_new): 72 | wrapped_layer = glayers.GrowLayer(layer) 73 | n_features = input_shape[-1] 74 | x = tf.random.uniform(input_shape) 75 | # New input after growing would have more features 76 | new_input_shape = input_shape[:-1] + (n_new,) 77 | new_x = tf.concat([x, tf.random.uniform(new_input_shape)], axis=-1) 78 | original_out = layer(x) 79 | old_weights, old_biases = wrapped_layer.get_weights() 80 | wrapped_layer.add_neurons(n_new, new_weights='zeros', is_outgoing=True) 81 | new_out = wrapped_layer(new_x) 82 | new_weights, new_biases = wrapped_layer.get_weights() 83 | print(new_weights, new_biases) 84 | # Output of the layer shouldn't change. 85 | self.assertAllClose(original_out, new_out) 86 | # Check biases are unchanged 87 | self.assertAllEqual(old_biases, new_biases) 88 | # Check the new weights are zero 89 | added_weights = new_weights[Ellipsis, n_features:, :] 90 | self.assertAllEqual(added_weights, tf.zeros_like(added_weights)) 91 | # Check the old weights are same 92 | kept_weights = new_weights[Ellipsis, :n_features, :] 93 | self.assertAllEqual(old_weights, kept_weights) 94 | 95 | @parameterized.named_parameters( 96 | ('dense_kernel', 'dense', ('kernel',)), 97 | ('dense_bias', 'dense', ('bias',)), 98 | ('dense_activity', 'dense', ('activity',)), 99 | ('dense_all', 'dense', ('kernel', 'bias', 'activity')), 100 | ('conv2d_kernel', 'conv2d', ('kernel',)), 101 | ('conv2d_bias', 'conv2d', ('bias',)), 102 | ('conv2d_activity', 'conv2d', ('activity',)), 103 | ('conv2d_all', 'conv2d', ('kernel', 'bias', 'activity')), 104 | ) 105 | def test_regularizer_incoming(self, layer_type, regularizer_types): 106 | reg_kwargs = {f'{r_type}_regularizer': tf.keras.regularizers.L2(0.1) 107 | for r_type in regularizer_types} 108 | print(reg_kwargs) 109 | if layer_type == 'dense': 110 | layer = tf.keras.layers.Dense(3, **reg_kwargs) 111 | input_shape = (3, 4) 112 | elif layer_type == 'conv2d': 113 | layer = tf.keras.layers.Conv2D(3, 3, **reg_kwargs) 114 | input_shape = (3, 5, 5, 4) 115 | else: 116 | raise ValueError('not supported') 117 | wrapped_layer = glayers.GrowLayer(layer) 118 | x = tf.random.uniform(input_shape) 119 | _ = wrapped_layer(x) 120 | old_losses = wrapped_layer.losses 121 | wrapped_layer.add_neurons(1, new_weights='zeros', is_outgoing=False) 122 | _ = wrapped_layer(x) 123 | new_losses = wrapped_layer.losses 124 | for old_loss, new_loss in zip(old_losses, new_losses): 125 | self.assertAllClose(old_loss, new_loss) 126 | 127 | @parameterized.named_parameters( 128 | ('dense_kernel', 'dense', ('kernel',)), 129 | ('dense_bias', 'dense', ('bias',)), 130 | ('dense_activity', 'dense', ('activity',)), 131 | ('dense_all', 'dense', ('kernel', 'bias', 'activity')), 132 | ('conv2d_kernel', 'conv2d', ('kernel',)), 133 | ('conv2d_bias', 'conv2d', ('bias',)), 134 | ('conv2d_activity', 'conv2d', ('activity',)), 135 | ('conv2d_all', 'conv2d', ('kernel', 'bias', 'activity')), 136 | ('bn_beta', 'bn', ('beta',)), 137 | ) 138 | def test_regularizer_outgoing(self, layer_type, regularizer_types): 139 | reg_kwargs = {f'{r_type}_regularizer': tf.keras.regularizers.L2(0.1) 140 | for r_type in regularizer_types} 141 | print(reg_kwargs) 142 | if layer_type == 'dense': 143 | layer = tf.keras.layers.Dense(3, **reg_kwargs) 144 | input_shape = (3, 4) 145 | elif layer_type == 'conv2d': 146 | layer = tf.keras.layers.Conv2D(3, 3, **reg_kwargs) 147 | input_shape = (3, 5, 5, 4) 148 | elif layer_type == 'bn': 149 | layer = tf.keras.layers.BatchNormalization(**reg_kwargs) 150 | input_shape = (3, 4) 151 | else: 152 | raise ValueError('not supported') 153 | wrapped_layer = glayers.GrowLayer(layer) 154 | x = tf.random.uniform(input_shape) 155 | _ = wrapped_layer(x) 156 | old_losses = wrapped_layer.losses 157 | if layer_type == 'bn': 158 | wrapped_layer.add_neurons_identity(1) 159 | else: 160 | wrapped_layer.add_neurons(1, new_weights='zeros', is_outgoing=True) 161 | new_input_shape = input_shape[:-1] + (1,) 162 | new_x = tf.concat([x, tf.random.uniform(new_input_shape)], axis=-1) 163 | _ = wrapped_layer(new_x) 164 | new_losses = wrapped_layer.losses 165 | for old_loss, new_loss in zip(old_losses, new_losses): 166 | self.assertAllClose(old_loss, new_loss) 167 | 168 | @parameterized.named_parameters( 169 | ('2d_axis1', (4, 5), -1), 170 | ('3d_axis1', (3, 3, 1), -1), 171 | ('4d_axis1', (3, 3, 4, 5), -1), 172 | ('2d_axis2', (4, 5), -2), 173 | ('3d_axis2', (3, 3, 1), -2), 174 | ('4d_axis2', (3, 3, 4, 5), -2), 175 | ) 176 | def test_norm_l2(self, shape, axis): 177 | tensor = tf.reshape(tf.range(tf.math.reduce_prod(shape), 178 | dtype=tf.float32), shape) 179 | calculated_norm = glayers.norm_l2(tensor, axis) 180 | if axis == -2: 181 | tensor = tf.einsum('...ij->...ji', tensor) 182 | # L2 norm should be 1 over axis 1 183 | flat_tensor = tf.reshape(tensor, 184 | [-1, tensor.shape[-1]]) 185 | expected_norms = tf.norm(flat_tensor, axis=-2) 186 | self.assertAllClose(expected_norms, calculated_norm) 187 | pass 188 | 189 | @parameterized.named_parameters( 190 | ('2d_axis1', (4, 5), -1), 191 | ('3d_axis1', (3, 3, 1), -1), 192 | ('4d_axis1', (3, 3, 4, 5), -1), 193 | ('2d_axis2', (4, 5), -2), 194 | ('3d_axis2', (3, 3, 1), -2), 195 | ('4d_axis2', (3, 3, 4, 5), -2), 196 | ) 197 | def test_normalize_l2(self, shape, axis): 198 | tensor = tf.reshape(tf.range(tf.math.reduce_prod(shape), 199 | dtype=tf.float32), shape) 200 | normalized_tensor = glayers.normalize_l2(tensor, axis) 201 | if axis == -2: 202 | normalized_tensor = tf.einsum('...ij->...ji', normalized_tensor) 203 | # L2 norm should be 1 over axis 1 204 | flat_tensor = tf.reshape(normalized_tensor, 205 | [-1, normalized_tensor.shape[-1]]) 206 | norms = tf.norm(flat_tensor, axis=-2) 207 | self.assertAllClose(norms, tf.ones_like(norms)) 208 | 209 | 210 | if __name__ == '__main__': 211 | tf.test.main() 212 | -------------------------------------------------------------------------------- /growneuron/imagenet/data_util.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 GradMax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Data preprocessing and augmentation. 17 | 18 | Forked from simclr/tf2 codebase. 19 | """ 20 | import tensorflow.compat.v2 as tf 21 | 22 | CROP_PROPORTION = 0.875 # Standard for ImageNet. 23 | 24 | 25 | def random_apply(func, p, x): 26 | """Randomly apply function func to x with probability p.""" 27 | return tf.cond( 28 | tf.less( 29 | tf.random.uniform([], minval=0, maxval=1, dtype=tf.float32), 30 | tf.cast(p, tf.float32)), lambda: func(x), lambda: x) 31 | 32 | 33 | def _compute_crop_shape( 34 | image_height, image_width, aspect_ratio, crop_proportion): 35 | """Compute aspect ratio-preserving shape for central crop. 36 | 37 | The resulting shape retains `crop_proportion` along one side and a proportion 38 | less than or equal to `crop_proportion` along the other side. 39 | 40 | Args: 41 | image_height: Height of image to be cropped. 42 | image_width: Width of image to be cropped. 43 | aspect_ratio: Desired aspect ratio (width / height) of output. 44 | crop_proportion: Proportion of image to retain along the less-cropped side. 45 | 46 | Returns: 47 | crop_height: Height of image after cropping. 48 | crop_width: Width of image after cropping. 49 | """ 50 | image_width_float = tf.cast(image_width, tf.float32) 51 | image_height_float = tf.cast(image_height, tf.float32) 52 | 53 | def _requested_aspect_ratio_wider_than_image(): 54 | crop_height = tf.cast( 55 | tf.math.rint(crop_proportion / aspect_ratio * image_width_float), 56 | tf.int32) 57 | crop_width = tf.cast( 58 | tf.math.rint(crop_proportion * image_width_float), tf.int32) 59 | return crop_height, crop_width 60 | 61 | def _image_wider_than_requested_aspect_ratio(): 62 | crop_height = tf.cast( 63 | tf.math.rint(crop_proportion * image_height_float), tf.int32) 64 | crop_width = tf.cast( 65 | tf.math.rint(crop_proportion * aspect_ratio * image_height_float), 66 | tf.int32) 67 | return crop_height, crop_width 68 | 69 | return tf.cond( 70 | aspect_ratio > image_width_float / image_height_float, 71 | _requested_aspect_ratio_wider_than_image, 72 | _image_wider_than_requested_aspect_ratio) 73 | 74 | 75 | def center_crop(image, height, width, crop_proportion): 76 | """Crops to center of image and rescales to desired size. 77 | 78 | Args: 79 | image: Image Tensor to crop. 80 | height: Height of image to be cropped. 81 | width: Width of image to be cropped. 82 | crop_proportion: Proportion of image to retain along the less-cropped side. 83 | 84 | Returns: 85 | A `height` x `width` x channels Tensor holding a central crop of `image`. 86 | """ 87 | shape = tf.shape(image) 88 | image_height = shape[0] 89 | image_width = shape[1] 90 | crop_height, crop_width = _compute_crop_shape( 91 | image_height, image_width, height / width, crop_proportion) 92 | offset_height = ((image_height - crop_height) + 1) // 2 93 | offset_width = ((image_width - crop_width) + 1) // 2 94 | image = tf.image.crop_to_bounding_box( 95 | image, offset_height, offset_width, crop_height, crop_width) 96 | 97 | image = tf.image.resize([image], [height, width], 98 | method=tf.image.ResizeMethod.BICUBIC)[0] 99 | 100 | return image 101 | 102 | 103 | def distorted_bounding_box_crop(image, 104 | bbox, 105 | min_object_covered=0.1, 106 | aspect_ratio_range=(0.75, 1.33), 107 | area_range=(0.05, 1.0), 108 | max_attempts=100, 109 | scope=None): 110 | """Generates cropped_image using one of the bboxes randomly distorted. 111 | 112 | See `tf.image.sample_distorted_bounding_box` for more documentation. 113 | 114 | Args: 115 | image: `Tensor` of image data. 116 | bbox: `Tensor` of bounding boxes arranged `[1, num_boxes, coords]` 117 | where each coordinate is [0, 1) and the coordinates are arranged 118 | as `[ymin, xmin, ymax, xmax]`. If num_boxes is 0 then use the whole 119 | image. 120 | min_object_covered: An optional `float`. Defaults to `0.1`. The cropped 121 | area of the image must contain at least this fraction of any bounding 122 | box supplied. 123 | aspect_ratio_range: An optional list of `float`s. The cropped area of the 124 | image must have an aspect ratio = width / height within this range. 125 | area_range: An optional list of `float`s. The cropped area of the image 126 | must contain a fraction of the supplied image within in this range. 127 | max_attempts: An optional `int`. Number of attempts at generating a cropped 128 | region of the image of the specified constraints. After `max_attempts` 129 | failures, return the entire image. 130 | scope: Optional `str` for name scope. 131 | Returns: 132 | (cropped image `Tensor`, distorted bbox `Tensor`). 133 | """ 134 | with tf.name_scope(scope or 'distorted_bounding_box_crop'): 135 | shape = tf.shape(image) 136 | sample_distorted_bounding_box = tf.image.sample_distorted_bounding_box( 137 | shape, 138 | bounding_boxes=bbox, 139 | min_object_covered=min_object_covered, 140 | aspect_ratio_range=aspect_ratio_range, 141 | area_range=area_range, 142 | max_attempts=max_attempts, 143 | use_image_if_no_bounding_boxes=True) 144 | bbox_begin, bbox_size, _ = sample_distorted_bounding_box 145 | 146 | # Crop the image to the specified bounding box. 147 | offset_y, offset_x, _ = tf.unstack(bbox_begin) 148 | target_height, target_width, _ = tf.unstack(bbox_size) 149 | image = tf.image.crop_to_bounding_box( 150 | image, offset_y, offset_x, target_height, target_width) 151 | 152 | return image 153 | 154 | 155 | def crop_and_resize(image, height, width): 156 | """Make a random crop and resize it to height `height` and width `width`. 157 | 158 | Args: 159 | image: Tensor representing the image. 160 | height: Desired image height. 161 | width: Desired image width. 162 | 163 | Returns: 164 | A `height` x `width` x channels Tensor holding a random crop of `image`. 165 | """ 166 | bbox = tf.constant([0.0, 0.0, 1.0, 1.0], dtype=tf.float32, shape=[1, 1, 4]) 167 | aspect_ratio = width / height 168 | image = distorted_bounding_box_crop( 169 | image, 170 | bbox, 171 | min_object_covered=0.1, 172 | aspect_ratio_range=(3. / 4 * aspect_ratio, 4. / 3. * aspect_ratio), 173 | area_range=(0.08, 1.0), 174 | max_attempts=100, 175 | scope=None) 176 | return tf.image.resize([image], [height, width], 177 | method=tf.image.ResizeMethod.BICUBIC)[0] 178 | 179 | 180 | def random_crop_with_resize(image, height, width, p=1.0): 181 | """Randomly crop and resize an image. 182 | 183 | Args: 184 | image: `Tensor` representing an image of arbitrary size. 185 | height: Height of output image. 186 | width: Width of output image. 187 | p: Probability of applying this transformation. 188 | 189 | Returns: 190 | A preprocessed image `Tensor`. 191 | """ 192 | def _transform(image): 193 | image = crop_and_resize(image, height, width) 194 | return image 195 | return random_apply(_transform, p=p, x=image) 196 | 197 | 198 | def preprocess_for_train(image, 199 | height, 200 | width, 201 | crop=True, 202 | flip=True): 203 | """Preprocesses the given image for training. 204 | 205 | Args: 206 | image: `Tensor` representing an image of arbitrary size. 207 | height: Height of output image. 208 | width: Width of output image. 209 | crop: Whether to crop the image. 210 | flip: Whether or not to flip left and right of an image. 211 | 212 | Returns: 213 | A preprocessed image `Tensor`. 214 | """ 215 | if crop: 216 | image = random_crop_with_resize(image, height, width) 217 | if flip: 218 | image = tf.image.random_flip_left_right(image) 219 | image = tf.reshape(image, [height, width, 3]) 220 | image = tf.clip_by_value(image, 0., 1.) 221 | return image 222 | 223 | 224 | def preprocess_for_eval(image, height, width, crop=True): 225 | """Preprocesses the given image for evaluation. 226 | 227 | Args: 228 | image: `Tensor` representing an image of arbitrary size. 229 | height: Height of output image. 230 | width: Width of output image. 231 | crop: Whether or not to (center) crop the test images. 232 | 233 | Returns: 234 | A preprocessed image `Tensor`. 235 | """ 236 | if crop: 237 | image = center_crop(image, height, width, crop_proportion=CROP_PROPORTION) 238 | image = tf.reshape(image, [height, width, 3]) 239 | image = tf.clip_by_value(image, 0., 1.) 240 | return image 241 | 242 | 243 | def preprocess_image(image, image_size, is_training=False, test_crop=True): 244 | """Preprocesses the given image. 245 | 246 | Args: 247 | image: `Tensor` representing an image of arbitrary size. 248 | image_size: Size of output image. 249 | is_training: `bool` for whether the preprocessing is for training. 250 | test_crop: whether or not to extract a central crop of the images 251 | (as for standard ImageNet evaluation) during the evaluation. 252 | 253 | Returns: 254 | A preprocessed image `Tensor` of range [0, 1]. 255 | """ 256 | image = tf.image.convert_image_dtype(image, dtype=tf.float32) 257 | if is_training: 258 | return preprocess_for_train(image, image_size, image_size) 259 | else: 260 | return preprocess_for_eval(image, image_size, image_size, crop=test_crop) 261 | -------------------------------------------------------------------------------- /growneuron/updaters.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 GradMax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Implements controllers for updating networks. 17 | """ 18 | import itertools 19 | from growneuron import growers 20 | from growneuron import layers 21 | import tensorflow as tf 22 | 23 | 24 | def pad_zeros_to(tensor, new_shape): 25 | """Pads a tensor with zeros such that final shape is new_shape. 26 | 27 | It expects the new_shape to be larger than the tensor.shape. 28 | Zeros are added to the end of each dimension. 29 | Args: 30 | tensor: 1d, 2d, 3d tensor. 31 | new_shape: list of dimensions where len(new_shape) == len(tensor.shape) 32 | Returns: 33 | new tensor of shape `new_shape`. 34 | """ 35 | old_shape = tensor.shape 36 | 37 | if len(old_shape) == 1: 38 | # Batchnorm or bias. 39 | diff_shape = [new_shape[-1] - old_shape[-1]] 40 | concat_axis = -1 41 | else: 42 | if old_shape[-2] == new_shape[-2]: 43 | # Input features are same, padding at axis=-1. 44 | concat_axis = -1 45 | else: 46 | concat_axis = -2 47 | diff_shape = list(old_shape) 48 | diff_shape[concat_axis] = new_shape[concat_axis] - old_shape[concat_axis] 49 | return tf.concat([tensor, tf.zeros(diff_shape)], axis=concat_axis) 50 | 51 | 52 | class Updater(): 53 | """Implements common methods. 54 | 55 | Updaters should be created under strategy scope or strategy should be passed 56 | directly. 57 | Attr: 58 | network_grower: growers.LayerGrower 59 | grow_layer_tuples: list of lists, candidates to be 60 | grown together with their outgoing weights. 61 | loss_fn: fn, Used to calculate loss. This function should get inputs 62 | as input and return loss. 63 | compile_fn: fn, Called to compile the model. 64 | update_frequency: int, Number of iterations before neurons are added. 65 | n_grow: int, number of neurons to grow at each growth step. 66 | n_grow_fraction: float, must be positive. Used together with initial width 67 | of candidate layers to decide n_neurons to grow at each growth step for 68 | each candidate separately. This approach is helpful when predicting the 69 | final architecture from the start as number of neurons added are fixed at 70 | the beginning for each layer. 71 | start_iteration: int, to start growing 72 | n_growth_steps: int, number of times the network is grown. 73 | scale: int, passed to the grower.grow_neurons 74 | carry_optimizer: bool, If true the running averages are carried to the new 75 | optimizer after the growth. Since variables are recreated after growth 76 | this is necessary. 77 | """ 78 | 79 | def __init__(self, network_grower, grow_layer_tuples, loss_fn=lambda x: x, 80 | compile_fn=lambda: None, update_frequency=1, n_grow=1, 81 | n_grow_fraction=None, start_iteration=None, n_growth_steps=None, 82 | scale=1., carry_optimizer=True): 83 | assert update_frequency > 0 84 | assert n_grow > 0 85 | self._update_frequency = update_frequency 86 | self._carry_optimizer = carry_optimizer 87 | self._n_grow = n_grow 88 | self._n_grow_fraction = n_grow_fraction 89 | self._scale = scale 90 | if start_iteration is None: 91 | start_iteration = update_frequency 92 | self._start_iteration = start_iteration 93 | self.loss_fn = loss_fn 94 | self.compile_fn = compile_fn 95 | self.strategy = tf.distribute.get_strategy() 96 | self.network_grower = self._prepare_grower(network_grower) 97 | self._n_growth_steps = n_growth_steps 98 | self._growth_counter = 0 99 | self._set_grow_layer_tuples(grow_layer_tuples) 100 | 101 | def _prepare_grower(self, grower): 102 | if grower: 103 | grower.loss_fn = self.loss_fn 104 | grower.compile_fn = self.compile_fn 105 | grower.strategy = self.strategy 106 | return grower 107 | 108 | def copy_optimizer_slots(self, optimizer, old_variables, new_variables): 109 | """Copy old slots and pad with zeros for new neurons.""" 110 | for old_var, new_var in zip(old_variables, new_variables): 111 | for s_name in sorted(optimizer.get_slot_names()): 112 | old_slot_var = optimizer.get_slot(old_var, s_name) 113 | new_slot_var = optimizer.get_slot(new_var, s_name) 114 | # This is used to retrieve the part of the new slot used for the 115 | # old variables. This assumes new variables are appended to the end. 116 | new_slot_values = pad_zeros_to(old_slot_var, new_slot_var.shape) 117 | new_slot_var.assign(new_slot_values) 118 | 119 | def delete_optimizer_slots(self, optimizer, variables): 120 | """Deleted old variable slots from the optimizer.""" 121 | for old_var in variables: 122 | key = (old_var._shared_name if old_var._in_graph_mode 123 | else old_var._unique_id) 124 | optimizer._slots.pop(key, None) 125 | 126 | def _set_grow_layer_tuples(self, grow_layer_tuples): 127 | """Sets the tuple of layers for growing.""" 128 | if not grow_layer_tuples: 129 | raise ValueError("grow_layer_tuples argument can't be empty.") 130 | self.grow_layer_tuples = grow_layer_tuples 131 | 132 | def get_n_neuron(n_neuron_initial): 133 | if self._n_grow_fraction: 134 | return int(max(1, n_neuron_initial * self._n_grow_fraction)) 135 | else: 136 | return self._n_grow 137 | # Used to calculate n_grow per layer using grow_fraction. 138 | # n_neurons are decided using the initial architecture. 139 | self._n_grow_dict = { 140 | tpl[0].name: get_n_neuron(tpl[0].weights[0].shape[-1]) 141 | for tpl in grow_layer_tuples 142 | } 143 | 144 | def is_update_iteration(self, iteration): 145 | assert iteration >= 0 146 | return ((self.network_grower is not None) and 147 | (iteration % self._update_frequency == 0) and 148 | (self._start_iteration <= iteration) and 149 | ((self._n_growth_steps is None) or 150 | (self._growth_counter < self._n_growth_steps))) 151 | 152 | def get_variable_list(self, grow_layer_tuple): 153 | return list(itertools.chain.from_iterable( 154 | [layer.trainable_weights for layer in grow_layer_tuple])) 155 | 156 | def get_grow_layer_stats(self): 157 | all_stats = [] 158 | for grow_layer_tuple in self.grow_layer_tuples: 159 | first_layer = grow_layer_tuple[0] 160 | n_neuron = first_layer.get_weights()[0].shape[-1] 161 | all_stats.append((first_layer.layer.name, n_neuron)) 162 | return all_stats 163 | 164 | def update_network(self, batch_data, optimizer=None): 165 | raise NotImplementedError() 166 | 167 | 168 | class DummyUpdater(Updater): 169 | """Implements common methods. 170 | 171 | Attr: 172 | network_grower: growers.LayerGrower 173 | grow_layer_tuples: list of lists, candidates to be 174 | grown together with their outgoing weights. 175 | update_frequency: int, Number of iterations before neurons are added. 176 | """ 177 | 178 | def __init__(self, grow_layer_tuples): 179 | super().__init__(None, grow_layer_tuples, None, None) 180 | 181 | def update_network(self, **kwargs): 182 | pass 183 | 184 | def is_update_iteration(self, epoch): 185 | del epoch 186 | return False 187 | 188 | def get_grow_layer_stats(self): 189 | return [] 190 | 191 | 192 | class RoundRobin(Updater): 193 | """Updates provided candidate layers in a round robin fashion.""" 194 | 195 | def _next_grow_layer_tuple(self, unused_batch_data): 196 | next_tuple_id = self._growth_counter % len(self.grow_layer_tuples) 197 | self._growth_counter += 1 198 | return self.grow_layer_tuples[next_tuple_id] 199 | 200 | def update_network(self, batch_data, optimizer=None): 201 | """Updates the network and optimizer slots.""" 202 | grow_layer_tuple = self._next_grow_layer_tuple(batch_data) 203 | old_variables = self.get_variable_list(grow_layer_tuple) 204 | n_new = self._n_grow_dict[grow_layer_tuple[0].name] 205 | self.network_grower.grow_neurons(grow_layer_tuple, batch_data, 206 | n_new=n_new, scale=self._scale) 207 | # Run the loss function to create new variables. 208 | self.compile_fn() 209 | new_variables = self.get_variable_list(grow_layer_tuple) 210 | optimizer._create_slots(new_variables) 211 | if self._carry_optimizer and optimizer: 212 | self.copy_optimizer_slots(optimizer, old_variables, new_variables) 213 | self.delete_optimizer_slots(optimizer, old_variables) 214 | 215 | 216 | class AllAtOnce(Updater): 217 | """Grows all candidate layers at once.""" 218 | 219 | def _get_all_grow_layer_tuples(self): 220 | self._growth_counter += 1 221 | return self.grow_layer_tuples[:] 222 | 223 | def update_network(self, batch_data, optimizer=None): 224 | """Updates the network and optimizer slots.""" 225 | grow_layer_tuples = self._get_all_grow_layer_tuples() 226 | for grow_layer_tuple in grow_layer_tuples: 227 | old_variables = self.get_variable_list(grow_layer_tuple) 228 | n_new = self._n_grow_dict[grow_layer_tuple[0].name] 229 | self.network_grower.grow_neurons(grow_layer_tuple, batch_data, 230 | n_new=n_new, scale=self._scale) 231 | # Run the loss function to create new variables. 232 | self.compile_fn() 233 | new_variables = self.get_variable_list(grow_layer_tuple) 234 | optimizer._create_slots(new_variables) 235 | if self._carry_optimizer and optimizer: 236 | self.copy_optimizer_slots(optimizer, old_variables, new_variables) 237 | self.delete_optimizer_slots(optimizer, old_variables) 238 | 239 | 240 | 241 | 242 | def adjust_epochs(train_epochs, width_scale, update_frequency, 243 | start_iteration, n_growth_steps, steps_per_epoch): 244 | """Adjust the epochs such as the total FLOPs are same as big-baseline.""" 245 | # Here we extend training according to the FLOP saved by starting with 246 | # a smaller width. 247 | saved_fraction = (1 - width_scale) 248 | # Saved before growth. 249 | saved_steps = saved_fraction * start_iteration 250 | growth_duration = (update_frequency * (n_growth_steps - 1)) 251 | # Saved during growth (2 is because of the trianble area). 252 | saved_steps += saved_fraction/2 * growth_duration 253 | new_epochs = train_epochs + int(saved_steps / steps_per_epoch) 254 | return new_epochs 255 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /growneuron/layers.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 GradMax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """GrowLayer wrapper module.""" 17 | import numpy as np 18 | import tensorflow as tf 19 | 20 | SUPPORTED_LAYERS = (tf.keras.layers.Dense, tf.keras.layers.Conv2D) 21 | 22 | 23 | def get_activation_fn(actv_fn): 24 | """Activation choices for the layer. 25 | 26 | Args: 27 | actv_fn: str. 28 | 29 | Returns: 30 | activation fn 31 | """ 32 | if actv_fn == 'relu1': 33 | # This has grad(f(0))=1 instead of 0 (default implementation). 34 | return lambda x: tf.math.maximum(x, 0) 35 | elif actv_fn == 'relu2': 36 | # This has grad(f(0))=1. 37 | return lambda x: tf.math.maximum(x, -1) 38 | else: 39 | return tf.keras.activations.get(actv_fn) 40 | 41 | 42 | class GrowLayer(tf.keras.layers.Wrapper): 43 | """This layer wraps keras.layers in order to support growing. 44 | 45 | This layer allows adding callbacks to the forward pass of a layer that will be 46 | called with the inputs and outputs of the underlying layer before the 47 | activations. 48 | 49 | Example Usage: 50 | ``` 51 | first_layer = GrowLayer(tf.keras.layers.Dense(32)) 52 | ``` 53 | 54 | This layer can be used for growing neurons. 55 | `first_layer.add_neurons_incoming(1, new_weights='zeros')` 56 | """ 57 | 58 | def __init__(self, *args, activation=None, **kwargs): 59 | if 'name' not in kwargs: 60 | # args[0] is the wrapped layer 61 | kwargs['name'] = f'glayer_{args[0].name}' 62 | super().__init__(*args, **kwargs) 63 | self.activation = get_activation_fn(activation) 64 | self.reset_callbacks() 65 | 66 | def add_callback(self, name, fn): 67 | self._callbacks[name] = fn 68 | 69 | def remove_callback(self, name): 70 | del self._callbacks[name] 71 | 72 | def reset_callbacks(self): 73 | self._callbacks = {} 74 | 75 | def __call__(self, inputs, *args, **kwargs): 76 | outputs = self.layer.__call__(inputs, *args, **kwargs) 77 | for _, callback_fn in self._callbacks.items(): 78 | inputs, outputs = callback_fn(inputs, outputs) 79 | if self.activation: 80 | outputs = self.activation(outputs) 81 | return outputs 82 | 83 | def add_neurons(self, n_new, new_weights='zeros', scale=1., 84 | is_outgoing=False, scale_method='mean_norm', 85 | new_bias='zeros'): 86 | """Adds new neurons and creates a new layer. 87 | 88 | New weights are scaled (if not zero) to have l2-norm equal to the mean 89 | l2-norm of the existing weights. 90 | TODO Unify splitting and adding neurons. 91 | Args: 92 | n_new: number of neurons to add. 93 | new_weights: 'zeros', 'random' or np.ndarray. 94 | scale: float, scales the new_weights multiplied with the mean norm of 95 | the existing weights. 96 | is_outgoing: bool, if true adds outgoing connections from the new neurons 97 | coming from previous layers. In other words number of neurons in current 98 | layer stays constant, but they aggregate information from n_new many 99 | new neurons. 100 | scale_method: str, Type of scaling to be used when initializing new 101 | neurons. 102 | - `mean_norm` means they are normalized using the mean norm of 103 | existing weights. 104 | - `fixed` means the weights are multiplied with scale directly. 105 | new_bias: str, 'zeros' or 'ones'. 106 | """ 107 | old_module = self.layer 108 | assert old_module.built 109 | assert new_bias in ('zeros', 'ones') 110 | assert isinstance(old_module, SUPPORTED_LAYERS) 111 | self.layer = grow_new_layer(old_module, n_new, new_weights, scale, 112 | is_outgoing=is_outgoing, new_bias=new_bias, 113 | scale_method=scale_method) 114 | 115 | def add_neurons_identity(self, n_new): 116 | """Adds identity neurons for various layer types. 117 | 118 | Args: 119 | n_new: number of neurons to add. 120 | """ 121 | old_module = self.layer 122 | assert old_module.built 123 | if isinstance(old_module, tf.keras.layers.BatchNormalization): 124 | self.layer = grow_new_bn_layer(old_module, n_new) 125 | elif isinstance(old_module, tf.keras.layers.LayerNormalization): 126 | self.layer = grow_new_ln_layer(old_module, n_new) 127 | elif isinstance(old_module, tf.keras.layers.DepthwiseConv2D): 128 | self.layer = grow_new_dw_layer(old_module, n_new) 129 | else: 130 | raise ValueError(f'layer: {old_module} of {type(old_module)} is not ' 131 | 'supported.') 132 | 133 | 134 | def grow_new_layer(old_module, n_new, new_weights, scale, is_outgoing=False, 135 | scale_method='mean_norm', new_bias='zeros'): 136 | """Creates new layer after adding incoming our outgoing connections. 137 | 138 | Args: 139 | old_module: Old layer to grow from. One of layers.SUPPORTED_LAYERS. 140 | n_new: number of neurons to add. 141 | new_weights: 'zeros', 'random' or np.ndarray. 142 | scale: float, scales the new_weights multiplied with the mean norm of 143 | the existing weights. 144 | is_outgoing: bool, True if the outgoing connections of the new neurons are 145 | being added to the next layer. In this case, no new neurons are generated; 146 | instead existing neurons receive new incoming connections. 147 | scale_method: str, Type of scaling to be used when initializing new 148 | neurons. 149 | - `mean_norm` means they are normalized using the mean norm of 150 | existing weights. 151 | - `fixed` means the weights are multiplied with scale directly. 152 | new_bias: str, zeros or ones. 153 | Returns: 154 | layer of same type as the old_module. 155 | """ 156 | old_weights = old_module.get_weights()[0] 157 | shape_axis = -2 if is_outgoing else -1 158 | 159 | if scale_method == 'mean_norm': 160 | magnitude_new = np.mean(norm_l2(old_weights, keep_dim=shape_axis).numpy()) 161 | magnitude_new *= scale 162 | elif scale_method == 'fixed': 163 | # We don't use the scale of existing weights for initialization. 164 | magnitude_new = scale 165 | else: 166 | raise ValueError(f'Not supported scale_method, {scale_method}') 167 | 168 | shape_new = list(old_weights.shape) 169 | shape_new[shape_axis] = n_new 170 | 171 | if isinstance(new_weights, np.ndarray): 172 | assert new_weights.shape == tuple(shape_new) 173 | # Normalize to unit norm and then scale. 174 | normalized_w = normalize_l2(new_weights, axis=shape_axis).numpy() 175 | new_neurons = normalized_w * magnitude_new 176 | elif new_weights == 'random': 177 | normalized_w = normalize_l2(np.random.uniform(size=shape_new), 178 | axis=shape_axis).numpy() 179 | # Normalize to unit norm and then scale. 180 | new_neurons = normalized_w * magnitude_new 181 | elif new_weights == 'zeros': 182 | new_neurons = np.zeros(shape_new) 183 | else: 184 | raise ValueError('new_weights: %s is not valid' % new_weights) 185 | new_layer_weights = [np.concatenate((old_weights, new_neurons), 186 | axis=shape_axis)] 187 | 188 | # Assuming bias is the second weight. 189 | if old_module.use_bias: 190 | new_bias_weights = old_module.get_weights()[1] 191 | if not is_outgoing: 192 | new_neuron_bias = (np.zeros([n_new]) if (new_bias == 'zeros') else 193 | np.ones([n_new])) 194 | new_bias_weights = np.concatenate((new_bias_weights, new_neuron_bias), 195 | axis=0) 196 | new_layer_weights.append(new_bias_weights) 197 | 198 | common_kwargs = { 199 | 'name': old_module.name, 200 | 'activation': old_module.activation, 201 | 'use_bias': old_module.use_bias 202 | } 203 | for r_name in ('kernel_regularizer', 'bias_regularizer', 204 | 'activity_regularizer'): 205 | regularizer = getattr(old_module, r_name) 206 | if regularizer is not None: 207 | common_kwargs[r_name] = regularizer 208 | n_out_new = new_layer_weights[0].shape[-1] 209 | if isinstance(old_module, tf.keras.layers.Dense): 210 | new_module = tf.keras.layers.Dense( 211 | n_out_new, 212 | weights=new_layer_weights, 213 | **common_kwargs) 214 | elif isinstance(old_module, tf.keras.layers.Conv2D): 215 | new_module = tf.keras.layers.Conv2D( 216 | n_out_new, 217 | kernel_size=old_module.kernel_size, 218 | strides=old_module.strides, 219 | padding=old_module.padding, 220 | weights=new_layer_weights, 221 | **common_kwargs) 222 | else: 223 | raise ValueError(f'Unexpected module: {old_module}') 224 | 225 | return new_module 226 | 227 | 228 | def grow_new_ln_layer(old_module, n_new): 229 | """Grows a new identity LayerNormalization layer.""" 230 | new_ln_weights = [] 231 | # One for gamma, beta 232 | for i in range(2): 233 | old_w = old_module.get_weights()[i] 234 | if i == 0: # gamma 235 | new_w = np.ones([n_new]) 236 | else: # beta 237 | new_w = np.zeros([n_new]) 238 | w = np.concatenate((old_w, new_w), axis=0) 239 | new_ln_weights.append(w) 240 | common_kwargs = { 241 | 'epsilon': old_module.epsilon 242 | } 243 | for r_name in ('gamma_regularizer', 'beta_regularizer'): 244 | regularizer = getattr(old_module, r_name) 245 | if regularizer is not None: 246 | common_kwargs[r_name] = regularizer 247 | return tf.keras.layers.LayerNormalization(weights=new_ln_weights, 248 | **common_kwargs) 249 | 250 | 251 | def grow_new_bn_layer(old_module, n_new): 252 | """Grows a new identity BatchNormalization layer.""" 253 | new_bn_weights = [] 254 | # One for gamma, beta, moving_mean and moving_variance 255 | for i in range(4): 256 | old_w = old_module.get_weights()[i] 257 | if i in (1, 2): # beta, moving_mean 258 | new_w = np.zeros([n_new]) 259 | else: # gamma, moving variance 260 | new_w = np.ones([n_new]) 261 | w = np.concatenate((old_w, new_w), axis=0) 262 | new_bn_weights.append(w) 263 | common_kwargs = { 264 | 'epsilon': old_module.epsilon 265 | } 266 | for r_name in ('gamma_regularizer', 'beta_regularizer'): 267 | regularizer = getattr(old_module, r_name) 268 | if regularizer is not None: 269 | common_kwargs[r_name] = regularizer 270 | return tf.keras.layers.BatchNormalization(weights=new_bn_weights, 271 | **common_kwargs) 272 | 273 | 274 | def grow_new_dw_layer(old_module, n_new): 275 | """Adds identity neurosn to the depthwise convolutional layers.""" 276 | old_weights = old_module.get_weights()[0] 277 | shape_new = list(old_weights.shape) 278 | shape_new[-2] = n_new 279 | new_weights = np.zeros(shape_new, dtype=old_weights.dtype) 280 | mid_index_x = new_weights.shape[0] // 2 281 | mid_index_y = new_weights.shape[1] // 2 282 | new_weights[mid_index_x, mid_index_y, Ellipsis] = 1. 283 | new_layer_weights = [np.concatenate((old_weights, new_weights), 284 | axis=-2)] 285 | 286 | # Assuming bias is the second weight. 287 | if old_module.use_bias: 288 | new_bias = old_module.get_weights()[1] 289 | new_neuron_bias = np.zeros([n_new]) 290 | new_bias = np.concatenate((new_bias, new_neuron_bias), axis=0) 291 | new_layer_weights.append(new_bias) 292 | 293 | regularizer_kwargs = {} 294 | for r_name in ('kernel_regularizer', 'bias_regularizer', 295 | 'activity_regularizer'): 296 | regularizer = getattr(old_module, r_name) 297 | if regularizer is not None: 298 | regularizer_kwargs[r_name] = regularizer 299 | new_module = tf.keras.layers.DepthwiseConv2D( 300 | kernel_size=old_module.kernel_size, 301 | name=old_module.name, 302 | activation=old_module.activation, 303 | use_bias=old_module.use_bias, 304 | strides=old_module.strides, 305 | padding=old_module.padding, 306 | weights=new_layer_weights, 307 | **regularizer_kwargs) 308 | return new_module 309 | 310 | 311 | def norm_l2(tensor, keep_dim): 312 | norm_axes = list(range(len(tensor.shape))) 313 | del norm_axes[keep_dim] 314 | return tf.sqrt(tf.reduce_sum(tf.pow(tensor, 2), axis=norm_axes)) 315 | 316 | 317 | def normalize_l2(tensor, axis): 318 | assert axis in (-2, -1) 319 | norm = norm_l2(tensor, axis) 320 | scale_recipe = '...ij,i->...ij' if (axis == -2) else '...ij,j->...ij' 321 | return tf.einsum(scale_recipe, tensor, 1 / norm) 322 | -------------------------------------------------------------------------------- /growneuron/imagenet/main.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 GradMax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | r"""MobileNet-v1 on ImageNet trained with maximum likelihood. 17 | 18 | """ 19 | 20 | import itertools 21 | import os 22 | import time 23 | 24 | from absl import app 25 | from absl import flags 26 | from absl import logging 27 | from growneuron import growers 28 | from growneuron import updaters 29 | from growneuron.imagenet import data 30 | from growneuron.imagenet import mb_v1 31 | from ml_collections import config_flags 32 | import tensorflow as tf 33 | import tensorflow_datasets as tfds 34 | import uncertainty_baselines.schedules as ub_schedules 35 | from tensorboard.plugins.hparams import api as hp 36 | 37 | 38 | config_flags.DEFINE_config_file( 39 | name='config', 40 | default='growneuron/imagenet/configs/' 41 | 'baseline_big.py', 42 | help_string='training config file.') 43 | # common flags 44 | flags.DEFINE_string( 45 | 'tpu', '', 46 | 'TPU address. If empty MirroredStrategy is used.') 47 | flags.DEFINE_string('data_dir', None, 48 | 'data_dir to be used for tfds dataset construction.' 49 | 'It is required when training with cloud TPUs') 50 | flags.DEFINE_bool('download_data', False, 51 | 'Whether to download data locally when initializing a ' 52 | 'dataset.') 53 | flags.DEFINE_string('output_dir', '/tmp/cifar', 'Output directory.') 54 | FLAGS = flags.FLAGS 55 | 56 | 57 | def get_optimizer(optimizer_config, train_epochs, batch_size, steps_per_epoch): 58 | """Given the config and training arguments returns an optimizer.""" 59 | # Linearly scale learning rate and the decay epochs by vanilla settings. 60 | base_lr = optimizer_config.base_learning_rate * batch_size / 128 61 | lr_decay_epochs = [int(fraction * train_epochs) 62 | for fraction in optimizer_config.lr_decay_epochs] 63 | if optimizer_config.decay_type == 'step': 64 | lr_schedule = ub_schedules.WarmUpPiecewiseConstantSchedule( 65 | steps_per_epoch, 66 | base_lr, 67 | decay_ratio=optimizer_config.lr_decay_ratio, 68 | decay_epochs=lr_decay_epochs, 69 | warmup_epochs=optimizer_config.lr_warmup_epochs) 70 | elif optimizer_config.decay_type == 'cosine': 71 | lr_schedule = tf.keras.optimizers.schedules.CosineDecay( 72 | base_lr, train_epochs * steps_per_epoch, alpha=0.0) 73 | else: 74 | lr_schedule = base_lr / 100. 75 | logging.info('No decay used') 76 | optimizer = tf.keras.optimizers.SGD(lr_schedule, 77 | momentum=optimizer_config.momentum, 78 | nesterov=optimizer_config.nesterov) 79 | return optimizer 80 | 81 | 82 | def main(argv): 83 | fmt = '[%(filename)s:%(lineno)s] %(message)s' 84 | formatter = logging.PythonFormatter(fmt) 85 | logging.get_absl_handler().setFormatter(formatter) 86 | del argv # unused arg 87 | config = FLAGS.config 88 | if (hasattr(config, 'grow_frequency_multiplier') and 89 | config.grow_frequency_multiplier != 1): 90 | # Scale the frequency of the growth steps 91 | factor = config.grow_frequency_multiplier 92 | config.updater.update_frequency = int( 93 | config.updater.update_frequency * factor) 94 | config.updater.n_growth_steps = int( 95 | config.updater.n_growth_steps / factor) 96 | config.updater.n_grow_fraction *= factor 97 | 98 | tf.io.gfile.makedirs(FLAGS.output_dir) 99 | logging.info('Saving checkpoints at %s', FLAGS.output_dir) 100 | tf.random.set_seed(config.seed) 101 | if FLAGS.tpu: 102 | resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu=FLAGS.tpu) 103 | tf.config.experimental_connect_to_cluster(resolver) 104 | topology = tf.tpu.experimental.initialize_tpu_system(resolver) 105 | logging.info('Topology:') 106 | logging.info('num_tasks: %d', topology.num_tasks) 107 | logging.info('num_tpus_per_task: %d', topology.num_tpus_per_task) 108 | strategy = tf.distribute.TPUStrategy(resolver) 109 | else: 110 | strategy = tf.distribute.MirroredStrategy() 111 | topology = None 112 | 113 | ds_builder = tfds.builder(config.dataset) 114 | if FLAGS.download_data: 115 | ds_builder.download_and_prepare() 116 | ds_info = ds_builder.info 117 | batch_size = config.per_core_batch_size * config.num_cores 118 | 119 | # Scale arguments that depend on 128 batch size total training iterations. 120 | multiplier = 512. / batch_size 121 | if hasattr(config.updater, 'update_frequency'): 122 | config.updater.update_frequency = int( 123 | config.updater.update_frequency * multiplier) 124 | config.updater.start_iteration = int( 125 | config.updater.start_iteration * multiplier) 126 | 127 | train_dataset_size = ds_info.splits['train'].num_examples 128 | steps_per_epoch = train_dataset_size // batch_size 129 | logging.info('Steps per epoch %s', steps_per_epoch) 130 | logging.info('Size of the dataset %s', train_dataset_size) 131 | steps_per_eval = ds_info.splits['validation'].num_examples // batch_size 132 | num_classes = ds_info.features['label'].num_classes 133 | 134 | train_dataset = strategy.distribute_datasets_from_function( 135 | data.build_input_fn(ds_builder, batch_size, topology=topology, 136 | is_training=True)) 137 | test_dataset = strategy.distribute_datasets_from_function( 138 | data.build_input_fn(ds_builder, batch_size, topology=topology, 139 | is_training=False)) 140 | # Maybe create a grower. 141 | grow_type = getattr(config, 'grow_type', None) 142 | if grow_type == 'add_random': 143 | grower = growers.AddRandom() 144 | elif grow_type == 'add_firefly': 145 | grower = growers.AddFirefly() 146 | elif grow_type == 'add_gradmax_opt': 147 | grower = growers.AddGradmaxOptim() 148 | elif grow_type == 'add_gradmax': 149 | grower = growers.AddGradmax() 150 | else: 151 | logging.info('No growing') 152 | grower = None 153 | 154 | if grower: 155 | grower.epsilon = config.grow_epsilon 156 | grower.scale_method = config.grow_scale_method 157 | grower.is_outgoing_zero = config.is_outgoing_zero 158 | 159 | if config.scale_epochs: 160 | old_epochs = config.train_epochs 161 | # Adjust the total epochs to match big-baseline training FLOPs. 162 | if grower: 163 | config.train_epochs = updaters.adjust_epochs( 164 | config.train_epochs, 165 | config.model.width_multiplier, 166 | config.updater.update_frequency, 167 | config.updater.start_iteration, 168 | config.updater.n_growth_steps, 169 | steps_per_epoch 170 | ) 171 | else: 172 | # baseline 173 | config.train_epochs = config.train_epochs / config.model.width_multiplier 174 | logging.info('Extended training from %s to %s', old_epochs, 175 | config.train_epochs) 176 | 177 | summary_writer = tf.summary.create_file_writer( 178 | os.path.join(FLAGS.output_dir, 'summaries')) 179 | 180 | with summary_writer.as_default(): 181 | flat_param_dict = {} 182 | def flat_fn(dic): 183 | for k, v in dic.items(): 184 | if isinstance(v, dict): 185 | flat_fn({f'{k}.{k2}': v2 for k2, v2 in v.items()}) 186 | else: 187 | flat_param_dict[k] = str(v) if isinstance(v, list) else v 188 | flat_fn(config.to_dict()) 189 | hp.hparams(flat_param_dict) 190 | 191 | grow_layer_tuples = [] 192 | architecture = config.get('architecture', 'mb_v1') 193 | with strategy.scope(): 194 | if architecture == 'mb_v1': 195 | logging.info('Building VGG model') 196 | model = mb_v1.create_model( 197 | num_classes=num_classes, 198 | seed=config.seed, 199 | **config.model) 200 | grow_layer_tuples = model.get_grow_layer_tuples() 201 | else: 202 | raise ValueError(f'Unknown architecture: {architecture}') 203 | logging.info('#grow_layer_tuples: %s', len(grow_layer_tuples)) 204 | logging.info('grow_layer_tuples[0]: %s', grow_layer_tuples[0]) 205 | grow_metrics = {layers[0]: tf.keras.metrics.Sum() 206 | for layers in grow_layer_tuples} 207 | # Initialize the parameters. 208 | def compile_model_fn(): 209 | model(tf.keras.Input((224, 224, 3))) 210 | compile_model_fn() 211 | logging.info('Model input shape: %s', model.input_shape) 212 | logging.info('Model output shape: %s', model.output_shape) 213 | logging.info('Model number of weights: %s', model.count_params()) 214 | optimizer = get_optimizer(config.optimizer, config.train_epochs, batch_size, 215 | steps_per_epoch) 216 | train_metrics = { 217 | 'train/negative_log_likelihood': 218 | tf.keras.metrics.Mean(), 219 | 'train/accuracy': 220 | tf.keras.metrics.SparseCategoricalAccuracy(), 221 | 'train/loss': 222 | tf.keras.metrics.Mean(), 223 | } 224 | 225 | eval_metrics = { 226 | 'test/negative_log_likelihood': 227 | tf.keras.metrics.Mean(), 228 | 'test/accuracy': 229 | tf.keras.metrics.SparseCategoricalAccuracy(), 230 | } 231 | model.summary() 232 | checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer) 233 | latest_checkpoint = tf.train.latest_checkpoint(FLAGS.output_dir) 234 | initial_epoch = 0 235 | if latest_checkpoint: 236 | # TODO This probably wouldn't work if the networks is grown; 237 | # so we need to switch to saved models maybe. 238 | # checkpoint.restore must be within a strategy.scope() so that optimizer 239 | # slot variables are mirrored. 240 | checkpoint.restore(latest_checkpoint) 241 | logging.info('Loaded checkpoint %s', latest_checkpoint) 242 | initial_epoch = optimizer.iterations.numpy() // steps_per_epoch 243 | 244 | # Create the updater. 245 | def loss_fn(inputs): 246 | images, labels = inputs 247 | logits = model(images, training=True) 248 | one_hot_labels = tf.one_hot(tf.cast(labels, tf.int32), num_classes) 249 | loss = tf.reduce_mean( 250 | tf.keras.losses.categorical_crossentropy( 251 | one_hot_labels, 252 | logits, 253 | from_logits=True)) 254 | scaled_loss = loss / strategy.num_replicas_in_sync 255 | # Don't add the regularization as unnecessary for zero variables. 256 | return scaled_loss 257 | 258 | updater_type = getattr(config, 'updater_type', None) 259 | if updater_type == 'round_robin': 260 | updater = updaters.RoundRobin(grower, grow_layer_tuples, loss_fn, 261 | compile_model_fn, **config.updater) 262 | elif updater_type == 'all_at_once': 263 | updater = updaters.AllAtOnce(grower, grow_layer_tuples, loss_fn, 264 | compile_model_fn, **config.updater) 265 | logging.info(message) 266 | 267 | if (epoch % 20 == 0) or (config.train_epochs == (epoch + 1)): 268 | test_iterator = iter(test_dataset) 269 | logging.info('Starting to run eval at epoch: %s', epoch) 270 | test_start_time = time.time() 271 | test_step(test_iterator, 'test', steps_per_eval) 272 | test_ms_per_example = (time.time() - test_start_time) * 1e6 / batch_size 273 | logging.info('Done with eval on') 274 | 275 | logging.info('Test NLL: %.4f, Accuracy: %.2f%%', 276 | eval_metrics['test/negative_log_likelihood'].result(), 277 | eval_metrics['test/accuracy'].result() * 100) 278 | total_results = {name: metric.result() for name, metric 279 | in eval_metrics.items()} 280 | total_results['train/ms_per_example'] = train_ms_per_example 281 | total_results['test/ms_per_example'] = test_ms_per_example 282 | with summary_writer.as_default(): 283 | for name, result in total_results.items(): 284 | tf.summary.scalar(name, result, step=epoch + 1) 285 | 286 | for metric in eval_metrics.values(): 287 | metric.reset_states() 288 | 289 | if (config.checkpoint_interval > 0 and 290 | (epoch + 1) % config.checkpoint_interval == 0): 291 | checkpoint_name = checkpoint.save( 292 | os.path.join(FLAGS.output_dir, 'checkpoint')) 293 | logging.info('Saved checkpoint to %s', checkpoint_name) 294 | 295 | final_checkpoint_name = checkpoint.save( 296 | os.path.join(FLAGS.output_dir, 'checkpoint')) 297 | logging.info('Saved last checkpoint to %s', final_checkpoint_name) 298 | 299 | 300 | if __name__ == '__main__': 301 | app.run(main) 302 | -------------------------------------------------------------------------------- /growneuron/cifar/main.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 GradMax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | r"""Wide ResNet 28-10 on CIFAR-10/100 trained with maximum likelihood. 17 | 18 | Hyperparameters differ slightly from the original paper's code 19 | (https://github.com/szagoruyko/wide-residual-networks) as TensorFlow uses, for 20 | example, l2 instead of weight decay, and a different parameterization for SGD's 21 | momentum. 22 | 23 | """ 24 | 25 | import itertools 26 | import os 27 | import time 28 | 29 | from absl import app 30 | from absl import flags 31 | from absl import logging 32 | from growneuron import growers 33 | from growneuron import updaters 34 | from growneuron.cifar import data 35 | from growneuron.cifar import vgg 36 | from growneuron.cifar import wide_resnet 37 | import growneuron.layers as glayers 38 | from ml_collections import config_flags 39 | import tensorflow as tf 40 | import tensorflow_datasets as tfds 41 | import uncertainty_baselines.schedules as ub_schedules 42 | from tensorboard.plugins.hparams import api as hp 43 | 44 | config_flags.DEFINE_config_file( 45 | name='config', 46 | default='growneuron/cifar/configs/' 47 | 'baseline_big.py', 48 | help_string='training config file.') 49 | # common flags 50 | 51 | flags.DEFINE_string('data_dir', None, 52 | 'data_dir to be used for tfds dataset construction.' 53 | 'It is required when training with cloud TPUs') 54 | flags.DEFINE_bool('download_data', False, 55 | 'Whether to download data locally when initializing a ' 56 | 'dataset.') 57 | flags.DEFINE_string('output_dir', '/tmp/cifar', 'Output directory.') 58 | flags.DEFINE_bool('collect_profile', False, 59 | 'Whether to trace a profile with tensorboard') 60 | 61 | FLAGS = flags.FLAGS 62 | 63 | 64 | def get_optimizer(optimizer_config, train_epochs, batch_size, steps_per_epoch): 65 | """Given the config and training arguments returns an optimizer.""" 66 | # Linearly scale learning rate and the decay epochs by vanilla settings. 67 | base_lr = optimizer_config.base_learning_rate * batch_size / 128 68 | lr_decay_epochs = [int(fraction * train_epochs) 69 | for fraction in optimizer_config.lr_decay_epochs] 70 | if optimizer_config.decay_type == 'step': 71 | lr_schedule = ub_schedules.WarmUpPiecewiseConstantSchedule( 72 | steps_per_epoch, 73 | base_lr, 74 | decay_ratio=optimizer_config.lr_decay_ratio, 75 | decay_epochs=lr_decay_epochs, 76 | warmup_epochs=optimizer_config.lr_warmup_epochs) 77 | elif optimizer_config.decay_type == 'cosine': 78 | lr_schedule = tf.keras.optimizers.schedules.CosineDecay( 79 | base_lr, train_epochs * steps_per_epoch, alpha=0.0) 80 | else: 81 | lr_schedule = base_lr / 100. 82 | logging.info('No decay used') 83 | optimizer = tf.keras.optimizers.SGD(lr_schedule, 84 | momentum=optimizer_config.momentum, 85 | nesterov=optimizer_config.nesterov) 86 | return optimizer 87 | 88 | 89 | def main(argv): 90 | fmt = '[%(filename)s:%(lineno)s] %(message)s' 91 | formatter = logging.PythonFormatter(fmt) 92 | logging.get_absl_handler().setFormatter(formatter) 93 | del argv # unused arg 94 | config = FLAGS.config 95 | 96 | tf.io.gfile.makedirs(FLAGS.output_dir) 97 | logging.info('Saving checkpoints at %s', FLAGS.output_dir) 98 | tf.random.set_seed(config.seed) 99 | 100 | strategy = tf.distribute.MirroredStrategy() 101 | 102 | ds_builder = tfds.builder(config.dataset) 103 | if FLAGS.download_data: 104 | ds_builder.download_and_prepare() 105 | ds_info = ds_builder.info 106 | batch_size = config.per_core_batch_size * config.num_cores 107 | 108 | # Scale arguments that depend on 128 batch size total training iterations. 109 | multiplier = 128. / batch_size 110 | if hasattr(config.updater, 'update_frequency'): 111 | config.updater.update_frequency = int( 112 | config.updater.update_frequency * multiplier) 113 | config.updater.start_iteration = int( 114 | config.updater.start_iteration * multiplier) 115 | 116 | train_dataset_size = ds_info.splits['train'].num_examples 117 | steps_per_epoch = train_dataset_size // batch_size 118 | logging.info('Steps per epoch %s', steps_per_epoch) 119 | logging.info('Size of the dataset %s', train_dataset_size) 120 | steps_per_eval = ds_info.splits['test'].num_examples // batch_size 121 | num_classes = ds_info.features['label'].num_classes 122 | 123 | train_dataset = strategy.distribute_datasets_from_function( 124 | data.build_input_fn(ds_builder, batch_size, topology=None, 125 | is_training=True, 126 | cache_dataset=config.cache_dataset)) 127 | # Grow batches might be different in size. 128 | grow_batch_size = getattr(config, 'grow_batch_size', batch_size) 129 | grow_dataset = strategy.distribute_datasets_from_function( 130 | data.build_input_fn(ds_builder, grow_batch_size, topology=None, 131 | is_training=True, 132 | cache_dataset=config.cache_dataset)) 133 | test_dataset = strategy.distribute_datasets_from_function( 134 | data.build_input_fn(ds_builder, batch_size, topology=None, 135 | is_training=False, 136 | cache_dataset=config.cache_dataset)) 137 | # Scale the trianing epochs to match roughly to big-baseline cost. 138 | arch_name = config.get('architecture', 'wide-resnet') 139 | # Maybe create a grower. 140 | grow_type = getattr(config, 'grow_type', None) 141 | if grow_type == 'add_random': 142 | grower = growers.AddRandom() 143 | elif grow_type == 'add_firefly': 144 | grower = growers.AddFirefly() 145 | elif grow_type == 'add_gradmax_opt': 146 | grower = growers.AddGradmaxOptim() 147 | elif grow_type == 'add_gradmax': 148 | grower = growers.AddGradmax() 149 | else: 150 | logging.info('No growing') 151 | grower = None 152 | 153 | if grower: 154 | grower.epsilon = config.grow_epsilon 155 | grower.scale_method = config.grow_scale_method 156 | grower.is_outgoing_zero = config.is_outgoing_zero 157 | 158 | if config.scale_epochs: 159 | if arch_name == 'wide-resnet': 160 | width_scale = config.model.block_width_multiplier 161 | elif arch_name == 'vgg': 162 | width_scale = config.model.width_multiplier 163 | else: 164 | raise ValueError(f'Unknown architecture: {arch_name}') 165 | old_epochs = config.train_epochs 166 | # Adjust the total epochs to match big-baseline training FLOPs. 167 | if grower: 168 | config.train_epochs = updaters.adjust_epochs( 169 | config.train_epochs, 170 | width_scale, 171 | config.updater.update_frequency, 172 | config.updater.start_iteration, 173 | config.updater.n_growth_steps, 174 | steps_per_epoch 175 | ) 176 | else: 177 | # baseline 178 | config.train_epochs = config.train_epochs / width_scale 179 | logging.info('Extended training from %s to %s', old_epochs, 180 | config.train_epochs) 181 | summary_writer = tf.summary.create_file_writer( 182 | os.path.join(FLAGS.output_dir, 'summaries')) 183 | 184 | with summary_writer.as_default(): 185 | flat_param_dict = {} 186 | def flat_fn(dic): 187 | for k, v in dic.items(): 188 | if isinstance(v, dict): 189 | flat_fn({f'{k}.{k2}': v2 for k2, v2 in v.items()}) 190 | else: 191 | flat_param_dict[k] = str(v) if isinstance(v, list) else v 192 | flat_fn(config.to_dict()) 193 | hp.hparams(flat_param_dict) 194 | 195 | grow_layer_tuples = [] 196 | with strategy.scope(): 197 | if arch_name == 'wide-resnet': 198 | logging.info('Building ResNet model') 199 | model = wide_resnet.create_model( 200 | num_classes=num_classes, 201 | seed=config.seed, 202 | **config.model) 203 | for block_seq in model.group_seq: 204 | for block_layers, _ in block_seq: 205 | # We need to get all layers between the two grow layers. 206 | glayer_indices = [i for i, l in enumerate(block_layers) 207 | if isinstance(l, glayers.GrowLayer)] 208 | start_index, end_index = glayer_indices[0], glayer_indices[-1] 209 | grow_layer_tuples.append(block_layers[start_index:(end_index+1)]) 210 | elif arch_name == 'vgg': 211 | logging.info('Building VGG model') 212 | model = vgg.create_model( 213 | num_classes=num_classes, 214 | seed=config.seed, 215 | **config.model) 216 | grow_layer_tuples = model.get_grow_layer_tuples() 217 | else: 218 | raise ValueError(f'Unknown architecture: {arch_name}') 219 | logging.info('grow_layer_tuples: %s', grow_layer_tuples) 220 | grow_metrics = {layers[0]: tf.keras.metrics.Sum() 221 | for layers in grow_layer_tuples} 222 | # Initialize the parameters. 223 | def compile_model_fn(): 224 | model(tf.keras.Input((32, 32, 3))) 225 | compile_model_fn() 226 | logging.info('Model input shape: %s', model.input_shape) 227 | logging.info('Model output shape: %s', model.output_shape) 228 | logging.info('Model number of weights: %s', model.count_params()) 229 | optimizer = get_optimizer(config.optimizer, config.train_epochs, batch_size, 230 | steps_per_epoch) 231 | train_metrics = { 232 | 'train/negative_log_likelihood': 233 | tf.keras.metrics.Mean(), 234 | 'train/accuracy': 235 | tf.keras.metrics.SparseCategoricalAccuracy(), 236 | 'train/loss': 237 | tf.keras.metrics.Mean(), 238 | } 239 | 240 | eval_metrics = { 241 | 'test/negative_log_likelihood': 242 | tf.keras.metrics.Mean(), 243 | 'test/accuracy': 244 | tf.keras.metrics.SparseCategoricalAccuracy(), 245 | } 246 | model.summary() 247 | 248 | checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer) 249 | latest_checkpoint = tf.train.latest_checkpoint(FLAGS.output_dir) 250 | initial_epoch = 0 251 | if latest_checkpoint: 252 | # TODO This probably wouldn't work if the networks is grown; 253 | # so we need to switch to saved models maybe. 254 | # checkpoint.restore must be within a strategy.scope() so that optimizer 255 | # slot variables are mirrored. 256 | checkpoint.restore(latest_checkpoint) 257 | logging.info('Loaded checkpoint %s', latest_checkpoint) 258 | initial_epoch = optimizer.iterations.numpy() // steps_per_epoch 259 | 260 | def loss_fn(inputs): 261 | images, labels = inputs 262 | logits = model(images, training=True) 263 | one_hot_labels = tf.one_hot(tf.cast(labels, tf.int32), num_classes) 264 | loss = tf.reduce_mean( 265 | tf.keras.losses.categorical_crossentropy( 266 | one_hot_labels, 267 | logits, 268 | from_logits=True)) 269 | scaled_loss = loss / strategy.num_replicas_in_sync 270 | # Don't add the regularization as unnecessary for zero variables. 271 | return scaled_loss 272 | 273 | updater_type = getattr(config, 'updater_type', None) 274 | if updater_type == 'round_robin': 275 | updater = updaters.RoundRobin(grower, grow_layer_tuples, loss_fn, 276 | compile_model_fn, **config.updater) 277 | elif updater_type == 'all_at_once': 278 | updater = updaters.AllAtOnce(grower, grow_layer_tuples, loss_fn, 279 | compile_model_fn, **config.updater) 280 | else: 281 | updater = updaters.DummyUpdater(grow_layer_tuples) 282 | 283 | def get_update_fn(model): 284 | """Returns Per-Replica update function.""" 285 | # We need to remap this as variable names change when the network is grown. 286 | variable_mapped_grow_metrics = { 287 | l.weights[0].name: metric for l, metric in grow_metrics.items() 288 | } 289 | 290 | @tf.function 291 | def _update_fn(inputs): 292 | images, labels = inputs 293 | with tf.GradientTape() as tape: 294 | logits = model(images, training=True) 295 | one_hot_labels = tf.one_hot(tf.cast(labels, tf.int32), num_classes) 296 | nll_loss = tf.reduce_mean( 297 | tf.keras.losses.categorical_crossentropy( 298 | one_hot_labels, 299 | logits, 300 | from_logits=True)) 301 | l2_loss = sum(model.losses) 302 | loss = nll_loss + l2_loss 303 | # Scale the loss given the TPUStrategy will reduce sum all gradients. 304 | scaled_loss = loss / strategy.num_replicas_in_sync 305 | grads = tape.gradient(scaled_loss, model.trainable_variables) 306 | # Logging some gradient norms 307 | for grad, var in zip(grads, model.trainable_variables): 308 | if var.name in variable_mapped_grow_metrics: 309 | sq_grad = tf.math.pow(grad, 2) 310 | variable_mapped_grow_metrics[var.name].update_state(sq_grad) 311 | optimizer.apply_gradients(zip(grads, model.trainable_variables)) 312 | train_metrics['train/loss'].update_state(loss) 313 | train_metrics['train/negative_log_likelihood'].update_state(nll_loss) 314 | train_metrics['train/accuracy'].update_state(labels, logits) 315 | 316 | return _update_fn 317 | 318 | def train_step(iterator, grow_iterator): 319 | """Training StepFn.""" 320 | # This allows retracing. We need retrace as model is changing. 321 | update_fn = get_update_fn(model) 322 | for _ in tf.range(tf.cast(steps_per_epoch, tf.int32)): 323 | # Maybe grow. 324 | is_update = updater.is_update_iteration(optimizer.iterations) 325 | if is_update: 326 | logging.info('Growing on iteration: %s', optimizer.iterations.numpy()) 327 | with strategy.scope(): 328 | updater.update_network(batch_data=next(grow_iterator), 329 | optimizer=optimizer) 330 | compile_model_fn() 331 | # Regenerate the function so that the model is retracted after growing. 332 | update_fn = get_update_fn(model) 333 | logging.info('Model number of weights: %s', model.count_params()) 334 | with summary_writer.as_default(): 335 | logging.info('Widths after growth') 336 | for name, n_neuron in updater.get_grow_layer_stats(): 337 | logging.info('%s: %d', name, n_neuron) 338 | tf.summary.scalar(f'n_neurons/{name}', n_neuron, 339 | step=optimizer.iterations) 340 | # Gradient Step. 341 | strategy.run(update_fn, args=(next(iterator),)) 342 | # Logging 343 | if is_update or optimizer.iterations % config.get('log_freq', 100) == 1: 344 | logging.info('Train Loss: %.4f, Accuracy: %.2f%%', 345 | train_metrics['train/loss'].result(), 346 | train_metrics['train/accuracy'].result() * 100) 347 | total_results = {name: metric.result() for name, metric in 348 | train_metrics.items()} 349 | total_results['lr'] = optimizer.learning_rate(optimizer.iterations) 350 | total_results['params/total'] = model.count_params() 351 | for layer, metric in grow_metrics.items(): 352 | total_results[f'grad/{layer.name}'] = metric.result() 353 | with summary_writer.as_default(): 354 | for name, result in total_results.items(): 355 | tf.summary.scalar(name, result, step=optimizer.iterations) 356 | for metric in itertools.chain(train_metrics.values(), 357 | grow_metrics.values()): 358 | metric.reset_states() 359 | 360 | def test_step(iterator, dataset_split, num_steps): 361 | """Evaluation StepFn.""" 362 | @tf.function 363 | def step_fn(inputs): 364 | """Per-Replica StepFn.""" 365 | images, labels = inputs 366 | logits = model(images, training=False) 367 | probs = tf.nn.softmax(logits) 368 | negative_log_likelihood = tf.reduce_mean( 369 | tf.keras.losses.sparse_categorical_crossentropy(labels, probs)) 370 | 371 | eval_metrics[f'{dataset_split}/negative_log_likelihood'].update_state( 372 | negative_log_likelihood) 373 | eval_metrics[f'{dataset_split}/accuracy'].update_state(labels, probs) 374 | 375 | for _ in tf.range(tf.cast(num_steps, tf.int32)): 376 | strategy.run(step_fn, args=(next(iterator),)) 377 | 378 | train_iterator = iter(train_dataset) 379 | grow_iterator = iter(grow_dataset) 380 | 381 | start_time = time.time() 382 | tb_callback = None 383 | if FLAGS.collect_profile: 384 | tb_callback = tf.keras.callbacks.TensorBoard( 385 | profile_batch=(100, 2000), 386 | log_dir=os.path.join(FLAGS.output_dir, 'logs')) 387 | tb_callback.set_model(model) 388 | 389 | for epoch in range(initial_epoch, config.train_epochs): 390 | logging.info('Starting to run epoch: %s', epoch) 391 | if tb_callback: 392 | tb_callback.on_epoch_begin(epoch) 393 | train_start_time = time.time() 394 | train_step(train_iterator, grow_iterator) 395 | train_ms_per_example = (time.time() - train_start_time) * 1e6 / batch_size 396 | 397 | current_step = (epoch + 1) * steps_per_epoch 398 | max_steps = steps_per_epoch * config.train_epochs 399 | time_elapsed = time.time() - start_time 400 | steps_per_sec = float(current_step) / time_elapsed 401 | eta_seconds = (max_steps - current_step) / steps_per_sec 402 | message = ('{:.1%} completion: epoch {:d}/{:d}. {:.1f} steps/s. ' 403 | 'ETA: {:.0f} min. Time elapsed: {:.0f} min'.format( 404 | current_step / max_steps, 405 | epoch + 1, 406 | config.train_epochs, 407 | steps_per_sec, 408 | eta_seconds / 60, 409 | time_elapsed / 60)) 410 | logging.info(message) 411 | if tb_callback: 412 | tb_callback.on_epoch_end(epoch) 413 | 414 | test_iterator = iter(test_dataset) 415 | logging.info('Starting to run eval at epoch: %s', epoch) 416 | test_start_time = time.time() 417 | test_step(test_iterator, 'test', steps_per_eval) 418 | test_ms_per_example = (time.time() - test_start_time) * 1e6 / batch_size 419 | logging.info('Done with eval on') 420 | 421 | logging.info('Test NLL: %.4f, Accuracy: %.2f%%', 422 | eval_metrics['test/negative_log_likelihood'].result(), 423 | eval_metrics['test/accuracy'].result() * 100) 424 | total_results = {name: metric.result() for name, metric 425 | in eval_metrics.items()} 426 | total_results['train/ms_per_example'] = train_ms_per_example 427 | total_results['test/ms_per_example'] = test_ms_per_example 428 | with summary_writer.as_default(): 429 | for name, result in total_results.items(): 430 | tf.summary.scalar(name, result, step=epoch + 1) 431 | 432 | for metric in eval_metrics.values(): 433 | metric.reset_states() 434 | 435 | if (config.checkpoint_interval > 0 and 436 | (epoch + 1) % config.checkpoint_interval == 0): 437 | checkpoint_name = checkpoint.save( 438 | os.path.join(FLAGS.output_dir, 'checkpoint')) 439 | logging.info('Saved checkpoint to %s', checkpoint_name) 440 | 441 | final_checkpoint_name = checkpoint.save( 442 | os.path.join(FLAGS.output_dir, 'checkpoint')) 443 | logging.info('Saved last checkpoint to %s', final_checkpoint_name) 444 | 445 | 446 | if __name__ == '__main__': 447 | app.run(main) 448 | -------------------------------------------------------------------------------- /growneuron/growers.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 GradMax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """This module implements various growing algorithms. 17 | """ 18 | import functools 19 | import logging 20 | import growneuron.layers as glayers 21 | import numpy as np 22 | from scipy.sparse.linalg.eigen import arpack 23 | import tensorflow as tf 24 | 25 | 26 | class LayerGrower(): 27 | """Base class for growing layer algorithms. 28 | 29 | Subclasses should implement grow_neurons. 30 | grad_fn: Should return list of variables and return aggregated gradients. 31 | grow_layers: list of GrowLayers. There are often 2 layers. First is the 32 | one we are adding neurons and the second is the layer that consumes 33 | neurons from the first layer. However in some architectures there could 34 | be some layers in between that transforms channel-wise information 35 | independetly: like Batchnorm and depth-wise convolutions. In such cases, 36 | we grow identity neurons for the layers inbetween the first and last. 37 | """ 38 | epsilon = 0. 39 | scale_method = 'mean_norm' 40 | strategy = None 41 | compile_fn = lambda: None 42 | loss_fn = lambda x: x 43 | 44 | def grow_neurons(self, grow_layers, batch_data, **kwargs): 45 | raise NotImplementedError() 46 | 47 | 48 | class AddRandom(LayerGrower): 49 | """Implements random growing.""" 50 | is_outgoing_zero = False 51 | is_all_zero = False 52 | 53 | def grow_neurons(self, grow_layers, batch_data, n_new=1, scale=1.): 54 | del batch_data 55 | scales = (self.epsilon, scale) 56 | new_bias = 'zeros' 57 | if self.is_all_zero: 58 | scales = (self.epsilon, self.epsilon) 59 | new_bias = 'ones' 60 | elif self.is_outgoing_zero: 61 | scales = (scale, self.epsilon) 62 | for i, layer in enumerate(grow_layers): 63 | if i == 0: 64 | # First layer 65 | layer.add_neurons(n_new, new_weights='random', is_outgoing=False, 66 | scale=scales[0], scale_method=self.scale_method, 67 | new_bias=new_bias) 68 | elif i == (len(grow_layers) - 1): 69 | # Last layer 70 | layer.add_neurons(n_new, new_weights='random', is_outgoing=True, 71 | scale=scales[1], scale_method=self.scale_method) 72 | else: 73 | if isinstance(layer, glayers.GrowLayer): 74 | layer.add_neurons_identity(n_new) 75 | 76 | 77 | class AddFirefly(AddRandom): 78 | """Implements Firefly style growing using direct optimization. 79 | 80 | Implements Eq:4 from the paper without extra candidates and splitting. 81 | https://arxiv.org/abs/2102.08574 82 | """ 83 | optim_n_step = 100 84 | optim_fn = lambda self: tf.keras.optimizers.Adam() 85 | 86 | def grow_neurons(self, grow_layers, batch_data, n_new=1, scale=1.): 87 | n_old_neuron = grow_layers[0].weights[0].shape[-1] 88 | # First add neurons randomly 89 | super().grow_neurons(grow_layers, batch_data, n_new=n_new, scale=scale) 90 | self.compile_fn() 91 | # Now optimize the random initialization 92 | layer_tuple = grow_layers[0], grow_layers[-1] 93 | 94 | optimizer = self.optim_fn() 95 | target_magnitudes = [] 96 | # Record the magnitude of the new_weights. 97 | for concat_axis, layer in zip([-1, -2], layer_tuple): 98 | _, new_weights = tf.split(layer.weights[0], [n_old_neuron, -1], 99 | axis=concat_axis) 100 | target_magnitudes.append( 101 | np.mean(glayers.norm_l2(new_weights, keep_dim=concat_axis))) 102 | logging.info('Minimizing loss.') 103 | weights = [l.weights[0] for l in layer_tuple] 104 | 105 | @tf.function 106 | def update_fn(inputs): 107 | with tf.GradientTape() as tape: 108 | loss = self.loss_fn(inputs) 109 | grads = tape.gradient(loss, weights) 110 | masked_grads = [] 111 | for concat_axis, grad in zip([-1, -2], grads): 112 | # Apply gradient only on new weights, zero out the rest. 113 | old_wgrad, new_wgrad = tf.split(grad, [n_old_neuron, -1], 114 | axis=concat_axis) 115 | masked_grad = tf.concat([tf.zeros_like(old_wgrad), new_wgrad], 116 | axis=concat_axis) 117 | masked_grads.append(masked_grad) 118 | optimizer.apply_gradients(zip(masked_grads, weights)) 119 | # Project new weights back to the target magnitude. 120 | return loss 121 | 122 | log_freq = self.optim_n_step // 10 123 | for i in range(self.optim_n_step): 124 | per_replica_losses = self.strategy.run(update_fn, args=(batch_data,)) 125 | loss = self.strategy.reduce( 126 | tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None) 127 | if i % log_freq == 0: 128 | logging.info('Firefly iter: %d, loss: %s', i, loss) 129 | for concat_axis, weight, target_magnitude in zip( 130 | [-1, -2], weights, target_magnitudes): 131 | old_w, new_w = tf.split(weight, [n_old_neuron, -1], 132 | axis=concat_axis) 133 | normalized_w = glayers.normalize_l2(new_w, axis=concat_axis) 134 | normalized_new_w = normalized_w * target_magnitude 135 | weight.assign( 136 | tf.concat([old_w, normalized_new_w], axis=concat_axis)) 137 | logging.info('Firefly final loss: %s', loss.numpy()) 138 | 139 | 140 | class AddGradmaxOptim(AddRandom): 141 | """Implements Gradmax using direct optimization.""" 142 | optim_n_step = 100 143 | optim_fn = lambda self: tf.keras.optimizers.Adam() 144 | 145 | def grow_neurons(self, grow_layers, batch_data, n_new=1, scale=1.): 146 | # For simplicity we do full backward and forward pass here, but note that 147 | # only thing we need here is inputs at l-1 and gradients at l+1. Those stay 148 | # same and don't need to be re-calculated each time. 149 | n_old_neuron = grow_layers[0].weights[0].shape[-1] 150 | # First add neurons randomly 151 | super().grow_neurons(grow_layers, batch_data, n_new=n_new, scale=scale) 152 | self.compile_fn() 153 | # Now optimize the random initialization 154 | if self.is_outgoing_zero: 155 | # We optimize incoming weights 156 | optim_layer, grad_layer = grow_layers[0], grow_layers[-1] 157 | concat_axis = -1 158 | grad_slic_fn = lambda a: a[Ellipsis, n_old_neuron:, :] 159 | else: 160 | # We optimize outgoing weights 161 | optim_layer, grad_layer = grow_layers[-1], grow_layers[0] 162 | concat_axis = -2 163 | grad_slic_fn = lambda a: a[Ellipsis, n_old_neuron:] 164 | 165 | optimizer = self.optim_fn() 166 | target_magnitude = None 167 | # Record the magnitude of the new_weights. 168 | _, new_weights = tf.split(optim_layer.weights[0], [n_old_neuron, -1], 169 | axis=concat_axis) 170 | target_magnitude = np.mean(glayers.norm_l2(new_weights, 171 | keep_dim=concat_axis)) 172 | logging.info('Target magnitude: %s', target_magnitude) 173 | optim_layer_weight = optim_layer.weights[0] 174 | logging.info('Minimizing loss.') 175 | 176 | @tf.function 177 | def update_fn(inputs): 178 | with tf.GradientTape(persistent=True) as tape: 179 | loss = self.loss_fn(inputs) 180 | grad_layer_weight = grad_layer.weights[0] 181 | inner_grad = tape.gradient(loss, grad_layer_weight) 182 | # Maximize gradient norm. 183 | final_loss = -tf.norm(grad_slic_fn(inner_grad)) 184 | grad = tape.gradient(final_loss, optim_layer_weight) 185 | # Apply gradient only on new weights, zero out the rest. 186 | old_wgrad, new_wgrad = tf.split(grad, [n_old_neuron, -1], 187 | axis=concat_axis) 188 | masked_grad = tf.concat([tf.zeros_like(old_wgrad), new_wgrad], 189 | axis=concat_axis) 190 | optimizer.apply_gradients([(masked_grad, optim_layer_weight)]) 191 | return final_loss 192 | 193 | log_freq = self.optim_n_step // 10 194 | for i in range(self.optim_n_step): 195 | per_replica_losses = self.strategy.run(update_fn, args=(batch_data,)) 196 | loss = self.strategy.reduce( 197 | tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None) 198 | if i % log_freq == 0: 199 | logging.info('Gradmax-opt: %d, loss: %s', i, loss) 200 | # Project new weights back to the target magnitude. 201 | # Record the magnitude of the new_weights. 202 | old_w, new_w = tf.split(optim_layer_weight, [n_old_neuron, -1], 203 | axis=concat_axis) 204 | normalized_w = glayers.normalize_l2(new_w, axis=concat_axis) 205 | normalized_new_w = normalized_w * target_magnitude 206 | optim_layer_weight.assign( 207 | tf.concat([old_w, normalized_new_w], axis=concat_axis)) 208 | logging.info('Grad-max-opt final loss: %s', loss.numpy()) 209 | 210 | def _grow_neurons_legacy(self, grow_layers, batch_data, n_new=1, scale=1.): 211 | """Old function to calculate gradmax-opt initialization efficiently.""" 212 | # Note that this version doesn't work currently. 213 | # The issue here is the inputs are shared, thus it is a challenge to 214 | # uptained them and reshard. This path is efficient but might be 215 | # unncessearily complicated, thus we do full pass like above. 216 | logging.warning('This function is not doing the right thing in multi-worker' 217 | 'setting.') 218 | n_old_neuron = grow_layers[0].weights[0].shape[-1] 219 | # First get the output gradient at l+1 and the input at l-1. 220 | aux_tensor = [] 221 | # For simplicity we do full backward and forward pass here, but note that 222 | # only thing we need here is inputs at l-1 and gradients at l+1. Those stay 223 | # same and don't need to be re-calculated each time. 224 | def next_layer_callback(next_inputs, next_outputs): 225 | aux_tensor.append(tf.zeros_like(next_outputs)) 226 | return next_inputs, (next_outputs + aux_tensor[-1]) 227 | grow_layers[-1].add_callback('add_zeros', next_layer_callback) 228 | inp_tensor = [] 229 | def first_layer_callback(next_inputs, next_outputs): 230 | inp_tensor.append(next_inputs) 231 | return next_inputs, next_outputs 232 | grow_layers[0].add_callback('collect_inp', first_layer_callback) 233 | 234 | def grad_fn(inputs): 235 | with tf.GradientTape() as tape: 236 | loss = self.loss_fn(inputs) 237 | return tape.gradient(loss, aux_tensor[0]) 238 | per_replica_grads = self.strategy.run(grad_fn, args=(batch_data,)) 239 | out_grads = self.strategy.reduce( 240 | tf.distribute.ReduceOp.SUM, per_replica_grads, axis=None) 241 | 242 | # Second add neurons randomly 243 | super().grow_neurons(grow_layers, batch_data, n_new=n_new, 244 | scale=scale) 245 | self.compile_fn() 246 | # Now optimize the random initialization 247 | if self.is_outgoing_zero: 248 | # We optimize incoming weights 249 | optim_layer, grad_layer = grow_layers[0], grow_layers[-1] 250 | concat_axis = -1 251 | grad_slic_fn = lambda a: a[Ellipsis, n_old_neuron:, :] 252 | else: 253 | # We optimize outgoing weights 254 | optim_layer, grad_layer = grow_layers[-1], grow_layers[0] 255 | concat_axis = -2 256 | grad_slic_fn = lambda a: a[Ellipsis, n_old_neuron:] 257 | 258 | optimizer = self.optim_fn() 259 | target_magnitude = None 260 | # Record the magnitude of the new_weights. 261 | _, new_weights = tf.split(optim_layer.weights[0], [n_old_neuron, -1], 262 | axis=concat_axis) 263 | target_magnitude = np.mean(glayers.norm_l2(new_weights, 264 | keep_dim=concat_axis)) 265 | logging.info('Target magnitude: %s', target_magnitude) 266 | optim_layer_weight = optim_layer.weights[0] 267 | 268 | @tf.function 269 | def update_fn(inp_tensor, out_grads): 270 | with tf.GradientTape(persistent=True) as tape: 271 | x = inp_tensor 272 | for l in grow_layers: 273 | x = l(x, training=True) 274 | # This simulates having output grads at the end. But it is way more 275 | # efficient as we don't need to run the input again through the whole 276 | # network. 277 | # dL/dx = out_grads because grad_x(x * y) = y 278 | loss = tf.reduce_sum(x*out_grads) 279 | grad_layer_weight = grad_layer.weights[0] 280 | inner_grad = tape.gradient(loss, grad_layer_weight) 281 | # Maximize gradient norm. 282 | final_loss = -tf.norm(grad_slic_fn(inner_grad)) 283 | grad = tape.gradient(final_loss, optim_layer_weight) 284 | # Apply gradient only on new weights, zero out the rest. 285 | old_wgrad, new_wgrad = tf.split(grad, [n_old_neuron, -1], 286 | axis=concat_axis) 287 | masked_grad = tf.concat([tf.zeros_like(old_wgrad), new_wgrad], 288 | axis=concat_axis) 289 | optimizer.apply_gradients([(masked_grad, optim_layer_weight)]) 290 | return final_loss 291 | logging.info('Maximizing gradients') 292 | log_freq = self.optim_n_step // 10 293 | for i in range(self.optim_n_step): 294 | per_replica_losses = self.strategy.run(update_fn, 295 | args=(inp_tensor[0], out_grads)) 296 | loss = self.strategy.reduce( 297 | tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None) 298 | if i % log_freq == 0: 299 | logging.info('GradmaxOptim iter: %d, loss: %s', i, loss.numpy()) 300 | # Project new weights back to the target magnitude. 301 | # Record the magnitude of the new_weights. 302 | old_w, new_w = tf.split(optim_layer_weight, [n_old_neuron, -1], 303 | axis=concat_axis) 304 | normalized_w = glayers.normalize_l2(new_w, axis=concat_axis) 305 | normalized_new_w = normalized_w * target_magnitude 306 | optim_layer_weight.assign( 307 | tf.concat([old_w, normalized_new_w], axis=concat_axis)) 308 | logging.info('Final Grad-Norm: %f', -loss.numpy()) 309 | 310 | 311 | class AddGradmax(LayerGrower): 312 | """Implements Gradmax using auxiliary layer formulation.""" 313 | 314 | def grow_neurons(self, grow_layers, batch_data, n_new=1, scale=1.): 315 | if len(grow_layers) == 2: 316 | current_layer, next_layer = grow_layers 317 | identity_layers = [] 318 | else: 319 | assert len(grow_layers) > 2 320 | current_layer, next_layer = grow_layers[0], grow_layers[-1] 321 | identity_layers = grow_layers[1:-1] 322 | # There is only one candidate 323 | growth_candidates = [(current_layer, next_layer)] 324 | if not isinstance(current_layer.layer, type(next_layer.layer)): 325 | # This is a temporary fix for dealing with heteregonous layers. 326 | # When two consecutive layers are different we grow randomly. 327 | # For example when a convolutional layer is followed by a fully connected 328 | # layer. 329 | logging.info('Growing randomly layers: %s %s, %s %s', 330 | current_layer.layer.name, type(current_layer.layer), 331 | next_layer.layer.name, type(next_layer.layer)) 332 | tmp_grower = AddRandom() 333 | tmp_grower.epsilon = self.epsilon 334 | tmp_grower.scale_method = self.scale_method 335 | tmp_grower.is_outgoing_zero = False 336 | tmp_grower.grow_neurons(grow_layers, batch_data, n_new=n_new, scale=scale) 337 | return 338 | unused_eigenvals, eigenvecs = self.get_growth_directions( 339 | batch_data, growth_candidates, [n_new])[0] 340 | # Grow incoming connections 341 | current_layer.add_neurons(n_new, new_weights='random', is_outgoing=False, 342 | scale=self.epsilon, 343 | scale_method=self.scale_method) 344 | # Initialize intermediate layers as identity. 345 | for layer in identity_layers: 346 | if isinstance(layer, glayers.GrowLayer): 347 | layer.add_neurons_identity(n_new) 348 | # Top-k Eigenvectors 349 | new_weights = eigenvecs[:, :n_new] 350 | c_shape = next_layer.weights[0].shape 351 | if len(c_shape) == 4: 352 | # First reshape each neuron and then transpose last 2 dimensions. 353 | new_filter_shape = c_shape[:2] + [c_shape[-1], n_new] 354 | new_weights = np.reshape(new_weights, new_filter_shape) 355 | new_weights = np.transpose(new_weights, axes=(0, 1, 3, 2)) 356 | elif len(c_shape) == 2: 357 | new_weights = new_weights.T 358 | 359 | next_layer.add_neurons(n_new, new_weights=new_weights, scale=scale, 360 | is_outgoing=True, 361 | scale_method=self.scale_method) 362 | 363 | def get_growth_directions(self, batch_data, growth_candidates, n_grows): 364 | """Efficiently retrieves eigen-decomposition for a set of candidates.""" 365 | # Adding all callbacks. 366 | aux_layers = [] 367 | post_process_fns = [] 368 | for current_layer, next_layer in growth_candidates: 369 | aux_layer, post_process_fn = self.get_aux_layer(current_layer.layer, 370 | next_layer.layer) 371 | post_process_fns.append(post_process_fn) 372 | def grow_layer_callback(inputs, outputs, aux_layer=aux_layer, 373 | next_layer=next_layer): 374 | add_h = aux_layer(inputs) 375 | def next_layer_callback(next_inputs, next_outputs): 376 | return next_inputs, (next_outputs + add_h) 377 | next_layer.add_callback('add_aux', next_layer_callback) 378 | return inputs, outputs 379 | current_layer.add_callback('pass_aux', grow_layer_callback) 380 | aux_layers.append(aux_layer) 381 | 382 | def grad_fn(inputs): 383 | with tf.GradientTape() as tape: 384 | loss = self.loss_fn(inputs) 385 | grad_vars = [aux_layer.weights[0] for aux_layer in aux_layers] 386 | return tape.gradient(loss, grad_vars) 387 | per_replica_grads = self.strategy.run(grad_fn, 388 | args=(batch_data,)) 389 | aux_grads = self.strategy.reduce( 390 | tf.distribute.ReduceOp.SUM, per_replica_grads, axis=None) 391 | grow_matrices = [ 392 | post_process_fn(g) 393 | for g, post_process_fn in zip(aux_grads, post_process_fns)] 394 | # Reset Callbacks 395 | for current_layer, next_layer in growth_candidates: 396 | current_layer.reset_callbacks() 397 | next_layer.reset_callbacks() 398 | results = [] 399 | # Calculate eigenvalues 400 | for grow_matrix, n_grow in zip(grow_matrices, n_grows): 401 | # M^{l+1} by M^{l+1} 402 | if n_grow > 0: 403 | # svds is equivalent to calling eigsh on M.T @ M (without materialiazing 404 | # this matrix) which is faster or slower (depending on the shape of M) 405 | _, s, vh = arpack.svds(grow_matrix, k=n_grow, 406 | return_singular_vectors='vh') 407 | eigenvals, eigenvecs = (s**2)[::-1], vh[::-1].T 408 | else: 409 | s, _, v = tf.linalg.svd(grow_matrix) 410 | eigenvals, eigenvecs = s**2, v 411 | results.append((eigenvals, eigenvecs)) 412 | return results 413 | 414 | def get_aux_layer(self, first_layer, second_layer): 415 | """Creates auxilarly layers for growing new neurons between layers.""" 416 | l = tf.keras.layers 417 | if isinstance(first_layer, l.Dense) and isinstance(second_layer, l.Dense): 418 | aux_layer = l.Dense(second_layer.units, activation=None, use_bias=False, 419 | kernel_initializer='zeros') 420 | post_process_fn = lambda a: a 421 | elif (isinstance(first_layer, l.Conv2D) and 422 | isinstance(second_layer, l.Conv2D)): 423 | # Combined auxiliary kernel would be the size of k1+k2-1. 424 | kernel_size = [k1+k2-1 for k1, k2 in 425 | zip(first_layer.kernel_size, second_layer.kernel_size)] 426 | # The auxiliary layer should have the combined stride. 427 | # Current implementation assumes tuple strides. 428 | strides = [(s1 + s2) if ((s1 > 1) and (s2 > 1)) else (s1 + s2 -1) 429 | for s1, s2 in zip(first_layer.strides, second_layer.strides)] 430 | # Current implementation assumes paddings are same for the 2 layers. 431 | aux_layer = l.Conv2D(second_layer.filters, kernel_size, activation=None, 432 | use_bias=False, padding=first_layer.padding, 433 | kernel_initializer='zeros', strides=strides) 434 | post_process_fn = functools.partial( 435 | process_conv_aux_gradient, 436 | second_kernel_size=second_layer.kernel_size) 437 | else: 438 | raise ValueError('Not Supported') 439 | 440 | return aux_layer, post_process_fn 441 | 442 | 443 | def process_conv_aux_gradient(grad, second_kernel_size): 444 | """Process the gradients of convolutional layer to generate grow matrix.""" 445 | # shape(grad): ksize X ksize X m0 X m2 ; ksize=k1+k2-1 446 | # second_kernel_size == k2 447 | grad = tf.transpose(grad, perm=(2, 0, 1, 3)) 448 | # shape(grad): m0 X ksize X ksize X m2 449 | patched_grow_matrix = extract_image_patches(grad, second_kernel_size) 450 | # shape(patched_grow_matrix): m0 X k1 X k1 X (m2 * k2 * k2) 451 | grow_matrix = tf.reshape(patched_grow_matrix, 452 | [-1, patched_grow_matrix.shape[-1]]) 453 | # shape(patched_grow_matrix): (m0 * k1 * k1) X (m2 * k2 * k2) 454 | return grow_matrix 455 | 456 | 457 | def extract_image_patches(x, kernel_size, stride=(1, 1)): 458 | """Extract convolutional patches from the layer. 459 | 460 | Manual replacement of tf.extract_image_patches, since its gradient cannot 461 | be evaluated on TPU. 462 | 463 | Args: 464 | x: batched input data. Size: [batch, in_height, in_width, in_channels] 465 | kernel_size: Tuple of two integers. Size of kernel. 466 | stride: Tuple of two integers. Stride size. 467 | 468 | Returns: 469 | 4D Tensor (batch, in_rows, in_cols, patch_size) of extracted patches. 470 | """ 471 | in_channels = x.get_shape()[3] 472 | kh, kw = kernel_size 473 | tile_filter = np.zeros(shape=[kh, kw, in_channels, kh * kw], dtype=np.float32) 474 | for i in range(kh): 475 | for j in range(kw): 476 | tile_filter[i, j, :, i * kw + j] = 1.0 477 | 478 | tile_filter_op = tf.constant(tile_filter, dtype=tf.float32) 479 | output = tf.nn.depthwise_conv2d( 480 | x, tile_filter_op, strides=[1, *stride, 1], padding='VALID') 481 | # reshaping below is needed so that 4th dimension of the output can be 482 | # reshaped into kernel[0] * kernel[1] * in_channels. 483 | batch, in_rows, in_cols, _ = output.get_shape() 484 | output = tf.reshape( 485 | output, shape=[batch, in_rows, in_cols, in_channels, kh * kw]) 486 | output = tf.transpose(output, perm=[0, 1, 2, 4, 3]) 487 | output = tf.reshape(output, [batch, in_rows, in_cols, -1]) 488 | 489 | return output 490 | --------------------------------------------------------------------------------