├── .gitignore ├── README.md ├── configs ├── cifar10_resnet18.py ├── imagenet_resnet50.py └── mnist_mlp.py ├── figures └── main.png ├── input_pipeline.py ├── main.py ├── models.py ├── pruner.py ├── requirements.txt ├── train.py └── train_utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | logdir/ -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Critical Influence of Overparameterization on Sharpness-aware Minimization 2 | 3 | This repository contains JAX/Flax source code for reproducing key results of the UAI 2025 paper: 4 | 5 | > **Critical Influence of Overparameterization on Sharpness-aware Minimization** \ 6 | [Sungbin Shin](https://ssbin4.github.io/)\***¹** [Dongyeop Lee](https://edong6768.github.io/)\***¹**, [Maksym Andriushchenko](https://www.andriushchenko.me/)**²**, and [Namhoon Lee](https://namhoonlee.github.io/)**¹** (*equal contribution) \ 7 | **¹** Pohang University of Science and Technology (POSTECH), **²** École Polytechnique Fédérale de Lausanne (EPFL) \ 8 | Paper: https://arxiv.org/abs/2311.17539 9 | 10 | 11 | ```bibtex 12 | @inproceedings{shin2025critical, 13 | title={Critical Influence of Overparameterization on Sharpness-aware Minimization}, 14 | author={Shin, Sungbin and Lee, Dongyeop and Andriushchenko, Maksym and Lee, Namhoon}, 15 | booktitle={The 41th Conference on Uncertainty in Artificial Intelligence}, 16 | year={2025} 17 | } 18 | ``` 19 | 20 | ## TL;DR 21 | 22 | We uncover both empirical and theoretical results that indicate a critical influence of overparameterization on SAM. 23 | 24 | ## Abstract 25 | 26 | ![fig](./figures/main.png) 27 | 28 | > Training overparameterized neural networks often yields solutions with varying generalization capabilities, even when achieving similar training losses. Recent evidence indicates a strong correlation between the sharpness of a minimum and its generalization error, leading to increased interest in optimization methods that explicitly seek flatter minima for improved generalization. Despite its contemporary relevance to overparameterization, however, this sharpness-aware minimization (SAM) strategy has not been studied much yet as to exactly how it is affected by overparameterization. In this work, we analyze SAM under varying degrees of overparameterization, presenting both empirical and theoretical findings that reveal its critical influence on SAM's effectiveness. First, we conduct extensive numerical experiments across diverse domains, demonstrating that SAM consistently benefits from overparameterization. Next, we attribute this phenomenon to the interplay between the enlarged solution space and increased implicit bias resulting from overparameterization. Furthermore, we show that this effect is particularly pronounced in practical settings involving label noise and sparsity, and yet, sufficient regularization is necessary. Last but not least, we provide other theoretical insights into how overparameterization helps SAM achieve minima with more uniform Hessian moments compared to SGD, and much faster convergence at a linear rate. 29 | 30 | 31 | ## Environments 32 | 33 | ### Python 34 | - python 3.8.0 35 | 36 | ### cuda 37 | - cuda 11.4.4 38 | - cudnn 8.6.0 39 | - nccl 2.11.4 40 | 41 | ### Dependencies 42 | ```bash 43 | pip install -r requirements.txt 44 | ``` 45 | 46 | ## Usage 47 | ```bash 48 | python main.py --workdir={logging_dir} --config={config_file} 49 | ``` 50 | 51 | Examples of the config files are located in the `configs` directory. 52 | 53 | The degree of overparameterization is determined by `config.num_neurons` for MLP and `config.num_filters` for ResNet, while the degree of sparsification is determined by `config.sparsity`. 54 | 55 | -------------------------------------------------------------------------------- /configs/cifar10_resnet18.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The Flax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # Copyright 2021 The Flax Authors. 16 | # 17 | # Licensed under the Apache License, Version 2.0 (the "License"); 18 | # you may not use this file except in compliance with the License. 19 | # You may obtain a copy of the License at 20 | # 21 | # http://www.apache.org/licenses/LICENSE-2.0 22 | # 23 | # Unless required by applicable law or agreed to in writing, software 24 | # distributed under the License is distributed on an "AS IS" BASIS, 25 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 26 | # See the License for the specific language governing permissions and 27 | # limitations under the License. 28 | """MNIST / 3-layer MLP""" 29 | 30 | import ml_collections 31 | 32 | def get_config(): 33 | """Get the default hyperparameter configuration.""" 34 | config = ml_collections.ConfigDict() 35 | 36 | # As defined in the `models` module. 37 | config.model = 'ResNet18' 38 | # `name` argument of tensorflow_datasets.builder() 39 | config.dataset = 'cifar10' 40 | 41 | config.lr_scheduler = 'step' 42 | config.learning_rate = 0.1 43 | config.momentum = 0.9 44 | config.batch_size = 128 45 | config.weight_decay = 0.0005 46 | 47 | config.num_epochs = 200 48 | config.log_every_steps = 100 49 | 50 | config.cache = False 51 | config.half_precision = False 52 | 53 | config.optimizer = 'sam' 54 | config.rho = 0.05 55 | 56 | config.seed = 1 57 | 58 | config.pruner='random' 59 | config.sparsity = 0.0 60 | 61 | config.num_filters = 64 62 | 63 | config.restore_checkpoint = False 64 | 65 | # If num_train_steps==-1 then the number of training steps is calculated from 66 | # num_epochs using the entire dataset. Similarly for steps_per_eval. 67 | config.num_train_steps = -1 68 | config.steps_per_eval = -1 69 | return config -------------------------------------------------------------------------------- /configs/imagenet_resnet50.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The Flax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # Copyright 2021 The Flax Authors. 16 | # 17 | # Licensed under the Apache License, Version 2.0 (the "License"); 18 | # you may not use this file except in compliance with the License. 19 | # You may obtain a copy of the License at 20 | # 21 | # http://www.apache.org/licenses/LICENSE-2.0 22 | # 23 | # Unless required by applicable law or agreed to in writing, software 24 | # distributed under the License is distributed on an "AS IS" BASIS, 25 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 26 | # See the License for the specific language governing permissions and 27 | # limitations under the License. 28 | """MNIST / 3-layer MLP""" 29 | 30 | import ml_collections 31 | 32 | def get_config(): 33 | """Get the default hyperparameter configuration.""" 34 | config = ml_collections.ConfigDict() 35 | 36 | # As defined in the `models` module. 37 | config.model = 'ResNet50_ImageNet' 38 | # `name` argument of tensorflow_datasets.builder() 39 | config.dataset = 'imagenet2012' 40 | 41 | config.lr_scheduler = 'step_mnist' 42 | config.learning_rate = 0.2 43 | config.warmup_steps = 5000 44 | config.momentum = 0.9 45 | config.batch_size = 512 46 | config.weight_decay = 0.0001 47 | 48 | config.num_epochs = 90 49 | config.log_every_steps = 100 50 | 51 | config.cache = False 52 | config.half_precision = False 53 | 54 | config.optimizer = 'sam' 55 | config.rho = 0.05 56 | 57 | config.seed = 1 58 | 59 | config.pruner='random' 60 | config.sparsity = 0.0 61 | 62 | config.restore_checkpoint = False 63 | 64 | config.num_filters = 64 65 | 66 | # If num_train_steps==-1 then the number of training steps is calculated from 67 | # num_epochs using the entire dataset. Similarly for steps_per_eval. 68 | config.num_train_steps = -1 69 | config.steps_per_eval = -1 70 | return config -------------------------------------------------------------------------------- /configs/mnist_mlp.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The Flax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # Copyright 2021 The Flax Authors. 16 | # 17 | # Licensed under the Apache License, Version 2.0 (the "License"); 18 | # you may not use this file except in compliance with the License. 19 | # You may obtain a copy of the License at 20 | # 21 | # http://www.apache.org/licenses/LICENSE-2.0 22 | # 23 | # Unless required by applicable law or agreed to in writing, software 24 | # distributed under the License is distributed on an "AS IS" BASIS, 25 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 26 | # See the License for the specific language governing permissions and 27 | # limitations under the License. 28 | """MNIST / 3-layer MLP""" 29 | 30 | import ml_collections 31 | 32 | def get_config(): 33 | """Get the default hyperparameter configuration.""" 34 | config = ml_collections.ConfigDict() 35 | 36 | # As defined in the `models` module. 37 | config.model = 'MLP' 38 | # `name` argument of tensorflow_datasets.builder() 39 | config.dataset = 'mnist' 40 | 41 | config.lr_scheduler = 'step_mnist' 42 | config.learning_rate = 0.1 43 | config.momentum = 0.9 44 | config.batch_size = 128 45 | config.weight_decay = 0.0001 46 | 47 | config.num_epochs = 100 48 | config.log_every_steps = 100 49 | 50 | config.cache = False 51 | config.half_precision = False 52 | 53 | config.optimizer = 'sam' 54 | config.rho = 0.05 55 | 56 | config.seed = 1 57 | 58 | config.pruner='random' 59 | config.sparsity = 0.0 60 | 61 | config.num_neurons=[300, 100] 62 | 63 | config.restore_checkpoint = False 64 | 65 | # If num_train_steps==-1 then the number of training steps is calculated from 66 | # num_epochs using the entire dataset. Similarly for steps_per_eval. 67 | config.num_train_steps = -1 68 | config.steps_per_eval = -1 69 | return config -------------------------------------------------------------------------------- /figures/main.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LOG-postech/SAM-overparam/a96543a160555481d5db3513daf67370e9062b5e/figures/main.png -------------------------------------------------------------------------------- /input_pipeline.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The Flax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # Some parts of the code are borrowed from https://github.com/google/flax/blob/main/examples/imagenet/input_pipeline.py 16 | 17 | """{mnist, cifar-10, imagenet} input pipeline. 18 | """ 19 | 20 | import jax 21 | import tensorflow as tf 22 | import tensorflow_datasets as tfds 23 | 24 | CIFAR10_MEAN_RGB = [0.4914 * 255, 0.4822 * 255, 0.4465 * 255] 25 | CIFAR10_STDDEV_RGB = [0.2023 * 255, 0.1994 * 255, 0.2010 * 255] 26 | 27 | CROP_PADDING = 32 28 | 29 | IMAGENET_MEAN_RGB = [0.485 * 255, 0.456 * 255, 0.406 * 255] 30 | IMAGENET_STDDEV_RGB = [0.229 * 255, 0.224 * 255, 0.225 * 255] 31 | 32 | 33 | def distorted_bounding_box_crop( 34 | image_bytes, 35 | bbox, 36 | min_object_covered=0.1, 37 | aspect_ratio_range=(0.75, 1.33), 38 | area_range=(0.05, 1.0), 39 | max_attempts=100, 40 | ): 41 | """Generates cropped_image using one of the bboxes randomly distorted. 42 | 43 | See `tf.image.sample_distorted_bounding_box` for more documentation. 44 | 45 | Args: 46 | image_bytes: `Tensor` of binary image data. 47 | bbox: `Tensor` of bounding boxes arranged `[1, num_boxes, coords]` 48 | where each coordinate is [0, 1) and the coordinates are arranged 49 | as `[ymin, xmin, ymax, xmax]`. If num_boxes is 0 then use the whole 50 | image. 51 | min_object_covered: An optional `float`. Defaults to `0.1`. The cropped 52 | area of the image must contain at least this fraction of any bounding 53 | box supplied. 54 | aspect_ratio_range: An optional list of `float`s. The cropped area of the 55 | image must have an aspect ratio = width / height within this range. 56 | area_range: An optional list of `float`s. The cropped area of the image 57 | must contain a fraction of the supplied image within this range. 58 | max_attempts: An optional `int`. Number of attempts at generating a cropped 59 | region of the image of the specified constraints. After `max_attempts` 60 | failures, return the entire image. 61 | Returns: 62 | cropped image `Tensor` 63 | """ 64 | shape = tf.io.extract_jpeg_shape(image_bytes) 65 | sample_distorted_bounding_box = tf.image.sample_distorted_bounding_box( 66 | shape, 67 | bounding_boxes=bbox, 68 | min_object_covered=min_object_covered, 69 | aspect_ratio_range=aspect_ratio_range, 70 | area_range=area_range, 71 | max_attempts=max_attempts, 72 | use_image_if_no_bounding_boxes=True, 73 | ) 74 | bbox_begin, bbox_size, _ = sample_distorted_bounding_box 75 | 76 | # Crop the image to the specified bounding box. 77 | offset_y, offset_x, _ = tf.unstack(bbox_begin) 78 | target_height, target_width, _ = tf.unstack(bbox_size) 79 | crop_window = tf.stack([offset_y, offset_x, target_height, target_width]) 80 | image = tf.io.decode_and_crop_jpeg(image_bytes, crop_window, channels=3) 81 | 82 | return image 83 | 84 | 85 | def _resize(image, image_size): 86 | return tf.image.resize([image], [image_size, image_size], 87 | method=tf.image.ResizeMethod.BICUBIC)[0] 88 | 89 | 90 | def _at_least_x_are_equal(a, b, x): 91 | """At least `x` of `a` and `b` `Tensors` are equal.""" 92 | match = tf.equal(a, b) 93 | match = tf.cast(match, tf.int32) 94 | return tf.greater_equal(tf.reduce_sum(match), x) 95 | 96 | 97 | def _decode_and_random_crop(image_bytes, image_size): 98 | """Make a random crop of image_size.""" 99 | bbox = tf.constant([0.0, 0.0, 1.0, 1.0], dtype=tf.float32, shape=[1, 1, 4]) 100 | image = distorted_bounding_box_crop( 101 | image_bytes, 102 | bbox, 103 | min_object_covered=0.1, 104 | aspect_ratio_range=(3.0 / 4, 4.0 / 3.0), 105 | area_range=(0.08, 1.0), 106 | max_attempts=10, 107 | ) 108 | original_shape = tf.io.extract_jpeg_shape(image_bytes) 109 | bad = _at_least_x_are_equal(original_shape, tf.shape(image), 3) 110 | 111 | image = tf.cond( 112 | bad, 113 | lambda: _decode_and_center_crop(image_bytes, image_size), 114 | lambda: _resize(image, image_size), 115 | ) 116 | 117 | return image 118 | 119 | 120 | def _decode_and_center_crop(image_bytes, image_size): 121 | """Crops to center of image with padding then scales image_size.""" 122 | shape = tf.io.extract_jpeg_shape(image_bytes) 123 | image_height = shape[0] 124 | image_width = shape[1] 125 | 126 | padded_center_crop_size = tf.cast( 127 | ( 128 | (image_size / (image_size + CROP_PADDING)) 129 | * tf.cast(tf.minimum(image_height, image_width), tf.float32) 130 | ), 131 | tf.int32, 132 | ) 133 | 134 | offset_height = ((image_height - padded_center_crop_size) + 1) // 2 135 | offset_width = ((image_width - padded_center_crop_size) + 1) // 2 136 | crop_window = tf.stack([ 137 | offset_height, 138 | offset_width, 139 | padded_center_crop_size, 140 | padded_center_crop_size, 141 | ]) 142 | image = tf.io.decode_and_crop_jpeg(image_bytes, crop_window, channels=3) 143 | image = _resize(image, image_size) 144 | 145 | return image 146 | 147 | 148 | def normalize_image(image): 149 | """Normalize the given image""" 150 | image -= tf.constant(IMAGENET_MEAN_RGB, shape=[1, 1, 3], dtype=image.dtype) 151 | image /= tf.constant(IMAGENET_STDDEV_RGB, shape=[1, 1, 3], dtype=image.dtype) 152 | return image 153 | 154 | def cifar10_augment(image, crop_padding=4, flip_lr=True): 155 | """Augment small image with random crop and h-flip. 156 | Args: 157 | image: image to augment 158 | crop_padding: random crop range 159 | flip_lr: if True perform random horizontal flip 160 | Returns: 161 | augmented image 162 | """ 163 | HEIGHT = 32 164 | WIDTH = 32 165 | NUM_CHANNELS = 3 166 | 167 | assert crop_padding >= 0 168 | if crop_padding > 0: 169 | # Pad with reflection padding 170 | # (See https://arxiv.org/abs/1605.07146) 171 | # Section 3 172 | image = tf.pad( 173 | image, [[crop_padding, crop_padding], 174 | [crop_padding, crop_padding], [0, 0]], 'CONSTANT') 175 | 176 | # Randomly crop a [HEIGHT, WIDTH] section of the image. 177 | image = tf.image.random_crop(image, [HEIGHT, WIDTH, NUM_CHANNELS]) 178 | 179 | if flip_lr: 180 | # Randomly flip the image horizontally. 181 | image = tf.image.random_flip_left_right(image) 182 | 183 | return image 184 | 185 | def cifar10_process_train_sample(x): 186 | cifar10_mean_rgb = tf.constant(CIFAR10_MEAN_RGB, shape=[1, 1, 3], dtype=tf.float32) 187 | cifar10_std_rgb = tf.constant(CIFAR10_STDDEV_RGB, shape=[1, 1, 3], dtype=tf.float32) 188 | image = tf.cast(x['image'], tf.float32) 189 | image = cifar10_augment(image, crop_padding=4, flip_lr=True) 190 | image = (image - cifar10_mean_rgb) / cifar10_std_rgb 191 | batch = {'image': image, 'label': x['label']} 192 | return batch 193 | 194 | def cifar10_process_test_sample(x): 195 | cifar10_mean_rgb = tf.constant(CIFAR10_MEAN_RGB, shape=[1, 1, 3], dtype=tf.float32) 196 | cifar10_std_rgb = tf.constant(CIFAR10_STDDEV_RGB, shape=[1, 1, 3], dtype=tf.float32) 197 | image = tf.cast(x['image'], tf.float32) 198 | image = (image - cifar10_mean_rgb) / cifar10_std_rgb 199 | batch = {'image': image, 'label': x['label']} 200 | return batch 201 | 202 | def imagenet_preprocess_for_train(image_bytes, dtype=tf.float32, image_size=224): 203 | """Preprocesses the given image for training. 204 | 205 | Args: 206 | image_bytes: `Tensor` representing an image binary of arbitrary size. 207 | dtype: data type of the image. 208 | image_size: image size. 209 | 210 | Returns: 211 | A preprocessed image `Tensor`. 212 | """ 213 | image = _decode_and_random_crop(image_bytes, image_size) 214 | image = tf.reshape(image, [image_size, image_size, 3]) 215 | image = tf.image.random_flip_left_right(image) 216 | image = normalize_image(image) 217 | image = tf.image.convert_image_dtype(image, dtype=dtype) 218 | return image 219 | 220 | def imagenet_preprocess_for_eval(image_bytes, dtype=tf.float32, image_size=224): 221 | """Preprocesses the given image for evaluation. 222 | 223 | Args: 224 | image_bytes: `Tensor` representing an image binary of arbitrary size. 225 | dtype: data type of the image. 226 | image_size: image size. 227 | 228 | Returns: 229 | A preprocessed image `Tensor`. 230 | """ 231 | image = _decode_and_center_crop(image_bytes, image_size) 232 | image = tf.reshape(image, [image_size, image_size, 3]) 233 | image = normalize_image(image) 234 | image = tf.image.convert_image_dtype(image, dtype=dtype) 235 | return image 236 | 237 | def mnist_process_sample(x): 238 | """Proprocess the mnist image (normalizing)""" 239 | image = tf.cast(x['image'], tf.float32) 240 | image = image / 255. 241 | batch = {'image': image, 'label': x['label']} 242 | return batch 243 | 244 | def create_split(dataset, dataset_builder, batch_size, train, cache=False): 245 | """Creates a split from the ImageNet dataset using TensorFlow Datasets. 246 | Args: 247 | dataset_builder: TFDS dataset builder for ImageNet. 248 | batch_size: the batch size returned by the data pipeline. 249 | train: Whether to load the train or evaluation split. 250 | cache: Whether to cache the dataset. 251 | Returns: 252 | A `tf.data.Dataset`. 253 | """ 254 | if train: 255 | train_examples = dataset_builder.info.splits['train'].num_examples 256 | split_size = train_examples // jax.process_count() 257 | start = jax.process_index() * split_size 258 | split = f'train[{start}:{start + split_size}]' 259 | else: 260 | if dataset == 'imagenet2012': 261 | validate_examples = dataset_builder.info.splits['validation'].num_examples 262 | elif dataset == 'cifar10' or dataset == 'mnist': 263 | validate_examples = dataset_builder.info.splits['test'].num_examples 264 | split_size = validate_examples // jax.process_count() 265 | start = jax.process_index() * split_size 266 | if dataset == 'imagenet2012': 267 | split = f'validation[{start}:{start + split_size}]' 268 | elif dataset == 'cifar10' or dataset == 'mnist': 269 | split = f'test[{start}:{start + split_size}]' 270 | 271 | def decode_example(example): 272 | if dataset == 'cifar10': 273 | if train: 274 | return cifar10_process_train_sample(example) 275 | else: 276 | return cifar10_process_test_sample(example) 277 | elif dataset == 'imagenet2012': 278 | if train: 279 | image = imagenet_preprocess_for_train(example['image']) 280 | else: 281 | image = imagenet_preprocess_for_eval(example['image']) 282 | return {'image': image, 'label': example['label']} 283 | elif dataset == 'mnist': 284 | return mnist_process_sample(example) 285 | 286 | 287 | kwargs=dict() 288 | num_train_samples = dataset_builder.info.splits['train'].num_examples 289 | 290 | 291 | if dataset == 'cifar10': 292 | ds = tfds.load(dataset, split=split, **kwargs).cache() 293 | elif dataset == 'imagenet2012': 294 | ds = dataset_builder.as_dataset(split=split, decoders={'image': tfds.decode.SkipDecoding(),}) 295 | elif dataset == 'mnist': 296 | ds = dataset_builder.as_dataset(split=split) 297 | 298 | if dataset == 'imagenet2012': 299 | options = tf.data.Options() 300 | options.experimental_threading.private_threadpool_size = 0 301 | ds = ds.with_options(options) 302 | 303 | if cache: 304 | ds = ds.cache() 305 | 306 | if train: 307 | ds = ds.repeat() 308 | ds = ds.shuffle(2_000, seed=0) 309 | 310 | ds = ds.map(decode_example, num_parallel_calls=tf.data.experimental.AUTOTUNE) 311 | ds = ds.batch(batch_size, drop_remainder=True) 312 | 313 | if not train: 314 | ds = ds.repeat() 315 | 316 | ds = ds.prefetch(tf.data.AUTOTUNE) 317 | else: 318 | ds = ds.map(decode_example, num_parallel_calls=tf.data.experimental.AUTOTUNE) 319 | if train: 320 | ds = ds.shuffle(num_train_samples, seed=0, reshuffle_each_iteration=True) 321 | 322 | if train: 323 | ds = ds.batch(batch_size) 324 | ds = ds.repeat() 325 | else: 326 | ds = ds.batch(batch_size) 327 | ds = ds.repeat() 328 | 329 | ds = ds.prefetch(10) 330 | 331 | return ds 332 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The Flax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # Some parts of the code are borrowed from https://github.com/google/flax/blob/main/examples/imagenet/main.py. 16 | 17 | """ 18 | Main file to run the code. 19 | """ 20 | 21 | from absl import app 22 | from absl import flags 23 | from absl import logging 24 | from clu import platform 25 | import jax 26 | from ml_collections import config_flags 27 | import tensorflow as tf 28 | 29 | import train 30 | 31 | import os 32 | import shutil 33 | 34 | FLAGS = flags.FLAGS 35 | 36 | flags.DEFINE_string('workdir', None, 'Directory to store model data.') 37 | config_flags.DEFINE_config_file( 38 | 'config', 39 | None, 40 | 'File path to the training hyperparameter configuration.', 41 | lock_config=True) 42 | 43 | 44 | def main(argv): 45 | if len(argv) > 1: 46 | raise app.UsageError('Too many command-line arguments.') 47 | 48 | # Hide any GPUs from TensorFlow. Otherwise TF might reserve memory and make 49 | # it unavailable to JAX. 50 | tf.config.experimental.set_visible_devices([], 'GPU') 51 | 52 | logging.info('JAX process: %d / %d', jax.process_index(), jax.process_count()) 53 | logging.info('JAX local devices: %r', jax.local_devices()) 54 | 55 | # Add a note so that we can tell which task is which JAX host. 56 | # (Depending on the platform task 0 is not guaranteed to be host 0) 57 | platform.work_unit().set_task_status(f'process_index: {jax.process_index()}, ' 58 | f'process_count: {jax.process_count()}') 59 | platform.work_unit().create_artifact(platform.ArtifactType.DIRECTORY, 60 | FLAGS.workdir, 'workdir') 61 | 62 | # For reproducible results 63 | os.environ["XLA_FLAGS"] = "xla_gpu_deterministic_reductions" 64 | os.environ["TF_CUDNN_DETERMINISTIC"] = "1" 65 | 66 | if FLAGS.config.optimizer == 'sam': 67 | assert (FLAGS.config.rho > 0.0) 68 | 69 | workdir_suffix = os.path.join( 70 | 'dataset_' + FLAGS.config.dataset, 71 | 'optimizer_' + FLAGS.config.optimizer, 72 | 'model_' + FLAGS.config.model, 73 | 'lr_' + str(FLAGS.config.learning_rate), 74 | 'wd_' + str(FLAGS.config.weight_decay), 75 | 'rho_' + str(FLAGS.config.rho), 76 | 'pruner_' + str(FLAGS.config.pruner), 77 | 'sparsity_' + str(FLAGS.config.sparsity), 78 | 'seed_' + str(FLAGS.config.seed) 79 | ) 80 | 81 | output_dir = os.path.join(FLAGS.workdir, workdir_suffix) 82 | 83 | if not FLAGS.config.restore_checkpoint: 84 | if os.path.exists(output_dir): # job restarted by cluster 85 | for f in os.listdir(output_dir): 86 | if os.path.isdir(os.path.join(output_dir, f)): 87 | shutil.rmtree(os.path.join(output_dir, f)) 88 | else: 89 | os.remove(os.path.join(output_dir, f)) 90 | else: 91 | os.makedirs(output_dir) 92 | 93 | train.train_and_evaluate(FLAGS.config, output_dir) 94 | 95 | if __name__ == '__main__': 96 | flags.mark_flags_as_required(['config', 'workdir']) 97 | app.run(main) 98 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The Flax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Flax implementation of ResNet V1.""" 16 | # Some parts of the code are taken from https://github.com/google/flax/blob/main/examples/imagenet/models.py. 17 | 18 | from functools import partial 19 | from typing import Any, Callable, Sequence, Tuple 20 | 21 | from flax import linen as nn 22 | import jax.numpy as jnp 23 | 24 | from jax import random 25 | 26 | Array = Any 27 | ModuleDef = Any 28 | KeyArray = random.KeyArray 29 | DTypeLikeInexact = Any 30 | 31 | class MLP(nn.Module): 32 | """Standard MLP.""" 33 | num_classes: int 34 | num_neurons: Sequence[int] 35 | dtype: Any = jnp.float32 36 | act: Callable = nn.relu 37 | 38 | @nn.compact 39 | def __call__(self, x, train: bool = True): 40 | x = x.reshape((x.shape[0], -1)) 41 | for i, num_neuron in enumerate(self.num_neurons): 42 | x = nn.Dense(num_neuron, dtype=self.dtype)(x) 43 | x = self.act(x) 44 | x = nn.Dense(self.num_classes, dtype=self.dtype)(x) 45 | x = jnp.asarray(x, self.dtype) 46 | return x 47 | 48 | class ResNetBlock(nn.Module): 49 | """ResNet block.""" 50 | filters: int 51 | conv: ModuleDef 52 | norm: ModuleDef 53 | act: Callable 54 | strides: Tuple[int, int] = (1, 1) 55 | 56 | @nn.compact 57 | def __call__(self, x,): 58 | residual = x 59 | y = self.conv(self.filters, (3, 3), strides=self.strides, padding=[(1, 1), (1, 1)])(x) 60 | y = self.norm()(y) 61 | y = self.act(y) 62 | y = self.conv(self.filters, (3, 3), strides=(1, 1), padding=[(1, 1), (1, 1)])(y) 63 | y = self.norm()(y) 64 | 65 | if residual.shape != y.shape or self.strides != (1, 1): 66 | residual = self.conv(self.filters, (1, 1), 67 | self.strides, padding=[(0, 0), (0, 0)], name='conv_proj')(residual) 68 | residual = self.norm(name='norm_proj')(residual) 69 | 70 | return self.act(residual + y) 71 | 72 | 73 | class BottleneckResNetBlock(nn.Module): 74 | """Bottleneck ResNet block.""" 75 | 76 | filters: int 77 | conv: ModuleDef 78 | norm: ModuleDef 79 | act: Callable 80 | strides: Tuple[int, int] = (1, 1) 81 | 82 | @nn.compact 83 | def __call__(self, x): 84 | residual = x 85 | y = self.conv(self.filters, (1, 1))(x) 86 | y = self.norm()(y) 87 | y = self.act(y) 88 | y = self.conv(self.filters, (3, 3), self.strides)(y) 89 | y = self.norm()(y) 90 | y = self.act(y) 91 | y = self.conv(self.filters * 4, (1, 1))(y) 92 | y = self.norm(scale_init=nn.initializers.zeros_init())(y) 93 | 94 | if residual.shape != y.shape: 95 | residual = self.conv( 96 | self.filters * 4, (1, 1), self.strides, name='conv_proj' 97 | )(residual) 98 | residual = self.norm(name='norm_proj')(residual) 99 | 100 | return self.act(residual + y) 101 | 102 | 103 | class ResNetBigCifar(nn.Module): 104 | """ResNetV1.""" 105 | stage_sizes: Sequence[int] 106 | block_cls: ModuleDef 107 | num_classes: int 108 | num_filters: int = 64 109 | dtype: Any = jnp.float32 110 | act: Callable = nn.relu 111 | conv: ModuleDef = nn.Conv 112 | norm: ModuleDef = nn.BatchNorm 113 | 114 | @nn.compact 115 | def __call__(self, x, train: bool = True): 116 | conv = partial(self.conv, use_bias=False, dtype=self.dtype) 117 | norm = partial(self.norm, 118 | use_running_average=not train, 119 | momentum=0.9, 120 | epsilon=1e-5, 121 | dtype=self.dtype) 122 | 123 | x = conv(self.num_filters, (3, 3), strides=(1, 1), 124 | padding=[(1, 1), (1, 1)], 125 | name='conv_init')(x) 126 | x = norm(name='bn_init')(x) 127 | x = nn.relu(x) 128 | 129 | for i, block_size in enumerate(self.stage_sizes): 130 | for j in range(block_size): 131 | strides = (2, 2) if i > 0 and j == 0 else (1, 1) 132 | x = self.block_cls(self.num_filters * 2 ** i, 133 | strides=strides, 134 | conv=conv, 135 | norm=norm, 136 | act=self.act)(x) 137 | x = nn.avg_pool(x, (4, 4), strides=(4, 4), padding=[(0, 0), (0, 0)]) 138 | x = x.reshape(x.shape[0], -1) 139 | x = nn.Dense(self.num_classes, dtype=self.dtype)(x) 140 | x = jnp.asarray(x, self.dtype) 141 | return x 142 | 143 | class ResNetImagenet(nn.Module): 144 | """ResNetV1.""" 145 | stage_sizes: Sequence[int] 146 | block_cls: ModuleDef 147 | num_classes: int 148 | num_filters: int = 64 149 | dtype: Any = jnp.float32 150 | act: Callable = nn.relu 151 | conv: ModuleDef = nn.Conv 152 | 153 | @nn.compact 154 | def __call__(self, x, train: bool = True): 155 | conv = partial(self.conv, use_bias=False, dtype=self.dtype) 156 | norm = partial( 157 | nn.BatchNorm, 158 | use_running_average=not train, 159 | momentum=0.9, 160 | epsilon=1e-5, 161 | dtype=self.dtype, 162 | axis_name='batch', 163 | ) 164 | 165 | x = conv( 166 | self.num_filters, 167 | (7, 7), 168 | (2, 2), 169 | padding=[(3, 3), (3, 3)], 170 | name='conv_init', 171 | )(x) 172 | x = norm(name='bn_init')(x) 173 | x = nn.relu(x) 174 | x = nn.max_pool(x, (3, 3), strides=(2, 2), padding='SAME') 175 | for i, block_size in enumerate(self.stage_sizes): 176 | for j in range(block_size): 177 | strides = (2, 2) if i > 0 and j == 0 else (1, 1) 178 | x = self.block_cls( 179 | self.num_filters * 2**i, 180 | strides=strides, 181 | conv=conv, 182 | norm=norm, 183 | act=self.act, 184 | )(x) 185 | x = jnp.mean(x, axis=(1, 2)) 186 | x = nn.Dense(self.num_classes, dtype=self.dtype)(x) 187 | x = jnp.asarray(x, self.dtype) 188 | return x 189 | 190 | 191 | ResNet18 = partial(ResNetBigCifar, stage_sizes=[2, 2, 2, 2], 192 | block_cls=ResNetBlock) 193 | ResNet50_ImageNet = partial(ResNetImagenet, stage_sizes=[3, 4, 6, 3], 194 | block_cls=BottleneckResNetBlock) -------------------------------------------------------------------------------- /pruner.py: -------------------------------------------------------------------------------- 1 | import jax 2 | from jax import lax 3 | from jax.tree_util import tree_map, tree_flatten, tree_unflatten 4 | from jax.flatten_util import ravel_pytree 5 | import jax.numpy as jnp 6 | 7 | """ 8 | Functions for sparsifying the networks. 9 | """ 10 | 11 | # Various saliency scores. 12 | # ----------------------------------------------------------------------------- 13 | 14 | def snip_score(params, batch, **kwargs): 15 | def loss_fn(params): 16 | """loss function used for training.""" 17 | logits, _= kwargs['apply_fn']( 18 | {'params': params, 'batch_stats': kwargs['batch_stats']}, 19 | batch['image'], 20 | train=True, 21 | mutable=['batch_stats']) 22 | loss = kwargs['loss_fn'](logits, batch['label']) 23 | return loss 24 | grad = jax.grad(loss_fn)(params) 25 | grad = jax.lax.pmean(grad, axis_name='batch') 26 | return tree_map(lambda w, g: lax.abs(w*g), params, grad) 27 | 28 | 29 | def magnitude_score(params, batch, **kwargs): 30 | return tree_map(lambda w: lax.abs(w), params) 31 | 32 | 33 | def random_score(params, batch, **kwargs): 34 | f_param, unravel = ravel_pytree(params) 35 | f_rand = jax.random.normal(kwargs['key'], f_param.shape) 36 | return unravel(f_rand) 37 | 38 | 39 | def compute_score(sc_type, params, batch, **kwargs): 40 | return globals()[f'{sc_type}_score'](params, batch, **kwargs) 41 | 42 | 43 | # Mask Utilities. 44 | # ----------------------------------------------------------------------------- 45 | 46 | def compute_mask(scores, sp, pruner): 47 | """Generate pruning mask based on given scores, keep highest (1-sp)-weights""" 48 | 49 | assert 0 <= sp <= 1 50 | 51 | # mask computing function given score and threshold 52 | def _mask_dict(sc, thr): 53 | if 'kernel' not in sc: return jnp.full(sc.shape, True) 54 | 55 | mask_dict = {'kernel': sc['kernel'] > thr} 56 | if 'bias' in sc: 57 | mask_dict['bias'] = jnp.full(sc['bias'].shape, True) 58 | 59 | return mask_dict 60 | 61 | if pruner == 'snip': 62 | scope = 'global' 63 | elif pruner == 'random': 64 | scope = 'local' 65 | 66 | # flatten scores pytree, leaf being dict containing 'kernel' instead of jnp.arrays 67 | flat_tr, trdef = tree_flatten(scores, lambda tr: 'kernel' in tr) 68 | 69 | 70 | # sort by scores, use only kernel/weight parameters 71 | if scope=='global': 72 | flat_sc, _ = ravel_pytree([sc['kernel'] for sc in flat_tr if 'kernel' in sc]) 73 | sort_sc = jnp.sort(flat_sc) 74 | thr = sort_sc[int(sp*len(sort_sc))] # compute global threshold 75 | 76 | _mask_dict_g = lambda sc: _mask_dict(sc, thr) 77 | flat_mask = [*map(_mask_dict_g, flat_tr)] # compute mask 78 | 79 | elif scope=='local': 80 | sort_scs = [(jnp.sort(sc['kernel'].ravel()) if 'kernel' in sc else None) for sc in flat_tr] 81 | thrs = [sc if sc==None else sc[int(sp*len(sc))] for sc in sort_scs] # compute layer thresholds 82 | 83 | flat_mask = [*map(_mask_dict, flat_tr, thrs)] # compute mask 84 | 85 | mask = tree_unflatten(trdef, flat_mask) 86 | 87 | return mask 88 | 89 | @jax.jit 90 | def apply_mask(params, mask): 91 | """Apply pruning mask to the parameters""" 92 | return tree_map(lambda p, m: p*m, params, mask) 93 | 94 | def weight_sparsity(params): 95 | """Calculate the overall sparsity of the model (only for the kernels)""" 96 | flat_tr, _ = tree_flatten(params, lambda tr: 'kernel' in tr) 97 | flat_w, _ = ravel_pytree([m['kernel'] for m in flat_tr if 'kernel' in m]) 98 | return (flat_w == 0).sum().item() / len(flat_w) // jax.device_count() -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==1.0.0 2 | astunparse==1.6.3 3 | attrs==22.2.0 4 | cached-property==1.5.2 5 | cachetools==5.3.0 6 | charset-normalizer==3.0.1 7 | chex==0.1.6 8 | clu==0.0.6 9 | colorama==0.4.6 10 | commonmark==0.9.1 11 | contextlib2==21.6.0 12 | contourpy==1.0.7 13 | cycler==0.11.0 14 | dill==0.3.6 15 | dm-tree==0.1.8 16 | einops==0.6.1 17 | etils==1.0.0 18 | flatbuffers==1.12 19 | flax==0.6.5 20 | fonttools==4.38.0 21 | future==0.18.3 22 | gast==0.4.0 23 | google-auth==2.16.1 24 | google-auth-oauthlib==0.4.6 25 | google-pasta==0.2.0 26 | googleapis-common-protos==1.58.0 27 | grpcio==1.51.1 28 | h5py==3.8.0 29 | idna==3.4 30 | importlib-metadata==6.0.0 31 | importlib-resources==5.12.0 32 | jax==0.4.4 33 | jaxlib==0.4.4+cuda11.cudnn82 34 | keras==2.9.0 35 | Keras-Preprocessing==1.1.2 36 | kiwisolver==1.4.4 37 | libclang==15.0.6.1 38 | Markdown==3.4.1 39 | MarkupSafe==2.1.2 40 | matplotlib==3.7.0 41 | ml-collections==0.1.0 42 | ml-dtypes==0.0.4 43 | msgpack==1.0.4 44 | numpy==1.22.0 45 | oauthlib==3.2.2 46 | opt-einsum==3.3.0 47 | optax==0.1.3 48 | orbax==0.1.2 49 | packaging==23.0 50 | pandas==1.5.0 51 | Pillow==9.4.0 52 | promise==2.3 53 | protobuf==3.19.6 54 | pyasn1==0.4.8 55 | pyasn1-modules==0.2.8 56 | Pygments==2.14.0 57 | pyparsing==3.0.9 58 | python-dateutil==2.8.2 59 | pytz==2023.3 60 | PyYAML==6.0 61 | requests==2.28.2 62 | requests-oauthlib==1.3.1 63 | rich==11.2.0 64 | rsa==4.9 65 | scipy==1.10.1 66 | seaborn==0.12.2 67 | six==1.16.0 68 | tensorboard==2.9.1 69 | tensorboard-data-server==0.6.1 70 | tensorboard-plugin-wit==1.8.1 71 | tensorflow==2.9.3 72 | tensorflow-addons==0.19.0 73 | tensorflow-datasets==4.4.0 74 | tensorflow-estimator==2.9.0 75 | tensorflow-io-gcs-filesystem==0.30.0 76 | tensorflow-metadata==1.12.0 77 | tensorstore==0.1.32 78 | termcolor==2.2.0 79 | toolz==0.12.0 80 | tqdm==4.64.1 81 | typeguard==2.13.3 82 | typing_extensions==4.5.0 83 | urllib3==1.26.14 84 | Werkzeug==2.2.3 85 | wrapt==1.14.1 86 | zipp==3.14.0 87 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The Flax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # Some parts of the code are borrowed from https://github.com/google/flax/blob/main/examples/imagenet/train.py. 16 | 17 | 18 | """ 19 | File to train the network. 20 | """ 21 | 22 | import time 23 | 24 | from absl import logging 25 | from clu import periodic_actions 26 | import numpy as np 27 | from flax import jax_utils 28 | from flax.training import checkpoints 29 | from flax.training import common_utils 30 | import jax 31 | from jax import lax 32 | from jax import random 33 | import ml_collections 34 | import tensorflow as tf 35 | 36 | import math 37 | from functools import partial 38 | 39 | import models 40 | from pruner import apply_mask 41 | from train_utils import * 42 | 43 | def create_model(config, num_classes): 44 | """Create the model.""" 45 | if config.model == 'MLP': 46 | model_cls = getattr(models, config.model) 47 | return model_cls(num_classes=num_classes, num_neurons=config.num_neurons) 48 | elif 'ResNet' in config.model: 49 | model_cls = getattr(models, config.model) 50 | return model_cls(num_classes=num_classes, num_filters=config.num_filters) 51 | 52 | @partial(jax.jit, static_argnames = ["optimizer", "loss_type"]) 53 | def train_step(state, batch, key, weight_decay, optimizer, rho, loss_type): 54 | """Perform a single training step.""" 55 | def loss_fn(params): 56 | """loss function used for training.""" 57 | params = apply_mask(params, state.mask) # apply pruning mask 58 | logits, new_model_state = state.apply_fn( 59 | {'params': params, 'batch_stats': state.batch_stats}, 60 | batch['image'], 61 | rngs=dict(dropout=key), 62 | train=True, 63 | mutable=['batch_stats']) 64 | loss = loss_type(logits, batch['label']) 65 | return loss, (new_model_state, logits) 66 | 67 | def get_sam_gradient(params, rho): 68 | """Returns the gradient of the SAM loss loss, updated state and logits. 69 | 70 | See https://arxiv.org/abs/2010.01412 for more details. 71 | 72 | Args: 73 | model: The model that we are training. 74 | rho: Size of the perturbation. 75 | """ 76 | # compute gradient on the whole batch 77 | (_, (inner_state, logits)), grad = jax.value_and_grad(loss_fn, has_aux=True)(params) 78 | grad = dual_vector(grad) 79 | noised_params = jax.tree_map(lambda p, b: p + rho * b, params, grad) 80 | (_, (_, _)), grad = jax.value_and_grad( 81 | loss_fn, has_aux=True)(noised_params) 82 | return (inner_state, logits), grad 83 | 84 | if optimizer == 'sgd': # SGD 85 | (_, (new_model_state, logits)), grads = jax.value_and_grad( 86 | loss_fn, has_aux=True)( 87 | state.params) 88 | elif optimizer == 'sam': # SAM 89 | (new_model_state, logits), grads = get_sam_gradient(state.params, rho) 90 | 91 | # We manually apply weight decay in this way. 92 | grads = jax.tree_map(lambda g, p: g + weight_decay * p, grads, state.params) 93 | 94 | grads = jax.lax.pmean(grads, axis_name='batch') 95 | 96 | metrics = compute_metrics(logits, batch['label'], loss_type) 97 | 98 | new_state = state.apply_gradients(grads=grads, batch_stats=new_model_state['batch_stats']) 99 | 100 | return new_state, metrics 101 | 102 | @partial(jax.jit, static_argnames = ["loss_type"]) 103 | def eval_step(state, batch, loss_type): 104 | """Evaluate the model on the test data.""" 105 | variables = {'params': state.params, 'batch_stats': state.batch_stats} 106 | logits = state.apply_fn( 107 | variables, batch['image'], train=False, mutable=False) 108 | return compute_metrics(logits, batch['label'], loss_type) 109 | 110 | def restore_checkpoint(state, workdir): 111 | """Restore the model from the checkpoint.""" 112 | return checkpoints.restore_checkpoint(workdir, state) 113 | 114 | def save_checkpoint(state, workdir): 115 | """Save the model checkpoint.""" 116 | if jax.process_index() == 0: 117 | # get train state from the first replica 118 | state = jax.device_get(jax.tree_util.tree_map(lambda x: x[0], state)) 119 | step = int(state.step) 120 | checkpoints.save_checkpoint(workdir, state, step, keep=3) 121 | 122 | # pmean only works inside pmap because it needs an axis name. 123 | # This function will average the inputs across all devices. 124 | cross_replica_mean = jax.pmap(lambda x: lax.pmean(x, 'x'), 'x') 125 | 126 | def sync_batch_stats(state): 127 | """Sync the batch statistics across replicas.""" 128 | # Each device has its own version of the running average batch statistics and 129 | # we sync them before evaluation. 130 | if state.batch_stats == {}: 131 | return state 132 | return state.replace(batch_stats=cross_replica_mean(state.batch_stats)) 133 | 134 | 135 | def train_and_evaluate(config: ml_collections.ConfigDict, 136 | workdir: str) -> TrainState: 137 | """Execute model training and evaluation loop. 138 | Args: 139 | config: Hyperparameter configuration for training and evaluation. 140 | workdir: Directory where the tensorboard summaries are written to. 141 | Returns: 142 | Final TrainState. 143 | """ 144 | logging.info(config) 145 | 146 | rng = random.PRNGKey(config.seed) 147 | tf.random.set_seed(config.seed) 148 | np.random.seed(config.seed) 149 | 150 | if config.batch_size % jax.device_count() > 0: 151 | raise ValueError('Batch size must be divisible by the number of devices') 152 | 153 | ############################# Prepare Dataset ############################# 154 | train_iter, eval_iter, num_train_samples, num_val_samples, num_classes, input_shape = prepare_dataset(config) 155 | 156 | steps_per_epoch = ( 157 | math.ceil(num_train_samples / config.batch_size) 158 | ) 159 | 160 | if config.steps_per_eval == -1: 161 | steps_per_eval = math.ceil(num_val_samples / config.batch_size) 162 | else: 163 | steps_per_eval = config.steps_per_eval 164 | 165 | ############################# Prepare lr & Model ############################# 166 | 167 | base_learning_rate = config.learning_rate 168 | 169 | model = create_model(config, num_classes) 170 | 171 | learning_rate_fn = create_learning_rate_fn( 172 | config, base_learning_rate, steps_per_epoch) 173 | 174 | 175 | ############################# Prepare / Restore State ############################# 176 | 177 | loss_type = partial(cross_entropy_loss, num_classes=num_classes) 178 | 179 | rng, init_rng = jax.random.split(rng) 180 | state = create_train_state(init_rng, model, input_shape, learning_rate_fn, loss_type, config, half_precision=config.half_precision, train_iter=train_iter) 181 | 182 | if config.restore_checkpoint: 183 | state = restore_checkpoint(state, workdir) 184 | # step_offset > 0 if restarting from checkpoint 185 | epoch_offset = int(state.step) // steps_per_epoch 186 | state = jax_utils.replicate(state) 187 | 188 | 189 | ############################# jit / pmap train_step ############################# 190 | 191 | jitted_train_step = jax.jit(train_step, static_argnames=["optimizer", "loss_type"]) 192 | 193 | p_train_step = jax.pmap( 194 | partial(jitted_train_step, weight_decay=config.weight_decay, optimizer=config.optimizer, 195 | rho=config.rho, loss_type=loss_type), 196 | axis_name='batch', 197 | ) 198 | p_eval_step = jax.pmap(partial(eval_step, loss_type=loss_type), axis_name='batch') 199 | 200 | 201 | ############################# Start Training ############################# 202 | 203 | hooks = [] 204 | if jax.process_index() == 0: 205 | hooks += [periodic_actions.Profile(num_profile_steps=5, logdir=workdir)] 206 | logging.info('Initial compilation, this might take some minutes...') 207 | 208 | total_steps = 0 209 | 210 | best_test_acc = -1 211 | best_epoch = -1 212 | 213 | for epoch in range(epoch_offset, int(config.num_epochs)): 214 | logging.info("Epoch %d / %d " % (epoch + 1, int(config.num_epochs))) 215 | 216 | train_loss_meter = AverageMeter() 217 | train_acc_meter = AverageMeter() 218 | 219 | start_time = time.time() 220 | for step in range(steps_per_epoch): 221 | batch = next(train_iter) 222 | rng, step_rng = jax.random.split(rng) 223 | sharded_keys = common_utils.shard_prng_key(step_rng) 224 | state, metrics = p_train_step(state, batch, sharded_keys) 225 | train_loss_meter.update(metrics['loss'].mean(), len(batch['label'][0])) 226 | train_acc_meter.update(metrics['accuracy'].mean(), len(batch['label'][0])) 227 | 228 | total_steps += 1 229 | 230 | if total_steps % config.log_every_steps == 0: 231 | logging.info("Epoch[%d] Step [%d/%d]: loss %.4f acc %.4f (time elapsed: %.4f)" % (epoch + 1, step, steps_per_epoch, metrics['loss'].mean(), metrics['accuracy'].mean(), time.time() - start_time)) 232 | 233 | 234 | cur_time = time.time() 235 | test_loss_meter = AverageMeter() 236 | test_acc_meter = AverageMeter() 237 | 238 | lr = learning_rate_fn(steps_per_epoch * epoch) 239 | 240 | state = sync_batch_stats(state) 241 | 242 | for step in range(steps_per_eval): 243 | batch = next(eval_iter) 244 | metrics = p_eval_step(state, batch) 245 | test_loss_meter.update(metrics['loss'].mean(), len(batch['label'][0])) 246 | test_acc_meter.update(metrics['accuracy'].mean(), len(batch['label'][0])) 247 | 248 | if test_acc_meter.avg > best_test_acc: 249 | best_test_acc = test_acc_meter.avg 250 | best_epoch = epoch 251 | elapsed_time = cur_time - start_time 252 | logging.info("Train: loss %.4f acc %.4f; Val: loss %.4f acc %.4f (lr %.4f / took %.2f seconds) \n" % (train_loss_meter.avg, train_acc_meter.avg, test_loss_meter.avg, test_acc_meter.avg, lr, elapsed_time)) 253 | 254 | if (epoch + 1) % 1 == 0 or (epoch + 1) == int(config.num_epochs): 255 | state = sync_batch_stats(state) 256 | save_checkpoint(state, workdir) 257 | 258 | # Wait until computations are done before exiting 259 | jax.random.normal(jax.random.PRNGKey(0), ()).block_until_ready() 260 | 261 | logging.info("Best test acc %.4f at epoch %d" % (best_test_acc, best_epoch + 1)) 262 | 263 | return state 264 | 265 | 266 | -------------------------------------------------------------------------------- /train_utils.py: -------------------------------------------------------------------------------- 1 | # Some parts of the code are borrowed from https://github.com/google/flax/blob/main/examples/imagenet/train.py. 2 | 3 | from typing import Any 4 | 5 | import jax 6 | from jax import lax 7 | import jax.numpy as jnp 8 | import tensorflow as tf 9 | import optax 10 | import ml_collections 11 | import tensorflow_datasets as tfds 12 | 13 | from flax.training import dynamic_scale as dynamic_scale_lib 14 | from flax.training import common_utils, train_state 15 | from flax import jax_utils 16 | from flax import struct, core 17 | 18 | from functools import partial 19 | 20 | from pruner import compute_mask, apply_mask, compute_score 21 | from input_pipeline import create_split 22 | 23 | """ 24 | Utility functions for training the network. 25 | """ 26 | 27 | class TrainState(train_state.TrainState): 28 | batch_stats: Any 29 | dynamic_scale: dynamic_scale_lib.DynamicScale 30 | mask: core.FrozenDict[str, Any] = struct.field(pytree_node=True) 31 | 32 | @jax.jit 33 | def dual_vector(y: jnp.ndarray) -> jnp.ndarray: 34 | """Returns the solution of max_x y^T x s.t. ||x||_2 <= 1. 35 | 36 | Args: 37 | y: A pytree of numpy ndarray, vector y in the equation above. 38 | """ 39 | gradient_norm = jnp.sqrt(sum( 40 | [jnp.sum(jnp.square(e)) for e in jax.tree_util.tree_leaves(y)])) 41 | normalized_gradient = jax.tree_map(lambda x: x / (gradient_norm + 1e-12), y) 42 | return normalized_gradient 43 | 44 | class AverageMeter (object): 45 | """Class for calculating the average""" 46 | def __init__(self): 47 | self.reset () 48 | 49 | def reset(self): 50 | self.val = 0 51 | self.avg = 0 52 | self.sum = 0 53 | self.count = 0 54 | 55 | def update(self, val, n=1): 56 | self.val = val 57 | self.sum += val * n 58 | self.count += n 59 | self.avg = self.sum / self.count 60 | 61 | def compute_metrics(logits, labels, loss_fn): 62 | """Compute loss and the accuracy""" 63 | loss = loss_fn(logits, labels) 64 | accuracy = jnp.mean(jnp.argmax(logits, -1) == labels) 65 | metrics = { 66 | 'loss': loss, 67 | 'accuracy': accuracy, 68 | } 69 | metrics = lax.pmean(metrics, axis_name='batch') 70 | return metrics 71 | 72 | def cross_entropy_loss(logits, labels, num_classes): 73 | """Standard cross entropy loss""" 74 | one_hot_labels = common_utils.onehot(labels, num_classes) 75 | xentropy = optax.softmax_cross_entropy(logits=logits, labels=one_hot_labels) 76 | return jnp.mean(xentropy) 77 | 78 | def prepare_dataset(config): 79 | """Create data iterators and the related infos""" 80 | if config.dataset == 'cifar10': 81 | _, image_size, _, _ = input_shape = (1, 32, 32, 3) 82 | num_classes = 10 83 | elif config.dataset == 'imagenet2012': 84 | _, image_size, _, _ = input_shape = (1, 224, 224, 3) 85 | num_classes = 1000 86 | elif config.dataset == 'mnist': 87 | _, image_size, _, _ = input_shape = (1, 28, 28, 1) 88 | num_classes = 10 89 | 90 | local_batch_size = config.batch_size // jax.process_count() 91 | 92 | platform = jax.local_devices()[0].platform 93 | 94 | if config.half_precision: 95 | if platform == 'tpu': 96 | input_dtype = tf.bfloat16 97 | else: 98 | input_dtype = tf.float16 99 | else: 100 | input_dtype = tf.float32 101 | 102 | dataset_builder = tfds.builder(config.dataset) 103 | if config.dataset == 'imagenet2012': 104 | manual_dataset_dir = "your_directory" # replace the directory 105 | imagenet_download_config = tfds.download.DownloadConfig( 106 | extract_dir='./tmp/', 107 | manual_dir = manual_dataset_dir) 108 | dataset_builder.download_and_prepare(download_config=imagenet_download_config) 109 | test_set = 'validation' 110 | else: 111 | dataset_builder.download_and_prepare() 112 | test_set = 'test' 113 | 114 | train_iter = create_input_iter( 115 | config.dataset, 116 | dataset_builder, local_batch_size, image_size, input_dtype, train=True, 117 | cache=config.cache) 118 | 119 | eval_iter = create_input_iter( 120 | config.dataset, 121 | dataset_builder, local_batch_size, image_size, input_dtype, train=False, 122 | cache=config.cache) 123 | 124 | num_train_samples = dataset_builder.info.splits['train'].num_examples 125 | num_val_samples = dataset_builder.info.splits[test_set].num_examples 126 | 127 | return train_iter, eval_iter, num_train_samples, num_val_samples, num_classes, input_shape 128 | 129 | def initialized(key, input_shape, model, batch_stats=True): 130 | """Initialize the parameters and the batchnorm stats""" 131 | @jax.jit 132 | def init(*args): 133 | return model.init(*args) 134 | variables = init({'params': key, 'dropout': key}, jnp.ones(input_shape)) 135 | if not batch_stats: 136 | return variables['params'], {} 137 | else: 138 | return variables['params'], variables['batch_stats'] 139 | 140 | def prepare_tf_data(xs): 141 | """Convert a input batch from tf Tensors to numpy arrays.""" 142 | local_device_count = jax.local_device_count() 143 | def _prepare(x): 144 | # Use _numpy() for zero-copy conversion between TF and NumPy. 145 | x = x._numpy() # pylint: disable=protected-access 146 | 147 | return x.reshape((local_device_count, -1) + x.shape[1:]) 148 | 149 | return jax.tree_util.tree_map(_prepare, xs) 150 | 151 | def create_learning_rate_fn( 152 | config: ml_collections.ConfigDict, 153 | base_learning_rate: float, 154 | steps_per_epoch: int): 155 | """Create learning rate schedule.""" 156 | 157 | if config.lr_scheduler == 'step': 158 | return optax.piecewise_constant_schedule(base_learning_rate, boundaries_and_scales={int(config.num_epochs * 0.5 * steps_per_epoch): 0.1, int(config.num_epochs * 0.75 * steps_per_epoch): 0.1}) 159 | elif config.lr_scheduler == 'step_mnist': 160 | return optax.piecewise_constant_schedule(base_learning_rate, boundaries_and_scales={int(config.num_epochs * 0.25 * steps_per_epoch): 0.1, int(config.num_epochs * 0.5 * steps_per_epoch): 0.1, int(config.num_epochs * 0.75 * steps_per_epoch): 0.1}) 161 | elif config.lr_scheduler == 'imagenet_cosine': 162 | # Taken from https://github.com/google-research/vision_transformer/blob/main/vit_jax/utils.py 163 | def step_fn(step): 164 | warmup_steps = config.warmup_steps 165 | lr = base_learning_rate 166 | total_steps = config.num_epochs * steps_per_epoch 167 | progress = (step - warmup_steps) / float(total_steps - warmup_steps) 168 | progress = jnp.clip(progress, 0.0, 1.0) 169 | lr = lr * 0.5 * (1. + jnp.cos(jnp.pi * progress)) 170 | if warmup_steps: 171 | lr = lr * jnp.minimum(1., step / warmup_steps) 172 | return jnp.asarray(lr, dtype=jnp.float32) 173 | return step_fn 174 | 175 | def create_pruning_mask(params, pruner, sparsity, loss_fn, **kwargs): 176 | """Create pruning mask""" 177 | batch = next(kwargs['train_iter']) 178 | kwargs['train_iter'].__init__() 179 | 180 | kwargs['loss_fn'] = loss_fn 181 | 182 | p_compute_score = jax.pmap(partial(compute_score, sc_type=pruner, **kwargs), axis_name='batch') 183 | scores = p_compute_score(params=jax_utils.replicate(params), batch=batch) 184 | scores = jax_utils.unreplicate(scores) 185 | mask = compute_mask(scores, sparsity, pruner) 186 | masked_params = apply_mask(params, mask) 187 | 188 | return masked_params, mask 189 | 190 | def create_input_iter(dataset, dataset_builder, batch_size, image_size, dtype, train, 191 | cache): 192 | """Create data iterator""" 193 | ds = create_split( 194 | dataset, dataset_builder, batch_size, train=train, cache=cache) 195 | it = map(prepare_tf_data, ds) 196 | it = jax_utils.prefetch_to_device(it, 2) 197 | return it 198 | 199 | 200 | def create_train_state(rng, model, input_shape, learning_rate_fn, loss_fn, config, half_precision=False, **kwargs): 201 | """Create initial training state.""" 202 | dynamic_scale = None 203 | platform = jax.local_devices()[0].platform 204 | if half_precision and platform == 'gpu': 205 | dynamic_scale = dynamic_scale_lib.DynamicScale() 206 | else: 207 | dynamic_scale = None 208 | 209 | params, batch_stats = initialized(rng, input_shape, model, batch_stats=('ResNet' in config.model)) 210 | 211 | tx = optax.sgd( 212 | learning_rate=learning_rate_fn, 213 | momentum=config.momentum, 214 | nesterov=False, 215 | ) 216 | 217 | # compute pai mask 218 | params, mask = create_pruning_mask(params, config.pruner, config.sparsity, loss_fn, key=rng, batch_stats=batch_stats, apply_fn=model.apply, **kwargs) 219 | 220 | state = TrainState.create( 221 | apply_fn=model.apply, 222 | params=params, 223 | tx=tx, 224 | batch_stats=batch_stats, 225 | dynamic_scale=dynamic_scale, 226 | mask=mask) 227 | 228 | return state --------------------------------------------------------------------------------