├── imgs
└── gradmax.png
├── growneuron
├── cifar
│ ├── configs
│ │ ├── README.md
│ │ ├── __init__.py
│ │ ├── baseline_big.py
│ │ ├── baseline_big_vgg.py
│ │ ├── baseline_small_vgg.py
│ │ ├── grow_all_at_once.py
│ │ ├── grow_round_robin.py
│ │ ├── grow_all_at_once_vgg.py
│ │ └── baseline_small.py
│ ├── __init__.py
│ ├── data.py
│ ├── vgg.py
│ ├── wide_resnet.py
│ └── main.py
├── imagenet
│ ├── __init__.py
│ ├── configs
│ │ ├── __init__.py
│ │ ├── baseline_big.py
│ │ ├── grow_all_at_once.py
│ │ └── baseline_small.py
│ ├── data.py
│ ├── mb_v1.py
│ ├── data_util.py
│ └── main.py
├── __init__.py
├── updaters_test.py
├── layers_test.py
├── updaters.py
├── layers.py
└── growers.py
├── CONTRIBUTING.md
├── run.sh
├── setup.py
├── README.md
└── LICENSE
/imgs/gradmax.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google-research/growneuron/HEAD/imgs/gradmax.png
--------------------------------------------------------------------------------
/growneuron/cifar/configs/README.md:
--------------------------------------------------------------------------------
1 | Here we explain main configs for WRN-28-1:
2 |
3 | - `baseline_small` does small 25% block-width dense training.
4 | - `baseline_big` does regular width dense training.
5 | - `grow` starts with 25% block-width dense model and grows the model in to the
6 | regular width 10% at a time.
7 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # How to Contribute
2 |
3 | We'd love to accept your patches and contributions to this project.
4 |
5 | - If you want to contribute to the library please check `Issues` tab and feel
6 | free to take on any problem/issue you find interesting.
7 | - If your `issue` is not reported yet, please create a new one. It is
8 | important to discuss the problem/request before implementing the solution.
9 |
10 | ## Code reviews
11 |
12 | All submissions, including submissions by project members, require review. We
13 | use GitHub pull requests for this purpose. Consult
14 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more
15 | information on using pull requests.
16 |
--------------------------------------------------------------------------------
/growneuron/cifar/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 GradMax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """This folder involves the code for cifar experiments."""
17 |
--------------------------------------------------------------------------------
/growneuron/cifar/configs/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 GradMax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """This folder involves configs for cifar experiments."""
17 |
--------------------------------------------------------------------------------
/growneuron/imagenet/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 GradMax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """This folder involves the code for imagenet experiments."""
17 |
--------------------------------------------------------------------------------
/growneuron/imagenet/configs/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 GradMax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """This folder involves configs for imagenet experiments."""
17 |
--------------------------------------------------------------------------------
/growneuron/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 GradMax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """This repo involves the code for gradient maximizing growth."""
17 | name = 'growneuron'
18 |
--------------------------------------------------------------------------------
/run.sh:
--------------------------------------------------------------------------------
1 | # Copyright 2022 GradMax Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | #!/bin/bash
16 | set -e
17 | set -x
18 |
19 | pip install .
20 | python -m growneuron.layers_test
21 | python -m growneuron.updaters_test
22 |
--------------------------------------------------------------------------------
/growneuron/cifar/configs/baseline_big.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 GradMax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Default config for big baseline training."""
17 | from growneuron.cifar.configs import baseline_small
18 |
19 |
20 | def get_config():
21 | """Builds and returns config."""
22 | config = baseline_small.get_config()
23 | config.model.block_width_multiplier = 1.
24 | return config
25 |
--------------------------------------------------------------------------------
/growneuron/imagenet/configs/baseline_big.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 GradMax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Default config for big baseline training."""
17 | from growneuron.imagenet.configs import baseline_small
18 |
19 |
20 | def get_config():
21 | """Builds and returns config."""
22 | config = baseline_small.get_config()
23 | config.model.width_multiplier = 1.
24 | return config
25 |
--------------------------------------------------------------------------------
/growneuron/cifar/configs/baseline_big_vgg.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 GradMax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Default config for big baseline training."""
17 | from growneuron.cifar.configs import baseline_small_vgg
18 |
19 |
20 | def get_config():
21 | """Builds and returns config."""
22 | config = baseline_small_vgg.get_config()
23 | config.model.width_multiplier = 1
24 | return config
25 |
--------------------------------------------------------------------------------
/growneuron/cifar/configs/baseline_small_vgg.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 GradMax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Default configs for baseline cifar-10 training."""
17 | from growneuron.cifar.configs import baseline_small
18 | import ml_collections
19 |
20 |
21 | def get_config():
22 | """Builds and returns config."""
23 | config = baseline_small.get_config()
24 |
25 | config.architecture = 'vgg'
26 | config.model = ml_collections.ConfigDict()
27 | config.model.depth = 11
28 | config.model.normalization_type = 'none'
29 | config.model.width_multiplier = 0.25
30 | config.optimizer.base_learning_rate = 0.05
31 |
32 | return config
33 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | """growneuron library setup."""
2 |
3 | import pathlib
4 | from setuptools import find_packages
5 | from setuptools import setup
6 |
7 | here = pathlib.Path(__file__).parent.resolve()
8 |
9 | long_description = (here / 'README.md').read_text(encoding='utf-8')
10 | setup(
11 | name='growneuron',
12 | version='0.1',
13 | description='Gradmax, gradient maximizing neural network growth.',
14 | long_description=long_description,
15 | long_description_content_type='text/markdown',
16 | url='https://github.com/google-research/growneuron',
17 | author='Google LLC',
18 | license='Apache 2.0',
19 | packages=find_packages(),
20 | package_data={},
21 | scripts=['growneuron/cifar/main.py', 'growneuron/imagenet/main.py'],
22 | classifiers=[
23 | 'Development Status :: 4 - Beta',
24 | 'Intended Audience :: Developers',
25 | 'Intended Audience :: Science/Research',
26 | 'License :: OSI Approved :: Apache Software License',
27 | 'Topic :: Scientific/Engineering :: Artificial Intelligence',
28 | ],
29 | keywords=('neural networks tensorflow machine learning growing growth'
30 | 'gradmax google convolutional during training'),
31 | install_requires=[
32 | 'absl-py',
33 | 'numpy',
34 | 'ml-collections',
35 | 'tensorflow==2.7',
36 | 'scipy==1.7.3',
37 | 'tfds-nightly',
38 | ('uncertainty_baselines @ git+https://github.com/google/'
39 | 'uncertainty-baselines.git#egg=uncertainty_baselines'),
40 | ],
41 | )
42 |
--------------------------------------------------------------------------------
/growneuron/imagenet/configs/grow_all_at_once.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 GradMax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Default config for random growing."""
17 | from growneuron.imagenet.configs import baseline_small
18 |
19 |
20 | def get_config():
21 | """Builds and returns config."""
22 | config = baseline_small.get_config()
23 | config.updater_type = 'all_at_once'
24 | config.grow_type = 'add_random'
25 | config.grow_epsilon = 0.
26 | config.is_outgoing_zero = False
27 | config.grow_scale_method = 'mean_norm'
28 | config.model.normalization_type = 'none'
29 | config.updater.carry_optimizer = True
30 |
31 | config.updater.update_frequency = 4000
32 | config.grow_frequency_multiplier = 1.
33 | config.updater.start_iteration = 10000
34 | # 1 cyle is 12 growth steps.
35 | config.updater.n_growth_steps = 12
36 | # Use one of the following
37 | # config.updater.n_grow = 2
38 | config.updater.n_grow_fraction = 0.25
39 | config.updater.scale = 0.5
40 |
41 | return config
42 |
--------------------------------------------------------------------------------
/growneuron/cifar/configs/grow_all_at_once.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 GradMax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Default config for random growing."""
17 | from growneuron.cifar.configs import baseline_small
18 |
19 |
20 | def get_config():
21 | """Builds and returns config."""
22 | config = baseline_small.get_config()
23 | config.updater_type = 'all_at_once'
24 | config.grow_type = 'add_random'
25 | config.grow_batch_size = 128
26 | config.grow_epsilon = 0.
27 | config.is_outgoing_zero = False
28 | config.grow_scale_method = 'mean_norm'
29 | config.model.normalization_type = 'none'
30 | config.updater.carry_optimizer = True
31 |
32 | # We are aiming 12*2500=30000 steps growth period.
33 | config.updater.update_frequency = 2500
34 | config.updater.start_iteration = 10000
35 | # 1 cyle is 12 growth steps.
36 | config.updater.n_growth_steps = 12 # 12 * 12 cycle
37 | # Use one of the following
38 | # config.updater.n_grow = 2
39 | config.updater.n_grow_fraction = 0.25
40 | config.updater.scale = 0.5
41 |
42 | return config
43 |
--------------------------------------------------------------------------------
/growneuron/cifar/configs/grow_round_robin.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 GradMax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Default config for random growing."""
17 | from growneuron.cifar.configs import baseline_small
18 |
19 |
20 | def get_config():
21 | """Builds and returns config."""
22 | config = baseline_small.get_config()
23 | config.updater_type = 'round_robin'
24 | config.grow_type = 'add_random'
25 | config.grow_batch_size = 128
26 | config.grow_epsilon = 0.
27 | config.is_outgoing_zero = False
28 | config.grow_scale_method = 'mean_norm'
29 | config.model.normalization_type = 'none'
30 | config.updater.carry_optimizer = True
31 |
32 | # We are aiming 144*200=28800 steps growth period
33 | config.updater.update_frequency = 150
34 | config.updater.start_iteration = 10000
35 | config.scale_epochs = True
36 | # 1 cyle is 12 growth steps.
37 | config.updater.n_growth_steps = 144 # 12 * 12 cycle
38 | # Use one of the following
39 | # config.updater.n_grow = 2
40 | config.updater.n_grow_fraction = 0.25
41 | config.updater.scale = 0.5
42 |
43 | return config
44 |
--------------------------------------------------------------------------------
/growneuron/cifar/configs/grow_all_at_once_vgg.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 GradMax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Default config for random growing."""
17 | from growneuron.cifar.configs import baseline_small_vgg
18 |
19 |
20 | def get_config():
21 | """Builds and returns config."""
22 | config = baseline_small_vgg.get_config()
23 | config.updater_type = 'all_at_once'
24 | config.grow_type = 'add_random'
25 | config.grow_batch_size = 128
26 | config.grow_epsilon = 0.
27 | config.is_outgoing_zero = False
28 | config.grow_scale_method = 'fixed'
29 | config.model.normalization_type = 'none'
30 | config.updater.carry_optimizer = True
31 |
32 | # We are aiming 12*2500=30000 steps growth period.
33 | config.updater.update_frequency = 2500
34 | config.updater.start_iteration = 10000
35 | config.scale_epochs = False
36 | # 1 cyle is 12 growth steps.
37 | config.updater.n_growth_steps = 12 # 12 cycle
38 | # Use one of the following
39 | # config.updater.n_grow = 2
40 | config.updater.n_grow_fraction = 0.25
41 | config.updater.scale = 0.5
42 |
43 | return config
44 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # GradMax: Growing Neural Networks using Gradient Information
2 |
3 | [](https://colab.research.google.com/github/google-research/growneuron/blob/main/Student_Teacher.ipynb)
4 |
5 | Code for reproducing our results in the GradMax paper [[arxiv.org/abs/2201.05125](https://arxiv.org/abs/2201.05125)].
6 |
7 |
8 |
9 |
10 | ## Setup
11 | First clone this repo.
12 | ```bash
13 | git clone https://github.com/google-research/growneuron.git
14 | cd growneuron
15 | ```
16 |
17 | Following script installs the necessary libraries and runs few tests.
18 | ```bash
19 | bash run.sh
20 | ```
21 |
22 | Following will download the data and run the baseline experiment. If the
23 | data is already downloaded use the `--data_dir` flag to pass the path.
24 | ```bash
25 | python growneuron/cifar/main.py --output_dir=/tmp/cifar --download_data
26 | ```
27 |
28 | ## Running GradMax
29 | Following command would start a training with WRN-28-0.25x and grow it into
30 | WRN-28-1. Growth is done every 2500 step starting from
31 | iteration 10000 at all convolutional layers at once.
32 | ```bash
33 | rm -rf /tmp/cifar
34 | python growneuron/cifar/main.py --output_dir=/tmp/cifar \
35 | --config=growneuron/cifar/configs/grow_all_at_once.py \
36 | --config.grow_type=add_gradmax
37 | ```
38 |
39 | ## Other Experiments
40 | - Baselines for WRN-28 and VGG11 can be ran using the corresponding configs in
41 | `growneuron/cifar/configs/`.
42 | - Set `--config.grow_type` argument to `add_gradmax`, `add_firefly`, `add_gradmax_opt` or
43 | `add_random` to grow using different strategies.
44 | - Use `--config.is_outgoing_zero` to run experiments where outgoing weights
45 | are set to zero.
46 | - Use `--config.model.normalization_type=batchnorm` to run experiments with
47 | batch normalization layers.
48 |
49 |
50 |
51 | ## Disclaimer
52 | This is not an officially supported Google product.
53 |
--------------------------------------------------------------------------------
/growneuron/imagenet/configs/baseline_small.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 GradMax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Default configs for baseline cifar-10 training."""
17 | import ml_collections
18 |
19 |
20 | def get_config():
21 | """Builds and returns config."""
22 | config = ml_collections.ConfigDict()
23 |
24 | config.optimizer = ml_collections.ConfigDict()
25 | # Base learning rate when total batch size is 128. It is scaled by the ratio
26 | # of the total batch size to 128.
27 | config.optimizer.base_learning_rate = 0.05
28 | # One of 'step', 'cosine'
29 | config.optimizer.decay_type = 'step'
30 | config.optimizer.nesterov = False
31 | # Amount to decay learning rate.
32 | config.optimizer.lr_decay_ratio = 0.1
33 | # Epochs to decay learning rate by.
34 | config.optimizer.lr_decay_epochs = [0.3, 0.6, 0.8]
35 | # Number of epochs for a linear warmup to the initial learning rate. Use 0 to'
36 | # do no warmup.
37 | config.optimizer.lr_warmup_epochs = 5
38 | # Optimizer momentum.
39 | config.optimizer.momentum = 0.9
40 | # Following is empty for the baselines and used by the growing algorithms.
41 | config.updater = ml_collections.ConfigDict()
42 | config.updater.carry_optimizer = False
43 | config.is_outgoing_zero = False
44 | config.scale_epochs = False
45 |
46 | config.model = ml_collections.ConfigDict()
47 | # L2 regularization coefficient.
48 | config.model.l2_coef = 1e-4
49 | config.model.width_multiplier = 0.25
50 | config.model.normalization_type = 'batchnorm'
51 |
52 | # Number of epochs between saving checkpoints. Use -1 for no checkpoints.
53 | config.checkpoint_interval = 25
54 | config.dataset = 'imagenet2012'
55 | # WBatch size per TPU core/GPU. The number of new datapoints gathered per
56 | # batch is this number divided by ensemble_size (we tile the batch by that #
57 | # of times).
58 | config.per_core_batch_size = 64
59 | config.num_cores = 1
60 | config.seed = 8
61 | config.train_epochs = 90
62 | config.log_freq = 200
63 |
64 | return config
65 |
--------------------------------------------------------------------------------
/growneuron/cifar/configs/baseline_small.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 GradMax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Default configs for baseline cifar-10 training."""
17 | import ml_collections
18 |
19 |
20 | def get_config():
21 | """Builds and returns config."""
22 | config = ml_collections.ConfigDict()
23 |
24 | config.optimizer = ml_collections.ConfigDict()
25 | # Base learning rate when total batch size is 128. It is scaled by the ratio
26 | # of the total batch size to 128.
27 | config.optimizer.base_learning_rate = 0.1
28 | # One of 'step', 'cosine'
29 | config.optimizer.decay_type = 'cosine'
30 | config.optimizer.nesterov = True
31 | # Amount to decay learning rate.
32 | config.optimizer.lr_decay_ratio = 0.2
33 | # Epochs to decay learning rate by.
34 | config.optimizer.lr_decay_epochs = [0.3, 0.6, 0.8]
35 | # Number of epochs for a linear warmup to the initial learning rate. Use 0 to'
36 | # do no warmup.
37 | config.optimizer.lr_warmup_epochs = 1
38 | # Optimizer momentum.
39 | config.optimizer.momentum = 0.9
40 | # Following is empty for the baselines and used by the growing algorithms.
41 | config.updater = ml_collections.ConfigDict()
42 | config.updater.carry_optimizer = False
43 | config.is_outgoing_zero = False
44 | config.scale_epochs = False
45 |
46 | config.model = ml_collections.ConfigDict()
47 | # L2 regularization coefficient.
48 | config.model.l2_coef = 2e-4
49 | config.model.depth = 28
50 | config.model.width_multiplier = 1
51 | config.model.normalization_type = 'none'
52 | config.model.block_width_multiplier = 0.25
53 |
54 | # Number of epochs between saving checkpoints. Use -1 for no checkpoints.
55 | config.checkpoint_interval = 25
56 | # One of ['cifar10', 'cifar100']
57 | config.dataset = 'cifar10'
58 | # Whether to cache the dataset.
59 | config.cache_dataset = True
60 | # WBatch size per TPU core/GPU. The number of new datapoints gathered per
61 | # batch is this number divided by ensemble_size (we tile the batch by that #
62 | # of times).
63 | config.per_core_batch_size = 128
64 | config.num_cores = 1
65 | config.seed = 8
66 | config.train_epochs = 200
67 | config.log_freq = 200
68 |
69 | return config
70 |
--------------------------------------------------------------------------------
/growneuron/cifar/data.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 GradMax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Data pipeline.
17 |
18 | Forked from simclr/tf2 codebase.
19 | """
20 | from typing import Optional
21 | from absl import logging
22 |
23 | import tensorflow.compat.v2 as tf
24 | import tensorflow_datasets as tfds
25 |
26 |
27 | def build_input_fn(
28 | builder,
29 | global_batch_size,
30 | topology,
31 | is_training,
32 | cache_dataset = True):
33 | """Build input function.
34 |
35 | Args:
36 | builder: TFDS builder for specified dataset.
37 | global_batch_size: Global batch size.
38 | topology: An instance of `tf.tpu.experimental.Topology` or None.
39 | is_training: Whether to build in training mode.
40 | cache_dataset: bool, whether to cache the dataset.
41 |
42 | Returns:
43 | A function that accepts a dict of params and returns a tuple of images and
44 | features, to be used as the input_fn in TPUEstimator.
45 | """
46 |
47 | def _input_fn(input_context):
48 | """Inner input function."""
49 | batch_size = input_context.get_per_replica_batch_size(global_batch_size)
50 | logging.info('Global batch size: %d', global_batch_size)
51 | logging.info('Per-replica batch size: %d', batch_size)
52 |
53 | def map_fn(image, label):
54 | """Produces multiple transformations of the same batch."""
55 | if is_training:
56 | image_shape = tf.shape(image)
57 | # Expand the image by 2 pixels, then crop back down to 32x32.
58 | image = tf.image.resize_with_crop_or_pad(
59 | image, image_shape[0] + 4, image_shape[1] + 4)
60 | image = tf.image.random_crop(image, (image_shape[0], image_shape[0], 3))
61 | image = tf.image.random_flip_left_right(image)
62 | image = tf.image.convert_image_dtype(image, tf.float32)
63 | return image, label
64 |
65 | dataset = builder.as_dataset(
66 | split='train' if is_training else 'test',
67 | shuffle_files=is_training,
68 | as_supervised=True)
69 | logging.info('num_input_pipelines: %d', input_context.num_input_pipelines)
70 | # The dataset is always sharded by number of hosts.
71 | # num_input_pipelines is the number of hosts rather than number of cores.
72 | if input_context.num_input_pipelines > 1:
73 | dataset = dataset.shard(input_context.num_input_pipelines,
74 | input_context.input_pipeline_id)
75 | if cache_dataset:
76 | dataset = dataset.cache()
77 | if is_training:
78 | dataset = dataset.shuffle(50000)
79 | dataset = dataset.repeat(-1)
80 | dataset = dataset.map(
81 | map_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE)
82 | dataset = dataset.batch(batch_size, drop_remainder=is_training)
83 | prefetch_buffer_size = 2 * topology.num_tpus_per_task if topology else 2
84 | dataset = dataset.prefetch(prefetch_buffer_size)
85 | return dataset
86 |
87 | return _input_fn
88 |
89 |
90 |
--------------------------------------------------------------------------------
/growneuron/imagenet/data.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 GradMax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Data pipeline.
17 |
18 | Forked from simclr/tf2 codebase.
19 | """
20 | import functools
21 | from typing import Optional
22 | from absl import logging
23 | from growneuron.imagenet import data_util
24 |
25 | import tensorflow.compat.v2 as tf
26 | import tensorflow_datasets as tfds
27 |
28 |
29 | def build_input_fn(
30 | builder,
31 | global_batch_size,
32 | topology,
33 | is_training,
34 | image_size = 224):
35 | """Build input function.
36 |
37 | Args:
38 | builder: TFDS builder for specified dataset.
39 | global_batch_size: Global batch size.
40 | topology: An instance of `tf.tpu.experimental.Topology` or None.
41 | is_training: Whether to build in training mode.
42 | image_size: Size of the output images.
43 |
44 | Returns:
45 | A function that accepts a dict of params and returns a tuple of images and
46 | features, to be used as the input_fn in TPUEstimator.
47 | """
48 |
49 | def _input_fn(input_context):
50 | """Inner input function."""
51 | batch_size = input_context.get_per_replica_batch_size(global_batch_size)
52 | logging.info('Global batch size: %d', global_batch_size)
53 | logging.info('Per-replica batch size: %d', batch_size)
54 |
55 | preprocess_fn = get_preprocess_fn(is_training, image_size)
56 | def map_fn(image, label):
57 | """Produces multiple transformations of the same batch."""
58 | image = preprocess_fn(image)
59 | return image, label
60 |
61 | dataset = builder.as_dataset(
62 | split='train' if is_training else 'validation',
63 | shuffle_files=is_training,
64 | as_supervised=True)
65 | logging.info('num_input_pipelines: %d', input_context.num_input_pipelines)
66 | # The dataset is always sharded by number of hosts.
67 | # num_input_pipelines is the number of hosts rather than number of cores.
68 | if input_context.num_input_pipelines > 1:
69 | dataset = dataset.shard(input_context.num_input_pipelines,
70 | input_context.input_pipeline_id)
71 | if is_training:
72 | buffer_multiplier = 50 if image_size <= 32 else 10
73 | dataset = dataset.shuffle(batch_size * buffer_multiplier)
74 | dataset = dataset.repeat(-1)
75 | dataset = dataset.map(
76 | map_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE)
77 | dataset = dataset.batch(batch_size, drop_remainder=is_training)
78 | prefetch_buffer_size = 2 * topology.num_tpus_per_task if topology else 2
79 | dataset = dataset.prefetch(prefetch_buffer_size)
80 | return dataset
81 |
82 | return _input_fn
83 |
84 |
85 | def get_preprocess_fn(is_training, image_size=224):
86 | """Get function that accepts an image and returns a preprocessed image."""
87 | # Disable test cropping for small images (e.g. CIFAR)
88 | if image_size <= 32:
89 | test_crop = False
90 | else:
91 | test_crop = True
92 | return functools.partial(
93 | data_util.preprocess_image,
94 | image_size=image_size,
95 | is_training=is_training,
96 | test_crop=test_crop)
97 |
--------------------------------------------------------------------------------
/growneuron/updaters_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 GradMax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for growneuron.updaters."""
17 |
18 | import itertools
19 | import absl.testing.parameterized as parameterized
20 | from growneuron import growers
21 | from growneuron import updaters
22 | import tensorflow as tf
23 |
24 |
25 | class RoundRobinScheduleTest(parameterized.TestCase, tf.test.TestCase):
26 |
27 | def setUp(self):
28 | super().setUp()
29 | self.all_layers = []
30 | for _ in range(5):
31 | layer = tf.keras.layers.Dense(2)
32 | layer.build((None, 8))
33 | self.all_layers.append(layer)
34 |
35 | def get_random_grower(self):
36 | return growers.AddRandom()
37 |
38 | def get_grow_layers(self, n=1):
39 | return list(itertools.permutations(self.all_layers, 2))[:n]
40 |
41 | @parameterized.named_parameters(
42 | ('every_100_start0_3times', 100, 0, 3,
43 | [(0, True), (4, False), (100, True), (2300, True), (100, False)]),
44 | ('every_1_start0_21times', 1, 0, 21,
45 | [(0, True), (100, True), (2300, True), (3, True)]),
46 | ('every_1_start2_5times', 1, 2, 4,
47 | [(1, False), (2, True), (100, True), (1, False), (3, True)]),
48 | ('every_100_start50_1times', 100, 50, 1,
49 | [(20, False), (75, False), (200, True), (234, False), (300, False)]),
50 | ('none_start', 5, None, 25,
51 | [(2, False), (5, True), (7, False), (45, True)]),
52 | )
53 | def test_update_iter(self, update_frequency, start_iteration, n_growth_steps,
54 | iterations):
55 | network_grower = self.get_random_grower()
56 | grow_layer_tuples = self.get_grow_layers()
57 |
58 | updater = updaters.RoundRobin(
59 | network_grower,
60 | grow_layer_tuples,
61 | update_frequency=update_frequency,
62 | n_grow=1,
63 | start_iteration=start_iteration,
64 | n_growth_steps=n_growth_steps)
65 | for iteration, bool_val in iterations:
66 | print(f'COUNT:{updater._growth_counter}, {iteration}, {bool_val}')
67 | self.assertEqual(updater.is_update_iteration(iteration), bool_val)
68 | if bool_val:
69 | updater._next_grow_layer_tuple(None)
70 |
71 | @parameterized.named_parameters(('n4', 4), ('n1', 1))
72 | def test_next_grow_layers(self, n_tuples):
73 | network_grower = self.get_random_grower()
74 | grow_layer_tuples = self.get_grow_layers(n=n_tuples)
75 | updater = updaters.RoundRobin(
76 | network_grower, grow_layer_tuples, update_frequency=2)
77 | for tup in itertools.chain(grow_layer_tuples, grow_layer_tuples):
78 | self.assertEqual(tup, updater._next_grow_layer_tuple(None))
79 |
80 | @parameterized.named_parameters(('1d', (2,), (3,)),
81 | ('2d_out', (2, 4), (2, 5)),
82 | ('2d_in', (2, 4), (4, 4)),
83 | # Deptwise kernel
84 | ('3d_in', (4, 4, 5), (4, 4, 8)),
85 | ('4d_in', (4, 4, 2, 3), (4, 4, 4, 3)),
86 | ('4d_out', (4, 4, 2, 3), (4, 4, 2, 5))
87 | )
88 | def test_pad_zeros_to(self, old_shape, new_shape):
89 | tensor = tf.random.uniform(old_shape)
90 | new_tensor = updaters.pad_zeros_to(tensor, new_shape)
91 | old_slice = tuple(slice(None, x) for x in tensor.shape)
92 | self.assertAllEqual(new_tensor[old_slice], tensor)
93 |
94 | @parameterized.named_parameters(
95 | ('dense_outgrown', (3, 4), (3, 5), lambda a, i: a[:, i]),
96 | ('dense_ingrown', (3, 4), (4, 4), lambda a, i: a[i, :]),
97 | ('conv2d_ingrown', (2, 2, 3, 4), (2, 2, 4, 4),
98 | lambda a, i: a[:, :, i, :]),
99 | ('conv2d_outgrown', (2, 2, 3, 4), (2, 2, 3, 5), lambda a, i: a[Ellipsis, i]),
100 | ('conv_dw', (2, 2, 3), (2, 2, 4), lambda a, i: a[:, :, i])
101 | )
102 | def test_copy_adam_slots(self, old_shape, new_shape, slice_fn):
103 | grow_layer_tuples = self.get_grow_layers()
104 | network_grower = self.get_random_grower()
105 | updater = updaters.RoundRobin(
106 | network_grower, grow_layer_tuples, update_frequency=2)
107 | old_var = tf.Variable(tf.ones(old_shape))
108 | new_var = tf.Variable(tf.ones(new_shape))
109 | optimizer = tf.keras.optimizers.Adam()
110 | optimizer._create_slots([old_var, new_var])
111 | random_slot_vals = tf.random.uniform(old_shape)
112 | for s_name in optimizer.get_slot_names():
113 | self.assertAllEqual(optimizer.get_slot(new_var, s_name),
114 | tf.zeros(new_shape))
115 | optimizer.get_slot(old_var, s_name).assign(random_slot_vals)
116 |
117 | updater.copy_optimizer_slots(optimizer, [old_var], [new_var])
118 | for s_name in optimizer.get_slot_names():
119 | # Check new_values still have zeros.
120 | new_values_slice = slice_fn(optimizer.get_slot(new_var, s_name), -1)
121 | self.assertAllEqual(new_values_slice, tf.zeros_like(new_values_slice))
122 | # Check old variables have their random values set correctly.
123 | old_values_slice = slice_fn(optimizer.get_slot(old_var, s_name), 0)
124 | new_values_slice = slice_fn(optimizer.get_slot(new_var, s_name), 0)
125 | self.assertAllEqual(new_values_slice, old_values_slice)
126 |
127 |
128 | if __name__ == '__main__':
129 | tf.test.main()
130 |
--------------------------------------------------------------------------------
/growneuron/cifar/vgg.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 GradMax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """VGG Network."""
17 |
18 | import functools
19 | from typing import Any, Dict
20 | from growneuron.cifar import wide_resnet
21 | import growneuron.layers as glayers
22 | import tensorflow as tf
23 |
24 | NormalizationType = wide_resnet.NormalizationType
25 |
26 | BatchNormalization = functools.partial(
27 | tf.keras.layers.BatchNormalization,
28 | epsilon=1e-5, # using epsilon and momentum defaults from Torch
29 | momentum=0.9)
30 |
31 | LayerNormalization = functools.partial(
32 | tf.keras.layers.LayerNormalization,
33 | epsilon=1e-5) # using epsilon and momentum defaults from Torch
34 |
35 |
36 | def Conv2D(filters, seed=None, **kwargs):
37 | """Conv2D layer that is deterministically initialized."""
38 | default_kwargs = {
39 | "kernel_size": 3,
40 | "padding": "same",
41 | "use_bias": False,
42 | # Note that we need to use the class constructor for the initializer to
43 | # get deterministic initialization.
44 | "kernel_initializer": tf.keras.initializers.HeNormal(seed=seed),
45 | }
46 | # Override defaults with the passed kwargs.
47 | default_kwargs.update(kwargs)
48 | return tf.keras.layers.Conv2D(filters, **default_kwargs)
49 |
50 |
51 | class VGG(tf.keras.Model):
52 | """Builds a VGG CNN without the FC layers at the end.
53 |
54 | We don't add the FC layers to stay in sync with the implementation in the
55 | "Firefly Neural Architecture Descent" paper.
56 |
57 | Attributes:
58 | depth: Use 11 for VGG11, 16 for VGG16, etc.,
59 | width_multiplier: The number of filters in the first layer
60 | ("1" corresponds to 64 filters).
61 | num_classes: Number of output classes.
62 | normalization_type: NormalizationType, of the normalization used inside
63 | blocks.
64 | l2: L2 regularization coefficient.
65 | seed: random seed used for initialization.
66 | """
67 |
68 | def __init__(self,
69 | depth,
70 | width_multiplier,
71 | num_classes,
72 | normalization_type,
73 | l2,
74 | seed = 42):
75 | super().__init__(name=F"VGG-{depth}-{width_multiplier}")
76 | l2_reg = tf.keras.regularizers.l2
77 |
78 | rng_seed = [seed, seed + 1]
79 | assert depth == 11, "Only supporting VGG11 right now"
80 |
81 | # VGG consists of blocks of convs separated by downsampling.
82 | # Within each block, each conv has of base_width * multiplier filters.
83 | # This dict maps VGG-xx to a list of blocks.
84 | architecture = {
85 | 11: [[1], [2], [4, 4], [8, 8], [8, 8]],
86 | 14: [[1, 1], [2, 2], [4, 4], [8, 8], [8, 8]],
87 | 16: [[1, 1], [2, 2], [4, 4, 4], [8, 8, 8], [8, 8, 8]],
88 | 19: [[1, 1], [2, 2], [4, 4, 4, 4], [8, 8, 8, 8], [8, 8, 8, 8]]
89 | }
90 |
91 | blocklist = architecture[depth]
92 | base_width = int(64 * width_multiplier)
93 |
94 | downsample = False
95 | self.layer_list = []
96 | for block in blocklist:
97 | for multiplier in block:
98 | rng_seed, seed = tf.random.experimental.stateless_split(rng_seed)
99 | self.layer_list.append(glayers.GrowLayer(Conv2D(
100 | base_width*multiplier, strides=1 if not downsample else 2,
101 | seed=seed[0],
102 | kernel_regularizer=tf.keras.regularizers.l2(l2))))
103 | downsample = False
104 | self.layer_list.append(tf.keras.layers.Activation(
105 | glayers.get_activation_fn("relu1")),)
106 | if normalization_type == NormalizationType.batchnorm:
107 | self.layer_list.append(glayers.GrowLayer(
108 | BatchNormalization(
109 | beta_regularizer=tf.keras.regularizers.l2(l2),
110 | gamma_regularizer=tf.keras.regularizers.l2(l2))))
111 | elif normalization_type == NormalizationType.layernorm:
112 | self.layer_list.append(glayers.GrowLayer(
113 | LayerNormalization(
114 | beta_regularizer=tf.keras.regularizers.l2(l2),
115 | gamma_regularizer=tf.keras.regularizers.l2(l2))))
116 | elif normalization_type == NormalizationType.none:
117 | pass
118 | else:
119 | raise ValueError
120 | downsample = True
121 | self.layer_list.append(
122 | glayers.GrowLayer(
123 | Conv2D(num_classes, strides=2, kernel_regularizer=l2_reg(l2))))
124 | self.layer_list.append(tf.keras.layers.Flatten())
125 |
126 | def call(self, x):
127 | for layer in self.layer_list:
128 | x = layer(x)
129 | return x
130 |
131 | def get_grow_layer_tuples(self):
132 | """Gets all groups of layers that need to grow together."""
133 |
134 | grow_layers = [
135 | i for i, l in enumerate(self.layer_list)
136 | if (isinstance(l, glayers.GrowLayer) and
137 | isinstance(l.layer, tf.keras.layers.Conv2D))
138 | ]
139 |
140 | grow_layer_tuples = []
141 | for i, j in zip(grow_layers[:-1], grow_layers[1:]):
142 | # Grow tuples should be in order.
143 | grow_layer_tuples.append(self.layer_list[i:(j+1)])
144 | return grow_layer_tuples
145 |
146 |
147 | def create_model(
148 | depth = 1,
149 | width_multiplier = 1,
150 | num_classes = 10,
151 | l2_coef = 0.0,
152 | normalization_type = "batchnorm",
153 | **unused_kwargs):
154 | """Creates model."""
155 | normalization_type = NormalizationType[normalization_type]
156 | model = VGG(
157 | depth=depth,
158 | width_multiplier=width_multiplier,
159 | num_classes=num_classes,
160 | normalization_type=normalization_type,
161 | l2=l2_coef)
162 | return model
163 |
--------------------------------------------------------------------------------
/growneuron/imagenet/mb_v1.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 GradMax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """VGG Network."""
17 |
18 | import functools
19 | from typing import Any, Dict
20 | from growneuron.cifar import wide_resnet
21 | import growneuron.layers as glayers
22 | import tensorflow as tf
23 |
24 | NormalizationType = wide_resnet.NormalizationType
25 |
26 | BatchNormalization = functools.partial(
27 | tf.keras.layers.BatchNormalization,
28 | epsilon=1e-5, # using epsilon and momentum defaults from Torch
29 | momentum=0.9)
30 |
31 | LayerNormalization = functools.partial(
32 | tf.keras.layers.LayerNormalization,
33 | epsilon=1e-5) # using epsilon and momentum defaults from Torch
34 |
35 |
36 | def check_grow_layer(layer):
37 | return (isinstance(layer, glayers.GrowLayer) and
38 | isinstance(layer.layer,
39 | (tf.keras.layers.Dense, tf.keras.layers.Conv2D)) and
40 | not isinstance(layer.layer, tf.keras.layers.DepthwiseConv2D))
41 |
42 |
43 | def Conv2D(filters, seed=None, **kwargs):
44 | """Conv2D layer that is deterministically initialized."""
45 | default_kwargs = {
46 | 'kernel_size': 1,
47 | 'padding': 'same',
48 | 'strides': 1,
49 | 'use_bias': False,
50 | # Note that we need to use the class constructor for the initializer to
51 | # get deterministic initialization.
52 | 'kernel_initializer': tf.keras.initializers.HeNormal(seed=seed),
53 | }
54 | # Override defaults with the passed kwargs.
55 | default_kwargs.update(kwargs)
56 | return tf.keras.layers.Conv2D(filters, **default_kwargs)
57 |
58 |
59 | def DepthwiseConv2D(seed=None, **kwargs):
60 | """DepthwiseConv2D layer that is deterministically initialized."""
61 | default_kwargs = {
62 | 'kernel_size': 3,
63 | 'padding': 'same',
64 | 'strides': 1,
65 | 'use_bias': False,
66 | # Note that we need to use the class constructor for the initializer to
67 | # get deterministic initialization.
68 | 'kernel_initializer': tf.keras.initializers.HeNormal(seed=seed),
69 | }
70 | # Override defaults with the passed kwargs.
71 | default_kwargs.update(kwargs)
72 | return tf.keras.layers.DepthwiseConv2D(**default_kwargs)
73 |
74 |
75 | class MobilenetV1(tf.keras.Model):
76 | """Builds a MobileNet-v1.
77 |
78 | Attributes:
79 | width_multiplier: The number of filters in the first layer
80 | ("1" corresponds to 64 filters).
81 | num_classes: Number of output classes.
82 | normalization_type: NormalizationType, of the normalization used inside
83 | blocks.
84 | l2: L2 regularization coefficient.
85 | seed: random seed used for initialization.
86 | """
87 |
88 | def __init__(self,
89 | width_multiplier,
90 | num_classes,
91 | normalization_type,
92 | l2,
93 | seed = 42):
94 | super().__init__(name=F'MBv1-{width_multiplier}')
95 | l2_reg = tf.keras.regularizers.l2
96 |
97 | rng_seed = [seed, seed + 1]
98 | rng_seed, seed = tf.random.experimental.stateless_split(rng_seed)
99 | self.layer_list = [
100 | glayers.GrowLayer(
101 | Conv2D(32 * width_multiplier,
102 | strides=2,
103 | kernel_size=3,
104 | seed=seed[0],
105 | kernel_regularizer=l2_reg(l2))),
106 | glayers.GrowLayer(BatchNormalization()),
107 | tf.keras.layers.Activation(glayers.get_activation_fn('relu1'))
108 | ]
109 |
110 | # MobileNet consists of blocks of convs.
111 | # Within each block, each conv has of base_width * multiplier filters.
112 | blocklist = [[1], [2, 2], [4, 4], [8, 8, 8, 8, 8, 8], [16, 16]]
113 | base_width = int(64 * width_multiplier)
114 | downsample = False
115 | for i, block in enumerate(blocklist):
116 | for j, multiplier in enumerate(block):
117 | rng_seed, seed = tf.random.experimental.stateless_split(rng_seed)
118 | self.layer_list.append(glayers.GrowLayer(DepthwiseConv2D(
119 | seed=seed[0],
120 | kernel_regularizer=tf.keras.regularizers.l2(l2))))
121 | self.layer_list.append(tf.keras.layers.Activation(
122 | glayers.get_activation_fn('relu1')))
123 | if normalization_type == NormalizationType.batchnorm:
124 | self.layer_list.append(glayers.GrowLayer(BatchNormalization()))
125 | elif normalization_type == NormalizationType.layernorm:
126 | self.layer_list.append(glayers.GrowLayer(LayerNormalization()))
127 | elif normalization_type == NormalizationType.none:
128 | pass
129 | else:
130 | raise ValueError
131 | rng_seed, seed = tf.random.experimental.stateless_split(rng_seed)
132 | # We are doing strides at conv not at dw, as it's better for
133 | # decomposition.
134 | n_channels = base_width * multiplier
135 | if (i+1) == len(blocklist) and (j+1) == len(block):
136 | # We don't scale the last layer since we are not growing it.
137 | n_channels = 64 * multiplier
138 | self.layer_list.append(glayers.GrowLayer(Conv2D(
139 | n_channels,
140 | seed=seed[0],
141 | strides=1 if not downsample else 2,
142 | kernel_regularizer=tf.keras.regularizers.l2(l2))))
143 | downsample = False
144 | self.layer_list.append(tf.keras.layers.Activation(
145 | glayers.get_activation_fn('relu1')),)
146 | self.layer_list.append(glayers.GrowLayer(BatchNormalization()))
147 |
148 | downsample = True
149 | # TODO make global pooling+dense a conv-layer, so that we can grow.
150 | self.layer_list.append(
151 | tf.keras.layers.GlobalAveragePooling2D())
152 | rng_seed, seed = tf.random.experimental.stateless_split(rng_seed)
153 | self.layer_list.append(
154 | tf.keras.layers.Dense(
155 | num_classes,
156 | kernel_initializer=tf.keras.initializers.HeNormal(
157 | seed=seed[0]),
158 | kernel_regularizer=l2_reg(l2)
159 | )
160 | )
161 |
162 | def call(self, x):
163 | for layer in self.layer_list:
164 | x = layer(x)
165 | return x
166 |
167 | def get_grow_layer_tuples(self):
168 | """Gets all groups of layers that need to grow together."""
169 | grow_layers = [i for i, l in enumerate(self.layer_list)
170 | if check_grow_layer(l)]
171 |
172 | grow_layer_tuples = []
173 | for i, j in zip(grow_layers[:-1], grow_layers[1:]):
174 | # Grow tuples should be in order.
175 | grow_layer_tuples.append(self.layer_list[i:(j+1)])
176 | return grow_layer_tuples
177 |
178 |
179 | def create_model(
180 | width_multiplier = 1,
181 | num_classes = 1000,
182 | l2_coef = 0.0,
183 | normalization_type = 'batchnorm',
184 | **unused_kwargs):
185 | """Creates model."""
186 | normalization_type = NormalizationType[normalization_type]
187 | model = MobilenetV1(
188 | width_multiplier=width_multiplier,
189 | num_classes=num_classes,
190 | normalization_type=normalization_type,
191 | l2=l2_coef)
192 | return model
193 |
--------------------------------------------------------------------------------
/growneuron/cifar/wide_resnet.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 GradMax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Wide Residual Network."""
17 |
18 | import enum
19 | import functools
20 | from typing import Any, Dict
21 | import growneuron.layers as glayers
22 | import tensorflow as tf
23 |
24 | BatchNormalization = functools.partial(
25 | tf.keras.layers.BatchNormalization,
26 | epsilon=1e-5, # using epsilon and momentum defaults from Torch
27 | momentum=0.9)
28 |
29 | LayerNormalization = functools.partial(
30 | tf.keras.layers.LayerNormalization,
31 | epsilon=1e-5) # using epsilon and momentum defaults from Torch
32 |
33 |
34 | @enum.unique
35 | class NormalizationType(enum.Enum):
36 | """Direction along the z-axis."""
37 | layernorm = 'layernorm'
38 | batchnorm = ' batchnorm'
39 | none = 'none'
40 |
41 |
42 | def Conv2D(filters, seed=None, **kwargs):
43 | """Conv2D layer that is deterministically initialized."""
44 | default_kwargs = {
45 | 'kernel_size': 3,
46 | 'padding': 'same',
47 | 'use_bias': False,
48 | # Note that we need to use the class constructor for the initializer to
49 | # get deterministic initialization.
50 | 'kernel_initializer': tf.keras.initializers.HeNormal(seed=seed),
51 | }
52 | # Override defaults with the passed kwargs.
53 | default_kwargs.update(kwargs)
54 | return tf.keras.layers.Conv2D(filters, **default_kwargs)
55 |
56 |
57 | def basic_block(
58 | filters,
59 | block_width,
60 | normalization_type,
61 | strides,
62 | l2,
63 | seed):
64 | """Basic residual block of two 3x3 convs.
65 |
66 | Args:
67 | filters: Number of filters for Conv2D.
68 | block_width: Multiplies the first filter.
69 | normalization_type: NormalizationType
70 | strides: Stride dimensions for Conv2D.
71 | l2: L2 regularization coefficient.
72 | seed: random seed used for initialization.
73 |
74 | Returns:
75 | block_layers: list of sequential layers for the main branch.
76 | skip_layer: tf.keras.Conv2D or None.
77 | """
78 | seeds = tf.random.experimental.stateless_split([seed, seed + 1], 3)[:, 0]
79 |
80 | block_layers = [
81 | BatchNormalization(beta_regularizer=tf.keras.regularizers.l2(l2),
82 | gamma_regularizer=tf.keras.regularizers.l2(l2)),
83 | tf.keras.layers.Activation('relu'),
84 | glayers.GrowLayer(
85 | Conv2D(int(filters*block_width), strides=strides, seed=seeds[0],
86 | kernel_regularizer=tf.keras.regularizers.l2(l2)))
87 | ]
88 | # Maybe add normalization in between the layers.
89 | if normalization_type == NormalizationType.batchnorm:
90 | block_layers.append(glayers.GrowLayer(
91 | BatchNormalization(beta_regularizer=tf.keras.regularizers.l2(l2),
92 | gamma_regularizer=tf.keras.regularizers.l2(l2))))
93 | elif normalization_type == NormalizationType.layernorm:
94 | block_layers.append(glayers.GrowLayer(
95 | LayerNormalization(beta_regularizer=tf.keras.regularizers.l2(l2),
96 | gamma_regularizer=tf.keras.regularizers.l2(l2))))
97 | elif normalization_type == NormalizationType.none:
98 | pass
99 | else:
100 | raise ValueError
101 |
102 | block_layers += [
103 | # This is to ensure gradient is 1 at 0 for relu.
104 | tf.keras.layers.Activation(glayers.get_activation_fn('relu1')),
105 | glayers.GrowLayer(
106 | Conv2D(filters, strides=1, seed=seeds[1],
107 | kernel_regularizer=tf.keras.regularizers.l2(l2)))
108 | ]
109 |
110 | if strides > 1:
111 | skip_layer = Conv2D(filters, kernel_size=1, strides=strides, seed=seeds[2],
112 | kernel_regularizer=tf.keras.regularizers.l2(l2))
113 | else:
114 | skip_layer = None
115 | return (block_layers, skip_layer)
116 |
117 |
118 | class WideResnet(tf.keras.Model):
119 | """Builds Wide ResNet.
120 |
121 | Following Zagoruyko and Komodakis (2016), it accepts a width multiplier on the
122 | number of filters. Using three groups of residual blocks, the network maps
123 | spatial features of size 32x32 -> 16x16 -> 8x8.
124 |
125 | Attributes:
126 | depth: Total number of convolutional layers. "n" in WRN-n-k. It differs from
127 | He et al. (2015)'s notation which uses the maximum depth of the network
128 | counting non-conv layers like dense.
129 | width_multiplier: Integer to multiply the number of typical filters by. "k"
130 | in WRN-n-k.
131 | block_width_multiplier: Multiplies the filters in the first conv for each
132 | block.
133 | normalization_type: NormalizationType, of the normalization used inside
134 | blocks.
135 | num_classes: Number of output classes.
136 | l2: L2 regularization coefficient.
137 | seed: random seed used for initialization.
138 |
139 | """
140 |
141 | def __init__(
142 | self,
143 | depth,
144 | width_multiplier,
145 | block_width_multiplier,
146 | normalization_type,
147 | num_classes,
148 | l2,
149 | seed = 42
150 | ):
151 | super().__init__(name='wide_resnet-{}-{}'.format(depth, width_multiplier))
152 | l2_reg = tf.keras.regularizers.l2
153 |
154 | seeds = tf.random.experimental.stateless_split([seed, seed + 1], 5)[:, 0]
155 | if (depth - 4) % 6 != 0:
156 | raise ValueError('depth should be 6n+4 (e.g., 16, 22, 28, 40).')
157 | num_blocks = (depth - 4) // 6
158 |
159 | self.conv_stem = Conv2D(16,
160 | strides=1,
161 | seed=seeds[0],
162 | kernel_regularizer=l2_reg(l2))
163 | self.group_seq = []
164 | for i, (filters, strides, seed) in enumerate(
165 | zip([16, 32, 64], [1, 2, 2], seeds[1:4])):
166 | block_seq = []
167 | group_seeds = tf.random.experimental.stateless_split(
168 | [seed, seed + 1], num_blocks)[:, 0]
169 | for j, group_seed in enumerate(group_seeds):
170 | block_strides = strides if j == 0 else 1
171 | block_seq.append(
172 | basic_block(filters=filters*width_multiplier,
173 | block_width=block_width_multiplier,
174 | normalization_type=normalization_type,
175 | strides=block_strides, l2=l2, seed=group_seed)
176 | )
177 | self.group_seq.append(block_seq)
178 |
179 | self.final_layers = [
180 | BatchNormalization(beta_regularizer=l2_reg(l2),
181 | gamma_regularizer=l2_reg(l2)),
182 | tf.keras.layers.Activation('relu'),
183 | tf.keras.layers.AveragePooling2D(pool_size=8),
184 | tf.keras.layers.Flatten(),
185 | tf.keras.layers.Dense(
186 | num_classes,
187 | kernel_initializer=tf.keras.initializers.HeNormal(seed=seeds[4]),
188 | kernel_regularizer=l2_reg(l2),
189 | bias_regularizer=l2_reg(l2))
190 | ]
191 |
192 | def call(self, inputs):
193 | x = self.conv_stem(inputs)
194 | for block_seq in self.group_seq:
195 | for block_layers, skip_layer in block_seq:
196 | y = x
197 | # Main branch.
198 | for layer in block_layers:
199 | y = layer(y)
200 | # Skip branch
201 | if skip_layer:
202 | x = skip_layer(x)
203 | x = x + y
204 | for layer in self.final_layers:
205 | x = layer(x)
206 | return x
207 |
208 |
209 | def create_model(
210 | depth = 22,
211 | width_multiplier = 1,
212 | block_width_multiplier = 1.,
213 | normalization_type = 'batchnorm',
214 | num_classes = 10,
215 | l2_coef = 0.0,
216 | **unused_kwargs):
217 | """Creates model."""
218 | normalization_type = NormalizationType[normalization_type]
219 | model = WideResnet(depth=depth,
220 | width_multiplier=width_multiplier,
221 | block_width_multiplier=block_width_multiplier,
222 | num_classes=num_classes,
223 | normalization_type=normalization_type,
224 | l2=l2_coef)
225 | return model
226 |
--------------------------------------------------------------------------------
/growneuron/layers_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 GradMax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for growneuron.layers."""
17 | import absl.testing.parameterized as parameterized
18 | import growneuron.layers as glayers
19 | import tensorflow as tf
20 |
21 |
22 | class LayerTest(parameterized.TestCase, tf.test.TestCase):
23 |
24 | @parameterized.named_parameters(
25 | ('dense', tf.keras.layers.Dense(3), (3, 4)),
26 | ('batchnorm', tf.keras.layers.BatchNormalization(), (2, 4)),
27 | ('conv2d', tf.keras.layers.Conv2D(3, 3), (3, 5, 5, 4))
28 | )
29 | def test_consistency(self, layer, input_shape):
30 | wrapped_layer = glayers.GrowLayer(layer)
31 | x = tf.random.uniform(input_shape)
32 | original_out = layer(x)
33 | new_out = wrapped_layer(x)
34 | self.assertAllEqual(original_out, new_out)
35 |
36 | @parameterized.named_parameters(
37 | ('dense', tf.keras.layers.Dense(3), (3, 4), 1),
38 | ('dense_5neuron', tf.keras.layers.Dense(3), (3, 4), 5),
39 | ('conv2d', tf.keras.layers.Conv2D(3, 3), (3, 5, 5, 4), 1),
40 | ('conv2d_5neuron', tf.keras.layers.Conv2D(3, 3), (3, 5, 5, 4), 5),
41 | )
42 | def test_add_neurons_incoming_zeros(self, layer, input_shape, n_new):
43 | wrapped_layer = glayers.GrowLayer(layer)
44 | x = tf.random.uniform(input_shape)
45 | original_out = wrapped_layer(x)
46 | old_output_shape = original_out.get_shape()
47 | n_neurons_old = old_output_shape[-1]
48 | wrapped_layer.add_neurons(n_new, new_weights='zeros', is_outgoing=False)
49 | new_out = wrapped_layer(x)
50 | # Check the output has the expected shape
51 | new_shape = old_output_shape[:-1] + [n_neurons_old+n_new]
52 | self.assertAllEqual(new_shape, new_out.get_shape())
53 | # Check the old neurons create same output
54 | self.assertAllClose(original_out, new_out[Ellipsis, :n_neurons_old])
55 | # Check the new neurons create zero output
56 | self.assertEqual(0, tf.math.count_nonzero(new_out[Ellipsis, n_neurons_old:]))
57 | new_weights, new_biases = wrapped_layer.get_weights()
58 | # Check the new weights are zero
59 | added_weights = new_weights[Ellipsis, n_neurons_old:]
60 | self.assertAllEqual(added_weights, tf.zeros_like(added_weights))
61 | # Check the new biases are zero
62 | added_biases = new_biases[n_neurons_old:]
63 | self.assertAllEqual(added_biases, tf.zeros_like(added_biases))
64 |
65 | @parameterized.named_parameters(
66 | ('dense', tf.keras.layers.Dense(3), (3, 4), 1),
67 | ('dense_5neuron', tf.keras.layers.Dense(3), (3, 4), 5),
68 | ('conv2d', tf.keras.layers.Conv2D(3, 3), (3, 5, 5, 4), 1),
69 | ('conv2d_5neuron', tf.keras.layers.Conv2D(3, 3), (3, 5, 5, 4), 5),
70 | )
71 | def test_add_neurons_outgoing_zeros(self, layer, input_shape, n_new):
72 | wrapped_layer = glayers.GrowLayer(layer)
73 | n_features = input_shape[-1]
74 | x = tf.random.uniform(input_shape)
75 | # New input after growing would have more features
76 | new_input_shape = input_shape[:-1] + (n_new,)
77 | new_x = tf.concat([x, tf.random.uniform(new_input_shape)], axis=-1)
78 | original_out = layer(x)
79 | old_weights, old_biases = wrapped_layer.get_weights()
80 | wrapped_layer.add_neurons(n_new, new_weights='zeros', is_outgoing=True)
81 | new_out = wrapped_layer(new_x)
82 | new_weights, new_biases = wrapped_layer.get_weights()
83 | print(new_weights, new_biases)
84 | # Output of the layer shouldn't change.
85 | self.assertAllClose(original_out, new_out)
86 | # Check biases are unchanged
87 | self.assertAllEqual(old_biases, new_biases)
88 | # Check the new weights are zero
89 | added_weights = new_weights[Ellipsis, n_features:, :]
90 | self.assertAllEqual(added_weights, tf.zeros_like(added_weights))
91 | # Check the old weights are same
92 | kept_weights = new_weights[Ellipsis, :n_features, :]
93 | self.assertAllEqual(old_weights, kept_weights)
94 |
95 | @parameterized.named_parameters(
96 | ('dense_kernel', 'dense', ('kernel',)),
97 | ('dense_bias', 'dense', ('bias',)),
98 | ('dense_activity', 'dense', ('activity',)),
99 | ('dense_all', 'dense', ('kernel', 'bias', 'activity')),
100 | ('conv2d_kernel', 'conv2d', ('kernel',)),
101 | ('conv2d_bias', 'conv2d', ('bias',)),
102 | ('conv2d_activity', 'conv2d', ('activity',)),
103 | ('conv2d_all', 'conv2d', ('kernel', 'bias', 'activity')),
104 | )
105 | def test_regularizer_incoming(self, layer_type, regularizer_types):
106 | reg_kwargs = {f'{r_type}_regularizer': tf.keras.regularizers.L2(0.1)
107 | for r_type in regularizer_types}
108 | print(reg_kwargs)
109 | if layer_type == 'dense':
110 | layer = tf.keras.layers.Dense(3, **reg_kwargs)
111 | input_shape = (3, 4)
112 | elif layer_type == 'conv2d':
113 | layer = tf.keras.layers.Conv2D(3, 3, **reg_kwargs)
114 | input_shape = (3, 5, 5, 4)
115 | else:
116 | raise ValueError('not supported')
117 | wrapped_layer = glayers.GrowLayer(layer)
118 | x = tf.random.uniform(input_shape)
119 | _ = wrapped_layer(x)
120 | old_losses = wrapped_layer.losses
121 | wrapped_layer.add_neurons(1, new_weights='zeros', is_outgoing=False)
122 | _ = wrapped_layer(x)
123 | new_losses = wrapped_layer.losses
124 | for old_loss, new_loss in zip(old_losses, new_losses):
125 | self.assertAllClose(old_loss, new_loss)
126 |
127 | @parameterized.named_parameters(
128 | ('dense_kernel', 'dense', ('kernel',)),
129 | ('dense_bias', 'dense', ('bias',)),
130 | ('dense_activity', 'dense', ('activity',)),
131 | ('dense_all', 'dense', ('kernel', 'bias', 'activity')),
132 | ('conv2d_kernel', 'conv2d', ('kernel',)),
133 | ('conv2d_bias', 'conv2d', ('bias',)),
134 | ('conv2d_activity', 'conv2d', ('activity',)),
135 | ('conv2d_all', 'conv2d', ('kernel', 'bias', 'activity')),
136 | ('bn_beta', 'bn', ('beta',)),
137 | )
138 | def test_regularizer_outgoing(self, layer_type, regularizer_types):
139 | reg_kwargs = {f'{r_type}_regularizer': tf.keras.regularizers.L2(0.1)
140 | for r_type in regularizer_types}
141 | print(reg_kwargs)
142 | if layer_type == 'dense':
143 | layer = tf.keras.layers.Dense(3, **reg_kwargs)
144 | input_shape = (3, 4)
145 | elif layer_type == 'conv2d':
146 | layer = tf.keras.layers.Conv2D(3, 3, **reg_kwargs)
147 | input_shape = (3, 5, 5, 4)
148 | elif layer_type == 'bn':
149 | layer = tf.keras.layers.BatchNormalization(**reg_kwargs)
150 | input_shape = (3, 4)
151 | else:
152 | raise ValueError('not supported')
153 | wrapped_layer = glayers.GrowLayer(layer)
154 | x = tf.random.uniform(input_shape)
155 | _ = wrapped_layer(x)
156 | old_losses = wrapped_layer.losses
157 | if layer_type == 'bn':
158 | wrapped_layer.add_neurons_identity(1)
159 | else:
160 | wrapped_layer.add_neurons(1, new_weights='zeros', is_outgoing=True)
161 | new_input_shape = input_shape[:-1] + (1,)
162 | new_x = tf.concat([x, tf.random.uniform(new_input_shape)], axis=-1)
163 | _ = wrapped_layer(new_x)
164 | new_losses = wrapped_layer.losses
165 | for old_loss, new_loss in zip(old_losses, new_losses):
166 | self.assertAllClose(old_loss, new_loss)
167 |
168 | @parameterized.named_parameters(
169 | ('2d_axis1', (4, 5), -1),
170 | ('3d_axis1', (3, 3, 1), -1),
171 | ('4d_axis1', (3, 3, 4, 5), -1),
172 | ('2d_axis2', (4, 5), -2),
173 | ('3d_axis2', (3, 3, 1), -2),
174 | ('4d_axis2', (3, 3, 4, 5), -2),
175 | )
176 | def test_norm_l2(self, shape, axis):
177 | tensor = tf.reshape(tf.range(tf.math.reduce_prod(shape),
178 | dtype=tf.float32), shape)
179 | calculated_norm = glayers.norm_l2(tensor, axis)
180 | if axis == -2:
181 | tensor = tf.einsum('...ij->...ji', tensor)
182 | # L2 norm should be 1 over axis 1
183 | flat_tensor = tf.reshape(tensor,
184 | [-1, tensor.shape[-1]])
185 | expected_norms = tf.norm(flat_tensor, axis=-2)
186 | self.assertAllClose(expected_norms, calculated_norm)
187 | pass
188 |
189 | @parameterized.named_parameters(
190 | ('2d_axis1', (4, 5), -1),
191 | ('3d_axis1', (3, 3, 1), -1),
192 | ('4d_axis1', (3, 3, 4, 5), -1),
193 | ('2d_axis2', (4, 5), -2),
194 | ('3d_axis2', (3, 3, 1), -2),
195 | ('4d_axis2', (3, 3, 4, 5), -2),
196 | )
197 | def test_normalize_l2(self, shape, axis):
198 | tensor = tf.reshape(tf.range(tf.math.reduce_prod(shape),
199 | dtype=tf.float32), shape)
200 | normalized_tensor = glayers.normalize_l2(tensor, axis)
201 | if axis == -2:
202 | normalized_tensor = tf.einsum('...ij->...ji', normalized_tensor)
203 | # L2 norm should be 1 over axis 1
204 | flat_tensor = tf.reshape(normalized_tensor,
205 | [-1, normalized_tensor.shape[-1]])
206 | norms = tf.norm(flat_tensor, axis=-2)
207 | self.assertAllClose(norms, tf.ones_like(norms))
208 |
209 |
210 | if __name__ == '__main__':
211 | tf.test.main()
212 |
--------------------------------------------------------------------------------
/growneuron/imagenet/data_util.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 GradMax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Data preprocessing and augmentation.
17 |
18 | Forked from simclr/tf2 codebase.
19 | """
20 | import tensorflow.compat.v2 as tf
21 |
22 | CROP_PROPORTION = 0.875 # Standard for ImageNet.
23 |
24 |
25 | def random_apply(func, p, x):
26 | """Randomly apply function func to x with probability p."""
27 | return tf.cond(
28 | tf.less(
29 | tf.random.uniform([], minval=0, maxval=1, dtype=tf.float32),
30 | tf.cast(p, tf.float32)), lambda: func(x), lambda: x)
31 |
32 |
33 | def _compute_crop_shape(
34 | image_height, image_width, aspect_ratio, crop_proportion):
35 | """Compute aspect ratio-preserving shape for central crop.
36 |
37 | The resulting shape retains `crop_proportion` along one side and a proportion
38 | less than or equal to `crop_proportion` along the other side.
39 |
40 | Args:
41 | image_height: Height of image to be cropped.
42 | image_width: Width of image to be cropped.
43 | aspect_ratio: Desired aspect ratio (width / height) of output.
44 | crop_proportion: Proportion of image to retain along the less-cropped side.
45 |
46 | Returns:
47 | crop_height: Height of image after cropping.
48 | crop_width: Width of image after cropping.
49 | """
50 | image_width_float = tf.cast(image_width, tf.float32)
51 | image_height_float = tf.cast(image_height, tf.float32)
52 |
53 | def _requested_aspect_ratio_wider_than_image():
54 | crop_height = tf.cast(
55 | tf.math.rint(crop_proportion / aspect_ratio * image_width_float),
56 | tf.int32)
57 | crop_width = tf.cast(
58 | tf.math.rint(crop_proportion * image_width_float), tf.int32)
59 | return crop_height, crop_width
60 |
61 | def _image_wider_than_requested_aspect_ratio():
62 | crop_height = tf.cast(
63 | tf.math.rint(crop_proportion * image_height_float), tf.int32)
64 | crop_width = tf.cast(
65 | tf.math.rint(crop_proportion * aspect_ratio * image_height_float),
66 | tf.int32)
67 | return crop_height, crop_width
68 |
69 | return tf.cond(
70 | aspect_ratio > image_width_float / image_height_float,
71 | _requested_aspect_ratio_wider_than_image,
72 | _image_wider_than_requested_aspect_ratio)
73 |
74 |
75 | def center_crop(image, height, width, crop_proportion):
76 | """Crops to center of image and rescales to desired size.
77 |
78 | Args:
79 | image: Image Tensor to crop.
80 | height: Height of image to be cropped.
81 | width: Width of image to be cropped.
82 | crop_proportion: Proportion of image to retain along the less-cropped side.
83 |
84 | Returns:
85 | A `height` x `width` x channels Tensor holding a central crop of `image`.
86 | """
87 | shape = tf.shape(image)
88 | image_height = shape[0]
89 | image_width = shape[1]
90 | crop_height, crop_width = _compute_crop_shape(
91 | image_height, image_width, height / width, crop_proportion)
92 | offset_height = ((image_height - crop_height) + 1) // 2
93 | offset_width = ((image_width - crop_width) + 1) // 2
94 | image = tf.image.crop_to_bounding_box(
95 | image, offset_height, offset_width, crop_height, crop_width)
96 |
97 | image = tf.image.resize([image], [height, width],
98 | method=tf.image.ResizeMethod.BICUBIC)[0]
99 |
100 | return image
101 |
102 |
103 | def distorted_bounding_box_crop(image,
104 | bbox,
105 | min_object_covered=0.1,
106 | aspect_ratio_range=(0.75, 1.33),
107 | area_range=(0.05, 1.0),
108 | max_attempts=100,
109 | scope=None):
110 | """Generates cropped_image using one of the bboxes randomly distorted.
111 |
112 | See `tf.image.sample_distorted_bounding_box` for more documentation.
113 |
114 | Args:
115 | image: `Tensor` of image data.
116 | bbox: `Tensor` of bounding boxes arranged `[1, num_boxes, coords]`
117 | where each coordinate is [0, 1) and the coordinates are arranged
118 | as `[ymin, xmin, ymax, xmax]`. If num_boxes is 0 then use the whole
119 | image.
120 | min_object_covered: An optional `float`. Defaults to `0.1`. The cropped
121 | area of the image must contain at least this fraction of any bounding
122 | box supplied.
123 | aspect_ratio_range: An optional list of `float`s. The cropped area of the
124 | image must have an aspect ratio = width / height within this range.
125 | area_range: An optional list of `float`s. The cropped area of the image
126 | must contain a fraction of the supplied image within in this range.
127 | max_attempts: An optional `int`. Number of attempts at generating a cropped
128 | region of the image of the specified constraints. After `max_attempts`
129 | failures, return the entire image.
130 | scope: Optional `str` for name scope.
131 | Returns:
132 | (cropped image `Tensor`, distorted bbox `Tensor`).
133 | """
134 | with tf.name_scope(scope or 'distorted_bounding_box_crop'):
135 | shape = tf.shape(image)
136 | sample_distorted_bounding_box = tf.image.sample_distorted_bounding_box(
137 | shape,
138 | bounding_boxes=bbox,
139 | min_object_covered=min_object_covered,
140 | aspect_ratio_range=aspect_ratio_range,
141 | area_range=area_range,
142 | max_attempts=max_attempts,
143 | use_image_if_no_bounding_boxes=True)
144 | bbox_begin, bbox_size, _ = sample_distorted_bounding_box
145 |
146 | # Crop the image to the specified bounding box.
147 | offset_y, offset_x, _ = tf.unstack(bbox_begin)
148 | target_height, target_width, _ = tf.unstack(bbox_size)
149 | image = tf.image.crop_to_bounding_box(
150 | image, offset_y, offset_x, target_height, target_width)
151 |
152 | return image
153 |
154 |
155 | def crop_and_resize(image, height, width):
156 | """Make a random crop and resize it to height `height` and width `width`.
157 |
158 | Args:
159 | image: Tensor representing the image.
160 | height: Desired image height.
161 | width: Desired image width.
162 |
163 | Returns:
164 | A `height` x `width` x channels Tensor holding a random crop of `image`.
165 | """
166 | bbox = tf.constant([0.0, 0.0, 1.0, 1.0], dtype=tf.float32, shape=[1, 1, 4])
167 | aspect_ratio = width / height
168 | image = distorted_bounding_box_crop(
169 | image,
170 | bbox,
171 | min_object_covered=0.1,
172 | aspect_ratio_range=(3. / 4 * aspect_ratio, 4. / 3. * aspect_ratio),
173 | area_range=(0.08, 1.0),
174 | max_attempts=100,
175 | scope=None)
176 | return tf.image.resize([image], [height, width],
177 | method=tf.image.ResizeMethod.BICUBIC)[0]
178 |
179 |
180 | def random_crop_with_resize(image, height, width, p=1.0):
181 | """Randomly crop and resize an image.
182 |
183 | Args:
184 | image: `Tensor` representing an image of arbitrary size.
185 | height: Height of output image.
186 | width: Width of output image.
187 | p: Probability of applying this transformation.
188 |
189 | Returns:
190 | A preprocessed image `Tensor`.
191 | """
192 | def _transform(image):
193 | image = crop_and_resize(image, height, width)
194 | return image
195 | return random_apply(_transform, p=p, x=image)
196 |
197 |
198 | def preprocess_for_train(image,
199 | height,
200 | width,
201 | crop=True,
202 | flip=True):
203 | """Preprocesses the given image for training.
204 |
205 | Args:
206 | image: `Tensor` representing an image of arbitrary size.
207 | height: Height of output image.
208 | width: Width of output image.
209 | crop: Whether to crop the image.
210 | flip: Whether or not to flip left and right of an image.
211 |
212 | Returns:
213 | A preprocessed image `Tensor`.
214 | """
215 | if crop:
216 | image = random_crop_with_resize(image, height, width)
217 | if flip:
218 | image = tf.image.random_flip_left_right(image)
219 | image = tf.reshape(image, [height, width, 3])
220 | image = tf.clip_by_value(image, 0., 1.)
221 | return image
222 |
223 |
224 | def preprocess_for_eval(image, height, width, crop=True):
225 | """Preprocesses the given image for evaluation.
226 |
227 | Args:
228 | image: `Tensor` representing an image of arbitrary size.
229 | height: Height of output image.
230 | width: Width of output image.
231 | crop: Whether or not to (center) crop the test images.
232 |
233 | Returns:
234 | A preprocessed image `Tensor`.
235 | """
236 | if crop:
237 | image = center_crop(image, height, width, crop_proportion=CROP_PROPORTION)
238 | image = tf.reshape(image, [height, width, 3])
239 | image = tf.clip_by_value(image, 0., 1.)
240 | return image
241 |
242 |
243 | def preprocess_image(image, image_size, is_training=False, test_crop=True):
244 | """Preprocesses the given image.
245 |
246 | Args:
247 | image: `Tensor` representing an image of arbitrary size.
248 | image_size: Size of output image.
249 | is_training: `bool` for whether the preprocessing is for training.
250 | test_crop: whether or not to extract a central crop of the images
251 | (as for standard ImageNet evaluation) during the evaluation.
252 |
253 | Returns:
254 | A preprocessed image `Tensor` of range [0, 1].
255 | """
256 | image = tf.image.convert_image_dtype(image, dtype=tf.float32)
257 | if is_training:
258 | return preprocess_for_train(image, image_size, image_size)
259 | else:
260 | return preprocess_for_eval(image, image_size, image_size, crop=test_crop)
261 |
--------------------------------------------------------------------------------
/growneuron/updaters.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 GradMax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Implements controllers for updating networks.
17 | """
18 | import itertools
19 | from growneuron import growers
20 | from growneuron import layers
21 | import tensorflow as tf
22 |
23 |
24 | def pad_zeros_to(tensor, new_shape):
25 | """Pads a tensor with zeros such that final shape is new_shape.
26 |
27 | It expects the new_shape to be larger than the tensor.shape.
28 | Zeros are added to the end of each dimension.
29 | Args:
30 | tensor: 1d, 2d, 3d tensor.
31 | new_shape: list of dimensions where len(new_shape) == len(tensor.shape)
32 | Returns:
33 | new tensor of shape `new_shape`.
34 | """
35 | old_shape = tensor.shape
36 |
37 | if len(old_shape) == 1:
38 | # Batchnorm or bias.
39 | diff_shape = [new_shape[-1] - old_shape[-1]]
40 | concat_axis = -1
41 | else:
42 | if old_shape[-2] == new_shape[-2]:
43 | # Input features are same, padding at axis=-1.
44 | concat_axis = -1
45 | else:
46 | concat_axis = -2
47 | diff_shape = list(old_shape)
48 | diff_shape[concat_axis] = new_shape[concat_axis] - old_shape[concat_axis]
49 | return tf.concat([tensor, tf.zeros(diff_shape)], axis=concat_axis)
50 |
51 |
52 | class Updater():
53 | """Implements common methods.
54 |
55 | Updaters should be created under strategy scope or strategy should be passed
56 | directly.
57 | Attr:
58 | network_grower: growers.LayerGrower
59 | grow_layer_tuples: list of lists, candidates to be
60 | grown together with their outgoing weights.
61 | loss_fn: fn, Used to calculate loss. This function should get inputs
62 | as input and return loss.
63 | compile_fn: fn, Called to compile the model.
64 | update_frequency: int, Number of iterations before neurons are added.
65 | n_grow: int, number of neurons to grow at each growth step.
66 | n_grow_fraction: float, must be positive. Used together with initial width
67 | of candidate layers to decide n_neurons to grow at each growth step for
68 | each candidate separately. This approach is helpful when predicting the
69 | final architecture from the start as number of neurons added are fixed at
70 | the beginning for each layer.
71 | start_iteration: int, to start growing
72 | n_growth_steps: int, number of times the network is grown.
73 | scale: int, passed to the grower.grow_neurons
74 | carry_optimizer: bool, If true the running averages are carried to the new
75 | optimizer after the growth. Since variables are recreated after growth
76 | this is necessary.
77 | """
78 |
79 | def __init__(self, network_grower, grow_layer_tuples, loss_fn=lambda x: x,
80 | compile_fn=lambda: None, update_frequency=1, n_grow=1,
81 | n_grow_fraction=None, start_iteration=None, n_growth_steps=None,
82 | scale=1., carry_optimizer=True):
83 | assert update_frequency > 0
84 | assert n_grow > 0
85 | self._update_frequency = update_frequency
86 | self._carry_optimizer = carry_optimizer
87 | self._n_grow = n_grow
88 | self._n_grow_fraction = n_grow_fraction
89 | self._scale = scale
90 | if start_iteration is None:
91 | start_iteration = update_frequency
92 | self._start_iteration = start_iteration
93 | self.loss_fn = loss_fn
94 | self.compile_fn = compile_fn
95 | self.strategy = tf.distribute.get_strategy()
96 | self.network_grower = self._prepare_grower(network_grower)
97 | self._n_growth_steps = n_growth_steps
98 | self._growth_counter = 0
99 | self._set_grow_layer_tuples(grow_layer_tuples)
100 |
101 | def _prepare_grower(self, grower):
102 | if grower:
103 | grower.loss_fn = self.loss_fn
104 | grower.compile_fn = self.compile_fn
105 | grower.strategy = self.strategy
106 | return grower
107 |
108 | def copy_optimizer_slots(self, optimizer, old_variables, new_variables):
109 | """Copy old slots and pad with zeros for new neurons."""
110 | for old_var, new_var in zip(old_variables, new_variables):
111 | for s_name in sorted(optimizer.get_slot_names()):
112 | old_slot_var = optimizer.get_slot(old_var, s_name)
113 | new_slot_var = optimizer.get_slot(new_var, s_name)
114 | # This is used to retrieve the part of the new slot used for the
115 | # old variables. This assumes new variables are appended to the end.
116 | new_slot_values = pad_zeros_to(old_slot_var, new_slot_var.shape)
117 | new_slot_var.assign(new_slot_values)
118 |
119 | def delete_optimizer_slots(self, optimizer, variables):
120 | """Deleted old variable slots from the optimizer."""
121 | for old_var in variables:
122 | key = (old_var._shared_name if old_var._in_graph_mode
123 | else old_var._unique_id)
124 | optimizer._slots.pop(key, None)
125 |
126 | def _set_grow_layer_tuples(self, grow_layer_tuples):
127 | """Sets the tuple of layers for growing."""
128 | if not grow_layer_tuples:
129 | raise ValueError("grow_layer_tuples argument can't be empty.")
130 | self.grow_layer_tuples = grow_layer_tuples
131 |
132 | def get_n_neuron(n_neuron_initial):
133 | if self._n_grow_fraction:
134 | return int(max(1, n_neuron_initial * self._n_grow_fraction))
135 | else:
136 | return self._n_grow
137 | # Used to calculate n_grow per layer using grow_fraction.
138 | # n_neurons are decided using the initial architecture.
139 | self._n_grow_dict = {
140 | tpl[0].name: get_n_neuron(tpl[0].weights[0].shape[-1])
141 | for tpl in grow_layer_tuples
142 | }
143 |
144 | def is_update_iteration(self, iteration):
145 | assert iteration >= 0
146 | return ((self.network_grower is not None) and
147 | (iteration % self._update_frequency == 0) and
148 | (self._start_iteration <= iteration) and
149 | ((self._n_growth_steps is None) or
150 | (self._growth_counter < self._n_growth_steps)))
151 |
152 | def get_variable_list(self, grow_layer_tuple):
153 | return list(itertools.chain.from_iterable(
154 | [layer.trainable_weights for layer in grow_layer_tuple]))
155 |
156 | def get_grow_layer_stats(self):
157 | all_stats = []
158 | for grow_layer_tuple in self.grow_layer_tuples:
159 | first_layer = grow_layer_tuple[0]
160 | n_neuron = first_layer.get_weights()[0].shape[-1]
161 | all_stats.append((first_layer.layer.name, n_neuron))
162 | return all_stats
163 |
164 | def update_network(self, batch_data, optimizer=None):
165 | raise NotImplementedError()
166 |
167 |
168 | class DummyUpdater(Updater):
169 | """Implements common methods.
170 |
171 | Attr:
172 | network_grower: growers.LayerGrower
173 | grow_layer_tuples: list of lists, candidates to be
174 | grown together with their outgoing weights.
175 | update_frequency: int, Number of iterations before neurons are added.
176 | """
177 |
178 | def __init__(self, grow_layer_tuples):
179 | super().__init__(None, grow_layer_tuples, None, None)
180 |
181 | def update_network(self, **kwargs):
182 | pass
183 |
184 | def is_update_iteration(self, epoch):
185 | del epoch
186 | return False
187 |
188 | def get_grow_layer_stats(self):
189 | return []
190 |
191 |
192 | class RoundRobin(Updater):
193 | """Updates provided candidate layers in a round robin fashion."""
194 |
195 | def _next_grow_layer_tuple(self, unused_batch_data):
196 | next_tuple_id = self._growth_counter % len(self.grow_layer_tuples)
197 | self._growth_counter += 1
198 | return self.grow_layer_tuples[next_tuple_id]
199 |
200 | def update_network(self, batch_data, optimizer=None):
201 | """Updates the network and optimizer slots."""
202 | grow_layer_tuple = self._next_grow_layer_tuple(batch_data)
203 | old_variables = self.get_variable_list(grow_layer_tuple)
204 | n_new = self._n_grow_dict[grow_layer_tuple[0].name]
205 | self.network_grower.grow_neurons(grow_layer_tuple, batch_data,
206 | n_new=n_new, scale=self._scale)
207 | # Run the loss function to create new variables.
208 | self.compile_fn()
209 | new_variables = self.get_variable_list(grow_layer_tuple)
210 | optimizer._create_slots(new_variables)
211 | if self._carry_optimizer and optimizer:
212 | self.copy_optimizer_slots(optimizer, old_variables, new_variables)
213 | self.delete_optimizer_slots(optimizer, old_variables)
214 |
215 |
216 | class AllAtOnce(Updater):
217 | """Grows all candidate layers at once."""
218 |
219 | def _get_all_grow_layer_tuples(self):
220 | self._growth_counter += 1
221 | return self.grow_layer_tuples[:]
222 |
223 | def update_network(self, batch_data, optimizer=None):
224 | """Updates the network and optimizer slots."""
225 | grow_layer_tuples = self._get_all_grow_layer_tuples()
226 | for grow_layer_tuple in grow_layer_tuples:
227 | old_variables = self.get_variable_list(grow_layer_tuple)
228 | n_new = self._n_grow_dict[grow_layer_tuple[0].name]
229 | self.network_grower.grow_neurons(grow_layer_tuple, batch_data,
230 | n_new=n_new, scale=self._scale)
231 | # Run the loss function to create new variables.
232 | self.compile_fn()
233 | new_variables = self.get_variable_list(grow_layer_tuple)
234 | optimizer._create_slots(new_variables)
235 | if self._carry_optimizer and optimizer:
236 | self.copy_optimizer_slots(optimizer, old_variables, new_variables)
237 | self.delete_optimizer_slots(optimizer, old_variables)
238 |
239 |
240 |
241 |
242 | def adjust_epochs(train_epochs, width_scale, update_frequency,
243 | start_iteration, n_growth_steps, steps_per_epoch):
244 | """Adjust the epochs such as the total FLOPs are same as big-baseline."""
245 | # Here we extend training according to the FLOP saved by starting with
246 | # a smaller width.
247 | saved_fraction = (1 - width_scale)
248 | # Saved before growth.
249 | saved_steps = saved_fraction * start_iteration
250 | growth_duration = (update_frequency * (n_growth_steps - 1))
251 | # Saved during growth (2 is because of the trianble area).
252 | saved_steps += saved_fraction/2 * growth_duration
253 | new_epochs = train_epochs + int(saved_steps / steps_per_epoch)
254 | return new_epochs
255 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 |
2 | Apache License
3 | Version 2.0, January 2004
4 | http://www.apache.org/licenses/
5 |
6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
7 |
8 | 1. Definitions.
9 |
10 | "License" shall mean the terms and conditions for use, reproduction,
11 | and distribution as defined by Sections 1 through 9 of this document.
12 |
13 | "Licensor" shall mean the copyright owner or entity authorized by
14 | the copyright owner that is granting the License.
15 |
16 | "Legal Entity" shall mean the union of the acting entity and all
17 | other entities that control, are controlled by, or are under common
18 | control with that entity. For the purposes of this definition,
19 | "control" means (i) the power, direct or indirect, to cause the
20 | direction or management of such entity, whether by contract or
21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
22 | outstanding shares, or (iii) beneficial ownership of such entity.
23 |
24 | "You" (or "Your") shall mean an individual or Legal Entity
25 | exercising permissions granted by this License.
26 |
27 | "Source" form shall mean the preferred form for making modifications,
28 | including but not limited to software source code, documentation
29 | source, and configuration files.
30 |
31 | "Object" form shall mean any form resulting from mechanical
32 | transformation or translation of a Source form, including but
33 | not limited to compiled object code, generated documentation,
34 | and conversions to other media types.
35 |
36 | "Work" shall mean the work of authorship, whether in Source or
37 | Object form, made available under the License, as indicated by a
38 | copyright notice that is included in or attached to the work
39 | (an example is provided in the Appendix below).
40 |
41 | "Derivative Works" shall mean any work, whether in Source or Object
42 | form, that is based on (or derived from) the Work and for which the
43 | editorial revisions, annotations, elaborations, or other modifications
44 | represent, as a whole, an original work of authorship. For the purposes
45 | of this License, Derivative Works shall not include works that remain
46 | separable from, or merely link (or bind by name) to the interfaces of,
47 | the Work and Derivative Works thereof.
48 |
49 | "Contribution" shall mean any work of authorship, including
50 | the original version of the Work and any modifications or additions
51 | to that Work or Derivative Works thereof, that is intentionally
52 | submitted to Licensor for inclusion in the Work by the copyright owner
53 | or by an individual or Legal Entity authorized to submit on behalf of
54 | the copyright owner. For the purposes of this definition, "submitted"
55 | means any form of electronic, verbal, or written communication sent
56 | to the Licensor or its representatives, including but not limited to
57 | communication on electronic mailing lists, source code control systems,
58 | and issue tracking systems that are managed by, or on behalf of, the
59 | Licensor for the purpose of discussing and improving the Work, but
60 | excluding communication that is conspicuously marked or otherwise
61 | designated in writing by the copyright owner as "Not a Contribution."
62 |
63 | "Contributor" shall mean Licensor and any individual or Legal Entity
64 | on behalf of whom a Contribution has been received by Licensor and
65 | subsequently incorporated within the Work.
66 |
67 | 2. Grant of Copyright License. Subject to the terms and conditions of
68 | this License, each Contributor hereby grants to You a perpetual,
69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
70 | copyright license to reproduce, prepare Derivative Works of,
71 | publicly display, publicly perform, sublicense, and distribute the
72 | Work and such Derivative Works in Source or Object form.
73 |
74 | 3. Grant of Patent License. Subject to the terms and conditions of
75 | this License, each Contributor hereby grants to You a perpetual,
76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
77 | (except as stated in this section) patent license to make, have made,
78 | use, offer to sell, sell, import, and otherwise transfer the Work,
79 | where such license applies only to those patent claims licensable
80 | by such Contributor that are necessarily infringed by their
81 | Contribution(s) alone or by combination of their Contribution(s)
82 | with the Work to which such Contribution(s) was submitted. If You
83 | institute patent litigation against any entity (including a
84 | cross-claim or counterclaim in a lawsuit) alleging that the Work
85 | or a Contribution incorporated within the Work constitutes direct
86 | or contributory patent infringement, then any patent licenses
87 | granted to You under this License for that Work shall terminate
88 | as of the date such litigation is filed.
89 |
90 | 4. Redistribution. You may reproduce and distribute copies of the
91 | Work or Derivative Works thereof in any medium, with or without
92 | modifications, and in Source or Object form, provided that You
93 | meet the following conditions:
94 |
95 | (a) You must give any other recipients of the Work or
96 | Derivative Works a copy of this License; and
97 |
98 | (b) You must cause any modified files to carry prominent notices
99 | stating that You changed the files; and
100 |
101 | (c) You must retain, in the Source form of any Derivative Works
102 | that You distribute, all copyright, patent, trademark, and
103 | attribution notices from the Source form of the Work,
104 | excluding those notices that do not pertain to any part of
105 | the Derivative Works; and
106 |
107 | (d) If the Work includes a "NOTICE" text file as part of its
108 | distribution, then any Derivative Works that You distribute must
109 | include a readable copy of the attribution notices contained
110 | within such NOTICE file, excluding those notices that do not
111 | pertain to any part of the Derivative Works, in at least one
112 | of the following places: within a NOTICE text file distributed
113 | as part of the Derivative Works; within the Source form or
114 | documentation, if provided along with the Derivative Works; or,
115 | within a display generated by the Derivative Works, if and
116 | wherever such third-party notices normally appear. The contents
117 | of the NOTICE file are for informational purposes only and
118 | do not modify the License. You may add Your own attribution
119 | notices within Derivative Works that You distribute, alongside
120 | or as an addendum to the NOTICE text from the Work, provided
121 | that such additional attribution notices cannot be construed
122 | as modifying the License.
123 |
124 | You may add Your own copyright statement to Your modifications and
125 | may provide additional or different license terms and conditions
126 | for use, reproduction, or distribution of Your modifications, or
127 | for any such Derivative Works as a whole, provided Your use,
128 | reproduction, and distribution of the Work otherwise complies with
129 | the conditions stated in this License.
130 |
131 | 5. Submission of Contributions. Unless You explicitly state otherwise,
132 | any Contribution intentionally submitted for inclusion in the Work
133 | by You to the Licensor shall be under the terms and conditions of
134 | this License, without any additional terms or conditions.
135 | Notwithstanding the above, nothing herein shall supersede or modify
136 | the terms of any separate license agreement you may have executed
137 | with Licensor regarding such Contributions.
138 |
139 | 6. Trademarks. This License does not grant permission to use the trade
140 | names, trademarks, service marks, or product names of the Licensor,
141 | except as required for reasonable and customary use in describing the
142 | origin of the Work and reproducing the content of the NOTICE file.
143 |
144 | 7. Disclaimer of Warranty. Unless required by applicable law or
145 | agreed to in writing, Licensor provides the Work (and each
146 | Contributor provides its Contributions) on an "AS IS" BASIS,
147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
148 | implied, including, without limitation, any warranties or conditions
149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
150 | PARTICULAR PURPOSE. You are solely responsible for determining the
151 | appropriateness of using or redistributing the Work and assume any
152 | risks associated with Your exercise of permissions under this License.
153 |
154 | 8. Limitation of Liability. In no event and under no legal theory,
155 | whether in tort (including negligence), contract, or otherwise,
156 | unless required by applicable law (such as deliberate and grossly
157 | negligent acts) or agreed to in writing, shall any Contributor be
158 | liable to You for damages, including any direct, indirect, special,
159 | incidental, or consequential damages of any character arising as a
160 | result of this License or out of the use or inability to use the
161 | Work (including but not limited to damages for loss of goodwill,
162 | work stoppage, computer failure or malfunction, or any and all
163 | other commercial damages or losses), even if such Contributor
164 | has been advised of the possibility of such damages.
165 |
166 | 9. Accepting Warranty or Additional Liability. While redistributing
167 | the Work or Derivative Works thereof, You may choose to offer,
168 | and charge a fee for, acceptance of support, warranty, indemnity,
169 | or other liability obligations and/or rights consistent with this
170 | License. However, in accepting such obligations, You may act only
171 | on Your own behalf and on Your sole responsibility, not on behalf
172 | of any other Contributor, and only if You agree to indemnify,
173 | defend, and hold each Contributor harmless for any liability
174 | incurred by, or claims asserted against, such Contributor by reason
175 | of your accepting any such warranty or additional liability.
176 |
177 | END OF TERMS AND CONDITIONS
178 |
179 | APPENDIX: How to apply the Apache License to your work.
180 |
181 | To apply the Apache License to your work, attach the following
182 | boilerplate notice, with the fields enclosed by brackets "[]"
183 | replaced with your own identifying information. (Don't include
184 | the brackets!) The text should be enclosed in the appropriate
185 | comment syntax for the file format. We also recommend that a
186 | file or class name and description of purpose be included on the
187 | same "printed page" as the copyright notice for easier
188 | identification within third-party archives.
189 |
190 | Copyright [yyyy] [name of copyright owner]
191 |
192 | Licensed under the Apache License, Version 2.0 (the "License");
193 | you may not use this file except in compliance with the License.
194 | You may obtain a copy of the License at
195 |
196 | http://www.apache.org/licenses/LICENSE-2.0
197 |
198 | Unless required by applicable law or agreed to in writing, software
199 | distributed under the License is distributed on an "AS IS" BASIS,
200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
201 | See the License for the specific language governing permissions and
202 | limitations under the License.
203 |
--------------------------------------------------------------------------------
/growneuron/layers.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 GradMax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """GrowLayer wrapper module."""
17 | import numpy as np
18 | import tensorflow as tf
19 |
20 | SUPPORTED_LAYERS = (tf.keras.layers.Dense, tf.keras.layers.Conv2D)
21 |
22 |
23 | def get_activation_fn(actv_fn):
24 | """Activation choices for the layer.
25 |
26 | Args:
27 | actv_fn: str.
28 |
29 | Returns:
30 | activation fn
31 | """
32 | if actv_fn == 'relu1':
33 | # This has grad(f(0))=1 instead of 0 (default implementation).
34 | return lambda x: tf.math.maximum(x, 0)
35 | elif actv_fn == 'relu2':
36 | # This has grad(f(0))=1.
37 | return lambda x: tf.math.maximum(x, -1)
38 | else:
39 | return tf.keras.activations.get(actv_fn)
40 |
41 |
42 | class GrowLayer(tf.keras.layers.Wrapper):
43 | """This layer wraps keras.layers in order to support growing.
44 |
45 | This layer allows adding callbacks to the forward pass of a layer that will be
46 | called with the inputs and outputs of the underlying layer before the
47 | activations.
48 |
49 | Example Usage:
50 | ```
51 | first_layer = GrowLayer(tf.keras.layers.Dense(32))
52 | ```
53 |
54 | This layer can be used for growing neurons.
55 | `first_layer.add_neurons_incoming(1, new_weights='zeros')`
56 | """
57 |
58 | def __init__(self, *args, activation=None, **kwargs):
59 | if 'name' not in kwargs:
60 | # args[0] is the wrapped layer
61 | kwargs['name'] = f'glayer_{args[0].name}'
62 | super().__init__(*args, **kwargs)
63 | self.activation = get_activation_fn(activation)
64 | self.reset_callbacks()
65 |
66 | def add_callback(self, name, fn):
67 | self._callbacks[name] = fn
68 |
69 | def remove_callback(self, name):
70 | del self._callbacks[name]
71 |
72 | def reset_callbacks(self):
73 | self._callbacks = {}
74 |
75 | def __call__(self, inputs, *args, **kwargs):
76 | outputs = self.layer.__call__(inputs, *args, **kwargs)
77 | for _, callback_fn in self._callbacks.items():
78 | inputs, outputs = callback_fn(inputs, outputs)
79 | if self.activation:
80 | outputs = self.activation(outputs)
81 | return outputs
82 |
83 | def add_neurons(self, n_new, new_weights='zeros', scale=1.,
84 | is_outgoing=False, scale_method='mean_norm',
85 | new_bias='zeros'):
86 | """Adds new neurons and creates a new layer.
87 |
88 | New weights are scaled (if not zero) to have l2-norm equal to the mean
89 | l2-norm of the existing weights.
90 | TODO Unify splitting and adding neurons.
91 | Args:
92 | n_new: number of neurons to add.
93 | new_weights: 'zeros', 'random' or np.ndarray.
94 | scale: float, scales the new_weights multiplied with the mean norm of
95 | the existing weights.
96 | is_outgoing: bool, if true adds outgoing connections from the new neurons
97 | coming from previous layers. In other words number of neurons in current
98 | layer stays constant, but they aggregate information from n_new many
99 | new neurons.
100 | scale_method: str, Type of scaling to be used when initializing new
101 | neurons.
102 | - `mean_norm` means they are normalized using the mean norm of
103 | existing weights.
104 | - `fixed` means the weights are multiplied with scale directly.
105 | new_bias: str, 'zeros' or 'ones'.
106 | """
107 | old_module = self.layer
108 | assert old_module.built
109 | assert new_bias in ('zeros', 'ones')
110 | assert isinstance(old_module, SUPPORTED_LAYERS)
111 | self.layer = grow_new_layer(old_module, n_new, new_weights, scale,
112 | is_outgoing=is_outgoing, new_bias=new_bias,
113 | scale_method=scale_method)
114 |
115 | def add_neurons_identity(self, n_new):
116 | """Adds identity neurons for various layer types.
117 |
118 | Args:
119 | n_new: number of neurons to add.
120 | """
121 | old_module = self.layer
122 | assert old_module.built
123 | if isinstance(old_module, tf.keras.layers.BatchNormalization):
124 | self.layer = grow_new_bn_layer(old_module, n_new)
125 | elif isinstance(old_module, tf.keras.layers.LayerNormalization):
126 | self.layer = grow_new_ln_layer(old_module, n_new)
127 | elif isinstance(old_module, tf.keras.layers.DepthwiseConv2D):
128 | self.layer = grow_new_dw_layer(old_module, n_new)
129 | else:
130 | raise ValueError(f'layer: {old_module} of {type(old_module)} is not '
131 | 'supported.')
132 |
133 |
134 | def grow_new_layer(old_module, n_new, new_weights, scale, is_outgoing=False,
135 | scale_method='mean_norm', new_bias='zeros'):
136 | """Creates new layer after adding incoming our outgoing connections.
137 |
138 | Args:
139 | old_module: Old layer to grow from. One of layers.SUPPORTED_LAYERS.
140 | n_new: number of neurons to add.
141 | new_weights: 'zeros', 'random' or np.ndarray.
142 | scale: float, scales the new_weights multiplied with the mean norm of
143 | the existing weights.
144 | is_outgoing: bool, True if the outgoing connections of the new neurons are
145 | being added to the next layer. In this case, no new neurons are generated;
146 | instead existing neurons receive new incoming connections.
147 | scale_method: str, Type of scaling to be used when initializing new
148 | neurons.
149 | - `mean_norm` means they are normalized using the mean norm of
150 | existing weights.
151 | - `fixed` means the weights are multiplied with scale directly.
152 | new_bias: str, zeros or ones.
153 | Returns:
154 | layer of same type as the old_module.
155 | """
156 | old_weights = old_module.get_weights()[0]
157 | shape_axis = -2 if is_outgoing else -1
158 |
159 | if scale_method == 'mean_norm':
160 | magnitude_new = np.mean(norm_l2(old_weights, keep_dim=shape_axis).numpy())
161 | magnitude_new *= scale
162 | elif scale_method == 'fixed':
163 | # We don't use the scale of existing weights for initialization.
164 | magnitude_new = scale
165 | else:
166 | raise ValueError(f'Not supported scale_method, {scale_method}')
167 |
168 | shape_new = list(old_weights.shape)
169 | shape_new[shape_axis] = n_new
170 |
171 | if isinstance(new_weights, np.ndarray):
172 | assert new_weights.shape == tuple(shape_new)
173 | # Normalize to unit norm and then scale.
174 | normalized_w = normalize_l2(new_weights, axis=shape_axis).numpy()
175 | new_neurons = normalized_w * magnitude_new
176 | elif new_weights == 'random':
177 | normalized_w = normalize_l2(np.random.uniform(size=shape_new),
178 | axis=shape_axis).numpy()
179 | # Normalize to unit norm and then scale.
180 | new_neurons = normalized_w * magnitude_new
181 | elif new_weights == 'zeros':
182 | new_neurons = np.zeros(shape_new)
183 | else:
184 | raise ValueError('new_weights: %s is not valid' % new_weights)
185 | new_layer_weights = [np.concatenate((old_weights, new_neurons),
186 | axis=shape_axis)]
187 |
188 | # Assuming bias is the second weight.
189 | if old_module.use_bias:
190 | new_bias_weights = old_module.get_weights()[1]
191 | if not is_outgoing:
192 | new_neuron_bias = (np.zeros([n_new]) if (new_bias == 'zeros') else
193 | np.ones([n_new]))
194 | new_bias_weights = np.concatenate((new_bias_weights, new_neuron_bias),
195 | axis=0)
196 | new_layer_weights.append(new_bias_weights)
197 |
198 | common_kwargs = {
199 | 'name': old_module.name,
200 | 'activation': old_module.activation,
201 | 'use_bias': old_module.use_bias
202 | }
203 | for r_name in ('kernel_regularizer', 'bias_regularizer',
204 | 'activity_regularizer'):
205 | regularizer = getattr(old_module, r_name)
206 | if regularizer is not None:
207 | common_kwargs[r_name] = regularizer
208 | n_out_new = new_layer_weights[0].shape[-1]
209 | if isinstance(old_module, tf.keras.layers.Dense):
210 | new_module = tf.keras.layers.Dense(
211 | n_out_new,
212 | weights=new_layer_weights,
213 | **common_kwargs)
214 | elif isinstance(old_module, tf.keras.layers.Conv2D):
215 | new_module = tf.keras.layers.Conv2D(
216 | n_out_new,
217 | kernel_size=old_module.kernel_size,
218 | strides=old_module.strides,
219 | padding=old_module.padding,
220 | weights=new_layer_weights,
221 | **common_kwargs)
222 | else:
223 | raise ValueError(f'Unexpected module: {old_module}')
224 |
225 | return new_module
226 |
227 |
228 | def grow_new_ln_layer(old_module, n_new):
229 | """Grows a new identity LayerNormalization layer."""
230 | new_ln_weights = []
231 | # One for gamma, beta
232 | for i in range(2):
233 | old_w = old_module.get_weights()[i]
234 | if i == 0: # gamma
235 | new_w = np.ones([n_new])
236 | else: # beta
237 | new_w = np.zeros([n_new])
238 | w = np.concatenate((old_w, new_w), axis=0)
239 | new_ln_weights.append(w)
240 | common_kwargs = {
241 | 'epsilon': old_module.epsilon
242 | }
243 | for r_name in ('gamma_regularizer', 'beta_regularizer'):
244 | regularizer = getattr(old_module, r_name)
245 | if regularizer is not None:
246 | common_kwargs[r_name] = regularizer
247 | return tf.keras.layers.LayerNormalization(weights=new_ln_weights,
248 | **common_kwargs)
249 |
250 |
251 | def grow_new_bn_layer(old_module, n_new):
252 | """Grows a new identity BatchNormalization layer."""
253 | new_bn_weights = []
254 | # One for gamma, beta, moving_mean and moving_variance
255 | for i in range(4):
256 | old_w = old_module.get_weights()[i]
257 | if i in (1, 2): # beta, moving_mean
258 | new_w = np.zeros([n_new])
259 | else: # gamma, moving variance
260 | new_w = np.ones([n_new])
261 | w = np.concatenate((old_w, new_w), axis=0)
262 | new_bn_weights.append(w)
263 | common_kwargs = {
264 | 'epsilon': old_module.epsilon
265 | }
266 | for r_name in ('gamma_regularizer', 'beta_regularizer'):
267 | regularizer = getattr(old_module, r_name)
268 | if regularizer is not None:
269 | common_kwargs[r_name] = regularizer
270 | return tf.keras.layers.BatchNormalization(weights=new_bn_weights,
271 | **common_kwargs)
272 |
273 |
274 | def grow_new_dw_layer(old_module, n_new):
275 | """Adds identity neurosn to the depthwise convolutional layers."""
276 | old_weights = old_module.get_weights()[0]
277 | shape_new = list(old_weights.shape)
278 | shape_new[-2] = n_new
279 | new_weights = np.zeros(shape_new, dtype=old_weights.dtype)
280 | mid_index_x = new_weights.shape[0] // 2
281 | mid_index_y = new_weights.shape[1] // 2
282 | new_weights[mid_index_x, mid_index_y, Ellipsis] = 1.
283 | new_layer_weights = [np.concatenate((old_weights, new_weights),
284 | axis=-2)]
285 |
286 | # Assuming bias is the second weight.
287 | if old_module.use_bias:
288 | new_bias = old_module.get_weights()[1]
289 | new_neuron_bias = np.zeros([n_new])
290 | new_bias = np.concatenate((new_bias, new_neuron_bias), axis=0)
291 | new_layer_weights.append(new_bias)
292 |
293 | regularizer_kwargs = {}
294 | for r_name in ('kernel_regularizer', 'bias_regularizer',
295 | 'activity_regularizer'):
296 | regularizer = getattr(old_module, r_name)
297 | if regularizer is not None:
298 | regularizer_kwargs[r_name] = regularizer
299 | new_module = tf.keras.layers.DepthwiseConv2D(
300 | kernel_size=old_module.kernel_size,
301 | name=old_module.name,
302 | activation=old_module.activation,
303 | use_bias=old_module.use_bias,
304 | strides=old_module.strides,
305 | padding=old_module.padding,
306 | weights=new_layer_weights,
307 | **regularizer_kwargs)
308 | return new_module
309 |
310 |
311 | def norm_l2(tensor, keep_dim):
312 | norm_axes = list(range(len(tensor.shape)))
313 | del norm_axes[keep_dim]
314 | return tf.sqrt(tf.reduce_sum(tf.pow(tensor, 2), axis=norm_axes))
315 |
316 |
317 | def normalize_l2(tensor, axis):
318 | assert axis in (-2, -1)
319 | norm = norm_l2(tensor, axis)
320 | scale_recipe = '...ij,i->...ij' if (axis == -2) else '...ij,j->...ij'
321 | return tf.einsum(scale_recipe, tensor, 1 / norm)
322 |
--------------------------------------------------------------------------------
/growneuron/imagenet/main.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 GradMax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | r"""MobileNet-v1 on ImageNet trained with maximum likelihood.
17 |
18 | """
19 |
20 | import itertools
21 | import os
22 | import time
23 |
24 | from absl import app
25 | from absl import flags
26 | from absl import logging
27 | from growneuron import growers
28 | from growneuron import updaters
29 | from growneuron.imagenet import data
30 | from growneuron.imagenet import mb_v1
31 | from ml_collections import config_flags
32 | import tensorflow as tf
33 | import tensorflow_datasets as tfds
34 | import uncertainty_baselines.schedules as ub_schedules
35 | from tensorboard.plugins.hparams import api as hp
36 |
37 |
38 | config_flags.DEFINE_config_file(
39 | name='config',
40 | default='growneuron/imagenet/configs/'
41 | 'baseline_big.py',
42 | help_string='training config file.')
43 | # common flags
44 | flags.DEFINE_string(
45 | 'tpu', '',
46 | 'TPU address. If empty MirroredStrategy is used.')
47 | flags.DEFINE_string('data_dir', None,
48 | 'data_dir to be used for tfds dataset construction.'
49 | 'It is required when training with cloud TPUs')
50 | flags.DEFINE_bool('download_data', False,
51 | 'Whether to download data locally when initializing a '
52 | 'dataset.')
53 | flags.DEFINE_string('output_dir', '/tmp/cifar', 'Output directory.')
54 | FLAGS = flags.FLAGS
55 |
56 |
57 | def get_optimizer(optimizer_config, train_epochs, batch_size, steps_per_epoch):
58 | """Given the config and training arguments returns an optimizer."""
59 | # Linearly scale learning rate and the decay epochs by vanilla settings.
60 | base_lr = optimizer_config.base_learning_rate * batch_size / 128
61 | lr_decay_epochs = [int(fraction * train_epochs)
62 | for fraction in optimizer_config.lr_decay_epochs]
63 | if optimizer_config.decay_type == 'step':
64 | lr_schedule = ub_schedules.WarmUpPiecewiseConstantSchedule(
65 | steps_per_epoch,
66 | base_lr,
67 | decay_ratio=optimizer_config.lr_decay_ratio,
68 | decay_epochs=lr_decay_epochs,
69 | warmup_epochs=optimizer_config.lr_warmup_epochs)
70 | elif optimizer_config.decay_type == 'cosine':
71 | lr_schedule = tf.keras.optimizers.schedules.CosineDecay(
72 | base_lr, train_epochs * steps_per_epoch, alpha=0.0)
73 | else:
74 | lr_schedule = base_lr / 100.
75 | logging.info('No decay used')
76 | optimizer = tf.keras.optimizers.SGD(lr_schedule,
77 | momentum=optimizer_config.momentum,
78 | nesterov=optimizer_config.nesterov)
79 | return optimizer
80 |
81 |
82 | def main(argv):
83 | fmt = '[%(filename)s:%(lineno)s] %(message)s'
84 | formatter = logging.PythonFormatter(fmt)
85 | logging.get_absl_handler().setFormatter(formatter)
86 | del argv # unused arg
87 | config = FLAGS.config
88 | if (hasattr(config, 'grow_frequency_multiplier') and
89 | config.grow_frequency_multiplier != 1):
90 | # Scale the frequency of the growth steps
91 | factor = config.grow_frequency_multiplier
92 | config.updater.update_frequency = int(
93 | config.updater.update_frequency * factor)
94 | config.updater.n_growth_steps = int(
95 | config.updater.n_growth_steps / factor)
96 | config.updater.n_grow_fraction *= factor
97 |
98 | tf.io.gfile.makedirs(FLAGS.output_dir)
99 | logging.info('Saving checkpoints at %s', FLAGS.output_dir)
100 | tf.random.set_seed(config.seed)
101 | if FLAGS.tpu:
102 | resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu=FLAGS.tpu)
103 | tf.config.experimental_connect_to_cluster(resolver)
104 | topology = tf.tpu.experimental.initialize_tpu_system(resolver)
105 | logging.info('Topology:')
106 | logging.info('num_tasks: %d', topology.num_tasks)
107 | logging.info('num_tpus_per_task: %d', topology.num_tpus_per_task)
108 | strategy = tf.distribute.TPUStrategy(resolver)
109 | else:
110 | strategy = tf.distribute.MirroredStrategy()
111 | topology = None
112 |
113 | ds_builder = tfds.builder(config.dataset)
114 | if FLAGS.download_data:
115 | ds_builder.download_and_prepare()
116 | ds_info = ds_builder.info
117 | batch_size = config.per_core_batch_size * config.num_cores
118 |
119 | # Scale arguments that depend on 128 batch size total training iterations.
120 | multiplier = 512. / batch_size
121 | if hasattr(config.updater, 'update_frequency'):
122 | config.updater.update_frequency = int(
123 | config.updater.update_frequency * multiplier)
124 | config.updater.start_iteration = int(
125 | config.updater.start_iteration * multiplier)
126 |
127 | train_dataset_size = ds_info.splits['train'].num_examples
128 | steps_per_epoch = train_dataset_size // batch_size
129 | logging.info('Steps per epoch %s', steps_per_epoch)
130 | logging.info('Size of the dataset %s', train_dataset_size)
131 | steps_per_eval = ds_info.splits['validation'].num_examples // batch_size
132 | num_classes = ds_info.features['label'].num_classes
133 |
134 | train_dataset = strategy.distribute_datasets_from_function(
135 | data.build_input_fn(ds_builder, batch_size, topology=topology,
136 | is_training=True))
137 | test_dataset = strategy.distribute_datasets_from_function(
138 | data.build_input_fn(ds_builder, batch_size, topology=topology,
139 | is_training=False))
140 | # Maybe create a grower.
141 | grow_type = getattr(config, 'grow_type', None)
142 | if grow_type == 'add_random':
143 | grower = growers.AddRandom()
144 | elif grow_type == 'add_firefly':
145 | grower = growers.AddFirefly()
146 | elif grow_type == 'add_gradmax_opt':
147 | grower = growers.AddGradmaxOptim()
148 | elif grow_type == 'add_gradmax':
149 | grower = growers.AddGradmax()
150 | else:
151 | logging.info('No growing')
152 | grower = None
153 |
154 | if grower:
155 | grower.epsilon = config.grow_epsilon
156 | grower.scale_method = config.grow_scale_method
157 | grower.is_outgoing_zero = config.is_outgoing_zero
158 |
159 | if config.scale_epochs:
160 | old_epochs = config.train_epochs
161 | # Adjust the total epochs to match big-baseline training FLOPs.
162 | if grower:
163 | config.train_epochs = updaters.adjust_epochs(
164 | config.train_epochs,
165 | config.model.width_multiplier,
166 | config.updater.update_frequency,
167 | config.updater.start_iteration,
168 | config.updater.n_growth_steps,
169 | steps_per_epoch
170 | )
171 | else:
172 | # baseline
173 | config.train_epochs = config.train_epochs / config.model.width_multiplier
174 | logging.info('Extended training from %s to %s', old_epochs,
175 | config.train_epochs)
176 |
177 | summary_writer = tf.summary.create_file_writer(
178 | os.path.join(FLAGS.output_dir, 'summaries'))
179 |
180 | with summary_writer.as_default():
181 | flat_param_dict = {}
182 | def flat_fn(dic):
183 | for k, v in dic.items():
184 | if isinstance(v, dict):
185 | flat_fn({f'{k}.{k2}': v2 for k2, v2 in v.items()})
186 | else:
187 | flat_param_dict[k] = str(v) if isinstance(v, list) else v
188 | flat_fn(config.to_dict())
189 | hp.hparams(flat_param_dict)
190 |
191 | grow_layer_tuples = []
192 | architecture = config.get('architecture', 'mb_v1')
193 | with strategy.scope():
194 | if architecture == 'mb_v1':
195 | logging.info('Building VGG model')
196 | model = mb_v1.create_model(
197 | num_classes=num_classes,
198 | seed=config.seed,
199 | **config.model)
200 | grow_layer_tuples = model.get_grow_layer_tuples()
201 | else:
202 | raise ValueError(f'Unknown architecture: {architecture}')
203 | logging.info('#grow_layer_tuples: %s', len(grow_layer_tuples))
204 | logging.info('grow_layer_tuples[0]: %s', grow_layer_tuples[0])
205 | grow_metrics = {layers[0]: tf.keras.metrics.Sum()
206 | for layers in grow_layer_tuples}
207 | # Initialize the parameters.
208 | def compile_model_fn():
209 | model(tf.keras.Input((224, 224, 3)))
210 | compile_model_fn()
211 | logging.info('Model input shape: %s', model.input_shape)
212 | logging.info('Model output shape: %s', model.output_shape)
213 | logging.info('Model number of weights: %s', model.count_params())
214 | optimizer = get_optimizer(config.optimizer, config.train_epochs, batch_size,
215 | steps_per_epoch)
216 | train_metrics = {
217 | 'train/negative_log_likelihood':
218 | tf.keras.metrics.Mean(),
219 | 'train/accuracy':
220 | tf.keras.metrics.SparseCategoricalAccuracy(),
221 | 'train/loss':
222 | tf.keras.metrics.Mean(),
223 | }
224 |
225 | eval_metrics = {
226 | 'test/negative_log_likelihood':
227 | tf.keras.metrics.Mean(),
228 | 'test/accuracy':
229 | tf.keras.metrics.SparseCategoricalAccuracy(),
230 | }
231 | model.summary()
232 | checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer)
233 | latest_checkpoint = tf.train.latest_checkpoint(FLAGS.output_dir)
234 | initial_epoch = 0
235 | if latest_checkpoint:
236 | # TODO This probably wouldn't work if the networks is grown;
237 | # so we need to switch to saved models maybe.
238 | # checkpoint.restore must be within a strategy.scope() so that optimizer
239 | # slot variables are mirrored.
240 | checkpoint.restore(latest_checkpoint)
241 | logging.info('Loaded checkpoint %s', latest_checkpoint)
242 | initial_epoch = optimizer.iterations.numpy() // steps_per_epoch
243 |
244 | # Create the updater.
245 | def loss_fn(inputs):
246 | images, labels = inputs
247 | logits = model(images, training=True)
248 | one_hot_labels = tf.one_hot(tf.cast(labels, tf.int32), num_classes)
249 | loss = tf.reduce_mean(
250 | tf.keras.losses.categorical_crossentropy(
251 | one_hot_labels,
252 | logits,
253 | from_logits=True))
254 | scaled_loss = loss / strategy.num_replicas_in_sync
255 | # Don't add the regularization as unnecessary for zero variables.
256 | return scaled_loss
257 |
258 | updater_type = getattr(config, 'updater_type', None)
259 | if updater_type == 'round_robin':
260 | updater = updaters.RoundRobin(grower, grow_layer_tuples, loss_fn,
261 | compile_model_fn, **config.updater)
262 | elif updater_type == 'all_at_once':
263 | updater = updaters.AllAtOnce(grower, grow_layer_tuples, loss_fn,
264 | compile_model_fn, **config.updater)
265 | logging.info(message)
266 |
267 | if (epoch % 20 == 0) or (config.train_epochs == (epoch + 1)):
268 | test_iterator = iter(test_dataset)
269 | logging.info('Starting to run eval at epoch: %s', epoch)
270 | test_start_time = time.time()
271 | test_step(test_iterator, 'test', steps_per_eval)
272 | test_ms_per_example = (time.time() - test_start_time) * 1e6 / batch_size
273 | logging.info('Done with eval on')
274 |
275 | logging.info('Test NLL: %.4f, Accuracy: %.2f%%',
276 | eval_metrics['test/negative_log_likelihood'].result(),
277 | eval_metrics['test/accuracy'].result() * 100)
278 | total_results = {name: metric.result() for name, metric
279 | in eval_metrics.items()}
280 | total_results['train/ms_per_example'] = train_ms_per_example
281 | total_results['test/ms_per_example'] = test_ms_per_example
282 | with summary_writer.as_default():
283 | for name, result in total_results.items():
284 | tf.summary.scalar(name, result, step=epoch + 1)
285 |
286 | for metric in eval_metrics.values():
287 | metric.reset_states()
288 |
289 | if (config.checkpoint_interval > 0 and
290 | (epoch + 1) % config.checkpoint_interval == 0):
291 | checkpoint_name = checkpoint.save(
292 | os.path.join(FLAGS.output_dir, 'checkpoint'))
293 | logging.info('Saved checkpoint to %s', checkpoint_name)
294 |
295 | final_checkpoint_name = checkpoint.save(
296 | os.path.join(FLAGS.output_dir, 'checkpoint'))
297 | logging.info('Saved last checkpoint to %s', final_checkpoint_name)
298 |
299 |
300 | if __name__ == '__main__':
301 | app.run(main)
302 |
--------------------------------------------------------------------------------
/growneuron/cifar/main.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 GradMax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | r"""Wide ResNet 28-10 on CIFAR-10/100 trained with maximum likelihood.
17 |
18 | Hyperparameters differ slightly from the original paper's code
19 | (https://github.com/szagoruyko/wide-residual-networks) as TensorFlow uses, for
20 | example, l2 instead of weight decay, and a different parameterization for SGD's
21 | momentum.
22 |
23 | """
24 |
25 | import itertools
26 | import os
27 | import time
28 |
29 | from absl import app
30 | from absl import flags
31 | from absl import logging
32 | from growneuron import growers
33 | from growneuron import updaters
34 | from growneuron.cifar import data
35 | from growneuron.cifar import vgg
36 | from growneuron.cifar import wide_resnet
37 | import growneuron.layers as glayers
38 | from ml_collections import config_flags
39 | import tensorflow as tf
40 | import tensorflow_datasets as tfds
41 | import uncertainty_baselines.schedules as ub_schedules
42 | from tensorboard.plugins.hparams import api as hp
43 |
44 | config_flags.DEFINE_config_file(
45 | name='config',
46 | default='growneuron/cifar/configs/'
47 | 'baseline_big.py',
48 | help_string='training config file.')
49 | # common flags
50 |
51 | flags.DEFINE_string('data_dir', None,
52 | 'data_dir to be used for tfds dataset construction.'
53 | 'It is required when training with cloud TPUs')
54 | flags.DEFINE_bool('download_data', False,
55 | 'Whether to download data locally when initializing a '
56 | 'dataset.')
57 | flags.DEFINE_string('output_dir', '/tmp/cifar', 'Output directory.')
58 | flags.DEFINE_bool('collect_profile', False,
59 | 'Whether to trace a profile with tensorboard')
60 |
61 | FLAGS = flags.FLAGS
62 |
63 |
64 | def get_optimizer(optimizer_config, train_epochs, batch_size, steps_per_epoch):
65 | """Given the config and training arguments returns an optimizer."""
66 | # Linearly scale learning rate and the decay epochs by vanilla settings.
67 | base_lr = optimizer_config.base_learning_rate * batch_size / 128
68 | lr_decay_epochs = [int(fraction * train_epochs)
69 | for fraction in optimizer_config.lr_decay_epochs]
70 | if optimizer_config.decay_type == 'step':
71 | lr_schedule = ub_schedules.WarmUpPiecewiseConstantSchedule(
72 | steps_per_epoch,
73 | base_lr,
74 | decay_ratio=optimizer_config.lr_decay_ratio,
75 | decay_epochs=lr_decay_epochs,
76 | warmup_epochs=optimizer_config.lr_warmup_epochs)
77 | elif optimizer_config.decay_type == 'cosine':
78 | lr_schedule = tf.keras.optimizers.schedules.CosineDecay(
79 | base_lr, train_epochs * steps_per_epoch, alpha=0.0)
80 | else:
81 | lr_schedule = base_lr / 100.
82 | logging.info('No decay used')
83 | optimizer = tf.keras.optimizers.SGD(lr_schedule,
84 | momentum=optimizer_config.momentum,
85 | nesterov=optimizer_config.nesterov)
86 | return optimizer
87 |
88 |
89 | def main(argv):
90 | fmt = '[%(filename)s:%(lineno)s] %(message)s'
91 | formatter = logging.PythonFormatter(fmt)
92 | logging.get_absl_handler().setFormatter(formatter)
93 | del argv # unused arg
94 | config = FLAGS.config
95 |
96 | tf.io.gfile.makedirs(FLAGS.output_dir)
97 | logging.info('Saving checkpoints at %s', FLAGS.output_dir)
98 | tf.random.set_seed(config.seed)
99 |
100 | strategy = tf.distribute.MirroredStrategy()
101 |
102 | ds_builder = tfds.builder(config.dataset)
103 | if FLAGS.download_data:
104 | ds_builder.download_and_prepare()
105 | ds_info = ds_builder.info
106 | batch_size = config.per_core_batch_size * config.num_cores
107 |
108 | # Scale arguments that depend on 128 batch size total training iterations.
109 | multiplier = 128. / batch_size
110 | if hasattr(config.updater, 'update_frequency'):
111 | config.updater.update_frequency = int(
112 | config.updater.update_frequency * multiplier)
113 | config.updater.start_iteration = int(
114 | config.updater.start_iteration * multiplier)
115 |
116 | train_dataset_size = ds_info.splits['train'].num_examples
117 | steps_per_epoch = train_dataset_size // batch_size
118 | logging.info('Steps per epoch %s', steps_per_epoch)
119 | logging.info('Size of the dataset %s', train_dataset_size)
120 | steps_per_eval = ds_info.splits['test'].num_examples // batch_size
121 | num_classes = ds_info.features['label'].num_classes
122 |
123 | train_dataset = strategy.distribute_datasets_from_function(
124 | data.build_input_fn(ds_builder, batch_size, topology=None,
125 | is_training=True,
126 | cache_dataset=config.cache_dataset))
127 | # Grow batches might be different in size.
128 | grow_batch_size = getattr(config, 'grow_batch_size', batch_size)
129 | grow_dataset = strategy.distribute_datasets_from_function(
130 | data.build_input_fn(ds_builder, grow_batch_size, topology=None,
131 | is_training=True,
132 | cache_dataset=config.cache_dataset))
133 | test_dataset = strategy.distribute_datasets_from_function(
134 | data.build_input_fn(ds_builder, batch_size, topology=None,
135 | is_training=False,
136 | cache_dataset=config.cache_dataset))
137 | # Scale the trianing epochs to match roughly to big-baseline cost.
138 | arch_name = config.get('architecture', 'wide-resnet')
139 | # Maybe create a grower.
140 | grow_type = getattr(config, 'grow_type', None)
141 | if grow_type == 'add_random':
142 | grower = growers.AddRandom()
143 | elif grow_type == 'add_firefly':
144 | grower = growers.AddFirefly()
145 | elif grow_type == 'add_gradmax_opt':
146 | grower = growers.AddGradmaxOptim()
147 | elif grow_type == 'add_gradmax':
148 | grower = growers.AddGradmax()
149 | else:
150 | logging.info('No growing')
151 | grower = None
152 |
153 | if grower:
154 | grower.epsilon = config.grow_epsilon
155 | grower.scale_method = config.grow_scale_method
156 | grower.is_outgoing_zero = config.is_outgoing_zero
157 |
158 | if config.scale_epochs:
159 | if arch_name == 'wide-resnet':
160 | width_scale = config.model.block_width_multiplier
161 | elif arch_name == 'vgg':
162 | width_scale = config.model.width_multiplier
163 | else:
164 | raise ValueError(f'Unknown architecture: {arch_name}')
165 | old_epochs = config.train_epochs
166 | # Adjust the total epochs to match big-baseline training FLOPs.
167 | if grower:
168 | config.train_epochs = updaters.adjust_epochs(
169 | config.train_epochs,
170 | width_scale,
171 | config.updater.update_frequency,
172 | config.updater.start_iteration,
173 | config.updater.n_growth_steps,
174 | steps_per_epoch
175 | )
176 | else:
177 | # baseline
178 | config.train_epochs = config.train_epochs / width_scale
179 | logging.info('Extended training from %s to %s', old_epochs,
180 | config.train_epochs)
181 | summary_writer = tf.summary.create_file_writer(
182 | os.path.join(FLAGS.output_dir, 'summaries'))
183 |
184 | with summary_writer.as_default():
185 | flat_param_dict = {}
186 | def flat_fn(dic):
187 | for k, v in dic.items():
188 | if isinstance(v, dict):
189 | flat_fn({f'{k}.{k2}': v2 for k2, v2 in v.items()})
190 | else:
191 | flat_param_dict[k] = str(v) if isinstance(v, list) else v
192 | flat_fn(config.to_dict())
193 | hp.hparams(flat_param_dict)
194 |
195 | grow_layer_tuples = []
196 | with strategy.scope():
197 | if arch_name == 'wide-resnet':
198 | logging.info('Building ResNet model')
199 | model = wide_resnet.create_model(
200 | num_classes=num_classes,
201 | seed=config.seed,
202 | **config.model)
203 | for block_seq in model.group_seq:
204 | for block_layers, _ in block_seq:
205 | # We need to get all layers between the two grow layers.
206 | glayer_indices = [i for i, l in enumerate(block_layers)
207 | if isinstance(l, glayers.GrowLayer)]
208 | start_index, end_index = glayer_indices[0], glayer_indices[-1]
209 | grow_layer_tuples.append(block_layers[start_index:(end_index+1)])
210 | elif arch_name == 'vgg':
211 | logging.info('Building VGG model')
212 | model = vgg.create_model(
213 | num_classes=num_classes,
214 | seed=config.seed,
215 | **config.model)
216 | grow_layer_tuples = model.get_grow_layer_tuples()
217 | else:
218 | raise ValueError(f'Unknown architecture: {arch_name}')
219 | logging.info('grow_layer_tuples: %s', grow_layer_tuples)
220 | grow_metrics = {layers[0]: tf.keras.metrics.Sum()
221 | for layers in grow_layer_tuples}
222 | # Initialize the parameters.
223 | def compile_model_fn():
224 | model(tf.keras.Input((32, 32, 3)))
225 | compile_model_fn()
226 | logging.info('Model input shape: %s', model.input_shape)
227 | logging.info('Model output shape: %s', model.output_shape)
228 | logging.info('Model number of weights: %s', model.count_params())
229 | optimizer = get_optimizer(config.optimizer, config.train_epochs, batch_size,
230 | steps_per_epoch)
231 | train_metrics = {
232 | 'train/negative_log_likelihood':
233 | tf.keras.metrics.Mean(),
234 | 'train/accuracy':
235 | tf.keras.metrics.SparseCategoricalAccuracy(),
236 | 'train/loss':
237 | tf.keras.metrics.Mean(),
238 | }
239 |
240 | eval_metrics = {
241 | 'test/negative_log_likelihood':
242 | tf.keras.metrics.Mean(),
243 | 'test/accuracy':
244 | tf.keras.metrics.SparseCategoricalAccuracy(),
245 | }
246 | model.summary()
247 |
248 | checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer)
249 | latest_checkpoint = tf.train.latest_checkpoint(FLAGS.output_dir)
250 | initial_epoch = 0
251 | if latest_checkpoint:
252 | # TODO This probably wouldn't work if the networks is grown;
253 | # so we need to switch to saved models maybe.
254 | # checkpoint.restore must be within a strategy.scope() so that optimizer
255 | # slot variables are mirrored.
256 | checkpoint.restore(latest_checkpoint)
257 | logging.info('Loaded checkpoint %s', latest_checkpoint)
258 | initial_epoch = optimizer.iterations.numpy() // steps_per_epoch
259 |
260 | def loss_fn(inputs):
261 | images, labels = inputs
262 | logits = model(images, training=True)
263 | one_hot_labels = tf.one_hot(tf.cast(labels, tf.int32), num_classes)
264 | loss = tf.reduce_mean(
265 | tf.keras.losses.categorical_crossentropy(
266 | one_hot_labels,
267 | logits,
268 | from_logits=True))
269 | scaled_loss = loss / strategy.num_replicas_in_sync
270 | # Don't add the regularization as unnecessary for zero variables.
271 | return scaled_loss
272 |
273 | updater_type = getattr(config, 'updater_type', None)
274 | if updater_type == 'round_robin':
275 | updater = updaters.RoundRobin(grower, grow_layer_tuples, loss_fn,
276 | compile_model_fn, **config.updater)
277 | elif updater_type == 'all_at_once':
278 | updater = updaters.AllAtOnce(grower, grow_layer_tuples, loss_fn,
279 | compile_model_fn, **config.updater)
280 | else:
281 | updater = updaters.DummyUpdater(grow_layer_tuples)
282 |
283 | def get_update_fn(model):
284 | """Returns Per-Replica update function."""
285 | # We need to remap this as variable names change when the network is grown.
286 | variable_mapped_grow_metrics = {
287 | l.weights[0].name: metric for l, metric in grow_metrics.items()
288 | }
289 |
290 | @tf.function
291 | def _update_fn(inputs):
292 | images, labels = inputs
293 | with tf.GradientTape() as tape:
294 | logits = model(images, training=True)
295 | one_hot_labels = tf.one_hot(tf.cast(labels, tf.int32), num_classes)
296 | nll_loss = tf.reduce_mean(
297 | tf.keras.losses.categorical_crossentropy(
298 | one_hot_labels,
299 | logits,
300 | from_logits=True))
301 | l2_loss = sum(model.losses)
302 | loss = nll_loss + l2_loss
303 | # Scale the loss given the TPUStrategy will reduce sum all gradients.
304 | scaled_loss = loss / strategy.num_replicas_in_sync
305 | grads = tape.gradient(scaled_loss, model.trainable_variables)
306 | # Logging some gradient norms
307 | for grad, var in zip(grads, model.trainable_variables):
308 | if var.name in variable_mapped_grow_metrics:
309 | sq_grad = tf.math.pow(grad, 2)
310 | variable_mapped_grow_metrics[var.name].update_state(sq_grad)
311 | optimizer.apply_gradients(zip(grads, model.trainable_variables))
312 | train_metrics['train/loss'].update_state(loss)
313 | train_metrics['train/negative_log_likelihood'].update_state(nll_loss)
314 | train_metrics['train/accuracy'].update_state(labels, logits)
315 |
316 | return _update_fn
317 |
318 | def train_step(iterator, grow_iterator):
319 | """Training StepFn."""
320 | # This allows retracing. We need retrace as model is changing.
321 | update_fn = get_update_fn(model)
322 | for _ in tf.range(tf.cast(steps_per_epoch, tf.int32)):
323 | # Maybe grow.
324 | is_update = updater.is_update_iteration(optimizer.iterations)
325 | if is_update:
326 | logging.info('Growing on iteration: %s', optimizer.iterations.numpy())
327 | with strategy.scope():
328 | updater.update_network(batch_data=next(grow_iterator),
329 | optimizer=optimizer)
330 | compile_model_fn()
331 | # Regenerate the function so that the model is retracted after growing.
332 | update_fn = get_update_fn(model)
333 | logging.info('Model number of weights: %s', model.count_params())
334 | with summary_writer.as_default():
335 | logging.info('Widths after growth')
336 | for name, n_neuron in updater.get_grow_layer_stats():
337 | logging.info('%s: %d', name, n_neuron)
338 | tf.summary.scalar(f'n_neurons/{name}', n_neuron,
339 | step=optimizer.iterations)
340 | # Gradient Step.
341 | strategy.run(update_fn, args=(next(iterator),))
342 | # Logging
343 | if is_update or optimizer.iterations % config.get('log_freq', 100) == 1:
344 | logging.info('Train Loss: %.4f, Accuracy: %.2f%%',
345 | train_metrics['train/loss'].result(),
346 | train_metrics['train/accuracy'].result() * 100)
347 | total_results = {name: metric.result() for name, metric in
348 | train_metrics.items()}
349 | total_results['lr'] = optimizer.learning_rate(optimizer.iterations)
350 | total_results['params/total'] = model.count_params()
351 | for layer, metric in grow_metrics.items():
352 | total_results[f'grad/{layer.name}'] = metric.result()
353 | with summary_writer.as_default():
354 | for name, result in total_results.items():
355 | tf.summary.scalar(name, result, step=optimizer.iterations)
356 | for metric in itertools.chain(train_metrics.values(),
357 | grow_metrics.values()):
358 | metric.reset_states()
359 |
360 | def test_step(iterator, dataset_split, num_steps):
361 | """Evaluation StepFn."""
362 | @tf.function
363 | def step_fn(inputs):
364 | """Per-Replica StepFn."""
365 | images, labels = inputs
366 | logits = model(images, training=False)
367 | probs = tf.nn.softmax(logits)
368 | negative_log_likelihood = tf.reduce_mean(
369 | tf.keras.losses.sparse_categorical_crossentropy(labels, probs))
370 |
371 | eval_metrics[f'{dataset_split}/negative_log_likelihood'].update_state(
372 | negative_log_likelihood)
373 | eval_metrics[f'{dataset_split}/accuracy'].update_state(labels, probs)
374 |
375 | for _ in tf.range(tf.cast(num_steps, tf.int32)):
376 | strategy.run(step_fn, args=(next(iterator),))
377 |
378 | train_iterator = iter(train_dataset)
379 | grow_iterator = iter(grow_dataset)
380 |
381 | start_time = time.time()
382 | tb_callback = None
383 | if FLAGS.collect_profile:
384 | tb_callback = tf.keras.callbacks.TensorBoard(
385 | profile_batch=(100, 2000),
386 | log_dir=os.path.join(FLAGS.output_dir, 'logs'))
387 | tb_callback.set_model(model)
388 |
389 | for epoch in range(initial_epoch, config.train_epochs):
390 | logging.info('Starting to run epoch: %s', epoch)
391 | if tb_callback:
392 | tb_callback.on_epoch_begin(epoch)
393 | train_start_time = time.time()
394 | train_step(train_iterator, grow_iterator)
395 | train_ms_per_example = (time.time() - train_start_time) * 1e6 / batch_size
396 |
397 | current_step = (epoch + 1) * steps_per_epoch
398 | max_steps = steps_per_epoch * config.train_epochs
399 | time_elapsed = time.time() - start_time
400 | steps_per_sec = float(current_step) / time_elapsed
401 | eta_seconds = (max_steps - current_step) / steps_per_sec
402 | message = ('{:.1%} completion: epoch {:d}/{:d}. {:.1f} steps/s. '
403 | 'ETA: {:.0f} min. Time elapsed: {:.0f} min'.format(
404 | current_step / max_steps,
405 | epoch + 1,
406 | config.train_epochs,
407 | steps_per_sec,
408 | eta_seconds / 60,
409 | time_elapsed / 60))
410 | logging.info(message)
411 | if tb_callback:
412 | tb_callback.on_epoch_end(epoch)
413 |
414 | test_iterator = iter(test_dataset)
415 | logging.info('Starting to run eval at epoch: %s', epoch)
416 | test_start_time = time.time()
417 | test_step(test_iterator, 'test', steps_per_eval)
418 | test_ms_per_example = (time.time() - test_start_time) * 1e6 / batch_size
419 | logging.info('Done with eval on')
420 |
421 | logging.info('Test NLL: %.4f, Accuracy: %.2f%%',
422 | eval_metrics['test/negative_log_likelihood'].result(),
423 | eval_metrics['test/accuracy'].result() * 100)
424 | total_results = {name: metric.result() for name, metric
425 | in eval_metrics.items()}
426 | total_results['train/ms_per_example'] = train_ms_per_example
427 | total_results['test/ms_per_example'] = test_ms_per_example
428 | with summary_writer.as_default():
429 | for name, result in total_results.items():
430 | tf.summary.scalar(name, result, step=epoch + 1)
431 |
432 | for metric in eval_metrics.values():
433 | metric.reset_states()
434 |
435 | if (config.checkpoint_interval > 0 and
436 | (epoch + 1) % config.checkpoint_interval == 0):
437 | checkpoint_name = checkpoint.save(
438 | os.path.join(FLAGS.output_dir, 'checkpoint'))
439 | logging.info('Saved checkpoint to %s', checkpoint_name)
440 |
441 | final_checkpoint_name = checkpoint.save(
442 | os.path.join(FLAGS.output_dir, 'checkpoint'))
443 | logging.info('Saved last checkpoint to %s', final_checkpoint_name)
444 |
445 |
446 | if __name__ == '__main__':
447 | app.run(main)
448 |
--------------------------------------------------------------------------------
/growneuron/growers.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 GradMax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """This module implements various growing algorithms.
17 | """
18 | import functools
19 | import logging
20 | import growneuron.layers as glayers
21 | import numpy as np
22 | from scipy.sparse.linalg.eigen import arpack
23 | import tensorflow as tf
24 |
25 |
26 | class LayerGrower():
27 | """Base class for growing layer algorithms.
28 |
29 | Subclasses should implement grow_neurons.
30 | grad_fn: Should return list of variables and return aggregated gradients.
31 | grow_layers: list of GrowLayers. There are often 2 layers. First is the
32 | one we are adding neurons and the second is the layer that consumes
33 | neurons from the first layer. However in some architectures there could
34 | be some layers in between that transforms channel-wise information
35 | independetly: like Batchnorm and depth-wise convolutions. In such cases,
36 | we grow identity neurons for the layers inbetween the first and last.
37 | """
38 | epsilon = 0.
39 | scale_method = 'mean_norm'
40 | strategy = None
41 | compile_fn = lambda: None
42 | loss_fn = lambda x: x
43 |
44 | def grow_neurons(self, grow_layers, batch_data, **kwargs):
45 | raise NotImplementedError()
46 |
47 |
48 | class AddRandom(LayerGrower):
49 | """Implements random growing."""
50 | is_outgoing_zero = False
51 | is_all_zero = False
52 |
53 | def grow_neurons(self, grow_layers, batch_data, n_new=1, scale=1.):
54 | del batch_data
55 | scales = (self.epsilon, scale)
56 | new_bias = 'zeros'
57 | if self.is_all_zero:
58 | scales = (self.epsilon, self.epsilon)
59 | new_bias = 'ones'
60 | elif self.is_outgoing_zero:
61 | scales = (scale, self.epsilon)
62 | for i, layer in enumerate(grow_layers):
63 | if i == 0:
64 | # First layer
65 | layer.add_neurons(n_new, new_weights='random', is_outgoing=False,
66 | scale=scales[0], scale_method=self.scale_method,
67 | new_bias=new_bias)
68 | elif i == (len(grow_layers) - 1):
69 | # Last layer
70 | layer.add_neurons(n_new, new_weights='random', is_outgoing=True,
71 | scale=scales[1], scale_method=self.scale_method)
72 | else:
73 | if isinstance(layer, glayers.GrowLayer):
74 | layer.add_neurons_identity(n_new)
75 |
76 |
77 | class AddFirefly(AddRandom):
78 | """Implements Firefly style growing using direct optimization.
79 |
80 | Implements Eq:4 from the paper without extra candidates and splitting.
81 | https://arxiv.org/abs/2102.08574
82 | """
83 | optim_n_step = 100
84 | optim_fn = lambda self: tf.keras.optimizers.Adam()
85 |
86 | def grow_neurons(self, grow_layers, batch_data, n_new=1, scale=1.):
87 | n_old_neuron = grow_layers[0].weights[0].shape[-1]
88 | # First add neurons randomly
89 | super().grow_neurons(grow_layers, batch_data, n_new=n_new, scale=scale)
90 | self.compile_fn()
91 | # Now optimize the random initialization
92 | layer_tuple = grow_layers[0], grow_layers[-1]
93 |
94 | optimizer = self.optim_fn()
95 | target_magnitudes = []
96 | # Record the magnitude of the new_weights.
97 | for concat_axis, layer in zip([-1, -2], layer_tuple):
98 | _, new_weights = tf.split(layer.weights[0], [n_old_neuron, -1],
99 | axis=concat_axis)
100 | target_magnitudes.append(
101 | np.mean(glayers.norm_l2(new_weights, keep_dim=concat_axis)))
102 | logging.info('Minimizing loss.')
103 | weights = [l.weights[0] for l in layer_tuple]
104 |
105 | @tf.function
106 | def update_fn(inputs):
107 | with tf.GradientTape() as tape:
108 | loss = self.loss_fn(inputs)
109 | grads = tape.gradient(loss, weights)
110 | masked_grads = []
111 | for concat_axis, grad in zip([-1, -2], grads):
112 | # Apply gradient only on new weights, zero out the rest.
113 | old_wgrad, new_wgrad = tf.split(grad, [n_old_neuron, -1],
114 | axis=concat_axis)
115 | masked_grad = tf.concat([tf.zeros_like(old_wgrad), new_wgrad],
116 | axis=concat_axis)
117 | masked_grads.append(masked_grad)
118 | optimizer.apply_gradients(zip(masked_grads, weights))
119 | # Project new weights back to the target magnitude.
120 | return loss
121 |
122 | log_freq = self.optim_n_step // 10
123 | for i in range(self.optim_n_step):
124 | per_replica_losses = self.strategy.run(update_fn, args=(batch_data,))
125 | loss = self.strategy.reduce(
126 | tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None)
127 | if i % log_freq == 0:
128 | logging.info('Firefly iter: %d, loss: %s', i, loss)
129 | for concat_axis, weight, target_magnitude in zip(
130 | [-1, -2], weights, target_magnitudes):
131 | old_w, new_w = tf.split(weight, [n_old_neuron, -1],
132 | axis=concat_axis)
133 | normalized_w = glayers.normalize_l2(new_w, axis=concat_axis)
134 | normalized_new_w = normalized_w * target_magnitude
135 | weight.assign(
136 | tf.concat([old_w, normalized_new_w], axis=concat_axis))
137 | logging.info('Firefly final loss: %s', loss.numpy())
138 |
139 |
140 | class AddGradmaxOptim(AddRandom):
141 | """Implements Gradmax using direct optimization."""
142 | optim_n_step = 100
143 | optim_fn = lambda self: tf.keras.optimizers.Adam()
144 |
145 | def grow_neurons(self, grow_layers, batch_data, n_new=1, scale=1.):
146 | # For simplicity we do full backward and forward pass here, but note that
147 | # only thing we need here is inputs at l-1 and gradients at l+1. Those stay
148 | # same and don't need to be re-calculated each time.
149 | n_old_neuron = grow_layers[0].weights[0].shape[-1]
150 | # First add neurons randomly
151 | super().grow_neurons(grow_layers, batch_data, n_new=n_new, scale=scale)
152 | self.compile_fn()
153 | # Now optimize the random initialization
154 | if self.is_outgoing_zero:
155 | # We optimize incoming weights
156 | optim_layer, grad_layer = grow_layers[0], grow_layers[-1]
157 | concat_axis = -1
158 | grad_slic_fn = lambda a: a[Ellipsis, n_old_neuron:, :]
159 | else:
160 | # We optimize outgoing weights
161 | optim_layer, grad_layer = grow_layers[-1], grow_layers[0]
162 | concat_axis = -2
163 | grad_slic_fn = lambda a: a[Ellipsis, n_old_neuron:]
164 |
165 | optimizer = self.optim_fn()
166 | target_magnitude = None
167 | # Record the magnitude of the new_weights.
168 | _, new_weights = tf.split(optim_layer.weights[0], [n_old_neuron, -1],
169 | axis=concat_axis)
170 | target_magnitude = np.mean(glayers.norm_l2(new_weights,
171 | keep_dim=concat_axis))
172 | logging.info('Target magnitude: %s', target_magnitude)
173 | optim_layer_weight = optim_layer.weights[0]
174 | logging.info('Minimizing loss.')
175 |
176 | @tf.function
177 | def update_fn(inputs):
178 | with tf.GradientTape(persistent=True) as tape:
179 | loss = self.loss_fn(inputs)
180 | grad_layer_weight = grad_layer.weights[0]
181 | inner_grad = tape.gradient(loss, grad_layer_weight)
182 | # Maximize gradient norm.
183 | final_loss = -tf.norm(grad_slic_fn(inner_grad))
184 | grad = tape.gradient(final_loss, optim_layer_weight)
185 | # Apply gradient only on new weights, zero out the rest.
186 | old_wgrad, new_wgrad = tf.split(grad, [n_old_neuron, -1],
187 | axis=concat_axis)
188 | masked_grad = tf.concat([tf.zeros_like(old_wgrad), new_wgrad],
189 | axis=concat_axis)
190 | optimizer.apply_gradients([(masked_grad, optim_layer_weight)])
191 | return final_loss
192 |
193 | log_freq = self.optim_n_step // 10
194 | for i in range(self.optim_n_step):
195 | per_replica_losses = self.strategy.run(update_fn, args=(batch_data,))
196 | loss = self.strategy.reduce(
197 | tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None)
198 | if i % log_freq == 0:
199 | logging.info('Gradmax-opt: %d, loss: %s', i, loss)
200 | # Project new weights back to the target magnitude.
201 | # Record the magnitude of the new_weights.
202 | old_w, new_w = tf.split(optim_layer_weight, [n_old_neuron, -1],
203 | axis=concat_axis)
204 | normalized_w = glayers.normalize_l2(new_w, axis=concat_axis)
205 | normalized_new_w = normalized_w * target_magnitude
206 | optim_layer_weight.assign(
207 | tf.concat([old_w, normalized_new_w], axis=concat_axis))
208 | logging.info('Grad-max-opt final loss: %s', loss.numpy())
209 |
210 | def _grow_neurons_legacy(self, grow_layers, batch_data, n_new=1, scale=1.):
211 | """Old function to calculate gradmax-opt initialization efficiently."""
212 | # Note that this version doesn't work currently.
213 | # The issue here is the inputs are shared, thus it is a challenge to
214 | # uptained them and reshard. This path is efficient but might be
215 | # unncessearily complicated, thus we do full pass like above.
216 | logging.warning('This function is not doing the right thing in multi-worker'
217 | 'setting.')
218 | n_old_neuron = grow_layers[0].weights[0].shape[-1]
219 | # First get the output gradient at l+1 and the input at l-1.
220 | aux_tensor = []
221 | # For simplicity we do full backward and forward pass here, but note that
222 | # only thing we need here is inputs at l-1 and gradients at l+1. Those stay
223 | # same and don't need to be re-calculated each time.
224 | def next_layer_callback(next_inputs, next_outputs):
225 | aux_tensor.append(tf.zeros_like(next_outputs))
226 | return next_inputs, (next_outputs + aux_tensor[-1])
227 | grow_layers[-1].add_callback('add_zeros', next_layer_callback)
228 | inp_tensor = []
229 | def first_layer_callback(next_inputs, next_outputs):
230 | inp_tensor.append(next_inputs)
231 | return next_inputs, next_outputs
232 | grow_layers[0].add_callback('collect_inp', first_layer_callback)
233 |
234 | def grad_fn(inputs):
235 | with tf.GradientTape() as tape:
236 | loss = self.loss_fn(inputs)
237 | return tape.gradient(loss, aux_tensor[0])
238 | per_replica_grads = self.strategy.run(grad_fn, args=(batch_data,))
239 | out_grads = self.strategy.reduce(
240 | tf.distribute.ReduceOp.SUM, per_replica_grads, axis=None)
241 |
242 | # Second add neurons randomly
243 | super().grow_neurons(grow_layers, batch_data, n_new=n_new,
244 | scale=scale)
245 | self.compile_fn()
246 | # Now optimize the random initialization
247 | if self.is_outgoing_zero:
248 | # We optimize incoming weights
249 | optim_layer, grad_layer = grow_layers[0], grow_layers[-1]
250 | concat_axis = -1
251 | grad_slic_fn = lambda a: a[Ellipsis, n_old_neuron:, :]
252 | else:
253 | # We optimize outgoing weights
254 | optim_layer, grad_layer = grow_layers[-1], grow_layers[0]
255 | concat_axis = -2
256 | grad_slic_fn = lambda a: a[Ellipsis, n_old_neuron:]
257 |
258 | optimizer = self.optim_fn()
259 | target_magnitude = None
260 | # Record the magnitude of the new_weights.
261 | _, new_weights = tf.split(optim_layer.weights[0], [n_old_neuron, -1],
262 | axis=concat_axis)
263 | target_magnitude = np.mean(glayers.norm_l2(new_weights,
264 | keep_dim=concat_axis))
265 | logging.info('Target magnitude: %s', target_magnitude)
266 | optim_layer_weight = optim_layer.weights[0]
267 |
268 | @tf.function
269 | def update_fn(inp_tensor, out_grads):
270 | with tf.GradientTape(persistent=True) as tape:
271 | x = inp_tensor
272 | for l in grow_layers:
273 | x = l(x, training=True)
274 | # This simulates having output grads at the end. But it is way more
275 | # efficient as we don't need to run the input again through the whole
276 | # network.
277 | # dL/dx = out_grads because grad_x(x * y) = y
278 | loss = tf.reduce_sum(x*out_grads)
279 | grad_layer_weight = grad_layer.weights[0]
280 | inner_grad = tape.gradient(loss, grad_layer_weight)
281 | # Maximize gradient norm.
282 | final_loss = -tf.norm(grad_slic_fn(inner_grad))
283 | grad = tape.gradient(final_loss, optim_layer_weight)
284 | # Apply gradient only on new weights, zero out the rest.
285 | old_wgrad, new_wgrad = tf.split(grad, [n_old_neuron, -1],
286 | axis=concat_axis)
287 | masked_grad = tf.concat([tf.zeros_like(old_wgrad), new_wgrad],
288 | axis=concat_axis)
289 | optimizer.apply_gradients([(masked_grad, optim_layer_weight)])
290 | return final_loss
291 | logging.info('Maximizing gradients')
292 | log_freq = self.optim_n_step // 10
293 | for i in range(self.optim_n_step):
294 | per_replica_losses = self.strategy.run(update_fn,
295 | args=(inp_tensor[0], out_grads))
296 | loss = self.strategy.reduce(
297 | tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None)
298 | if i % log_freq == 0:
299 | logging.info('GradmaxOptim iter: %d, loss: %s', i, loss.numpy())
300 | # Project new weights back to the target magnitude.
301 | # Record the magnitude of the new_weights.
302 | old_w, new_w = tf.split(optim_layer_weight, [n_old_neuron, -1],
303 | axis=concat_axis)
304 | normalized_w = glayers.normalize_l2(new_w, axis=concat_axis)
305 | normalized_new_w = normalized_w * target_magnitude
306 | optim_layer_weight.assign(
307 | tf.concat([old_w, normalized_new_w], axis=concat_axis))
308 | logging.info('Final Grad-Norm: %f', -loss.numpy())
309 |
310 |
311 | class AddGradmax(LayerGrower):
312 | """Implements Gradmax using auxiliary layer formulation."""
313 |
314 | def grow_neurons(self, grow_layers, batch_data, n_new=1, scale=1.):
315 | if len(grow_layers) == 2:
316 | current_layer, next_layer = grow_layers
317 | identity_layers = []
318 | else:
319 | assert len(grow_layers) > 2
320 | current_layer, next_layer = grow_layers[0], grow_layers[-1]
321 | identity_layers = grow_layers[1:-1]
322 | # There is only one candidate
323 | growth_candidates = [(current_layer, next_layer)]
324 | if not isinstance(current_layer.layer, type(next_layer.layer)):
325 | # This is a temporary fix for dealing with heteregonous layers.
326 | # When two consecutive layers are different we grow randomly.
327 | # For example when a convolutional layer is followed by a fully connected
328 | # layer.
329 | logging.info('Growing randomly layers: %s %s, %s %s',
330 | current_layer.layer.name, type(current_layer.layer),
331 | next_layer.layer.name, type(next_layer.layer))
332 | tmp_grower = AddRandom()
333 | tmp_grower.epsilon = self.epsilon
334 | tmp_grower.scale_method = self.scale_method
335 | tmp_grower.is_outgoing_zero = False
336 | tmp_grower.grow_neurons(grow_layers, batch_data, n_new=n_new, scale=scale)
337 | return
338 | unused_eigenvals, eigenvecs = self.get_growth_directions(
339 | batch_data, growth_candidates, [n_new])[0]
340 | # Grow incoming connections
341 | current_layer.add_neurons(n_new, new_weights='random', is_outgoing=False,
342 | scale=self.epsilon,
343 | scale_method=self.scale_method)
344 | # Initialize intermediate layers as identity.
345 | for layer in identity_layers:
346 | if isinstance(layer, glayers.GrowLayer):
347 | layer.add_neurons_identity(n_new)
348 | # Top-k Eigenvectors
349 | new_weights = eigenvecs[:, :n_new]
350 | c_shape = next_layer.weights[0].shape
351 | if len(c_shape) == 4:
352 | # First reshape each neuron and then transpose last 2 dimensions.
353 | new_filter_shape = c_shape[:2] + [c_shape[-1], n_new]
354 | new_weights = np.reshape(new_weights, new_filter_shape)
355 | new_weights = np.transpose(new_weights, axes=(0, 1, 3, 2))
356 | elif len(c_shape) == 2:
357 | new_weights = new_weights.T
358 |
359 | next_layer.add_neurons(n_new, new_weights=new_weights, scale=scale,
360 | is_outgoing=True,
361 | scale_method=self.scale_method)
362 |
363 | def get_growth_directions(self, batch_data, growth_candidates, n_grows):
364 | """Efficiently retrieves eigen-decomposition for a set of candidates."""
365 | # Adding all callbacks.
366 | aux_layers = []
367 | post_process_fns = []
368 | for current_layer, next_layer in growth_candidates:
369 | aux_layer, post_process_fn = self.get_aux_layer(current_layer.layer,
370 | next_layer.layer)
371 | post_process_fns.append(post_process_fn)
372 | def grow_layer_callback(inputs, outputs, aux_layer=aux_layer,
373 | next_layer=next_layer):
374 | add_h = aux_layer(inputs)
375 | def next_layer_callback(next_inputs, next_outputs):
376 | return next_inputs, (next_outputs + add_h)
377 | next_layer.add_callback('add_aux', next_layer_callback)
378 | return inputs, outputs
379 | current_layer.add_callback('pass_aux', grow_layer_callback)
380 | aux_layers.append(aux_layer)
381 |
382 | def grad_fn(inputs):
383 | with tf.GradientTape() as tape:
384 | loss = self.loss_fn(inputs)
385 | grad_vars = [aux_layer.weights[0] for aux_layer in aux_layers]
386 | return tape.gradient(loss, grad_vars)
387 | per_replica_grads = self.strategy.run(grad_fn,
388 | args=(batch_data,))
389 | aux_grads = self.strategy.reduce(
390 | tf.distribute.ReduceOp.SUM, per_replica_grads, axis=None)
391 | grow_matrices = [
392 | post_process_fn(g)
393 | for g, post_process_fn in zip(aux_grads, post_process_fns)]
394 | # Reset Callbacks
395 | for current_layer, next_layer in growth_candidates:
396 | current_layer.reset_callbacks()
397 | next_layer.reset_callbacks()
398 | results = []
399 | # Calculate eigenvalues
400 | for grow_matrix, n_grow in zip(grow_matrices, n_grows):
401 | # M^{l+1} by M^{l+1}
402 | if n_grow > 0:
403 | # svds is equivalent to calling eigsh on M.T @ M (without materialiazing
404 | # this matrix) which is faster or slower (depending on the shape of M)
405 | _, s, vh = arpack.svds(grow_matrix, k=n_grow,
406 | return_singular_vectors='vh')
407 | eigenvals, eigenvecs = (s**2)[::-1], vh[::-1].T
408 | else:
409 | s, _, v = tf.linalg.svd(grow_matrix)
410 | eigenvals, eigenvecs = s**2, v
411 | results.append((eigenvals, eigenvecs))
412 | return results
413 |
414 | def get_aux_layer(self, first_layer, second_layer):
415 | """Creates auxilarly layers for growing new neurons between layers."""
416 | l = tf.keras.layers
417 | if isinstance(first_layer, l.Dense) and isinstance(second_layer, l.Dense):
418 | aux_layer = l.Dense(second_layer.units, activation=None, use_bias=False,
419 | kernel_initializer='zeros')
420 | post_process_fn = lambda a: a
421 | elif (isinstance(first_layer, l.Conv2D) and
422 | isinstance(second_layer, l.Conv2D)):
423 | # Combined auxiliary kernel would be the size of k1+k2-1.
424 | kernel_size = [k1+k2-1 for k1, k2 in
425 | zip(first_layer.kernel_size, second_layer.kernel_size)]
426 | # The auxiliary layer should have the combined stride.
427 | # Current implementation assumes tuple strides.
428 | strides = [(s1 + s2) if ((s1 > 1) and (s2 > 1)) else (s1 + s2 -1)
429 | for s1, s2 in zip(first_layer.strides, second_layer.strides)]
430 | # Current implementation assumes paddings are same for the 2 layers.
431 | aux_layer = l.Conv2D(second_layer.filters, kernel_size, activation=None,
432 | use_bias=False, padding=first_layer.padding,
433 | kernel_initializer='zeros', strides=strides)
434 | post_process_fn = functools.partial(
435 | process_conv_aux_gradient,
436 | second_kernel_size=second_layer.kernel_size)
437 | else:
438 | raise ValueError('Not Supported')
439 |
440 | return aux_layer, post_process_fn
441 |
442 |
443 | def process_conv_aux_gradient(grad, second_kernel_size):
444 | """Process the gradients of convolutional layer to generate grow matrix."""
445 | # shape(grad): ksize X ksize X m0 X m2 ; ksize=k1+k2-1
446 | # second_kernel_size == k2
447 | grad = tf.transpose(grad, perm=(2, 0, 1, 3))
448 | # shape(grad): m0 X ksize X ksize X m2
449 | patched_grow_matrix = extract_image_patches(grad, second_kernel_size)
450 | # shape(patched_grow_matrix): m0 X k1 X k1 X (m2 * k2 * k2)
451 | grow_matrix = tf.reshape(patched_grow_matrix,
452 | [-1, patched_grow_matrix.shape[-1]])
453 | # shape(patched_grow_matrix): (m0 * k1 * k1) X (m2 * k2 * k2)
454 | return grow_matrix
455 |
456 |
457 | def extract_image_patches(x, kernel_size, stride=(1, 1)):
458 | """Extract convolutional patches from the layer.
459 |
460 | Manual replacement of tf.extract_image_patches, since its gradient cannot
461 | be evaluated on TPU.
462 |
463 | Args:
464 | x: batched input data. Size: [batch, in_height, in_width, in_channels]
465 | kernel_size: Tuple of two integers. Size of kernel.
466 | stride: Tuple of two integers. Stride size.
467 |
468 | Returns:
469 | 4D Tensor (batch, in_rows, in_cols, patch_size) of extracted patches.
470 | """
471 | in_channels = x.get_shape()[3]
472 | kh, kw = kernel_size
473 | tile_filter = np.zeros(shape=[kh, kw, in_channels, kh * kw], dtype=np.float32)
474 | for i in range(kh):
475 | for j in range(kw):
476 | tile_filter[i, j, :, i * kw + j] = 1.0
477 |
478 | tile_filter_op = tf.constant(tile_filter, dtype=tf.float32)
479 | output = tf.nn.depthwise_conv2d(
480 | x, tile_filter_op, strides=[1, *stride, 1], padding='VALID')
481 | # reshaping below is needed so that 4th dimension of the output can be
482 | # reshaped into kernel[0] * kernel[1] * in_channels.
483 | batch, in_rows, in_cols, _ = output.get_shape()
484 | output = tf.reshape(
485 | output, shape=[batch, in_rows, in_cols, in_channels, kh * kw])
486 | output = tf.transpose(output, perm=[0, 1, 2, 4, 3])
487 | output = tf.reshape(output, [batch, in_rows, in_cols, -1])
488 |
489 | return output
490 |
--------------------------------------------------------------------------------