├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── diffstride ├── __init__.py ├── examples │ ├── __init__.py │ ├── cifar10.gin │ ├── data.py │ ├── external_configurables.py │ ├── main.py │ └── train.py ├── pooling.py └── resnet.py ├── images └── diffstride.png ├── pyproject.toml ├── requirements.txt ├── setup.cfg ├── setup.py └── tests ├── pooling_test.py └── resnet_test.py /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to Contribute 2 | 3 | We'd love to accept your patches and contributions to this project. There are 4 | just a few small guidelines you need to follow. 5 | 6 | ## Contributor License Agreement 7 | 8 | Contributions to this project must be accompanied by a Contributor License Agreement (CLA). You (or your employer) retain the copyright to your contribution; this simply gives us permission to use and redistribute your contributions as part of the project. Head over to to see your current agreements on file or to sign a new one. 9 | 10 | You generally only need to submit a CLA once, so if you've already submitted 11 | one (even if it was for a different project), you probably don't need to do it again. 12 | 13 | ## Code reviews 14 | 15 | All submissions, including submissions by project members, require review. We use GitHub pull requests for this purpose. Consult [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more information on using pull requests. 16 | 17 | ## Community Guidelines 18 | 19 | This project follows [Google's Open Source Community Guidelines (https://opensource.google/conduct/). 20 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DiffStride: Learning strides in convolutional neural networks 2 | 3 | ![Overview](./images/diffstride.png) 4 | 5 | DiffStride is a pooling layer with learnable strides. Unlike strided convolutions, average pooling or max-pooling that require cross-validating stride values at each layer, DiffStride can be initialized with an arbitrary value at each layer (e.g. `(2, 2`) and during training its strides will be optimized for the task at hand. 6 | 7 | We describe DiffStride in our ICLR 2022 paper [Learning Strides in Convolutional Neural Network](https://arxiv.org/abs/2202.01653). Compared to the experiments described in the paper, this implementation uses a [Pre-Act Resnet](https://arxiv.org/abs/1603.05027) and uses [Mixup](https://arxiv.org/abs/1710.09412) in training. 8 | 9 | ## Installation 10 | 11 | To install the diffstride library, run the following `pip` git clone this repo: 12 | 13 | ``` 14 | git clone https://github.com/google-research/diffstride.git 15 | ``` 16 | 17 | The cd into the root and run the command: 18 | ``` 19 | pip install -e . 20 | ``` 21 | 22 | ## Example training 23 | 24 | To run an example training on CIFAR10 and save the result in TensorBoard: 25 | 26 | ``` 27 | python3 -m diffstride.examples.main \ 28 | --gin_config=cifar10.gin \ 29 | --gin_bindings="train.workdir = '/tmp/exp/diffstride/resnet18/'" 30 | ``` 31 | 32 | ## Using custom parameters 33 | This implementation uses [Gin](https://github.com/google/gin-config) to parametrize the model, data processing and training loop. 34 | To use custom parameters, one should edit `examples/cifar10.gin`. 35 | 36 | For example, to train with SpectralPooling on cifar100: 37 | 38 | ``` 39 | data.load_datasets: 40 | name = 'cifar100' 41 | 42 | resnet.Resnet: 43 | pooling_cls = @pooling.FixedSpectralPooling 44 | ``` 45 | 46 | Or to train with strided convolutions and without Mixup: 47 | 48 | ``` 49 | data.load_datasets: 50 | mixup_alpha = 0.0 51 | 52 | resnet.Resnet: 53 | pooling_cls = None 54 | ``` 55 | 56 | ## Results 57 | This current implementation gives the following accuracy on CIFAR-10 and CIFAR-100, averaged over three runs. To show the robustness of DiffStride to stride initialization, we run both with the standard strides of ResNet (`resnet.resnet18.strides = '1, 1, 2, 2, 2'`) and with a 'poor' choice of strides (`resnet.resnet18.strides = '1, 1, 3, 2, 3'`). Unlike Strided Convolutions and fixed Spectral Pooling, DiffStride is not affected by the stride initialization. 58 | 59 | ### CIFAR-10 60 | 61 | | Pooling | Test Accuracy (%) w/ strides = (1, 1, 2, 2, 2)| Test Accuracy (%) w/ strides = (1, 1, 3, 2, 3)| 62 | | -------------------------------- | --------------------------------------------- | --------------------------------------------- | 63 | | Strided Convolution (Baseline) | 91.06 ± 0.04 | 89.21 ± 0.27 | 64 | | Spectral Pooling | 93.49 ± 0.05 | 92.00 ± 0.08 | 65 | | DiffStride | **94.20 ± 0.06** | **94.19 ± 0.15** | 66 | 67 | ### CIFAR-100 68 | 69 | | Pooling | Test Accuracy (%) w/ strides = (1, 1, 2, 2, 2)| Test Accuracy (%) w/ strides = (1, 1, 3, 2, 3)| 70 | | -------------------------------- | --------------------------------------------- | --------------------------------------------- | 71 | | Strided Convolution (Baseline) | 65.75 ± 0.39 | 60.82 ± 0.42 | 72 | | Spectral Pooling | 72.86 ± 0.23 | 67.74 ± 0.43 | 73 | | DiffStride | **76.08 ± 0.23** | **76.09 ± 0.06** | 74 | 75 | ## CPU/GPU Warning 76 | We rely on the tensorflow FFT implementation which requires the input data to be in the `channels_first` format. This is usually not the regular data format of most datasets (including CIFAR) and running with `channels_first` also prevents from using of convolutions on CPU. Therefore even if we do support `channels_last` data format for CPU compatibility , we do encourage the user to run with `channels_first` data format *on GPU*. 77 | 78 | ## Reference 79 | If you use this repository, please consider citing: 80 | 81 | ``` 82 | @article{riad2022diffstride, 83 | title={Learning Strides in Convolutional Neural Networks}, 84 | author={Riad, Rachid and Teboul, Olivier and Grangier, David and Zeghidour, Neil}, 85 | journal={ICLR}, 86 | year={2022} 87 | } 88 | ``` 89 | 90 | ## Disclainer 91 | This is not an official Google product. 92 | 93 | -------------------------------------------------------------------------------- /diffstride/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 Google LLC. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | -------------------------------------------------------------------------------- /diffstride/examples/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 Google LLC. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | -------------------------------------------------------------------------------- /diffstride/examples/cifar10.gin: -------------------------------------------------------------------------------- 1 | import diffstride.examples.external_configurables 2 | 3 | import diffstride.examples.data as data 4 | import diffstride.examples.train as train 5 | import diffstride.resnet as resnet 6 | import diffstride.pooling as pooling 7 | 8 | data.load_datasets: 9 | name = 'cifar10' 10 | batch_size = 128 11 | mixup_alpha = 0.2 12 | 13 | train.train: 14 | load_data_fn = @data.load_datasets 15 | model_cls = @resnet.resnet18 16 | optimizer_cls = @tf.keras.optimizers.SGD 17 | num_epochs = 400 18 | 19 | tf.keras.optimizers.SGD: 20 | learning_rate = @tf.keras.optimizers.schedules.PiecewiseConstantDecay() 21 | momentum = 0.9 22 | 23 | tf.keras.optimizers.schedules.PiecewiseConstantDecay: 24 | boundaries = [400, 18_000, 32_000] 25 | values = [0.01, 0.1, 0.01, 0.001] 26 | 27 | tf.keras.losses.CategoricalCrossentropy: 28 | label_smoothing = 0.2 29 | 30 | # Initialize the strides differently if needed. 31 | resnet.resnet18.strides = [1, 1, 2, 2, 2] 32 | resnet.Resnet: 33 | output_activation = 'softmax' 34 | weight_decay = 5e-4 35 | # Either @pooling.DiffStride, @pooling.FixedSpectralPooling or None 36 | # for strided convolutions. 37 | pooling_cls = @pooling.DiffStride 38 | 39 | batch_norm: 40 | momentum = 0.9 41 | epsilon = 1e-5 42 | 43 | # Set to `None` for baseline. 44 | pooling.DiffStride: 45 | smoothness_factor = 4.0 46 | cropping = True 47 | trainable = True 48 | shared_stride = False 49 | lower_limit_stride = None 50 | upper_limit_stride = None 51 | -------------------------------------------------------------------------------- /diffstride/examples/data.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 Google LLC. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Prepares data.""" 17 | 18 | from typing import Tuple 19 | 20 | 21 | import gin 22 | import tensorflow as tf 23 | import tensorflow_datasets as tfds 24 | 25 | 26 | def convert(image: tf.Tensor) -> tf.Tensor: 27 | mean = tf.constant([[[0.4914, 0.4822, 0.4465]]]) 28 | std = tf.constant([[[0.2023, 0.1994, 0.2010]]]) 29 | image = tf.cast(image, tf.float32) / tf.uint8.max 30 | return (image - mean) / std 31 | 32 | 33 | @gin.configurable 34 | def augment(image: tf.Tensor, pad_size: int = 2) -> tf.Tensor: 35 | """Data augmentation: Random shifts and flips.""" 36 | shape = image.shape 37 | height, width = shape[:2] 38 | image = tf.image.pad_to_bounding_box(image, pad_size, pad_size, 39 | height + 2 * pad_size, 40 | width + 2 * pad_size) 41 | image = tf.image.random_crop(image, shape) 42 | image = tf.image.random_flip_left_right(image) 43 | return image 44 | 45 | 46 | def sample_beta_distribution(size: int, 47 | concentration_0: float = 0.2, 48 | concentration_1: float = 0.2) -> tf.Tensor: 49 | """Samples from a beta distribution.""" 50 | gamma_1_sample = tf.random.gamma(shape=[size], alpha=concentration_1) 51 | gamma_2_sample = tf.random.gamma(shape=[size], alpha=concentration_0) 52 | return gamma_1_sample / (gamma_1_sample + gamma_2_sample) 53 | 54 | 55 | def mix_up(images: tf.Tensor, 56 | labels: tf.Tensor, 57 | alpha: float = 0.2) -> Tuple[tf.Tensor, tf.Tensor]: 58 | """Applies mixup to two examples.""" 59 | batch_size = images.shape[0] 60 | if batch_size != 2: 61 | raise ValueError(f'Mixup expects batch_size == 2 but got {batch_size}.') 62 | 63 | # Sample lambda and reshape it to do the mixup 64 | weight = sample_beta_distribution(1, alpha, alpha)[0] 65 | 66 | # Perform mixup on both images and labels by combining a pair of images/labels 67 | # (one from each dataset) into one image/label 68 | mixed_image = images[0] * weight + images[1] * (1 - weight) 69 | mixed_label = labels[0] * weight + labels[1] * (1 - weight) 70 | return (mixed_image, mixed_label) 71 | 72 | 73 | def prepare(ds: tf.data.Dataset, 74 | num_classes: int, 75 | batch_size: int = 32, 76 | training: bool = True, 77 | augment_fn=augment, 78 | mixup_alpha: float = 0.0) -> tf.data.Dataset: 79 | """Prepares a dataset for train/test.""" 80 | def transform( 81 | image: tf.Tensor, label: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]: 82 | image = convert(image) 83 | if training and augment_fn is not None: 84 | image = augment_fn(image) 85 | label = tf.one_hot(label, depth=num_classes) 86 | return image, label 87 | 88 | ds = ds.map(transform, num_parallel_calls=tf.data.AUTOTUNE) 89 | if training: 90 | ds = ds.shuffle(1000) 91 | if mixup_alpha > 0.0: 92 | ds = ds.batch(2, drop_remainder=True) 93 | ds = ds.map(mix_up) 94 | ds = ds.batch(batch_size, drop_remainder=training) 95 | ds = ds.prefetch(tf.data.AUTOTUNE) 96 | return ds 97 | 98 | 99 | @gin.configurable 100 | def load_datasets(name: str, 101 | batch_size: int = 32, 102 | augment_fn=None, 103 | mixup_alpha: float = 0.0): 104 | """Loads a tf.Dataset corresponding to the given name.""" 105 | datasets, info = tfds.load(name, as_supervised=True, with_info=True) 106 | _, label_key = info.supervised_keys 107 | num_classes = info.features[label_key].num_classes 108 | ds_train = prepare( 109 | datasets['train'], 110 | num_classes=num_classes, 111 | batch_size=batch_size, 112 | training=True, 113 | augment_fn=augment_fn, 114 | mixup_alpha=mixup_alpha) 115 | ds_test = prepare( 116 | datasets['test'], 117 | num_classes=num_classes, 118 | batch_size=batch_size, 119 | training=False, 120 | augment_fn=None, 121 | mixup_alpha=0.0) 122 | return ds_train, ds_test, num_classes 123 | -------------------------------------------------------------------------------- /diffstride/examples/external_configurables.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 Google LLC. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Makes some classes and function gin configurable.""" 17 | 18 | import gin 19 | import gin.tf.external_configurables 20 | import tensorflow as tf 21 | import tensorflow_addons.optimizers as tfa_optimizers 22 | 23 | configurables = { 24 | 'tf.keras.layers': ( 25 | tf.keras.layers.Conv1D, 26 | tf.keras.layers.Conv1DTranspose, 27 | tf.keras.layers.Conv2D, 28 | tf.keras.layers.Conv2DTranspose, 29 | tf.keras.layers.Dense, 30 | tf.keras.layers.Flatten, 31 | tf.keras.layers.Reshape, 32 | tf.keras.layers.MaxPooling2D, 33 | tf.keras.layers.GlobalMaxPooling2D, 34 | tf.keras.layers.BatchNormalization, 35 | tf.keras.layers.LayerNormalization, 36 | ), 37 | 'tf.keras.regularizers': ( 38 | tf.keras.regularizers.L1, 39 | tf.keras.regularizers.L2, 40 | tf.keras.regularizers.L1L2, 41 | ), 42 | 'tf.keras.initializers': ( 43 | tf.keras.initializers.Constant, 44 | ), 45 | 'tf.keras.losses': ( 46 | tf.keras.losses.CategoricalCrossentropy, 47 | ), 48 | 'tf.keras.optimizers.schedules': ( 49 | tf.keras.optimizers.schedules.CosineDecay, 50 | tf.keras.optimizers.schedules.PiecewiseConstantDecay, 51 | ), 52 | 'tfa.optimizers': ( 53 | tfa_optimizers.MovingAverage, 54 | ), 55 | } 56 | 57 | for module in configurables: 58 | for v in configurables[module]: 59 | gin.config.external_configurable(v, module=module) 60 | -------------------------------------------------------------------------------- /diffstride/examples/main.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 Google LLC. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | r"""Launches the training loop given a configuration. 17 | 18 | For instance to train a resnet18 and cifar10, with diffstride and saves the 19 | training results locally to be displayed with tensorboard: 20 | 21 | python3 -m diffstride.example.main \ 22 | --gin_config=cifar10 \ 23 | --gin_bindings="train.workdir=/tmp/exp/diffstride/resnet18/" 24 | """ 25 | 26 | import os 27 | from typing import Sequence 28 | 29 | from absl import app 30 | from absl import flags 31 | from diffstride.examples import train 32 | import gin 33 | import tensorflow as tf 34 | 35 | flags.DEFINE_multi_string( 36 | 'gin_config', [], 'List of paths to the config files.') 37 | flags.DEFINE_multi_string( 38 | 'gin_bindings', [], 'Newline separated list of Gin parameter bindings.') 39 | flags.DEFINE_string( 40 | 'workdir', None, 'Sets the directory where to save tfevents.') 41 | flags.DEFINE_integer('seed', 1, 'Used for replication.') 42 | flags.DEFINE_string('configs_folder', 43 | 'diffstride/examples', 44 | 'Where to find the gin config files.') 45 | FLAGS = flags.FLAGS 46 | 47 | 48 | def main(argv: Sequence[str]) -> None: 49 | if len(argv) > 1: 50 | raise app.UsageError('Too many command-line arguments.') 51 | tf.random.set_seed(FLAGS.seed) 52 | gin_files = [os.path.join(FLAGS.configs_folder, x) for x in FLAGS.gin_config] 53 | gin.parse_config_files_and_bindings(gin_files, FLAGS.gin_bindings) 54 | train.train(workdir=FLAGS.workdir) 55 | 56 | 57 | if __name__ == '__main__': 58 | app.run(main) 59 | -------------------------------------------------------------------------------- /diffstride/examples/train.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 Google LLC. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Training library.""" 17 | 18 | import os 19 | from typing import Optional, Type 20 | from absl import logging 21 | 22 | import gin 23 | import tensorflow as tf 24 | 25 | 26 | @gin.configurable 27 | def train(load_data_fn, 28 | model_cls: Type[tf.keras.Model], 29 | optimizer_cls: Type[tf.keras.optimizers.Optimizer], 30 | num_epochs: int = 200, 31 | workdir: Optional[str] = '/tmp/diffstride/') -> tf.keras.Model: 32 | """Runs the training using keras .fit way.""" 33 | train_ds, test_ds, num_classes = load_data_fn() 34 | 35 | strategy = tf.distribute.MirroredStrategy() 36 | logging.info('Number of devices: %d', strategy.num_replicas_in_sync) 37 | 38 | # Decides to run channels first on GPU and channels last otherwise. 39 | with strategy.scope(): 40 | model = model_cls( 41 | num_output_classes=num_classes, 42 | channels_first=bool(tf.config.list_physical_devices('GPU'))) 43 | model.compile(optimizer=optimizer_cls(), 44 | loss=tf.keras.losses.CategoricalCrossentropy(), 45 | metrics=[tf.keras.metrics.CategoricalAccuracy()]) 46 | 47 | callbacks = [] 48 | if workdir is not None: 49 | callbacks.extend([ 50 | tf.keras.callbacks.TensorBoard( 51 | log_dir=workdir, write_steps_per_second=True), 52 | tf.keras.callbacks.ModelCheckpoint( 53 | filepath=os.path.join(workdir, 'ckpts')), 54 | tf.keras.callbacks.experimental.BackupAndRestore( 55 | backup_dir=os.path.join(workdir, 'backup')) 56 | ]) 57 | model.fit(train_ds, validation_data=test_ds, 58 | epochs=num_epochs, callbacks=callbacks) 59 | return model 60 | -------------------------------------------------------------------------------- /diffstride/pooling.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 Google LLC. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Pooling functions: spectral or spatial, with learnable stride or not. 17 | 18 | Warning: This module runs faster in channels_first data format due to the use of 19 | FFT. In case of channels_last, the tensors will be transposed to channels_first 20 | and then transposed back, which increases the time and memory overhead 21 | significantly. It is therefore highly recommended to run with channels_first on 22 | a GPU. 23 | """ 24 | 25 | from typing import Optional, Tuple, Union 26 | 27 | import gin 28 | import tensorflow as tf 29 | 30 | OptionalDim = Union[int, tf.Tensor, None] 31 | Number = Union[float, int] 32 | Stride = Union[Number, Tuple[Number, Number]] 33 | CHANNELS_FIRST = 'channels_first' 34 | 35 | 36 | def compute_adaptive_span_mask(threshold: tf.float32, 37 | ramp_softness: tf.float32, 38 | pos: tf.Tensor) -> tf.Tensor: 39 | """Adaptive mask as proposed in https://arxiv.org/pdf/1905.07799.pdf. 40 | 41 | Args: 42 | threshold: Threshold that starts the ramp. 43 | ramp_softness: Smoothness of the ramp. 44 | pos: Position indices. 45 | 46 | Returns: 47 | A tf.Tensor containing the 48 | thresholdings for the mask with the same size of pos. 49 | """ 50 | output = (1.0 / ramp_softness) * (ramp_softness + threshold - pos) 51 | return tf.cast(tf.clip_by_value(output, 0.0, 1.0), dtype=tf.complex64) 52 | 53 | 54 | def fixed_spectral_pooling(inputs: tf.Tensor, 55 | lower_height: OptionalDim = None, 56 | upper_height: OptionalDim = None, 57 | upper_width: OptionalDim = None) -> Tuple[tf.Tensor]: 58 | """Fixed spectral pooling in 2D. Expects channels_first data format. 59 | 60 | Args: 61 | inputs: tf.Tensor[batch_size, channels_in, height, width] of input 62 | sequences, obtained from the tf.signal.rfft2d. 63 | lower_height: Lower height limit to apply in the Fourier domain. This limit 64 | represents the upper bound for the lower corner. 65 | upper_height: Upper height limit to apply in the Fourier domain. This limit 66 | represents the lower bound for the upper corner. 67 | upper_width: Width limit to apply in the Fourier domain. 68 | 69 | Returns: 70 | A tf.Tensor[batch_size, channels_out, height, width] containing the 71 | cropped coefficients of the Fourier transform. 72 | """ 73 | return tf.concat([inputs[:, :, :lower_height, :upper_width], 74 | inputs[:, :, upper_height:, :upper_width]], 75 | axis=2) 76 | 77 | 78 | @gin.configurable 79 | class SpatialPooling(tf.keras.layers.AveragePooling2D): 80 | """Fixed pooling layer, computed in the spatial domain.""" 81 | 82 | def __init__(self, 83 | pool_size: Union[int, Tuple[int, int]] = (1, 1), 84 | strides: Stride = (2, 2), 85 | **kwargs): 86 | super().__init__( 87 | pool_size=pool_size, strides=strides, padding='same', **kwargs) 88 | 89 | 90 | @gin.configurable 91 | class FixedSpectralPooling(tf.keras.layers.Layer): 92 | """Fixed Spectral pooling layer, computed in the Fourier domain.""" 93 | 94 | def __init__(self, 95 | strides: Stride = (2.0, 2.0), 96 | data_format: str = CHANNELS_FIRST, 97 | **kwargs): 98 | """Fixed Spectral pooling layer. 99 | 100 | Args: 101 | strides: Fractional strides to apply via the Fourier domain. 102 | data_format: either 'channels_first' or 'channels_last'. Be aware that 103 | providing the data in channels_last format will significantly increase 104 | the overhead due to the need to transpose temporarily to channels_first. 105 | **kwargs: Additional arguments for parent class. 106 | """ 107 | super().__init__(**kwargs) 108 | self._channels_first = data_format == CHANNELS_FIRST 109 | strides = ( 110 | (strides, strides) if isinstance(strides, (int, float)) else strides) 111 | strides = tuple(map(float, strides)) 112 | self._strides = strides 113 | if not strides[0] >= 1 and strides[1] >= 1: 114 | raise ValueError('Strides params need to be above 1, not ({}, {})'.format( 115 | str(strides[0]), str(strides[1]))) 116 | 117 | def build(self, input_shape): 118 | if self._channels_first: 119 | height, width = input_shape[2], input_shape[3] 120 | else: 121 | height, width = input_shape[1], input_shape[2] 122 | self.strides = self.add_weight( 123 | shape=(2,), 124 | initializer=tf.initializers.Constant(self._strides), 125 | trainable=False, 126 | dtype=tf.float32, 127 | name='strides') 128 | strided_height = height // self.strides[0] 129 | strided_height -= strided_height % 2 130 | strided_width = width // self.strides[1] 131 | # The parameter 2 is the minimum to avoid collapse of the feature map. 132 | strided_height = tf.math.maximum(strided_height, 2) 133 | strided_width = tf.math.maximum(strided_width, 2) 134 | lower_height = strided_height // 2 135 | upper_height = height - lower_height 136 | upper_width = strided_width // 2 + 1 137 | self._output_shape = [int(strided_height), int(strided_width)] 138 | self._limits = [int(lower_height), int(upper_height), int(upper_width)] 139 | 140 | def call(self, inputs: tf.Tensor, training: bool = False): 141 | if not self._channels_first: 142 | inputs = tf.transpose(inputs, (0, 3, 1, 2)) 143 | batch_size, input_chans = inputs.shape.as_list()[:2] 144 | lh, uh, uw = self._limits 145 | output_height, output_width = self._output_shape 146 | f_inputs = tf.signal.rfft2d(inputs) 147 | output = fixed_spectral_pooling( 148 | f_inputs, lower_height=lh, upper_height=uh, upper_width=uw) 149 | result = tf.ensure_shape( 150 | tf.signal.irfft2d(output, fft_length=[output_height, output_width]), 151 | [batch_size, input_chans, output_height, output_width]) 152 | if not self._channels_first: 153 | result = tf.transpose(result, (0, 2, 3, 1)) 154 | return result 155 | 156 | 157 | class StrideConstraint(tf.keras.constraints.Constraint): 158 | """Constraint strides. 159 | 160 | Strides are constrained in [1,+infty) as default as smoothness factor 161 | always leave some feature map by default. 162 | """ 163 | 164 | def __init__(self, 165 | lower_limit: Optional[float] = None, 166 | upper_limit: Optional[float] = None, 167 | **kwargs): 168 | """Constraint strides. 169 | 170 | Args: 171 | lower_limit: Lower limit for the stride. 172 | upper_limit: Upper limit for the stride. 173 | **kwargs: Additional arguments for parent class. 174 | """ 175 | super().__init__(**kwargs) 176 | self._lower_limit = lower_limit if lower_limit is not None else 1.0 177 | self._upper_limit = ( 178 | upper_limit if upper_limit is not None else tf.float32.max) 179 | 180 | def __call__(self, kernel): 181 | return tf.clip_by_value(kernel, self._lower_limit, self._upper_limit) 182 | 183 | 184 | @gin.configurable 185 | class DiffStride(tf.keras.layers.Layer): 186 | """Learnable Spectral pooling layer, computed in the Fourier domain. 187 | 188 | The adaptive window function is inspired from 189 | https://arxiv.org/pdf/1905.07799.pdf. 190 | """ 191 | 192 | def __init__(self, 193 | strides: Stride = (2.0, 2.0), 194 | smoothness_factor: float = 4.0, 195 | cropping: bool = True, 196 | trainable: bool = True, 197 | shared_stride: bool = False, 198 | lower_limit_stride: Optional[float] = None, 199 | upper_limit_stride: Optional[float] = None, 200 | data_format: str = CHANNELS_FIRST, 201 | **kwargs): 202 | """Learnable Spectral pooling layer. 203 | 204 | Vertical and horizontal positions are the indices of the feature map. It 205 | allows to selectively weight the output of the fourier transform based 206 | on these positions. 207 | Args: 208 | strides: Fractional strides to init before learning the reduction in the 209 | Fourier domain. 210 | smoothness_factor: Smoothness factor to reduce/crop the input feature map 211 | in the Fourier domain. 212 | cropping: Boolean to specify if the layer crops or set to 0 the 213 | coefficients outside the cropping window in the Fourier domain. 214 | trainable: Boolean to specify if the stride is learnable. 215 | shared_stride: If `True`, a single parameter is shared for vertical and 216 | horizontal strides. 217 | lower_limit_stride: Lower limit for the stride. It can be useful when 218 | there are memory issues, it avoids the stride converge to small values. 219 | upper_limit_stride: Upper limit for the stride. 220 | data_format: either `channels_first` or `channels_last`. Be aware that 221 | channels_last will increase the memory cost due transformation to 222 | channels_first. 223 | **kwargs: Additional arguments for parent class. 224 | """ 225 | super().__init__(**kwargs) 226 | self._cropping = cropping 227 | self._smoothness_factor = smoothness_factor 228 | self._shared_stride = shared_stride 229 | self.trainable = trainable 230 | self._lower_limit_stride = lower_limit_stride 231 | self._upper_limit_stride = upper_limit_stride 232 | self._channels_first = data_format == CHANNELS_FIRST 233 | 234 | # Ensures a tuple of floats. 235 | strides = ( 236 | (strides, strides) if isinstance(strides, (int, float)) else strides) 237 | strides = tuple(map(float, strides)) 238 | if strides[0] != strides[1] and shared_stride: 239 | raise ValueError('shared_stride requires the same initialization for ' 240 | f'vertical and horizontal strides but got {strides}') 241 | if strides[0] < 1 or strides[1] < 1: 242 | raise ValueError(f'Both strides should be >=1 but got {strides}') 243 | if smoothness_factor < 0.0: 244 | raise ValueError('Smoothness factor should be >= 0 but got ' 245 | f'{smoothness_factor}.') 246 | self._strides = strides 247 | 248 | def build(self, input_shape): 249 | del input_shape 250 | init = self._strides[0] if self._shared_stride else self._strides 251 | self.strides = self.add_weight( 252 | shape=(1,) if self._shared_stride else (2,), 253 | initializer=tf.initializers.Constant(init), 254 | trainable=self.trainable, 255 | dtype=tf.float32, 256 | name='strides', 257 | constraint=StrideConstraint( 258 | lower_limit=self._lower_limit_stride, 259 | upper_limit=self._upper_limit_stride)) 260 | 261 | def call(self, inputs: tf.Tensor, training: bool = False): 262 | if not self._channels_first: 263 | inputs = tf.transpose(inputs, (0, 3, 1, 2)) 264 | batch_size, channels = inputs.shape.as_list()[:2] 265 | height, width = tf.shape(inputs)[2], tf.shape(inputs)[3] 266 | 267 | horizontal_positions = tf.range(width // 2 + 1, dtype=tf.float32) 268 | vertical_positions = tf.range( 269 | height // 2 + height % 2, dtype=tf.float32) 270 | vertical_positions = tf.concat([ 271 | tf.reverse(vertical_positions[(height % 2):], axis=[0]), 272 | vertical_positions], axis=0) 273 | # This clipping by .assign is performed to allow gradient to flow, 274 | # even when the stride becomes too small, i.e. close to 1. 275 | min_vertical_stride = tf.cast(height, tf.float32) / ( 276 | tf.cast(height, tf.float32) - self._smoothness_factor) 277 | min_horizontal_stride = tf.cast(width, tf.float32) / ( 278 | tf.cast(width, tf.float32) - self._smoothness_factor) 279 | if self._shared_stride: 280 | min_stride = tf.math.maximum(min_vertical_stride, min_horizontal_stride) 281 | self.strides[0].assign(tf.math.maximum(self.strides[0], min_stride)) 282 | vertical_stride, horizontal_stride = self.strides[0], self.strides[0] 283 | else: 284 | self.strides[0].assign( 285 | tf.math.maximum(self.strides[0], min_vertical_stride)) 286 | self.strides[1].assign( 287 | tf.math.maximum(self.strides[1], min_horizontal_stride)) 288 | vertical_stride, horizontal_stride = self.strides[0], self.strides[1] 289 | 290 | # Explicitly calls the stride constraints on strides. 291 | vertical_stride = self.strides.constraint(vertical_stride) 292 | horizontal_stride = self.strides.constraint(horizontal_stride) 293 | 294 | strided_height = tf.cast(height, tf.float32) / vertical_stride 295 | strided_width = tf.cast(width, tf.float32) / horizontal_stride 296 | # Warning: Little discrepancy for the init of strided_height with 297 | # FixedSpectralPooling. As the gradient of the operation below is 0, it 298 | # is removed for DiffStride. 299 | # strided_height = strided_height - tf.math.floormod(strided_height, 2) 300 | # The parameter 2 is the minimum to avoid collapse of the feature map. 301 | strided_height = tf.math.maximum(strided_height, 2.0) 302 | strided_width = tf.math.maximum(strided_width, 2.0) 303 | lower_height = strided_height / 2.0 304 | upper_width = strided_width / 2.0 + 1.0 305 | 306 | f_inputs = tf.signal.rfft2d(inputs) 307 | horizontal_mask = compute_adaptive_span_mask( 308 | upper_width, self._smoothness_factor, horizontal_positions) 309 | vertical_mask = compute_adaptive_span_mask( 310 | lower_height, self._smoothness_factor, vertical_positions) 311 | 312 | vertical_mask = tf.signal.fftshift(vertical_mask) 313 | output = f_inputs * horizontal_mask[None, None, None, :] 314 | output = output * vertical_mask[None, None, :, None] 315 | if self._cropping: 316 | horizontal_to_keep = tf.stop_gradient( 317 | tf.where(tf.cast(horizontal_mask, tf.float32) > 0.)[:, 0]) 318 | vertical_to_keep = tf.stop_gradient( 319 | tf.where(tf.cast(vertical_mask, tf.float32) > 0.)[:, 0]) 320 | 321 | output = tf.gather(output, indices=vertical_to_keep, axis=2) 322 | output = tf.gather(output, indices=horizontal_to_keep, axis=3) 323 | 324 | result = tf.ensure_shape( 325 | tf.signal.irfft2d(output), [batch_size, channels, None, None]) 326 | if not self._channels_first: 327 | result = tf.transpose(result, (0, 2, 3, 1)) 328 | return result 329 | 330 | def compute_output_shape(self, input_shape): 331 | batch_size, channels = input_shape[:2] 332 | return (batch_size, channels, None, None) 333 | -------------------------------------------------------------------------------- /diffstride/resnet.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 Google LLC. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Defines reset that are compatible with learnable stridding.""" 17 | 18 | import functools 19 | from typing import Optional, Sequence, Tuple, Union 20 | 21 | import gin 22 | import tensorflow as tf 23 | 24 | Number = Union[float, int] 25 | Stride = Union[Number, Tuple[Number, Number]] 26 | 27 | 28 | def data_format(channels_first: bool = True) -> str: 29 | return 'channels_first' if channels_first else 'channels_last' 30 | 31 | 32 | def conv2d( 33 | *args, channels_first: bool = True, weight_decay: float = 0.0, **kwargs): 34 | return tf.keras.layers.Conv2D( 35 | *args, 36 | kernel_initializer='he_normal', 37 | kernel_regularizer=tf.keras.regularizers.L2(weight_decay), 38 | data_format=data_format(channels_first), 39 | use_bias=False, 40 | **kwargs) 41 | 42 | 43 | @gin.configurable 44 | def batch_norm(channels_first: bool = True, **kwargs): 45 | axis = 1 if channels_first else 3 46 | return tf.keras.layers.BatchNormalization(axis=axis, **kwargs) 47 | 48 | 49 | @gin.configurable 50 | class ResidualLayer(tf.keras.layers.Layer): 51 | """A generic residual layer for Resnet, using the pre-act formulation. 52 | 53 | This resnet can represent an `IdBlock` or a `ProjBlock` by setting the 54 | `project` parameter and can be compatible with Spectral or Learnable poolings 55 | by setting the `pooling_cls` parameter. 56 | 57 | The pooling_cls and strides will be overwritten automatically in case of an 58 | ID block. 59 | 60 | The pre-act formulation applies batch norm and non-linearity before the first 61 | conv. 62 | """ 63 | 64 | def __init__(self, 65 | filters: int, 66 | kernel_size: int = gin.REQUIRED, 67 | strides: Stride = (1, 1), 68 | pooling_cls=None, 69 | project: bool = False, 70 | channels_first: bool = True, 71 | weight_decay: float = 5e-3, 72 | **kwargs): 73 | super().__init__(**kwargs) 74 | 75 | # If we are in an Id Layer there is no striding of any kind. 76 | pooling_cls = None if not project else pooling_cls 77 | strides = (1, 1) if not project else strides 78 | # DiffStride compatibility: the strides go into the pooling layer. 79 | if pooling_cls is not None: 80 | conv_strides = (1, 1) 81 | self._pooling = pooling_cls( 82 | strides=strides, data_format=data_format(channels_first)) 83 | else: 84 | self._pooling = tf.identity 85 | conv_strides = strides 86 | 87 | self._strided_conv = conv2d( 88 | filters, kernel_size, strides=conv_strides, padding='same', 89 | channels_first=channels_first, weight_decay=weight_decay) 90 | 91 | # The second convolution is a regular one with no strides, no matter what. 92 | self._unstrided_conv = conv2d( 93 | filters, kernel_size, strides=(1, 1), padding='same', 94 | channels_first=channels_first, weight_decay=weight_decay) 95 | self._bns = tuple(batch_norm(channels_first) for _ in range(2)) 96 | 97 | self._shortcut_conv = None 98 | if project: 99 | self._shortcut_conv = conv2d( 100 | filters, kernel_size=1, strides=conv_strides, padding='same', 101 | channels_first=channels_first, weight_decay=weight_decay) 102 | 103 | def call(self, inputs: tf.Tensor, training: bool = True) -> tf.Tensor: 104 | shortcut_x = inputs 105 | 106 | x = self._bns[0](inputs, training=training) 107 | x = tf.nn.relu(x) 108 | x = self._strided_conv(x) 109 | x = self._pooling(x) 110 | 111 | x = self._bns[1](x, training=training) 112 | x = tf.nn.relu(x) 113 | x = self._unstrided_conv(x) 114 | 115 | if self._shortcut_conv is not None: 116 | shortcut_x = self._shortcut_conv(shortcut_x) 117 | shortcut_x = self._pooling(shortcut_x) 118 | 119 | return x + shortcut_x 120 | 121 | 122 | @gin.configurable 123 | class ResnetBlock(tf.keras.Sequential): 124 | """A block of residual layers sharing the same number of filters. 125 | 126 | The first residual layer of the block and only this one might be strided. 127 | This parameter is controlled by the `project_first` parameters. 128 | 129 | The kwargs are passed down to the ResidualLayer. 130 | """ 131 | 132 | def __init__(self, 133 | filters: int = gin.REQUIRED, 134 | strides: Stride = gin.REQUIRED, 135 | num_layers: int = 2, 136 | project_first: bool = True, 137 | **kwargs): 138 | residual_fn = functools.partial( 139 | ResidualLayer, filters=filters, strides=strides, **kwargs) 140 | blocks = [residual_fn(project=True)] if project_first else [] 141 | num_left_layers = num_layers - int(project_first) 142 | blocks.extend([residual_fn(project=False) for i in range(num_left_layers)]) 143 | super().__init__(blocks) 144 | 145 | 146 | @gin.configurable 147 | class Resnet(tf.keras.Sequential): 148 | """A generic Resnet class, using the pre-activation implementation. 149 | 150 | Depending on the number of blocks and the used filters, it can easily 151 | instantiate a Resnet18 or Resnet56. 152 | 153 | The kwargs are passed down to the ResnetBlock layer. 154 | """ 155 | 156 | def __init__(self, 157 | filters: Sequence[int], 158 | strides: Sequence[Stride], 159 | num_output_classes: int = gin.REQUIRED, 160 | output_activation: Optional[str] = None, 161 | id_only: Sequence[int] = (), 162 | channels_first: bool = True, 163 | pooling_cls=None, 164 | weight_decay: float = 5e-3, 165 | **kwargs): 166 | if len(filters) != len(strides): 167 | raise ValueError(f'The number of `filters` ({len(filters)}) should match' 168 | f' the number of strides ({len(strides)})') 169 | df = data_format(channels_first) 170 | layers = [ 171 | tf.keras.layers.Permute((3, 1, 2)) if channels_first else None, 172 | conv2d(filters[0], 3, padding='same', 173 | strides=(1, 1) if pooling_cls is not None else strides[0], 174 | channels_first=channels_first, weight_decay=weight_decay), 175 | pooling_cls( 176 | strides=strides[0], data_format=df) if pooling_cls else None, 177 | ] 178 | for i, (num_filters, stride) in enumerate(zip(filters[1:], strides[1:])): 179 | layers.append(ResnetBlock(filters=num_filters, 180 | strides=stride, 181 | project_first=(i not in id_only), 182 | channels_first=channels_first, 183 | weight_decay=weight_decay, 184 | pooling_cls=pooling_cls, 185 | **kwargs)) 186 | layers.extend([ 187 | batch_norm(channels_first), 188 | tf.keras.layers.ReLU(), 189 | tf.keras.layers.GlobalAveragePooling2D(data_format=df), 190 | tf.keras.layers.Flatten(), 191 | tf.keras.layers.Dense( 192 | num_output_classes, 193 | activation=output_activation, 194 | kernel_initializer='he_normal', 195 | kernel_regularizer=tf.keras.regularizers.L2(weight_decay), 196 | bias_regularizer=tf.keras.regularizers.L2(weight_decay), 197 | ), 198 | ]) 199 | super().__init__(list(filter(None, layers))) 200 | 201 | 202 | @gin.configurable 203 | def resnet18(strides=None, **kwargs): 204 | strides = [1, 1, 2, 2, 2] if strides is None else strides 205 | filters = [64, 64, 128, 256, 512] 206 | return Resnet( 207 | filters, strides, id_only=[0], num_layers=2, kernel_size=3, **kwargs) 208 | 209 | 210 | @gin.configurable 211 | def resnet56(strides=None, **kwargs): 212 | filters = [16, 16, 32, 64] 213 | strides = [1, 1, 2, 2] if strides is None else strides 214 | return Resnet(filters, strides, num_layers=9, kernel_size=3, **kwargs) 215 | -------------------------------------------------------------------------------- /images/diffstride.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-research/diffstride/b3ef50f6b837265317682bde26b1d15317e73155/images/diffstride.png -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = [ 3 | "setuptools >= 40.9.0", 4 | "wheel", 5 | ] 6 | build-backend = "setuptools.build_meta" -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py>=1.0.0 2 | gin-config>=0.5.0 3 | tensorflow>=2.1 4 | tensorflow-datasets>=4.0.0 5 | tensorflow-addons>=0.15.0 6 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | name = diffstrifde 3 | version = 0.0.1 4 | licens = Apache 2.0 5 | license_files = LICENSE 6 | author = Google LLC 7 | description = Learnable stride for ConvNets. 8 | keywords = convnet, deep learning, signal processing 9 | long_description = file: README.md 10 | long_description_content_type = text/markdown 11 | url = https://github.com/google-research/diffstride 12 | classifiers = 13 | Programming Language :: Python :: 3 14 | 15 | [options] 16 | packages = find: 17 | python_requires = >=3.7 18 | 19 | [options.packages.find] 20 | exclude = 21 | tests, docs, images -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 Google LLC. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Setup script for installing diffstride as a pip module.""" 17 | import os 18 | import setuptools 19 | 20 | 21 | __version__ = '0.1.0' 22 | 23 | # Reads the requirements from requirements.txt 24 | folder = os.path.dirname(__file__) 25 | path = os.path.join(folder, 'requirements.txt') 26 | install_requires = [] 27 | if os.path.exists(path): 28 | with open(path) as fp: 29 | install_requires = [line.strip() for line in fp] 30 | 31 | 32 | setuptools.setup(version=__version__, 33 | install_requires=install_requires) 34 | -------------------------------------------------------------------------------- /tests/pooling_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 Google LLC. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for pooling.""" 17 | 18 | from absl.testing import parameterized 19 | from diffstride import pooling 20 | import tensorflow as tf 21 | 22 | 23 | class PoolingTest(tf.test.TestCase, parameterized.TestCase): 24 | 25 | def setUp(self): 26 | super().setUp() 27 | tf.random.set_seed(0) 28 | 29 | def test_spatial_pooling(self): 30 | shape = (1, 64, 64, 3) # Because CPU. 31 | pool = pooling.SpatialPooling(strides=(2, 4), data_format='channels_last') 32 | inputs = tf.random.uniform(shape) 33 | output = pool(inputs) 34 | self.assertEqual(output.shape, (1, 32, 16, 3)) 35 | 36 | @parameterized.parameters(['channels_first', 'channels_last']) 37 | def test_spectral_pooling(self, data_format): 38 | is_channels_last = data_format == 'channels_last' 39 | shape = (1, 64, 64, 3) if is_channels_last else (1, 3, 64, 64) 40 | pool = pooling.FixedSpectralPooling(strides=(2, 4), data_format=data_format) 41 | inputs = tf.random.uniform(shape) 42 | output = pool(inputs) 43 | output_shape = (1, 32, 16, 3) if is_channels_last else (1, 3, 32, 16) 44 | self.assertEqual(output.shape, output_shape) 45 | 46 | @parameterized.parameters(['channels_first', 'channels_last']) 47 | def test_learnable_spectral_pooling(self, data_format): 48 | is_channels_last = data_format == 'channels_last' 49 | shape = (1, 64, 64, 3) if is_channels_last else (1, 3, 64, 64) 50 | pool = pooling.DiffStride( 51 | strides=(2, 4), data_format=data_format) 52 | inputs = tf.random.uniform(shape) 53 | output = pool(inputs) 54 | output_shape = (1, 40, 24, 3) if is_channels_last else (1, 3, 40, 24) 55 | self.assertEqual(output.shape, output_shape) 56 | 57 | 58 | if __name__ == '__main__': 59 | tf.test.main() 60 | -------------------------------------------------------------------------------- /tests/resnet_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 Google LLC. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for reset.""" 17 | 18 | from absl.testing import parameterized 19 | from diffstride import resnet 20 | import tensorflow as tf 21 | 22 | 23 | class ResnetTest(tf.test.TestCase, parameterized.TestCase): 24 | """Tests for resnets.""" 25 | 26 | def setUp(self): 27 | super().setUp() 28 | tf.random.set_seed(0) 29 | 30 | def test_residual_layer(self): 31 | num_filters = 7 32 | shape = (2, 32, 32, 3) 33 | proj_layer = resnet.ResidualLayer( 34 | filters=num_filters, kernel_size=3, strides=2, channels_first=False, 35 | pooling_cls=None, project=True) 36 | inputs = tf.random.uniform(shape) 37 | output = proj_layer(inputs) 38 | self.assertEqual(output.shape, (2, 16, 16, 7)) 39 | 40 | # Should have the same number of features. 41 | id_layer = resnet.ResidualLayer( 42 | filters=num_filters, kernel_size=3, channels_first=False, 43 | pooling_cls=None, project=False) 44 | output2 = id_layer(output) 45 | self.assertEqual(output2.shape, output.shape) 46 | 47 | def test_resnet_block(self): 48 | num_filters = 7 49 | num_layers = 10 50 | block = resnet.ResnetBlock( 51 | filters=num_filters, kernel_size=3, strides=(2, 4), 52 | num_layers=num_layers, project_first=True, channels_first=False) 53 | self.assertLen(block.layers, num_layers) 54 | 55 | shape = (2, 64, 64, 3) 56 | inputs = tf.random.uniform(shape) 57 | output = block(inputs) 58 | self.assertEqual(output.shape, (2, 32, 16, num_filters)) 59 | 60 | block = resnet.ResnetBlock( 61 | filters=7, strides=(2, 4), kernel_size=3, channels_first=False, 62 | num_layers=num_layers, project_first=False) 63 | self.assertLen(block.layers, num_layers) 64 | output2 = block(output) 65 | self.assertEqual(output2.shape, output.shape) 66 | 67 | 68 | if __name__ == '__main__': 69 | tf.test.main() 70 | --------------------------------------------------------------------------------