├── CONTRIBUTING.md
├── LICENSE
├── README.md
├── diffstride
├── __init__.py
├── examples
│ ├── __init__.py
│ ├── cifar10.gin
│ ├── data.py
│ ├── external_configurables.py
│ ├── main.py
│ └── train.py
├── pooling.py
└── resnet.py
├── images
└── diffstride.png
├── pyproject.toml
├── requirements.txt
├── setup.cfg
├── setup.py
└── tests
├── pooling_test.py
└── resnet_test.py
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # How to Contribute
2 |
3 | We'd love to accept your patches and contributions to this project. There are
4 | just a few small guidelines you need to follow.
5 |
6 | ## Contributor License Agreement
7 |
8 | Contributions to this project must be accompanied by a Contributor License Agreement (CLA). You (or your employer) retain the copyright to your contribution; this simply gives us permission to use and redistribute your contributions as part of the project. Head over to to see your current agreements on file or to sign a new one.
9 |
10 | You generally only need to submit a CLA once, so if you've already submitted
11 | one (even if it was for a different project), you probably don't need to do it again.
12 |
13 | ## Code reviews
14 |
15 | All submissions, including submissions by project members, require review. We use GitHub pull requests for this purpose. Consult [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more information on using pull requests.
16 |
17 | ## Community Guidelines
18 |
19 | This project follows [Google's Open Source Community Guidelines (https://opensource.google/conduct/).
20 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 |
2 | Apache License
3 | Version 2.0, January 2004
4 | http://www.apache.org/licenses/
5 |
6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
7 |
8 | 1. Definitions.
9 |
10 | "License" shall mean the terms and conditions for use, reproduction,
11 | and distribution as defined by Sections 1 through 9 of this document.
12 |
13 | "Licensor" shall mean the copyright owner or entity authorized by
14 | the copyright owner that is granting the License.
15 |
16 | "Legal Entity" shall mean the union of the acting entity and all
17 | other entities that control, are controlled by, or are under common
18 | control with that entity. For the purposes of this definition,
19 | "control" means (i) the power, direct or indirect, to cause the
20 | direction or management of such entity, whether by contract or
21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
22 | outstanding shares, or (iii) beneficial ownership of such entity.
23 |
24 | "You" (or "Your") shall mean an individual or Legal Entity
25 | exercising permissions granted by this License.
26 |
27 | "Source" form shall mean the preferred form for making modifications,
28 | including but not limited to software source code, documentation
29 | source, and configuration files.
30 |
31 | "Object" form shall mean any form resulting from mechanical
32 | transformation or translation of a Source form, including but
33 | not limited to compiled object code, generated documentation,
34 | and conversions to other media types.
35 |
36 | "Work" shall mean the work of authorship, whether in Source or
37 | Object form, made available under the License, as indicated by a
38 | copyright notice that is included in or attached to the work
39 | (an example is provided in the Appendix below).
40 |
41 | "Derivative Works" shall mean any work, whether in Source or Object
42 | form, that is based on (or derived from) the Work and for which the
43 | editorial revisions, annotations, elaborations, or other modifications
44 | represent, as a whole, an original work of authorship. For the purposes
45 | of this License, Derivative Works shall not include works that remain
46 | separable from, or merely link (or bind by name) to the interfaces of,
47 | the Work and Derivative Works thereof.
48 |
49 | "Contribution" shall mean any work of authorship, including
50 | the original version of the Work and any modifications or additions
51 | to that Work or Derivative Works thereof, that is intentionally
52 | submitted to Licensor for inclusion in the Work by the copyright owner
53 | or by an individual or Legal Entity authorized to submit on behalf of
54 | the copyright owner. For the purposes of this definition, "submitted"
55 | means any form of electronic, verbal, or written communication sent
56 | to the Licensor or its representatives, including but not limited to
57 | communication on electronic mailing lists, source code control systems,
58 | and issue tracking systems that are managed by, or on behalf of, the
59 | Licensor for the purpose of discussing and improving the Work, but
60 | excluding communication that is conspicuously marked or otherwise
61 | designated in writing by the copyright owner as "Not a Contribution."
62 |
63 | "Contributor" shall mean Licensor and any individual or Legal Entity
64 | on behalf of whom a Contribution has been received by Licensor and
65 | subsequently incorporated within the Work.
66 |
67 | 2. Grant of Copyright License. Subject to the terms and conditions of
68 | this License, each Contributor hereby grants to You a perpetual,
69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
70 | copyright license to reproduce, prepare Derivative Works of,
71 | publicly display, publicly perform, sublicense, and distribute the
72 | Work and such Derivative Works in Source or Object form.
73 |
74 | 3. Grant of Patent License. Subject to the terms and conditions of
75 | this License, each Contributor hereby grants to You a perpetual,
76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
77 | (except as stated in this section) patent license to make, have made,
78 | use, offer to sell, sell, import, and otherwise transfer the Work,
79 | where such license applies only to those patent claims licensable
80 | by such Contributor that are necessarily infringed by their
81 | Contribution(s) alone or by combination of their Contribution(s)
82 | with the Work to which such Contribution(s) was submitted. If You
83 | institute patent litigation against any entity (including a
84 | cross-claim or counterclaim in a lawsuit) alleging that the Work
85 | or a Contribution incorporated within the Work constitutes direct
86 | or contributory patent infringement, then any patent licenses
87 | granted to You under this License for that Work shall terminate
88 | as of the date such litigation is filed.
89 |
90 | 4. Redistribution. You may reproduce and distribute copies of the
91 | Work or Derivative Works thereof in any medium, with or without
92 | modifications, and in Source or Object form, provided that You
93 | meet the following conditions:
94 |
95 | (a) You must give any other recipients of the Work or
96 | Derivative Works a copy of this License; and
97 |
98 | (b) You must cause any modified files to carry prominent notices
99 | stating that You changed the files; and
100 |
101 | (c) You must retain, in the Source form of any Derivative Works
102 | that You distribute, all copyright, patent, trademark, and
103 | attribution notices from the Source form of the Work,
104 | excluding those notices that do not pertain to any part of
105 | the Derivative Works; and
106 |
107 | (d) If the Work includes a "NOTICE" text file as part of its
108 | distribution, then any Derivative Works that You distribute must
109 | include a readable copy of the attribution notices contained
110 | within such NOTICE file, excluding those notices that do not
111 | pertain to any part of the Derivative Works, in at least one
112 | of the following places: within a NOTICE text file distributed
113 | as part of the Derivative Works; within the Source form or
114 | documentation, if provided along with the Derivative Works; or,
115 | within a display generated by the Derivative Works, if and
116 | wherever such third-party notices normally appear. The contents
117 | of the NOTICE file are for informational purposes only and
118 | do not modify the License. You may add Your own attribution
119 | notices within Derivative Works that You distribute, alongside
120 | or as an addendum to the NOTICE text from the Work, provided
121 | that such additional attribution notices cannot be construed
122 | as modifying the License.
123 |
124 | You may add Your own copyright statement to Your modifications and
125 | may provide additional or different license terms and conditions
126 | for use, reproduction, or distribution of Your modifications, or
127 | for any such Derivative Works as a whole, provided Your use,
128 | reproduction, and distribution of the Work otherwise complies with
129 | the conditions stated in this License.
130 |
131 | 5. Submission of Contributions. Unless You explicitly state otherwise,
132 | any Contribution intentionally submitted for inclusion in the Work
133 | by You to the Licensor shall be under the terms and conditions of
134 | this License, without any additional terms or conditions.
135 | Notwithstanding the above, nothing herein shall supersede or modify
136 | the terms of any separate license agreement you may have executed
137 | with Licensor regarding such Contributions.
138 |
139 | 6. Trademarks. This License does not grant permission to use the trade
140 | names, trademarks, service marks, or product names of the Licensor,
141 | except as required for reasonable and customary use in describing the
142 | origin of the Work and reproducing the content of the NOTICE file.
143 |
144 | 7. Disclaimer of Warranty. Unless required by applicable law or
145 | agreed to in writing, Licensor provides the Work (and each
146 | Contributor provides its Contributions) on an "AS IS" BASIS,
147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
148 | implied, including, without limitation, any warranties or conditions
149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
150 | PARTICULAR PURPOSE. You are solely responsible for determining the
151 | appropriateness of using or redistributing the Work and assume any
152 | risks associated with Your exercise of permissions under this License.
153 |
154 | 8. Limitation of Liability. In no event and under no legal theory,
155 | whether in tort (including negligence), contract, or otherwise,
156 | unless required by applicable law (such as deliberate and grossly
157 | negligent acts) or agreed to in writing, shall any Contributor be
158 | liable to You for damages, including any direct, indirect, special,
159 | incidental, or consequential damages of any character arising as a
160 | result of this License or out of the use or inability to use the
161 | Work (including but not limited to damages for loss of goodwill,
162 | work stoppage, computer failure or malfunction, or any and all
163 | other commercial damages or losses), even if such Contributor
164 | has been advised of the possibility of such damages.
165 |
166 | 9. Accepting Warranty or Additional Liability. While redistributing
167 | the Work or Derivative Works thereof, You may choose to offer,
168 | and charge a fee for, acceptance of support, warranty, indemnity,
169 | or other liability obligations and/or rights consistent with this
170 | License. However, in accepting such obligations, You may act only
171 | on Your own behalf and on Your sole responsibility, not on behalf
172 | of any other Contributor, and only if You agree to indemnify,
173 | defend, and hold each Contributor harmless for any liability
174 | incurred by, or claims asserted against, such Contributor by reason
175 | of your accepting any such warranty or additional liability.
176 |
177 | END OF TERMS AND CONDITIONS
178 |
179 | APPENDIX: How to apply the Apache License to your work.
180 |
181 | To apply the Apache License to your work, attach the following
182 | boilerplate notice, with the fields enclosed by brackets "[]"
183 | replaced with your own identifying information. (Don't include
184 | the brackets!) The text should be enclosed in the appropriate
185 | comment syntax for the file format. We also recommend that a
186 | file or class name and description of purpose be included on the
187 | same "printed page" as the copyright notice for easier
188 | identification within third-party archives.
189 |
190 | Copyright [yyyy] [name of copyright owner]
191 |
192 | Licensed under the Apache License, Version 2.0 (the "License");
193 | you may not use this file except in compliance with the License.
194 | You may obtain a copy of the License at
195 |
196 | http://www.apache.org/licenses/LICENSE-2.0
197 |
198 | Unless required by applicable law or agreed to in writing, software
199 | distributed under the License is distributed on an "AS IS" BASIS,
200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
201 | See the License for the specific language governing permissions and
202 | limitations under the License.
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # DiffStride: Learning strides in convolutional neural networks
2 |
3 | 
4 |
5 | DiffStride is a pooling layer with learnable strides. Unlike strided convolutions, average pooling or max-pooling that require cross-validating stride values at each layer, DiffStride can be initialized with an arbitrary value at each layer (e.g. `(2, 2`) and during training its strides will be optimized for the task at hand.
6 |
7 | We describe DiffStride in our ICLR 2022 paper [Learning Strides in Convolutional Neural Network](https://arxiv.org/abs/2202.01653). Compared to the experiments described in the paper, this implementation uses a [Pre-Act Resnet](https://arxiv.org/abs/1603.05027) and uses [Mixup](https://arxiv.org/abs/1710.09412) in training.
8 |
9 | ## Installation
10 |
11 | To install the diffstride library, run the following `pip` git clone this repo:
12 |
13 | ```
14 | git clone https://github.com/google-research/diffstride.git
15 | ```
16 |
17 | The cd into the root and run the command:
18 | ```
19 | pip install -e .
20 | ```
21 |
22 | ## Example training
23 |
24 | To run an example training on CIFAR10 and save the result in TensorBoard:
25 |
26 | ```
27 | python3 -m diffstride.examples.main \
28 | --gin_config=cifar10.gin \
29 | --gin_bindings="train.workdir = '/tmp/exp/diffstride/resnet18/'"
30 | ```
31 |
32 | ## Using custom parameters
33 | This implementation uses [Gin](https://github.com/google/gin-config) to parametrize the model, data processing and training loop.
34 | To use custom parameters, one should edit `examples/cifar10.gin`.
35 |
36 | For example, to train with SpectralPooling on cifar100:
37 |
38 | ```
39 | data.load_datasets:
40 | name = 'cifar100'
41 |
42 | resnet.Resnet:
43 | pooling_cls = @pooling.FixedSpectralPooling
44 | ```
45 |
46 | Or to train with strided convolutions and without Mixup:
47 |
48 | ```
49 | data.load_datasets:
50 | mixup_alpha = 0.0
51 |
52 | resnet.Resnet:
53 | pooling_cls = None
54 | ```
55 |
56 | ## Results
57 | This current implementation gives the following accuracy on CIFAR-10 and CIFAR-100, averaged over three runs. To show the robustness of DiffStride to stride initialization, we run both with the standard strides of ResNet (`resnet.resnet18.strides = '1, 1, 2, 2, 2'`) and with a 'poor' choice of strides (`resnet.resnet18.strides = '1, 1, 3, 2, 3'`). Unlike Strided Convolutions and fixed Spectral Pooling, DiffStride is not affected by the stride initialization.
58 |
59 | ### CIFAR-10
60 |
61 | | Pooling | Test Accuracy (%) w/ strides = (1, 1, 2, 2, 2)| Test Accuracy (%) w/ strides = (1, 1, 3, 2, 3)|
62 | | -------------------------------- | --------------------------------------------- | --------------------------------------------- |
63 | | Strided Convolution (Baseline) | 91.06 ± 0.04 | 89.21 ± 0.27 |
64 | | Spectral Pooling | 93.49 ± 0.05 | 92.00 ± 0.08 |
65 | | DiffStride | **94.20 ± 0.06** | **94.19 ± 0.15** |
66 |
67 | ### CIFAR-100
68 |
69 | | Pooling | Test Accuracy (%) w/ strides = (1, 1, 2, 2, 2)| Test Accuracy (%) w/ strides = (1, 1, 3, 2, 3)|
70 | | -------------------------------- | --------------------------------------------- | --------------------------------------------- |
71 | | Strided Convolution (Baseline) | 65.75 ± 0.39 | 60.82 ± 0.42 |
72 | | Spectral Pooling | 72.86 ± 0.23 | 67.74 ± 0.43 |
73 | | DiffStride | **76.08 ± 0.23** | **76.09 ± 0.06** |
74 |
75 | ## CPU/GPU Warning
76 | We rely on the tensorflow FFT implementation which requires the input data to be in the `channels_first` format. This is usually not the regular data format of most datasets (including CIFAR) and running with `channels_first` also prevents from using of convolutions on CPU. Therefore even if we do support `channels_last` data format for CPU compatibility , we do encourage the user to run with `channels_first` data format *on GPU*.
77 |
78 | ## Reference
79 | If you use this repository, please consider citing:
80 |
81 | ```
82 | @article{riad2022diffstride,
83 | title={Learning Strides in Convolutional Neural Networks},
84 | author={Riad, Rachid and Teboul, Olivier and Grangier, David and Zeghidour, Neil},
85 | journal={ICLR},
86 | year={2022}
87 | }
88 | ```
89 |
90 | ## Disclainer
91 | This is not an official Google product.
92 |
93 |
--------------------------------------------------------------------------------
/diffstride/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 Google LLC.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 |
--------------------------------------------------------------------------------
/diffstride/examples/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 Google LLC.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 |
--------------------------------------------------------------------------------
/diffstride/examples/cifar10.gin:
--------------------------------------------------------------------------------
1 | import diffstride.examples.external_configurables
2 |
3 | import diffstride.examples.data as data
4 | import diffstride.examples.train as train
5 | import diffstride.resnet as resnet
6 | import diffstride.pooling as pooling
7 |
8 | data.load_datasets:
9 | name = 'cifar10'
10 | batch_size = 128
11 | mixup_alpha = 0.2
12 |
13 | train.train:
14 | load_data_fn = @data.load_datasets
15 | model_cls = @resnet.resnet18
16 | optimizer_cls = @tf.keras.optimizers.SGD
17 | num_epochs = 400
18 |
19 | tf.keras.optimizers.SGD:
20 | learning_rate = @tf.keras.optimizers.schedules.PiecewiseConstantDecay()
21 | momentum = 0.9
22 |
23 | tf.keras.optimizers.schedules.PiecewiseConstantDecay:
24 | boundaries = [400, 18_000, 32_000]
25 | values = [0.01, 0.1, 0.01, 0.001]
26 |
27 | tf.keras.losses.CategoricalCrossentropy:
28 | label_smoothing = 0.2
29 |
30 | # Initialize the strides differently if needed.
31 | resnet.resnet18.strides = [1, 1, 2, 2, 2]
32 | resnet.Resnet:
33 | output_activation = 'softmax'
34 | weight_decay = 5e-4
35 | # Either @pooling.DiffStride, @pooling.FixedSpectralPooling or None
36 | # for strided convolutions.
37 | pooling_cls = @pooling.DiffStride
38 |
39 | batch_norm:
40 | momentum = 0.9
41 | epsilon = 1e-5
42 |
43 | # Set to `None` for baseline.
44 | pooling.DiffStride:
45 | smoothness_factor = 4.0
46 | cropping = True
47 | trainable = True
48 | shared_stride = False
49 | lower_limit_stride = None
50 | upper_limit_stride = None
51 |
--------------------------------------------------------------------------------
/diffstride/examples/data.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 Google LLC.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Prepares data."""
17 |
18 | from typing import Tuple
19 |
20 |
21 | import gin
22 | import tensorflow as tf
23 | import tensorflow_datasets as tfds
24 |
25 |
26 | def convert(image: tf.Tensor) -> tf.Tensor:
27 | mean = tf.constant([[[0.4914, 0.4822, 0.4465]]])
28 | std = tf.constant([[[0.2023, 0.1994, 0.2010]]])
29 | image = tf.cast(image, tf.float32) / tf.uint8.max
30 | return (image - mean) / std
31 |
32 |
33 | @gin.configurable
34 | def augment(image: tf.Tensor, pad_size: int = 2) -> tf.Tensor:
35 | """Data augmentation: Random shifts and flips."""
36 | shape = image.shape
37 | height, width = shape[:2]
38 | image = tf.image.pad_to_bounding_box(image, pad_size, pad_size,
39 | height + 2 * pad_size,
40 | width + 2 * pad_size)
41 | image = tf.image.random_crop(image, shape)
42 | image = tf.image.random_flip_left_right(image)
43 | return image
44 |
45 |
46 | def sample_beta_distribution(size: int,
47 | concentration_0: float = 0.2,
48 | concentration_1: float = 0.2) -> tf.Tensor:
49 | """Samples from a beta distribution."""
50 | gamma_1_sample = tf.random.gamma(shape=[size], alpha=concentration_1)
51 | gamma_2_sample = tf.random.gamma(shape=[size], alpha=concentration_0)
52 | return gamma_1_sample / (gamma_1_sample + gamma_2_sample)
53 |
54 |
55 | def mix_up(images: tf.Tensor,
56 | labels: tf.Tensor,
57 | alpha: float = 0.2) -> Tuple[tf.Tensor, tf.Tensor]:
58 | """Applies mixup to two examples."""
59 | batch_size = images.shape[0]
60 | if batch_size != 2:
61 | raise ValueError(f'Mixup expects batch_size == 2 but got {batch_size}.')
62 |
63 | # Sample lambda and reshape it to do the mixup
64 | weight = sample_beta_distribution(1, alpha, alpha)[0]
65 |
66 | # Perform mixup on both images and labels by combining a pair of images/labels
67 | # (one from each dataset) into one image/label
68 | mixed_image = images[0] * weight + images[1] * (1 - weight)
69 | mixed_label = labels[0] * weight + labels[1] * (1 - weight)
70 | return (mixed_image, mixed_label)
71 |
72 |
73 | def prepare(ds: tf.data.Dataset,
74 | num_classes: int,
75 | batch_size: int = 32,
76 | training: bool = True,
77 | augment_fn=augment,
78 | mixup_alpha: float = 0.0) -> tf.data.Dataset:
79 | """Prepares a dataset for train/test."""
80 | def transform(
81 | image: tf.Tensor, label: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]:
82 | image = convert(image)
83 | if training and augment_fn is not None:
84 | image = augment_fn(image)
85 | label = tf.one_hot(label, depth=num_classes)
86 | return image, label
87 |
88 | ds = ds.map(transform, num_parallel_calls=tf.data.AUTOTUNE)
89 | if training:
90 | ds = ds.shuffle(1000)
91 | if mixup_alpha > 0.0:
92 | ds = ds.batch(2, drop_remainder=True)
93 | ds = ds.map(mix_up)
94 | ds = ds.batch(batch_size, drop_remainder=training)
95 | ds = ds.prefetch(tf.data.AUTOTUNE)
96 | return ds
97 |
98 |
99 | @gin.configurable
100 | def load_datasets(name: str,
101 | batch_size: int = 32,
102 | augment_fn=None,
103 | mixup_alpha: float = 0.0):
104 | """Loads a tf.Dataset corresponding to the given name."""
105 | datasets, info = tfds.load(name, as_supervised=True, with_info=True)
106 | _, label_key = info.supervised_keys
107 | num_classes = info.features[label_key].num_classes
108 | ds_train = prepare(
109 | datasets['train'],
110 | num_classes=num_classes,
111 | batch_size=batch_size,
112 | training=True,
113 | augment_fn=augment_fn,
114 | mixup_alpha=mixup_alpha)
115 | ds_test = prepare(
116 | datasets['test'],
117 | num_classes=num_classes,
118 | batch_size=batch_size,
119 | training=False,
120 | augment_fn=None,
121 | mixup_alpha=0.0)
122 | return ds_train, ds_test, num_classes
123 |
--------------------------------------------------------------------------------
/diffstride/examples/external_configurables.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 Google LLC.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Makes some classes and function gin configurable."""
17 |
18 | import gin
19 | import gin.tf.external_configurables
20 | import tensorflow as tf
21 | import tensorflow_addons.optimizers as tfa_optimizers
22 |
23 | configurables = {
24 | 'tf.keras.layers': (
25 | tf.keras.layers.Conv1D,
26 | tf.keras.layers.Conv1DTranspose,
27 | tf.keras.layers.Conv2D,
28 | tf.keras.layers.Conv2DTranspose,
29 | tf.keras.layers.Dense,
30 | tf.keras.layers.Flatten,
31 | tf.keras.layers.Reshape,
32 | tf.keras.layers.MaxPooling2D,
33 | tf.keras.layers.GlobalMaxPooling2D,
34 | tf.keras.layers.BatchNormalization,
35 | tf.keras.layers.LayerNormalization,
36 | ),
37 | 'tf.keras.regularizers': (
38 | tf.keras.regularizers.L1,
39 | tf.keras.regularizers.L2,
40 | tf.keras.regularizers.L1L2,
41 | ),
42 | 'tf.keras.initializers': (
43 | tf.keras.initializers.Constant,
44 | ),
45 | 'tf.keras.losses': (
46 | tf.keras.losses.CategoricalCrossentropy,
47 | ),
48 | 'tf.keras.optimizers.schedules': (
49 | tf.keras.optimizers.schedules.CosineDecay,
50 | tf.keras.optimizers.schedules.PiecewiseConstantDecay,
51 | ),
52 | 'tfa.optimizers': (
53 | tfa_optimizers.MovingAverage,
54 | ),
55 | }
56 |
57 | for module in configurables:
58 | for v in configurables[module]:
59 | gin.config.external_configurable(v, module=module)
60 |
--------------------------------------------------------------------------------
/diffstride/examples/main.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 Google LLC.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | r"""Launches the training loop given a configuration.
17 |
18 | For instance to train a resnet18 and cifar10, with diffstride and saves the
19 | training results locally to be displayed with tensorboard:
20 |
21 | python3 -m diffstride.example.main \
22 | --gin_config=cifar10 \
23 | --gin_bindings="train.workdir=/tmp/exp/diffstride/resnet18/"
24 | """
25 |
26 | import os
27 | from typing import Sequence
28 |
29 | from absl import app
30 | from absl import flags
31 | from diffstride.examples import train
32 | import gin
33 | import tensorflow as tf
34 |
35 | flags.DEFINE_multi_string(
36 | 'gin_config', [], 'List of paths to the config files.')
37 | flags.DEFINE_multi_string(
38 | 'gin_bindings', [], 'Newline separated list of Gin parameter bindings.')
39 | flags.DEFINE_string(
40 | 'workdir', None, 'Sets the directory where to save tfevents.')
41 | flags.DEFINE_integer('seed', 1, 'Used for replication.')
42 | flags.DEFINE_string('configs_folder',
43 | 'diffstride/examples',
44 | 'Where to find the gin config files.')
45 | FLAGS = flags.FLAGS
46 |
47 |
48 | def main(argv: Sequence[str]) -> None:
49 | if len(argv) > 1:
50 | raise app.UsageError('Too many command-line arguments.')
51 | tf.random.set_seed(FLAGS.seed)
52 | gin_files = [os.path.join(FLAGS.configs_folder, x) for x in FLAGS.gin_config]
53 | gin.parse_config_files_and_bindings(gin_files, FLAGS.gin_bindings)
54 | train.train(workdir=FLAGS.workdir)
55 |
56 |
57 | if __name__ == '__main__':
58 | app.run(main)
59 |
--------------------------------------------------------------------------------
/diffstride/examples/train.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 Google LLC.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Training library."""
17 |
18 | import os
19 | from typing import Optional, Type
20 | from absl import logging
21 |
22 | import gin
23 | import tensorflow as tf
24 |
25 |
26 | @gin.configurable
27 | def train(load_data_fn,
28 | model_cls: Type[tf.keras.Model],
29 | optimizer_cls: Type[tf.keras.optimizers.Optimizer],
30 | num_epochs: int = 200,
31 | workdir: Optional[str] = '/tmp/diffstride/') -> tf.keras.Model:
32 | """Runs the training using keras .fit way."""
33 | train_ds, test_ds, num_classes = load_data_fn()
34 |
35 | strategy = tf.distribute.MirroredStrategy()
36 | logging.info('Number of devices: %d', strategy.num_replicas_in_sync)
37 |
38 | # Decides to run channels first on GPU and channels last otherwise.
39 | with strategy.scope():
40 | model = model_cls(
41 | num_output_classes=num_classes,
42 | channels_first=bool(tf.config.list_physical_devices('GPU')))
43 | model.compile(optimizer=optimizer_cls(),
44 | loss=tf.keras.losses.CategoricalCrossentropy(),
45 | metrics=[tf.keras.metrics.CategoricalAccuracy()])
46 |
47 | callbacks = []
48 | if workdir is not None:
49 | callbacks.extend([
50 | tf.keras.callbacks.TensorBoard(
51 | log_dir=workdir, write_steps_per_second=True),
52 | tf.keras.callbacks.ModelCheckpoint(
53 | filepath=os.path.join(workdir, 'ckpts')),
54 | tf.keras.callbacks.experimental.BackupAndRestore(
55 | backup_dir=os.path.join(workdir, 'backup'))
56 | ])
57 | model.fit(train_ds, validation_data=test_ds,
58 | epochs=num_epochs, callbacks=callbacks)
59 | return model
60 |
--------------------------------------------------------------------------------
/diffstride/pooling.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 Google LLC.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Pooling functions: spectral or spatial, with learnable stride or not.
17 |
18 | Warning: This module runs faster in channels_first data format due to the use of
19 | FFT. In case of channels_last, the tensors will be transposed to channels_first
20 | and then transposed back, which increases the time and memory overhead
21 | significantly. It is therefore highly recommended to run with channels_first on
22 | a GPU.
23 | """
24 |
25 | from typing import Optional, Tuple, Union
26 |
27 | import gin
28 | import tensorflow as tf
29 |
30 | OptionalDim = Union[int, tf.Tensor, None]
31 | Number = Union[float, int]
32 | Stride = Union[Number, Tuple[Number, Number]]
33 | CHANNELS_FIRST = 'channels_first'
34 |
35 |
36 | def compute_adaptive_span_mask(threshold: tf.float32,
37 | ramp_softness: tf.float32,
38 | pos: tf.Tensor) -> tf.Tensor:
39 | """Adaptive mask as proposed in https://arxiv.org/pdf/1905.07799.pdf.
40 |
41 | Args:
42 | threshold: Threshold that starts the ramp.
43 | ramp_softness: Smoothness of the ramp.
44 | pos: Position indices.
45 |
46 | Returns:
47 | A tf.Tensor containing the
48 | thresholdings for the mask with the same size of pos.
49 | """
50 | output = (1.0 / ramp_softness) * (ramp_softness + threshold - pos)
51 | return tf.cast(tf.clip_by_value(output, 0.0, 1.0), dtype=tf.complex64)
52 |
53 |
54 | def fixed_spectral_pooling(inputs: tf.Tensor,
55 | lower_height: OptionalDim = None,
56 | upper_height: OptionalDim = None,
57 | upper_width: OptionalDim = None) -> Tuple[tf.Tensor]:
58 | """Fixed spectral pooling in 2D. Expects channels_first data format.
59 |
60 | Args:
61 | inputs: tf.Tensor[batch_size, channels_in, height, width] of input
62 | sequences, obtained from the tf.signal.rfft2d.
63 | lower_height: Lower height limit to apply in the Fourier domain. This limit
64 | represents the upper bound for the lower corner.
65 | upper_height: Upper height limit to apply in the Fourier domain. This limit
66 | represents the lower bound for the upper corner.
67 | upper_width: Width limit to apply in the Fourier domain.
68 |
69 | Returns:
70 | A tf.Tensor[batch_size, channels_out, height, width] containing the
71 | cropped coefficients of the Fourier transform.
72 | """
73 | return tf.concat([inputs[:, :, :lower_height, :upper_width],
74 | inputs[:, :, upper_height:, :upper_width]],
75 | axis=2)
76 |
77 |
78 | @gin.configurable
79 | class SpatialPooling(tf.keras.layers.AveragePooling2D):
80 | """Fixed pooling layer, computed in the spatial domain."""
81 |
82 | def __init__(self,
83 | pool_size: Union[int, Tuple[int, int]] = (1, 1),
84 | strides: Stride = (2, 2),
85 | **kwargs):
86 | super().__init__(
87 | pool_size=pool_size, strides=strides, padding='same', **kwargs)
88 |
89 |
90 | @gin.configurable
91 | class FixedSpectralPooling(tf.keras.layers.Layer):
92 | """Fixed Spectral pooling layer, computed in the Fourier domain."""
93 |
94 | def __init__(self,
95 | strides: Stride = (2.0, 2.0),
96 | data_format: str = CHANNELS_FIRST,
97 | **kwargs):
98 | """Fixed Spectral pooling layer.
99 |
100 | Args:
101 | strides: Fractional strides to apply via the Fourier domain.
102 | data_format: either 'channels_first' or 'channels_last'. Be aware that
103 | providing the data in channels_last format will significantly increase
104 | the overhead due to the need to transpose temporarily to channels_first.
105 | **kwargs: Additional arguments for parent class.
106 | """
107 | super().__init__(**kwargs)
108 | self._channels_first = data_format == CHANNELS_FIRST
109 | strides = (
110 | (strides, strides) if isinstance(strides, (int, float)) else strides)
111 | strides = tuple(map(float, strides))
112 | self._strides = strides
113 | if not strides[0] >= 1 and strides[1] >= 1:
114 | raise ValueError('Strides params need to be above 1, not ({}, {})'.format(
115 | str(strides[0]), str(strides[1])))
116 |
117 | def build(self, input_shape):
118 | if self._channels_first:
119 | height, width = input_shape[2], input_shape[3]
120 | else:
121 | height, width = input_shape[1], input_shape[2]
122 | self.strides = self.add_weight(
123 | shape=(2,),
124 | initializer=tf.initializers.Constant(self._strides),
125 | trainable=False,
126 | dtype=tf.float32,
127 | name='strides')
128 | strided_height = height // self.strides[0]
129 | strided_height -= strided_height % 2
130 | strided_width = width // self.strides[1]
131 | # The parameter 2 is the minimum to avoid collapse of the feature map.
132 | strided_height = tf.math.maximum(strided_height, 2)
133 | strided_width = tf.math.maximum(strided_width, 2)
134 | lower_height = strided_height // 2
135 | upper_height = height - lower_height
136 | upper_width = strided_width // 2 + 1
137 | self._output_shape = [int(strided_height), int(strided_width)]
138 | self._limits = [int(lower_height), int(upper_height), int(upper_width)]
139 |
140 | def call(self, inputs: tf.Tensor, training: bool = False):
141 | if not self._channels_first:
142 | inputs = tf.transpose(inputs, (0, 3, 1, 2))
143 | batch_size, input_chans = inputs.shape.as_list()[:2]
144 | lh, uh, uw = self._limits
145 | output_height, output_width = self._output_shape
146 | f_inputs = tf.signal.rfft2d(inputs)
147 | output = fixed_spectral_pooling(
148 | f_inputs, lower_height=lh, upper_height=uh, upper_width=uw)
149 | result = tf.ensure_shape(
150 | tf.signal.irfft2d(output, fft_length=[output_height, output_width]),
151 | [batch_size, input_chans, output_height, output_width])
152 | if not self._channels_first:
153 | result = tf.transpose(result, (0, 2, 3, 1))
154 | return result
155 |
156 |
157 | class StrideConstraint(tf.keras.constraints.Constraint):
158 | """Constraint strides.
159 |
160 | Strides are constrained in [1,+infty) as default as smoothness factor
161 | always leave some feature map by default.
162 | """
163 |
164 | def __init__(self,
165 | lower_limit: Optional[float] = None,
166 | upper_limit: Optional[float] = None,
167 | **kwargs):
168 | """Constraint strides.
169 |
170 | Args:
171 | lower_limit: Lower limit for the stride.
172 | upper_limit: Upper limit for the stride.
173 | **kwargs: Additional arguments for parent class.
174 | """
175 | super().__init__(**kwargs)
176 | self._lower_limit = lower_limit if lower_limit is not None else 1.0
177 | self._upper_limit = (
178 | upper_limit if upper_limit is not None else tf.float32.max)
179 |
180 | def __call__(self, kernel):
181 | return tf.clip_by_value(kernel, self._lower_limit, self._upper_limit)
182 |
183 |
184 | @gin.configurable
185 | class DiffStride(tf.keras.layers.Layer):
186 | """Learnable Spectral pooling layer, computed in the Fourier domain.
187 |
188 | The adaptive window function is inspired from
189 | https://arxiv.org/pdf/1905.07799.pdf.
190 | """
191 |
192 | def __init__(self,
193 | strides: Stride = (2.0, 2.0),
194 | smoothness_factor: float = 4.0,
195 | cropping: bool = True,
196 | trainable: bool = True,
197 | shared_stride: bool = False,
198 | lower_limit_stride: Optional[float] = None,
199 | upper_limit_stride: Optional[float] = None,
200 | data_format: str = CHANNELS_FIRST,
201 | **kwargs):
202 | """Learnable Spectral pooling layer.
203 |
204 | Vertical and horizontal positions are the indices of the feature map. It
205 | allows to selectively weight the output of the fourier transform based
206 | on these positions.
207 | Args:
208 | strides: Fractional strides to init before learning the reduction in the
209 | Fourier domain.
210 | smoothness_factor: Smoothness factor to reduce/crop the input feature map
211 | in the Fourier domain.
212 | cropping: Boolean to specify if the layer crops or set to 0 the
213 | coefficients outside the cropping window in the Fourier domain.
214 | trainable: Boolean to specify if the stride is learnable.
215 | shared_stride: If `True`, a single parameter is shared for vertical and
216 | horizontal strides.
217 | lower_limit_stride: Lower limit for the stride. It can be useful when
218 | there are memory issues, it avoids the stride converge to small values.
219 | upper_limit_stride: Upper limit for the stride.
220 | data_format: either `channels_first` or `channels_last`. Be aware that
221 | channels_last will increase the memory cost due transformation to
222 | channels_first.
223 | **kwargs: Additional arguments for parent class.
224 | """
225 | super().__init__(**kwargs)
226 | self._cropping = cropping
227 | self._smoothness_factor = smoothness_factor
228 | self._shared_stride = shared_stride
229 | self.trainable = trainable
230 | self._lower_limit_stride = lower_limit_stride
231 | self._upper_limit_stride = upper_limit_stride
232 | self._channels_first = data_format == CHANNELS_FIRST
233 |
234 | # Ensures a tuple of floats.
235 | strides = (
236 | (strides, strides) if isinstance(strides, (int, float)) else strides)
237 | strides = tuple(map(float, strides))
238 | if strides[0] != strides[1] and shared_stride:
239 | raise ValueError('shared_stride requires the same initialization for '
240 | f'vertical and horizontal strides but got {strides}')
241 | if strides[0] < 1 or strides[1] < 1:
242 | raise ValueError(f'Both strides should be >=1 but got {strides}')
243 | if smoothness_factor < 0.0:
244 | raise ValueError('Smoothness factor should be >= 0 but got '
245 | f'{smoothness_factor}.')
246 | self._strides = strides
247 |
248 | def build(self, input_shape):
249 | del input_shape
250 | init = self._strides[0] if self._shared_stride else self._strides
251 | self.strides = self.add_weight(
252 | shape=(1,) if self._shared_stride else (2,),
253 | initializer=tf.initializers.Constant(init),
254 | trainable=self.trainable,
255 | dtype=tf.float32,
256 | name='strides',
257 | constraint=StrideConstraint(
258 | lower_limit=self._lower_limit_stride,
259 | upper_limit=self._upper_limit_stride))
260 |
261 | def call(self, inputs: tf.Tensor, training: bool = False):
262 | if not self._channels_first:
263 | inputs = tf.transpose(inputs, (0, 3, 1, 2))
264 | batch_size, channels = inputs.shape.as_list()[:2]
265 | height, width = tf.shape(inputs)[2], tf.shape(inputs)[3]
266 |
267 | horizontal_positions = tf.range(width // 2 + 1, dtype=tf.float32)
268 | vertical_positions = tf.range(
269 | height // 2 + height % 2, dtype=tf.float32)
270 | vertical_positions = tf.concat([
271 | tf.reverse(vertical_positions[(height % 2):], axis=[0]),
272 | vertical_positions], axis=0)
273 | # This clipping by .assign is performed to allow gradient to flow,
274 | # even when the stride becomes too small, i.e. close to 1.
275 | min_vertical_stride = tf.cast(height, tf.float32) / (
276 | tf.cast(height, tf.float32) - self._smoothness_factor)
277 | min_horizontal_stride = tf.cast(width, tf.float32) / (
278 | tf.cast(width, tf.float32) - self._smoothness_factor)
279 | if self._shared_stride:
280 | min_stride = tf.math.maximum(min_vertical_stride, min_horizontal_stride)
281 | self.strides[0].assign(tf.math.maximum(self.strides[0], min_stride))
282 | vertical_stride, horizontal_stride = self.strides[0], self.strides[0]
283 | else:
284 | self.strides[0].assign(
285 | tf.math.maximum(self.strides[0], min_vertical_stride))
286 | self.strides[1].assign(
287 | tf.math.maximum(self.strides[1], min_horizontal_stride))
288 | vertical_stride, horizontal_stride = self.strides[0], self.strides[1]
289 |
290 | # Explicitly calls the stride constraints on strides.
291 | vertical_stride = self.strides.constraint(vertical_stride)
292 | horizontal_stride = self.strides.constraint(horizontal_stride)
293 |
294 | strided_height = tf.cast(height, tf.float32) / vertical_stride
295 | strided_width = tf.cast(width, tf.float32) / horizontal_stride
296 | # Warning: Little discrepancy for the init of strided_height with
297 | # FixedSpectralPooling. As the gradient of the operation below is 0, it
298 | # is removed for DiffStride.
299 | # strided_height = strided_height - tf.math.floormod(strided_height, 2)
300 | # The parameter 2 is the minimum to avoid collapse of the feature map.
301 | strided_height = tf.math.maximum(strided_height, 2.0)
302 | strided_width = tf.math.maximum(strided_width, 2.0)
303 | lower_height = strided_height / 2.0
304 | upper_width = strided_width / 2.0 + 1.0
305 |
306 | f_inputs = tf.signal.rfft2d(inputs)
307 | horizontal_mask = compute_adaptive_span_mask(
308 | upper_width, self._smoothness_factor, horizontal_positions)
309 | vertical_mask = compute_adaptive_span_mask(
310 | lower_height, self._smoothness_factor, vertical_positions)
311 |
312 | vertical_mask = tf.signal.fftshift(vertical_mask)
313 | output = f_inputs * horizontal_mask[None, None, None, :]
314 | output = output * vertical_mask[None, None, :, None]
315 | if self._cropping:
316 | horizontal_to_keep = tf.stop_gradient(
317 | tf.where(tf.cast(horizontal_mask, tf.float32) > 0.)[:, 0])
318 | vertical_to_keep = tf.stop_gradient(
319 | tf.where(tf.cast(vertical_mask, tf.float32) > 0.)[:, 0])
320 |
321 | output = tf.gather(output, indices=vertical_to_keep, axis=2)
322 | output = tf.gather(output, indices=horizontal_to_keep, axis=3)
323 |
324 | result = tf.ensure_shape(
325 | tf.signal.irfft2d(output), [batch_size, channels, None, None])
326 | if not self._channels_first:
327 | result = tf.transpose(result, (0, 2, 3, 1))
328 | return result
329 |
330 | def compute_output_shape(self, input_shape):
331 | batch_size, channels = input_shape[:2]
332 | return (batch_size, channels, None, None)
333 |
--------------------------------------------------------------------------------
/diffstride/resnet.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 Google LLC.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Defines reset that are compatible with learnable stridding."""
17 |
18 | import functools
19 | from typing import Optional, Sequence, Tuple, Union
20 |
21 | import gin
22 | import tensorflow as tf
23 |
24 | Number = Union[float, int]
25 | Stride = Union[Number, Tuple[Number, Number]]
26 |
27 |
28 | def data_format(channels_first: bool = True) -> str:
29 | return 'channels_first' if channels_first else 'channels_last'
30 |
31 |
32 | def conv2d(
33 | *args, channels_first: bool = True, weight_decay: float = 0.0, **kwargs):
34 | return tf.keras.layers.Conv2D(
35 | *args,
36 | kernel_initializer='he_normal',
37 | kernel_regularizer=tf.keras.regularizers.L2(weight_decay),
38 | data_format=data_format(channels_first),
39 | use_bias=False,
40 | **kwargs)
41 |
42 |
43 | @gin.configurable
44 | def batch_norm(channels_first: bool = True, **kwargs):
45 | axis = 1 if channels_first else 3
46 | return tf.keras.layers.BatchNormalization(axis=axis, **kwargs)
47 |
48 |
49 | @gin.configurable
50 | class ResidualLayer(tf.keras.layers.Layer):
51 | """A generic residual layer for Resnet, using the pre-act formulation.
52 |
53 | This resnet can represent an `IdBlock` or a `ProjBlock` by setting the
54 | `project` parameter and can be compatible with Spectral or Learnable poolings
55 | by setting the `pooling_cls` parameter.
56 |
57 | The pooling_cls and strides will be overwritten automatically in case of an
58 | ID block.
59 |
60 | The pre-act formulation applies batch norm and non-linearity before the first
61 | conv.
62 | """
63 |
64 | def __init__(self,
65 | filters: int,
66 | kernel_size: int = gin.REQUIRED,
67 | strides: Stride = (1, 1),
68 | pooling_cls=None,
69 | project: bool = False,
70 | channels_first: bool = True,
71 | weight_decay: float = 5e-3,
72 | **kwargs):
73 | super().__init__(**kwargs)
74 |
75 | # If we are in an Id Layer there is no striding of any kind.
76 | pooling_cls = None if not project else pooling_cls
77 | strides = (1, 1) if not project else strides
78 | # DiffStride compatibility: the strides go into the pooling layer.
79 | if pooling_cls is not None:
80 | conv_strides = (1, 1)
81 | self._pooling = pooling_cls(
82 | strides=strides, data_format=data_format(channels_first))
83 | else:
84 | self._pooling = tf.identity
85 | conv_strides = strides
86 |
87 | self._strided_conv = conv2d(
88 | filters, kernel_size, strides=conv_strides, padding='same',
89 | channels_first=channels_first, weight_decay=weight_decay)
90 |
91 | # The second convolution is a regular one with no strides, no matter what.
92 | self._unstrided_conv = conv2d(
93 | filters, kernel_size, strides=(1, 1), padding='same',
94 | channels_first=channels_first, weight_decay=weight_decay)
95 | self._bns = tuple(batch_norm(channels_first) for _ in range(2))
96 |
97 | self._shortcut_conv = None
98 | if project:
99 | self._shortcut_conv = conv2d(
100 | filters, kernel_size=1, strides=conv_strides, padding='same',
101 | channels_first=channels_first, weight_decay=weight_decay)
102 |
103 | def call(self, inputs: tf.Tensor, training: bool = True) -> tf.Tensor:
104 | shortcut_x = inputs
105 |
106 | x = self._bns[0](inputs, training=training)
107 | x = tf.nn.relu(x)
108 | x = self._strided_conv(x)
109 | x = self._pooling(x)
110 |
111 | x = self._bns[1](x, training=training)
112 | x = tf.nn.relu(x)
113 | x = self._unstrided_conv(x)
114 |
115 | if self._shortcut_conv is not None:
116 | shortcut_x = self._shortcut_conv(shortcut_x)
117 | shortcut_x = self._pooling(shortcut_x)
118 |
119 | return x + shortcut_x
120 |
121 |
122 | @gin.configurable
123 | class ResnetBlock(tf.keras.Sequential):
124 | """A block of residual layers sharing the same number of filters.
125 |
126 | The first residual layer of the block and only this one might be strided.
127 | This parameter is controlled by the `project_first` parameters.
128 |
129 | The kwargs are passed down to the ResidualLayer.
130 | """
131 |
132 | def __init__(self,
133 | filters: int = gin.REQUIRED,
134 | strides: Stride = gin.REQUIRED,
135 | num_layers: int = 2,
136 | project_first: bool = True,
137 | **kwargs):
138 | residual_fn = functools.partial(
139 | ResidualLayer, filters=filters, strides=strides, **kwargs)
140 | blocks = [residual_fn(project=True)] if project_first else []
141 | num_left_layers = num_layers - int(project_first)
142 | blocks.extend([residual_fn(project=False) for i in range(num_left_layers)])
143 | super().__init__(blocks)
144 |
145 |
146 | @gin.configurable
147 | class Resnet(tf.keras.Sequential):
148 | """A generic Resnet class, using the pre-activation implementation.
149 |
150 | Depending on the number of blocks and the used filters, it can easily
151 | instantiate a Resnet18 or Resnet56.
152 |
153 | The kwargs are passed down to the ResnetBlock layer.
154 | """
155 |
156 | def __init__(self,
157 | filters: Sequence[int],
158 | strides: Sequence[Stride],
159 | num_output_classes: int = gin.REQUIRED,
160 | output_activation: Optional[str] = None,
161 | id_only: Sequence[int] = (),
162 | channels_first: bool = True,
163 | pooling_cls=None,
164 | weight_decay: float = 5e-3,
165 | **kwargs):
166 | if len(filters) != len(strides):
167 | raise ValueError(f'The number of `filters` ({len(filters)}) should match'
168 | f' the number of strides ({len(strides)})')
169 | df = data_format(channels_first)
170 | layers = [
171 | tf.keras.layers.Permute((3, 1, 2)) if channels_first else None,
172 | conv2d(filters[0], 3, padding='same',
173 | strides=(1, 1) if pooling_cls is not None else strides[0],
174 | channels_first=channels_first, weight_decay=weight_decay),
175 | pooling_cls(
176 | strides=strides[0], data_format=df) if pooling_cls else None,
177 | ]
178 | for i, (num_filters, stride) in enumerate(zip(filters[1:], strides[1:])):
179 | layers.append(ResnetBlock(filters=num_filters,
180 | strides=stride,
181 | project_first=(i not in id_only),
182 | channels_first=channels_first,
183 | weight_decay=weight_decay,
184 | pooling_cls=pooling_cls,
185 | **kwargs))
186 | layers.extend([
187 | batch_norm(channels_first),
188 | tf.keras.layers.ReLU(),
189 | tf.keras.layers.GlobalAveragePooling2D(data_format=df),
190 | tf.keras.layers.Flatten(),
191 | tf.keras.layers.Dense(
192 | num_output_classes,
193 | activation=output_activation,
194 | kernel_initializer='he_normal',
195 | kernel_regularizer=tf.keras.regularizers.L2(weight_decay),
196 | bias_regularizer=tf.keras.regularizers.L2(weight_decay),
197 | ),
198 | ])
199 | super().__init__(list(filter(None, layers)))
200 |
201 |
202 | @gin.configurable
203 | def resnet18(strides=None, **kwargs):
204 | strides = [1, 1, 2, 2, 2] if strides is None else strides
205 | filters = [64, 64, 128, 256, 512]
206 | return Resnet(
207 | filters, strides, id_only=[0], num_layers=2, kernel_size=3, **kwargs)
208 |
209 |
210 | @gin.configurable
211 | def resnet56(strides=None, **kwargs):
212 | filters = [16, 16, 32, 64]
213 | strides = [1, 1, 2, 2] if strides is None else strides
214 | return Resnet(filters, strides, num_layers=9, kernel_size=3, **kwargs)
215 |
--------------------------------------------------------------------------------
/images/diffstride.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google-research/diffstride/b3ef50f6b837265317682bde26b1d15317e73155/images/diffstride.png
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = [
3 | "setuptools >= 40.9.0",
4 | "wheel",
5 | ]
6 | build-backend = "setuptools.build_meta"
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | absl-py>=1.0.0
2 | gin-config>=0.5.0
3 | tensorflow>=2.1
4 | tensorflow-datasets>=4.0.0
5 | tensorflow-addons>=0.15.0
6 |
--------------------------------------------------------------------------------
/setup.cfg:
--------------------------------------------------------------------------------
1 | [metadata]
2 | name = diffstrifde
3 | version = 0.0.1
4 | licens = Apache 2.0
5 | license_files = LICENSE
6 | author = Google LLC
7 | description = Learnable stride for ConvNets.
8 | keywords = convnet, deep learning, signal processing
9 | long_description = file: README.md
10 | long_description_content_type = text/markdown
11 | url = https://github.com/google-research/diffstride
12 | classifiers =
13 | Programming Language :: Python :: 3
14 |
15 | [options]
16 | packages = find:
17 | python_requires = >=3.7
18 |
19 | [options.packages.find]
20 | exclude =
21 | tests, docs, images
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 Google LLC.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Setup script for installing diffstride as a pip module."""
17 | import os
18 | import setuptools
19 |
20 |
21 | __version__ = '0.1.0'
22 |
23 | # Reads the requirements from requirements.txt
24 | folder = os.path.dirname(__file__)
25 | path = os.path.join(folder, 'requirements.txt')
26 | install_requires = []
27 | if os.path.exists(path):
28 | with open(path) as fp:
29 | install_requires = [line.strip() for line in fp]
30 |
31 |
32 | setuptools.setup(version=__version__,
33 | install_requires=install_requires)
34 |
--------------------------------------------------------------------------------
/tests/pooling_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 Google LLC.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for pooling."""
17 |
18 | from absl.testing import parameterized
19 | from diffstride import pooling
20 | import tensorflow as tf
21 |
22 |
23 | class PoolingTest(tf.test.TestCase, parameterized.TestCase):
24 |
25 | def setUp(self):
26 | super().setUp()
27 | tf.random.set_seed(0)
28 |
29 | def test_spatial_pooling(self):
30 | shape = (1, 64, 64, 3) # Because CPU.
31 | pool = pooling.SpatialPooling(strides=(2, 4), data_format='channels_last')
32 | inputs = tf.random.uniform(shape)
33 | output = pool(inputs)
34 | self.assertEqual(output.shape, (1, 32, 16, 3))
35 |
36 | @parameterized.parameters(['channels_first', 'channels_last'])
37 | def test_spectral_pooling(self, data_format):
38 | is_channels_last = data_format == 'channels_last'
39 | shape = (1, 64, 64, 3) if is_channels_last else (1, 3, 64, 64)
40 | pool = pooling.FixedSpectralPooling(strides=(2, 4), data_format=data_format)
41 | inputs = tf.random.uniform(shape)
42 | output = pool(inputs)
43 | output_shape = (1, 32, 16, 3) if is_channels_last else (1, 3, 32, 16)
44 | self.assertEqual(output.shape, output_shape)
45 |
46 | @parameterized.parameters(['channels_first', 'channels_last'])
47 | def test_learnable_spectral_pooling(self, data_format):
48 | is_channels_last = data_format == 'channels_last'
49 | shape = (1, 64, 64, 3) if is_channels_last else (1, 3, 64, 64)
50 | pool = pooling.DiffStride(
51 | strides=(2, 4), data_format=data_format)
52 | inputs = tf.random.uniform(shape)
53 | output = pool(inputs)
54 | output_shape = (1, 40, 24, 3) if is_channels_last else (1, 3, 40, 24)
55 | self.assertEqual(output.shape, output_shape)
56 |
57 |
58 | if __name__ == '__main__':
59 | tf.test.main()
60 |
--------------------------------------------------------------------------------
/tests/resnet_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 Google LLC.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for reset."""
17 |
18 | from absl.testing import parameterized
19 | from diffstride import resnet
20 | import tensorflow as tf
21 |
22 |
23 | class ResnetTest(tf.test.TestCase, parameterized.TestCase):
24 | """Tests for resnets."""
25 |
26 | def setUp(self):
27 | super().setUp()
28 | tf.random.set_seed(0)
29 |
30 | def test_residual_layer(self):
31 | num_filters = 7
32 | shape = (2, 32, 32, 3)
33 | proj_layer = resnet.ResidualLayer(
34 | filters=num_filters, kernel_size=3, strides=2, channels_first=False,
35 | pooling_cls=None, project=True)
36 | inputs = tf.random.uniform(shape)
37 | output = proj_layer(inputs)
38 | self.assertEqual(output.shape, (2, 16, 16, 7))
39 |
40 | # Should have the same number of features.
41 | id_layer = resnet.ResidualLayer(
42 | filters=num_filters, kernel_size=3, channels_first=False,
43 | pooling_cls=None, project=False)
44 | output2 = id_layer(output)
45 | self.assertEqual(output2.shape, output.shape)
46 |
47 | def test_resnet_block(self):
48 | num_filters = 7
49 | num_layers = 10
50 | block = resnet.ResnetBlock(
51 | filters=num_filters, kernel_size=3, strides=(2, 4),
52 | num_layers=num_layers, project_first=True, channels_first=False)
53 | self.assertLen(block.layers, num_layers)
54 |
55 | shape = (2, 64, 64, 3)
56 | inputs = tf.random.uniform(shape)
57 | output = block(inputs)
58 | self.assertEqual(output.shape, (2, 32, 16, num_filters))
59 |
60 | block = resnet.ResnetBlock(
61 | filters=7, strides=(2, 4), kernel_size=3, channels_first=False,
62 | num_layers=num_layers, project_first=False)
63 | self.assertLen(block.layers, num_layers)
64 | output2 = block(output)
65 | self.assertEqual(output2.shape, output.shape)
66 |
67 |
68 | if __name__ == '__main__':
69 | tf.test.main()
70 |
--------------------------------------------------------------------------------