├── CONTRIBUTING.md
├── LICENSE
├── README.md
├── label_dp
├── datasets.py
├── lr_schedules.py
├── main.py
├── models.py
├── profiles
│ ├── __init__.py
│ ├── __main__.py
│ ├── p100_cifar10.py
│ └── registry.py
├── train.py
└── utils.py
└── requirements.txt
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # How to Contribute
2 |
3 | We'd love to accept your patches and contributions to this project. There are
4 | just a few small guidelines you need to follow.
5 |
6 | ## Contributor License Agreement
7 |
8 | Contributions to this project must be accompanied by a Contributor License
9 | Agreement (CLA). You (or your employer) retain the copyright to your
10 | contribution; this simply gives us permission to use and redistribute your
11 | contributions as part of the project. Head over to
12 | to see your current agreements on file or
13 | to sign a new one.
14 |
15 | You generally only need to submit a CLA once, so if you've already submitted one
16 | (even if it was for a different project), you probably don't need to do it
17 | again.
18 |
19 | ## Code Reviews
20 |
21 | All submissions, including submissions by project members, require review. We
22 | use GitHub pull requests for this purpose. Consult
23 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more
24 | information on using pull requests.
25 |
26 | ## Community Guidelines
27 |
28 | This project follows
29 | [Google's Open Source Community Guidelines](https://opensource.google/conduct/).
30 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 |
2 | Apache License
3 | Version 2.0, January 2004
4 | http://www.apache.org/licenses/
5 |
6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
7 |
8 | 1. Definitions.
9 |
10 | "License" shall mean the terms and conditions for use, reproduction,
11 | and distribution as defined by Sections 1 through 9 of this document.
12 |
13 | "Licensor" shall mean the copyright owner or entity authorized by
14 | the copyright owner that is granting the License.
15 |
16 | "Legal Entity" shall mean the union of the acting entity and all
17 | other entities that control, are controlled by, or are under common
18 | control with that entity. For the purposes of this definition,
19 | "control" means (i) the power, direct or indirect, to cause the
20 | direction or management of such entity, whether by contract or
21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
22 | outstanding shares, or (iii) beneficial ownership of such entity.
23 |
24 | "You" (or "Your") shall mean an individual or Legal Entity
25 | exercising permissions granted by this License.
26 |
27 | "Source" form shall mean the preferred form for making modifications,
28 | including but not limited to software source code, documentation
29 | source, and configuration files.
30 |
31 | "Object" form shall mean any form resulting from mechanical
32 | transformation or translation of a Source form, including but
33 | not limited to compiled object code, generated documentation,
34 | and conversions to other media types.
35 |
36 | "Work" shall mean the work of authorship, whether in Source or
37 | Object form, made available under the License, as indicated by a
38 | copyright notice that is included in or attached to the work
39 | (an example is provided in the Appendix below).
40 |
41 | "Derivative Works" shall mean any work, whether in Source or Object
42 | form, that is based on (or derived from) the Work and for which the
43 | editorial revisions, annotations, elaborations, or other modifications
44 | represent, as a whole, an original work of authorship. For the purposes
45 | of this License, Derivative Works shall not include works that remain
46 | separable from, or merely link (or bind by name) to the interfaces of,
47 | the Work and Derivative Works thereof.
48 |
49 | "Contribution" shall mean any work of authorship, including
50 | the original version of the Work and any modifications or additions
51 | to that Work or Derivative Works thereof, that is intentionally
52 | submitted to Licensor for inclusion in the Work by the copyright owner
53 | or by an individual or Legal Entity authorized to submit on behalf of
54 | the copyright owner. For the purposes of this definition, "submitted"
55 | means any form of electronic, verbal, or written communication sent
56 | to the Licensor or its representatives, including but not limited to
57 | communication on electronic mailing lists, source code control systems,
58 | and issue tracking systems that are managed by, or on behalf of, the
59 | Licensor for the purpose of discussing and improving the Work, but
60 | excluding communication that is conspicuously marked or otherwise
61 | designated in writing by the copyright owner as "Not a Contribution."
62 |
63 | "Contributor" shall mean Licensor and any individual or Legal Entity
64 | on behalf of whom a Contribution has been received by Licensor and
65 | subsequently incorporated within the Work.
66 |
67 | 2. Grant of Copyright License. Subject to the terms and conditions of
68 | this License, each Contributor hereby grants to You a perpetual,
69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
70 | copyright license to reproduce, prepare Derivative Works of,
71 | publicly display, publicly perform, sublicense, and distribute the
72 | Work and such Derivative Works in Source or Object form.
73 |
74 | 3. Grant of Patent License. Subject to the terms and conditions of
75 | this License, each Contributor hereby grants to You a perpetual,
76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
77 | (except as stated in this section) patent license to make, have made,
78 | use, offer to sell, sell, import, and otherwise transfer the Work,
79 | where such license applies only to those patent claims licensable
80 | by such Contributor that are necessarily infringed by their
81 | Contribution(s) alone or by combination of their Contribution(s)
82 | with the Work to which such Contribution(s) was submitted. If You
83 | institute patent litigation against any entity (including a
84 | cross-claim or counterclaim in a lawsuit) alleging that the Work
85 | or a Contribution incorporated within the Work constitutes direct
86 | or contributory patent infringement, then any patent licenses
87 | granted to You under this License for that Work shall terminate
88 | as of the date such litigation is filed.
89 |
90 | 4. Redistribution. You may reproduce and distribute copies of the
91 | Work or Derivative Works thereof in any medium, with or without
92 | modifications, and in Source or Object form, provided that You
93 | meet the following conditions:
94 |
95 | (a) You must give any other recipients of the Work or
96 | Derivative Works a copy of this License; and
97 |
98 | (b) You must cause any modified files to carry prominent notices
99 | stating that You changed the files; and
100 |
101 | (c) You must retain, in the Source form of any Derivative Works
102 | that You distribute, all copyright, patent, trademark, and
103 | attribution notices from the Source form of the Work,
104 | excluding those notices that do not pertain to any part of
105 | the Derivative Works; and
106 |
107 | (d) If the Work includes a "NOTICE" text file as part of its
108 | distribution, then any Derivative Works that You distribute must
109 | include a readable copy of the attribution notices contained
110 | within such NOTICE file, excluding those notices that do not
111 | pertain to any part of the Derivative Works, in at least one
112 | of the following places: within a NOTICE text file distributed
113 | as part of the Derivative Works; within the Source form or
114 | documentation, if provided along with the Derivative Works; or,
115 | within a display generated by the Derivative Works, if and
116 | wherever such third-party notices normally appear. The contents
117 | of the NOTICE file are for informational purposes only and
118 | do not modify the License. You may add Your own attribution
119 | notices within Derivative Works that You distribute, alongside
120 | or as an addendum to the NOTICE text from the Work, provided
121 | that such additional attribution notices cannot be construed
122 | as modifying the License.
123 |
124 | You may add Your own copyright statement to Your modifications and
125 | may provide additional or different license terms and conditions
126 | for use, reproduction, or distribution of Your modifications, or
127 | for any such Derivative Works as a whole, provided Your use,
128 | reproduction, and distribution of the Work otherwise complies with
129 | the conditions stated in this License.
130 |
131 | 5. Submission of Contributions. Unless You explicitly state otherwise,
132 | any Contribution intentionally submitted for inclusion in the Work
133 | by You to the Licensor shall be under the terms and conditions of
134 | this License, without any additional terms or conditions.
135 | Notwithstanding the above, nothing herein shall supersede or modify
136 | the terms of any separate license agreement you may have executed
137 | with Licensor regarding such Contributions.
138 |
139 | 6. Trademarks. This License does not grant permission to use the trade
140 | names, trademarks, service marks, or product names of the Licensor,
141 | except as required for reasonable and customary use in describing the
142 | origin of the Work and reproducing the content of the NOTICE file.
143 |
144 | 7. Disclaimer of Warranty. Unless required by applicable law or
145 | agreed to in writing, Licensor provides the Work (and each
146 | Contributor provides its Contributions) on an "AS IS" BASIS,
147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
148 | implied, including, without limitation, any warranties or conditions
149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
150 | PARTICULAR PURPOSE. You are solely responsible for determining the
151 | appropriateness of using or redistributing the Work and assume any
152 | risks associated with Your exercise of permissions under this License.
153 |
154 | 8. Limitation of Liability. In no event and under no legal theory,
155 | whether in tort (including negligence), contract, or otherwise,
156 | unless required by applicable law (such as deliberate and grossly
157 | negligent acts) or agreed to in writing, shall any Contributor be
158 | liable to You for damages, including any direct, indirect, special,
159 | incidental, or consequential damages of any character arising as a
160 | result of this License or out of the use or inability to use the
161 | Work (including but not limited to damages for loss of goodwill,
162 | work stoppage, computer failure or malfunction, or any and all
163 | other commercial damages or losses), even if such Contributor
164 | has been advised of the possibility of such damages.
165 |
166 | 9. Accepting Warranty or Additional Liability. While redistributing
167 | the Work or Derivative Works thereof, You may choose to offer,
168 | and charge a fee for, acceptance of support, warranty, indemnity,
169 | or other liability obligations and/or rights consistent with this
170 | License. However, in accepting such obligations, You may act only
171 | on Your own behalf and on Your sole responsibility, not on behalf
172 | of any other Contributor, and only if You agree to indemnify,
173 | defend, and hold each Contributor harmless for any liability
174 | incurred by, or claims asserted against, such Contributor by reason
175 | of your accepting any such warranty or additional liability.
176 |
177 | END OF TERMS AND CONDITIONS
178 |
179 | APPENDIX: How to apply the Apache License to your work.
180 |
181 | To apply the Apache License to your work, attach the following
182 | boilerplate notice, with the fields enclosed by brackets "[]"
183 | replaced with your own identifying information. (Don't include
184 | the brackets!) The text should be enclosed in the appropriate
185 | comment syntax for the file format. We also recommend that a
186 | file or class name and description of purpose be included on the
187 | same "printed page" as the copyright notice for easier
188 | identification within third-party archives.
189 |
190 | Copyright [yyyy] [name of copyright owner]
191 |
192 | Licensed under the Apache License, Version 2.0 (the "License");
193 | you may not use this file except in compliance with the License.
194 | You may obtain a copy of the License at
195 |
196 | http://www.apache.org/licenses/LICENSE-2.0
197 |
198 | Unless required by applicable law or agreed to in writing, software
199 | distributed under the License is distributed on an "AS IS" BASIS,
200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
201 | See the License for the specific language governing permissions and
202 | limitations under the License.
203 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | This repository contains code for the multi-stage label differential privacy
2 | training code for the paper
3 |
4 | > Badih Ghazi, Noah Golowich, Ravi Kumar, Pasin Manurangsi, Chiyuan Zhang. *Deep
5 | > Learning with Label Differential Privacy*. Advances in Neural Information
6 | > Processing Systems (**NeurIPS**), 2021.
7 | > [arxiv:2102.06062](https://arxiv.org/abs/2102.06062)
8 |
9 | **Note**: This is not an officially supported Google product.
10 |
11 | ## Abstract
12 |
13 | The Randomized Response (`RR`) algorithm is a classical technique to improve
14 | robustness in survey aggregation, and has been widely adopted in applications
15 | with differential privacy guarantees. We propose a novel algorithm, *Randomized
16 | Response with Prior* (`RRWithPrior`), which can provide more accurate results
17 | while maintaining the same level of privacy guaranteed by `RR`. We then apply
18 | `RRWithPrior` to learn neural networks with *label* differential privacy
19 | (`LabelDP`), and show that when only the label needs to be protected, the model
20 | performance can be significantly improved over the previous state-of-the-art
21 | private baselines. Moreover, we study different ways to obtain priors, which
22 | when used with `RRWithPrior` can additionally improve the model performance,
23 | further reducing the accuracy gap between private and non-private models. We
24 | complement the empirical results with theoretical analysis showing that
25 | `LabelDP` is provably easier than protecting both the inputs and labels.
26 |
27 | ## Getting Started
28 |
29 | ### Requirements
30 |
31 | This codebase is implemented with [Jax](https://github.com/google/jax),
32 | [Flax](https://github.com/google/flax) and
33 | [Optax](https://github.com/deepmind/optax). To install the dependencies, run the
34 | following commands (see also the [Jax](https://github.com/google/jax) homepage
35 | for the latest installation guides for GPU/TPU support).
36 |
37 | ```
38 | pip install --upgrade pip
39 | pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_releases.html
40 |
41 | pip install -r requirements.txt
42 | ```
43 |
44 | ### 2-stage LabelDP Training for CIFAR-10
45 |
46 | Use the following command to run
47 |
48 | ```
49 | python3 -m label_dp.main --base_workdir --profile_key
50 | ```
51 |
52 | An example `` is `cifar10/e2/lp-2st/run0`. Check out
53 | `profiles/p100_cifar10.py` for a list of all pre-defined profiles. One can also
54 | run the following command to query a list of all registered profile keys
55 | matching a given regex:
56 |
57 | ```
58 | python3 -m label_dp.profiles
59 | ```
60 |
61 | Note the original experiments in the paper were run with Tensorflow. This
62 | codebase is a reimplementation with Jax. The results we obtained with this
63 | codebase are slightly different from what was reported in the paper. We include
64 | the numbers here for your reference (CIFAR-10 with 2-stage training):
65 |
66 | Epsilon | Test Accuracy ± std | Accuracy from Table 1
67 | ------- | ------------------- | ---------------------
68 | 1.0 | 62.89 ± 2.07 | 63.67
69 | 2.0 | 88.11 ± 0.38 | 86.05
70 | 4.0 | 94.18 ± 0.13 | 93.37
71 | 8.0 | 95.18 ± 0.07 | 94.52
72 |
73 | ### Structure of the Code
74 |
75 | To use the training code in the current setup, you just need to define new
76 | profiles that specify hyperparameters such as what dataset to load, what
77 | optimizer to use, etc. Create a new file `profiles/pXXX_xxx.py` and import it
78 | from `profiles/__init__.py`. Functions starting with `register_` will be called
79 | automatically to register hyperparameter profiles.
80 |
81 | The code loads dataset using [tfds](https://www.tensorflow.org/datasets), so
82 | theoretically any dataset available in tfds runs. But note for convenience we
83 | load the entire dataset into memory when doing data splitting for multi-stage
84 | training. This should work for CIFAR scale dataset.
85 |
86 | The key algorithm `RRWithPrior` (Algorithm 2 from the paper) is implemented in
87 | `rr_with_prior` (`train.py`) with numpy, and can be used independently in other
88 | scenarios.
89 |
90 | ## Citation
91 |
92 | ```
93 | @article{ghazi2021deep,
94 | title={Deep Learning with Label Differential Privacy},
95 | author={Ghazi, Badih and Golowich, Noah and Kumar, Ravi and Manurangsi, Pasin and Zhang, Chiyuan},
96 | journal={Advances in Neural Information Processing Systems},
97 | volume={34},
98 | year={2021}
99 | }
100 | ```
101 |
--------------------------------------------------------------------------------
/label_dp/datasets.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 Google LLC.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Dataset."""
16 |
17 | import numpy as np
18 | import tensorflow_datasets as tfds
19 |
20 |
21 | def batch_random_crop(batch_image_np, pad=4):
22 | """Randomly cropping images for data augmentation."""
23 | n, h, w, c = batch_image_np.shape
24 | # pad
25 | padded_image = np.zeros((n, h+2*pad, w+2*pad, c),
26 | dtype=batch_image_np.dtype)
27 | padded_image[:, pad:-pad, pad:-pad, :] = batch_image_np
28 | # crop
29 | idxs = np.random.randint(2*pad, size=(n, 2))
30 | cropped_image = np.array([
31 | padded_image[i, y:y+h, x:x+w, :]
32 | for i, (y, x) in enumerate(idxs)])
33 | return cropped_image
34 |
35 |
36 | def batch_random_fliplr(batch_image_np):
37 | """Randomly do left-right flip on images."""
38 | n = batch_image_np.shape[0]
39 | coins = np.random.choice([-1, 1], size=n)
40 | flipped_image = np.array([
41 | batch_image_np[i, :, ::coins[i], :]
42 | for i in range(n)])
43 | return flipped_image
44 |
45 |
46 | def batch_random_cutout(batch_image_np, size=8):
47 | """Random cutout.
48 |
49 | Note we are using the same cutout region for all the images in the batch.
50 |
51 | Args:
52 | batch_image_np: images.
53 | size: cutout size.
54 |
55 | Returns:
56 | Images with random cutout.
57 | """
58 | _, h, w, _ = batch_image_np.shape
59 | h0 = np.random.randint(0, h-size)
60 | w0 = np.random.randint(0, w-size)
61 | batch_image_np[:, h0:h0+size, w0:w0+size, :] = 0
62 | return batch_image_np
63 |
64 |
65 | class TFDSNumpyDataset:
66 | """Images dataset loaded into memory as numpy array.
67 |
68 | The full data array in numpy format can be easily accessed. Suitable for
69 | smaller scale image datasets like MNIST (and variants), CIFAR-10 / CIFAR-100,
70 | SVHN, etc.
71 | """
72 |
73 | def __init__(self, name, random_crop=True, random_fliplr=True,
74 | random_cutout=0):
75 | """Constructs a dataset from tfds.
76 |
77 | Args:
78 | name: name of the dataset.
79 | random_crop: whether to perform random crop in data augmentation.
80 | random_fliplr: whether to perform random left-right flip in data aug.
81 | random_cutout: if non-zero, denote the cutout length.
82 | """
83 | self.name = name
84 |
85 | self._random_crop = random_crop
86 | self._random_fliplr = random_fliplr
87 | self._random_cutout = random_cutout
88 | self.ds, self.info = tfds.load(name, batch_size=-1,
89 | as_dataset_kwargs={'shuffle_files': False},
90 | with_info=True)
91 | self.ds_np = tfds.as_numpy(self.ds)
92 |
93 | self._add_index_feature()
94 |
95 | def _add_index_feature(self):
96 | """Adds 'index' feature if not present."""
97 | for split in self.ds_np:
98 | if 'index' in self.ds_np[split]:
99 | continue
100 | n_sample = len(self.ds_np[split]['label'])
101 | index = np.arange(n_sample)
102 | if 'id' in self.ds_np[split]:
103 | # remove the 'id' feature, b/c jax cannot handle string type
104 | self.ds_np[split].pop('id')
105 | self.ds_np[split]['index'] = index
106 |
107 | @property
108 | def num_classes(self):
109 | return self.info.features['label'].num_classes
110 |
111 | @property
112 | def use_onehot_label(self):
113 | return False
114 |
115 | @property
116 | def data_scale(self):
117 | return 255.0
118 |
119 | def get_num_examples(self, split_name):
120 | return self.ds_np[split_name]['image'].shape[0]
121 |
122 | def get_input_shape(self, input_name):
123 | if input_name == 'image':
124 | return self.ds_np['train']['image'].shape[1:]
125 | raise KeyError(f'getting input shape for {input_name}')
126 |
127 | def normalize_images(self, batch_image_np):
128 | images = batch_image_np.astype(np.float32) / self.data_scale
129 | return images
130 |
131 | def iterate(self, split_name, batch_size, shuffle=False, augmentation=False,
132 | subset_index=None):
133 | """Iterates over the dataset."""
134 | n_sample = self.get_num_examples(split_name)
135 | # make a shallow copy
136 | dset = dict(self.ds_np[split_name])
137 |
138 | if subset_index is not None:
139 | n_sample = len(subset_index)
140 | for key in dset:
141 | dset[key] = dset[key][subset_index]
142 |
143 | if shuffle:
144 | rp = np.random.permutation(n_sample)
145 | for key in dset:
146 | dset[key] = dset[key][rp]
147 |
148 | for i in range(0, n_sample, batch_size):
149 | batch = {key: val[i:i+batch_size]
150 | for key, val in dset.items()}
151 | batch['image'] = self.normalize_images(batch['image'])
152 | if augmentation:
153 | if self._random_crop:
154 | batch['image'] = batch_random_crop(batch['image'])
155 | if self._random_fliplr:
156 | batch['image'] = batch_random_fliplr(batch['image'])
157 | if self._random_cutout > 0:
158 | batch['image'] = batch_random_cutout(
159 | batch['image'], self._random_cutout)
160 |
161 | yield batch
162 |
163 |
164 | class LabelRemappedTrainDataset:
165 | """A derived dataset where the labels are remapped according to the index."""
166 |
167 | def __init__(self, dataset, subset_index):
168 | self.dataset = dataset
169 | self.subset_index = subset_index
170 | # if not None, could be a (n, k) array that defines the k-dimensional
171 | # one-hot vector label for each vector. Note n here is the total number
172 | # of examples in the original dataset as the index in the original data
173 | # is used to address this array. However, this array does not need to have
174 | # meaningful values outside of the examples specified by subset_index.
175 | self.label_mapping = None
176 | # subset mask can be used to further filter out some
177 | # of the examples in the training set
178 | self.subset_mask = np.ones(len(subset_index), dtype=np.bool)
179 |
180 | @property
181 | def num_classes(self):
182 | return self.dataset.num_classes
183 |
184 | @property
185 | def use_onehot_label(self):
186 | return True
187 |
188 | def get_num_examples(self, split_name):
189 | if split_name == 'train':
190 | return len(self.subset_index[self.subset_mask])
191 | else:
192 | return self.dataset.get_num_examples(split_name)
193 |
194 | def get_input_shape(self, input_name):
195 | return self.dataset.get_input_shape(input_name)
196 |
197 | def iterate(self, split_name, batch_size, shuffle=False, augmentation=False,
198 | subset_index=None):
199 | """Iterate over the dataset."""
200 | assert subset_index is None, ('LabelRemappedTrainDataset does not support '
201 | 'further subset indexing.')
202 |
203 | if split_name == 'train':
204 | for batch in self.dataset.iterate(
205 | 'train', batch_size, shuffle=shuffle, augmentation=augmentation,
206 | subset_index=self.subset_index[self.subset_mask]):
207 | yield self.remap_batch_label(batch)
208 | else:
209 | for batch in self.dataset.iterate(
210 | split_name, batch_size, shuffle=shuffle, augmentation=augmentation):
211 | batch['orig_label'] = batch['label']
212 | batch['label'] = self._make_onehot(batch['label'])
213 | yield batch
214 |
215 | def remap_batch_label(self, batch):
216 | batch['orig_label'] = batch['label']
217 | if self.label_mapping is not None:
218 | batch['label'] = self.label_mapping[batch['index'], :]
219 | else:
220 | batch['label'] = self._make_onehot(batch['label'])
221 | return batch
222 |
223 | def _make_onehot(self, labels):
224 | return np.eye(self.num_classes, dtype=np.float32)[labels, :]
225 |
--------------------------------------------------------------------------------
/label_dp/lr_schedules.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 Google LLC.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Learning rate schedules.
16 |
17 | A thin wrapper over optax schedulers with a unifying constructing interface.
18 | """
19 |
20 | from typing import Callable
21 | import optax
22 |
23 |
24 | Schedule = Callable[[int], float]
25 |
26 |
27 | def constant(base_lr, num_train_steps) -> Schedule:
28 | del num_train_steps
29 | return lambda _: base_lr
30 |
31 |
32 | def piecewise_constant(base_lr, num_train_steps, *,
33 | rampup_thresh=0.15,
34 | stages=((0.3, 0.1), (0.6, 0.1), (0.9, 0.1))) -> Schedule:
35 | """Piecewise constant learning rate with optional linear rampup.
36 |
37 | Args:
38 | base_lr: base learning rate.
39 | num_train_steps: total number of training steps.
40 | rampup_thresh: if not None, can specify a linear rampup.
41 | stages: a sequence of (step_ratio, scaling_factor). The step_ratio times
42 | the num_train_steps is the decaying boundary. The scaling factor for all
43 | the stages whose decaying boundary is less than the current step is
44 | multiplied to the base learning rate.
45 |
46 | Returns:
47 | A learning rate schedule.
48 | """
49 | lr_fn = optax.piecewise_constant_schedule(
50 | init_value=base_lr,
51 | boundaries_and_scales={int(r*num_train_steps): s for r, s in stages})
52 | if rampup_thresh is not None and rampup_thresh > 0:
53 | rampup_steps = int(rampup_thresh * num_train_steps)
54 | rampup_fn = optax.linear_schedule(
55 | init_value=0, end_value=base_lr, transition_steps=rampup_steps)
56 | lr_fn = optax.join_schedules(
57 | schedules=[rampup_fn, lr_fn], boundaries=[rampup_steps])
58 | return lr_fn
59 |
--------------------------------------------------------------------------------
/label_dp/main.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 Google LLC.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Sets up environment and calls training."""
16 |
17 | import copy
18 | import datetime
19 | import logging as native_logging
20 | import os
21 | from typing import Sequence
22 |
23 | from absl import app
24 | from absl import flags
25 | from absl import logging
26 | from clu import platform
27 |
28 | import jax
29 | import ml_collections
30 | import tensorflow as tf
31 |
32 | from label_dp import profiles
33 | from label_dp import train
34 |
35 |
36 | FLAGS = flags.FLAGS
37 |
38 |
39 | _BASE_WORKDIR = flags.DEFINE_string(
40 | 'base_workdir', '',
41 | 'Base directory for logs, checkpoints, and other outputs. '
42 | 'When a profile key is given, logs will go into subfolders '
43 | 'specified by the key.')
44 | _PROFILE_KEY = flags.DEFINE_string('profile_key', None,
45 | 'Key to a pre-defined profile.')
46 |
47 |
48 | def main(argv: Sequence[str]) -> None:
49 | if len(argv) > 1:
50 | raise app.UsageError('Too many command-line arguments.')
51 | if jax.process_count() > 1:
52 | raise NotImplementedError()
53 |
54 | # Hide any GPUs from TensorFlow. Otherwise TF might reserve memory and make
55 | # it unavailable to JAX.
56 | tf.config.experimental.set_visible_devices([], 'GPU')
57 |
58 | configs = profiles.Registry.get_profile(_PROFILE_KEY.value)
59 | configs = ml_collections.ConfigDict(copy.deepcopy(configs))
60 | workdir = os.path.join(_BASE_WORKDIR.value, _PROFILE_KEY.value)
61 |
62 | # logging
63 | logdir = os.path.join(workdir, 'logs')
64 | tf.io.gfile.makedirs(logdir)
65 | log_file = os.path.join(
66 | logdir, datetime.datetime.now().strftime('%Y%m%d-%H%M%S') + '.txt')
67 | log_format = '%(asctime)s %(name)-12s %(levelname)-8s %(message)s'
68 | formatter = native_logging.Formatter(log_format)
69 | file_stream = tf.io.gfile.GFile(log_file, 'w')
70 | handler = native_logging.StreamHandler(file_stream)
71 | handler.setLevel(native_logging.INFO)
72 | handler.setFormatter(formatter)
73 | logging.get_absl_logger().addHandler(handler)
74 |
75 | if jax.process_index() == 0:
76 | work_unit = platform.work_unit()
77 | work_unit.create_artifact(
78 | artifact_type=platform.ArtifactType.DIRECTORY,
79 | artifact=workdir, description='Working directory')
80 | work_unit.create_artifact(
81 | artifact_type=platform.ArtifactType.FILE,
82 | artifact=log_file, description='Log file')
83 |
84 | train.multi_stage_train(configs.train, workdir)
85 | logging.flush()
86 |
87 |
88 | if __name__ == '__main__':
89 | app.run(main)
90 |
--------------------------------------------------------------------------------
/label_dp/models.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 Google LLC.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """ResNet."""
16 |
17 |
18 | import functools
19 | from typing import Any, Callable, Optional, Sequence, Tuple
20 |
21 | from flax import linen as nn
22 | import jax.numpy as jnp
23 |
24 | ModuleDef = Any
25 |
26 |
27 | ################################################################################
28 | # ResNet V2
29 | ################################################################################
30 | class BasicBlockV2(nn.Module):
31 | """Basic Block for a ResNet V2."""
32 |
33 | channels: int
34 | conv: ModuleDef
35 | norm: ModuleDef
36 | act: Callable
37 | strides: Tuple[int, int] = (1, 1)
38 |
39 | @nn.compact
40 | def __call__(self, x):
41 | preact = self.act(self.norm()(x))
42 | y = self.conv(self.channels, (3, 3), self.strides)(preact)
43 | y = self.act(self.norm()(y))
44 | y = self.conv(self.channels, (3, 3))(y)
45 |
46 | if y.shape != x.shape:
47 | shortcut = self.conv(self.channels, (1, 1), self.strides)(preact)
48 | else:
49 | shortcut = x
50 | return shortcut + y
51 |
52 |
53 | class BottleneckBlockV2(nn.Module):
54 | """Bottleneck Block for a ResNet V2."""
55 |
56 | channels: int
57 | conv: ModuleDef
58 | norm: ModuleDef
59 | act: Callable
60 | strides: Tuple[int, int] = (1, 1)
61 |
62 | @nn.compact
63 | def __call__(self, x):
64 | preact = self.act(self.norm()(x))
65 | y = self.conv(self.channels, (1, 1))(preact)
66 | y = self.act(self.norm()(y))
67 | y = self.conv(self.channels, (3, 3), self.strides)(y)
68 | y = self.act(self.norm()(y))
69 | y = self.conv(self.channels * 4, (1, 1))(y)
70 |
71 | if y.shape != x.shape:
72 | shortcut = self.conv(self.channels * 4, (1, 1), self.strides)(preact)
73 | else:
74 | shortcut = x
75 |
76 | return shortcut + y
77 |
78 |
79 | class ResNetV2(nn.Module):
80 | """ResNet v2.
81 |
82 | K. He, X. Zhang, S. Ren, and J. Sun. Identity mappings in deep residual
83 | networks. In ECCV, pages 630–645, 2016.
84 | """
85 |
86 | stage_sizes: Sequence[int]
87 | block_class: ModuleDef
88 | num_classes: Optional[int] = None
89 | base_channels: int = 64
90 | act: Callable = nn.relu
91 | dtype: Any = jnp.float32
92 | small_image: bool = False
93 | # if not None, batch statistics are sync-ed across replica according to
94 | # this axis_name used in pmap
95 | bn_cross_replica_axis_name: Optional[str] = None
96 |
97 | @nn.compact
98 | def __call__(self, x, train: bool = True):
99 | conv = functools.partial(nn.Conv, use_bias=False, dtype=self.dtype)
100 | norm = functools.partial(
101 | nn.BatchNorm, use_running_average=not train,
102 | momentum=0.9, epsilon=1e-5, dtype=self.dtype,
103 | axis_name=self.bn_cross_replica_axis_name)
104 |
105 | if self.small_image: # suitable for Cifar
106 | x = conv(self.base_channels, (3, 3), padding='SAME')(x)
107 | else:
108 | x = conv(self.base_channels, (7, 7), (2, 2), padding=[(3, 3), (3, 3)])(x)
109 | x = nn.max_pool(x, (3, 3), strides=(2, 2), padding='SAME')
110 |
111 | for i, n_blocks in enumerate(self.stage_sizes):
112 | for j in range(n_blocks):
113 | strides = (2, 2) if i > 0 and j == 0 else (1, 1)
114 | x = self.block_class(self.base_channels * 2 ** i, strides=strides,
115 | conv=conv, norm=norm, act=self.act)(x)
116 |
117 | x = self.act(norm(name='bn_final')(x))
118 | x = jnp.mean(x, axis=(1, 2))
119 | if self.num_classes is not None:
120 | x = nn.Dense(self.num_classes, dtype=self.dtype, name='classifier')(x)
121 | return x
122 |
123 |
124 | CifarResNet18V2 = functools.partial(
125 | ResNetV2, stage_sizes=[2, 2, 2, 2], block_class=BasicBlockV2,
126 | small_image=True)
127 |
128 | CifarResNet50V2 = functools.partial(
129 | ResNetV2, stage_sizes=[3, 4, 6, 3], block_class=BottleneckBlockV2,
130 | small_image=True)
131 |
132 | ResNet50V2 = functools.partial(
133 | ResNetV2, stage_sizes=[3, 4, 6, 3], block_class=BottleneckBlockV2)
134 |
135 |
136 | ################################################################################
137 | # ResNet V1
138 | ################################################################################
139 | class BasicBlockV1(nn.Module):
140 | """Basic block for a ResNet V1."""
141 |
142 | channels: int
143 | conv: ModuleDef
144 | norm: ModuleDef
145 | act: Callable
146 | strides: Tuple[int, int] = (1, 1)
147 |
148 | @nn.compact
149 | def __call__(self, x,):
150 | residual = x
151 | y = self.conv(self.channels, (3, 3), self.strides)(x)
152 | y = self.norm()(y)
153 | y = self.act(y)
154 | y = self.conv(self.channels, (3, 3))(y)
155 | y = self.norm(scale_init=nn.initializers.zeros)(y)
156 |
157 | if residual.shape != y.shape:
158 | residual = self.conv(self.channels, (1, 1),
159 | self.strides, name='conv_proj')(residual)
160 | residual = self.norm(name='norm_proj')(residual)
161 |
162 | return self.act(residual + y)
163 |
164 |
165 | class BottleneckBlockV1(nn.Module):
166 | """Bottleneck block for ResNet V1."""
167 |
168 | channels: int
169 | conv: ModuleDef
170 | norm: ModuleDef
171 | act: Callable
172 | strides: Tuple[int, int] = (1, 1)
173 |
174 | @nn.compact
175 | def __call__(self, x):
176 | residual = x
177 | y = self.conv(self.channels, (1, 1))(x)
178 | y = self.norm()(y)
179 | y = self.act(y)
180 | y = self.conv(self.channels, (3, 3), self.strides)(y)
181 | y = self.norm()(y)
182 | y = self.act(y)
183 | y = self.conv(self.channels * 4, (1, 1))(y)
184 | y = self.norm(scale_init=nn.initializers.zeros)(y)
185 |
186 | if residual.shape != y.shape:
187 | residual = self.conv(self.channels * 4, (1, 1),
188 | self.strides, name='conv_proj')(residual)
189 | residual = self.norm(name='norm_proj')(residual)
190 |
191 | return self.act(residual + y)
192 |
193 |
194 | class ResNetV1(nn.Module):
195 | """ResNetV1.
196 |
197 | K. He, X. Zhang, S. Ren, and J. Sun. Deep residual learning for image
198 | recognition. In CVPR, pages 770–778, 2016.
199 | """
200 |
201 | stage_sizes: Sequence[int]
202 | block_class: ModuleDef
203 | num_classes: Optional[int] = None
204 | base_channels: int = 64
205 | act: Callable = nn.relu
206 | dtype: Any = jnp.float32
207 | small_image: bool = False
208 | # if not None, batch statistics are sync-ed across replica according to
209 | # this axis_name used in pmap
210 | bn_cross_replica_axis_name: Optional[str] = None
211 |
212 | @nn.compact
213 | def __call__(self, x, train: bool = True):
214 | conv = functools.partial(nn.Conv, use_bias=False, dtype=self.dtype)
215 | norm = functools.partial(
216 | nn.BatchNorm, use_running_average=not train, momentum=0.9,
217 | epsilon=1e-5, dtype=self.dtype,
218 | axis_name=self.bn_cross_replica_axis_name)
219 |
220 | if self.small_image: # suitable for Cifar
221 | x = conv(self.base_channels, (3, 3), padding='SAME', name='conv_init')(x)
222 | x = norm(name='bn_init')(x)
223 | x = self.act(x)
224 | else:
225 | x = conv(self.base_channels, (7, 7), (2, 2), padding=[(3, 3), (3, 3)],
226 | name='conv_init')(x)
227 | x = norm(name='bn_init')(x)
228 | x = nn.relu(x)
229 | x = nn.max_pool(x, (3, 3), strides=(2, 2), padding='SAME')
230 |
231 | for i, block_size in enumerate(self.stage_sizes):
232 | for j in range(block_size):
233 | strides = (2, 2) if i > 0 and j == 0 else (1, 1)
234 | x = self.block_class(self.base_channels * 2 ** i, strides=strides,
235 | conv=conv, norm=norm, act=self.act)(x)
236 |
237 | x = jnp.mean(x, axis=(1, 2))
238 | if self.num_classes is not None:
239 | x = nn.Dense(self.num_classes, dtype=self.dtype, name='classifier')(x)
240 | return x
241 |
242 |
243 | CifarResNet18V1 = functools.partial(
244 | ResNetV1, stage_sizes=[2, 2, 2, 2], block_class=BasicBlockV1,
245 | small_image=True)
246 |
247 | CifarResNet50V1 = functools.partial(
248 | ResNetV1, stage_sizes=[3, 4, 6, 3], block_class=BottleneckBlockV1,
249 | small_image=True)
250 |
251 | ResNet50V1 = functools.partial(
252 | ResNetV1, stage_sizes=[3, 4, 6, 3], block_class=BottleneckBlockV1)
253 |
--------------------------------------------------------------------------------
/label_dp/profiles/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 Google LLC.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Profiles."""
16 |
17 | from . import p100_cifar10
18 |
19 | from . import registry
20 |
21 |
22 | Registry = registry.Registry
23 |
--------------------------------------------------------------------------------
/label_dp/profiles/__main__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 Google LLC.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | # pylint: disable=invalid-name
16 | """Useful to list / search in model zoo.
17 |
18 | python3 -m profiles
19 | """
20 | import sys
21 |
22 | from . import registry
23 |
24 |
25 | if __name__ == '__main__':
26 | key = '.*'
27 | if len(sys.argv) > 1:
28 | key = sys.argv[1]
29 |
30 | registry.Registry.print_profiles(key)
31 |
--------------------------------------------------------------------------------
/label_dp/profiles/p100_cifar10.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 Google LLC.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """CIFAR-10 experiments."""
16 |
17 |
18 | def register_cifar10_experiment(registry):
19 | """Register experiment configs."""
20 | batch_size = 512
21 | learning_rate = 0.4
22 |
23 | num_epochs = 200
24 | cutout = 8
25 | lr_fn = 'piecewise_constant'
26 | optimizer = 'sgd'
27 | for epsilon, hparam in [
28 | (1, {'mixup': (16, 8), 'data_split': 80, 'temperature': 0.6}),
29 | (2, {'mixup': (8, 8), 'data_split': 70, 'temperature': 0.5}),
30 | (4, {'mixup': (8, 4), 'data_split': 60, 'temperature': 0.5}),
31 | (8, {'mixup': (4, 4), 'data_split': 60, 'temperature': 0.5}),
32 | ]:
33 | for rep in range(5):
34 | mixup1, mixup2 = hparam['mixup']
35 | first_stage_data_ratio = hparam['data_split']
36 | temperature = hparam['temperature']
37 | key = f'cifar10/e{epsilon}/lp-2st/run{rep}'
38 | meta = {'target': 'main', 'platform': 'v100'}
39 | train_configs = {
40 | 'run_seed': 1234 + rep,
41 | 'batch_size': batch_size,
42 | 'half_precision': False,
43 | 'l2_regu': 1e-4,
44 | 'num_epochs': num_epochs,
45 | 'eval_splits': ['test'],
46 | 'reuse_last_stage_data': True,
47 | 'mask_last_stage_label_by_prior': True,
48 | 'data': {'name': 'cifar10', 'kwargs': {'random_cutout': cutout}},
49 | 'base_lr': learning_rate,
50 | 'lr_fn': {'name': lr_fn, 'kwargs': {}},
51 | 'optimizer': get_optimizer(optimizer),
52 | 'model': {'arch': 'CifarResNet18V2', 'kwargs': {}},
53 | 'stage_specs': [
54 | dict(type='rr', seed=2019, eps=epsilon,
55 | data_split=first_stage_data_ratio/100, mixup=mixup1),
56 | dict(type='rr-with-prior', seed=2020, eps=epsilon,
57 | data_split=1 - first_stage_data_ratio/100,
58 | temperature=temperature, mixup=mixup2)
59 | ]
60 | }
61 |
62 | spec = {'key': key, 'meta': meta, 'train': train_configs}
63 | registry.register(spec)
64 |
65 |
66 | def get_optimizer(name):
67 | cfg = {'name': name}
68 | if name == 'sgd':
69 | cfg['kwargs'] = {'momentum': 0.9, 'nesterov': True}
70 | else:
71 | cfg['kwargs'] = {}
72 | return cfg
73 |
--------------------------------------------------------------------------------
/label_dp/profiles/registry.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 Google LLC.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Profile registry."""
16 | import re
17 | import sys
18 |
19 |
20 | class Registry(object):
21 | """Registry holding all model specs in the model zoo."""
22 | registry = None
23 |
24 | @classmethod
25 | def build_registry(cls):
26 | """Builds registry, called upon first zoo query."""
27 | if cls.registry is not None:
28 | return
29 |
30 | cls.registry = dict()
31 |
32 | mod_profiles = sys.modules[sys.modules[__name__].__package__]
33 | profile_mods = [getattr(mod_profiles, x) for x in dir(mod_profiles)
34 | if re.match(r'^p[0-9]+_', x)]
35 | profile_mods = [x for x in profile_mods if isinstance(x, type(sys))]
36 | profile_mods.sort(key=str)
37 | for p_mod in profile_mods:
38 | register_funcs = [getattr(p_mod, x) for x in dir(p_mod)
39 | if x.startswith('register_')]
40 | register_funcs = filter(callable, register_funcs)
41 |
42 | for func in register_funcs:
43 | func(cls)
44 |
45 | @classmethod
46 | def register(cls, profile):
47 | key = profile['key']
48 | if key in cls.registry: # pylint: disable=unsupported-membership-test
49 | raise KeyError('duplicated profile key: {}'.format(key))
50 | cls.registry[key] = profile # pylint: disable=unsupported-assignment-operation
51 |
52 | @classmethod
53 | def list_profiles(cls, regex):
54 | cls.build_registry()
55 | profiles = [cls.registry[key] for key in cls.registry.keys()
56 | if re.search(regex, key)]
57 | return profiles
58 |
59 | @classmethod
60 | def print_profiles(cls, regex):
61 | profiles = cls.list_profiles(regex)
62 | print('{} profiles found ====== with regex: {}'.format(
63 | len(profiles), regex))
64 | for i, profile in enumerate(profiles):
65 | print(' {:>3d}) {}'.format(i, profile['key']))
66 |
67 | @classmethod
68 | def get_profile(cls, key):
69 | cls.build_registry()
70 | assert key is not None
71 | return cls.registry[key]
72 |
--------------------------------------------------------------------------------
/label_dp/train.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 Google LLC.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Training script."""
16 |
17 | import functools
18 | import pathlib
19 | from typing import Any
20 |
21 | from absl import logging
22 |
23 | from clu import metric_writers
24 |
25 | from flax import jax_utils
26 | from flax.optim import dynamic_scale as dynamic_scale_lib
27 | from flax.training import common_utils
28 | from flax.training import train_state
29 | import jax
30 | import ml_collections
31 | import numpy as np
32 | import tensorflow as tf
33 |
34 | from label_dp import datasets
35 | from label_dp import models
36 | from label_dp import utils
37 |
38 |
39 | class TrainState(train_state.TrainState):
40 | epoch: int
41 | model_states: Any
42 | dynamic_scale: dynamic_scale_lib.DynamicScale
43 |
44 |
45 | def create_train_state(rng, input_shape, half_precision, model, optimizer_cfgs):
46 | """Creates initial training state."""
47 | dynamic_scale = None
48 | platform = jax.local_devices()[0].platform
49 | if half_precision and platform == 'gpu':
50 | dynamic_scale = dynamic_scale_lib.DynamicScale()
51 | else:
52 | dynamic_scale = None
53 |
54 | params, model_states = utils.initialize_model(
55 | rng, input_shape, model)
56 | tx = utils.build_optimizer(**optimizer_cfgs)
57 | state = TrainState.create(
58 | apply_fn=model.apply, params=params, tx=tx, model_states=model_states,
59 | dynamic_scale=dynamic_scale, epoch=0)
60 | return state
61 |
62 |
63 | def build_model(model_configs, num_classes, dtype):
64 | kwargs = dict(model_configs.kwargs, num_classes=num_classes, dtype=dtype)
65 | return getattr(models, model_configs.arch)(**kwargs)
66 |
67 |
68 | def report_metrics(i_stage, report_name, epoch, metrics, writer):
69 | metrics = common_utils.stack_forest(metrics)
70 | summary = jax.tree_map(lambda x: float(x.mean()), metrics)
71 | summary = {f'stage{i_stage}/{report_name}/{k}': v for k, v in summary.items()}
72 | writer.write_scalars(epoch, summary)
73 |
74 |
75 | def multi_stage_train(configs: ml_collections.ConfigDict, workdir: str):
76 | """Multi-stage training."""
77 | orig_dataset = datasets.TFDSNumpyDataset(
78 | name=configs.data.name, **configs.data.kwargs)
79 | stage_datasets = utils.derive_subset_dataset(
80 | orig_dataset, n_stages=len(configs.stage_specs),
81 | stage_splits=[spec['data_split'] for spec in configs.stage_specs],
82 | seed=sum([spec['seed'] for spec in configs.stage_specs]))
83 |
84 | n_tr_total = orig_dataset.get_num_examples('train')
85 | input_dtype = utils.get_dtype(configs.half_precision)
86 | model = build_model(configs.model, orig_dataset.num_classes, input_dtype)
87 | last_stage_state = None
88 | for i_stage in range(len(stage_datasets)):
89 | if configs.stage_specs[i_stage]['type'] == 'rr':
90 | k_for_prior = compute_randomized_labels(
91 | stage_datasets[i_stage], configs.batch_size,
92 | configs.stage_specs[i_stage], n_tr_total)
93 | elif configs.stage_specs[i_stage]['type'] == 'rr-with-prior':
94 | k_for_prior = compute_randomized_labels_with_priors(
95 | stage_datasets[i_stage], model, last_stage_state, configs.batch_size,
96 | configs.stage_specs[i_stage], n_tr_total)
97 | else:
98 | raise KeyError('Unknown label randomization type')
99 |
100 | if i_stage != 0 and configs.reuse_last_stage_data:
101 | logging.info('Reusing randomized data from stage %d in stage %d',
102 | i_stage - 1, i_stage)
103 | if configs.mask_last_stage_label_by_prior:
104 | logging.info('Filtering out egs from stage %d using learned prior.',
105 | i_stage - 1)
106 | # Note we are passing in the dataset from the last stage and the k
107 | # from this stage
108 | filter_stage_data_by_prior(
109 | stage_datasets[i_stage-1], model, last_stage_state, k_for_prior,
110 | n_tr_total, configs.batch_size)
111 | merge_stage_data(stage_datasets[i_stage-1], stage_datasets[i_stage])
112 |
113 | logging.info('#' * 60)
114 | logging.info('# Stage %d training', i_stage)
115 | logging.info('#' * 60)
116 | last_stage_state = single_stage_train(
117 | configs, workdir, i_stage, stage_datasets, model, last_stage_state)
118 |
119 |
120 | def single_stage_train(
121 | configs: ml_collections.ConfigDict, workdir: str, i_stage: int,
122 | stage_datasets, model, last_stage_state=None):
123 | """Training for a single stage."""
124 | workdir = pathlib.Path(workdir) / f'stage{i_stage}'
125 | tf.io.gfile.makedirs(str(workdir))
126 |
127 | writer = metric_writers.create_default_writer(
128 | str(workdir / 'tensorboard'), just_logging=jax.process_index() > 0)
129 |
130 | rng = jax.random.PRNGKey(configs.run_seed + i_stage)
131 | np_rng = np.random.RandomState(seed=configs.run_seed + i_stage + 123)
132 | mixup_alpha = configs.stage_specs[i_stage]['mixup']
133 | mixup_sampler = functools.partial(
134 | np_rng.beta, mixup_alpha, mixup_alpha) if mixup_alpha > 0 else None
135 | local_batch_size = utils.get_local_batch_size(configs.batch_size)
136 | dataset = stage_datasets[i_stage]
137 |
138 | n_train = dataset.get_num_examples('train')
139 | n_train_steps = int(configs.num_epochs * n_train / local_batch_size)
140 | lr_fn = utils.build_lr_fn(
141 | configs.lr_fn.name, configs.base_lr, n_train_steps, configs.lr_fn.kwargs)
142 | optimizer_cfgs = dict(configs.optimizer, learning_rate=lr_fn)
143 | state = create_train_state(
144 | rng, dataset.get_input_shape('image'), configs.half_precision, model,
145 | optimizer_cfgs)
146 | if last_stage_state is not None:
147 | state = state.replace(
148 | params=last_stage_state.params,
149 | model_states=last_stage_state.model_states)
150 |
151 | state = jax_utils.replicate(state)
152 |
153 | # pmap the train and eval functions
154 | p_train_step = jax.pmap(
155 | functools.partial(utils.train_step, model.apply, l2_regu=configs.l2_regu),
156 | axis_name='batch')
157 | p_eval_step = jax.pmap(functools.partial(utils.eval_step, model.apply),
158 | axis_name='batch')
159 |
160 | def run_eval(epoch=0, split_name='test'):
161 | eval_metrics = []
162 | for batch in utils.iterate_data(dataset, split_name, local_batch_size,
163 | desc=f'E{epoch:03d} eval-{split_name}'):
164 | metrics = p_eval_step(state, batch)
165 | eval_metrics.append(utils.metrics_to_numpy(metrics))
166 |
167 | report_metrics(i_stage, f'eval-{split_name}', epoch, eval_metrics, writer)
168 |
169 | start_epoch = jax_utils.unreplicate(state.epoch)
170 | logging.info('Start training from epoch %d...', start_epoch)
171 |
172 | with metric_writers.ensure_flushes(writer):
173 | while True:
174 | epoch = int(jax_utils.unreplicate(state.epoch))
175 | if epoch >= configs.num_epochs:
176 | break
177 |
178 | train_metrics = []
179 | for batch in utils.iterate_data(dataset, 'train', local_batch_size,
180 | augmentation=True, shuffle=True,
181 | desc=f'E{epoch+1:03d} train',
182 | mixup_sampler=mixup_sampler):
183 | state, metrics = p_train_step(state, batch)
184 | train_metrics.append(utils.metrics_to_numpy(metrics))
185 |
186 | state = state.replace(epoch=state.epoch + 1)
187 | epoch += 1
188 | report_metrics(i_stage, 'train', epoch, train_metrics, writer)
189 |
190 | # sync batch statistics across replicas
191 | state = utils.sync_batch_stats(state)
192 |
193 | for split in configs.eval_splits:
194 | run_eval(epoch, split)
195 |
196 | utils.block_until_computation_finish()
197 | return jax_utils.unreplicate(state)
198 |
199 |
200 | def compute_randomized_labels(dataset, batch_size, spec, n_tr_total):
201 | """Computes randomized labels."""
202 | assert isinstance(dataset, datasets.LabelRemappedTrainDataset)
203 | n_classes = dataset.num_classes
204 | orig_labels = np.zeros(n_tr_total, dtype=np.int64)
205 |
206 | for batch in dataset.iterate('train', batch_size):
207 | orig_labels[batch['index']] = batch['orig_label']
208 |
209 | # assign new labels
210 | rng = np.random.RandomState(seed=spec['seed'])
211 | dataset.label_mapping = np.zeros((n_tr_total, n_classes), dtype=np.float32)
212 | dataset.subset_mask = np.ones(len(dataset.subset_index), dtype=np.bool)
213 | for idx in dataset.subset_index:
214 | if spec['type'] == 'rr':
215 | label = orig_labels[idx]
216 | rate = 1 / (np.exp(spec['eps']) + n_classes - 1)
217 | prob = np.zeros(n_classes) + rate
218 | prob[label] = 1 - rate * (n_classes - 1)
219 | new_label = rng.choice(n_classes, 1, p=prob)
220 | dataset.label_mapping[idx][new_label] = 1
221 | else:
222 | raise KeyError(f'Unknown type: {spec["type"]}')
223 |
224 | if spec['type'] == 'rr':
225 | return n_classes
226 |
227 |
228 | def compute_randomized_labels_with_priors(
229 | dataset, model, last_stage_state, batch_size, spec, n_tr_total):
230 | """Computes randomized labels based on lp."""
231 | assert isinstance(dataset, datasets.LabelRemappedTrainDataset)
232 | assert spec['type'] == 'rr-with-prior'
233 | ds_weight = spec.get('domain_specific_prior_weight', 0.0)
234 | assert ds_weight >= 0.0 and ds_weight <= 1.0
235 |
236 | n_classes = dataset.num_classes
237 | dataset.label_mapping = np.zeros((n_tr_total, n_classes), dtype=np.float32)
238 | dataset.subset_mask = np.ones(len(dataset.subset_index), dtype=np.bool)
239 | rng = np.random.RandomState(seed=spec['seed'])
240 | logging.info('RRWithPrior labeling (T=%f, ds_weight=%f)',
241 | spec['temperature'], ds_weight)
242 |
243 | model_vars = {'params': last_stage_state.params,
244 | **last_stage_state.model_states}
245 | j_pred = jax.jit(functools.partial(model.apply, train=False))
246 |
247 | soft_k = 0.0
248 | for batch in dataset.iterate('train', batch_size):
249 | if np.isclose(ds_weight, 1.0):
250 | p_last_model = 0.0
251 | else:
252 | logits = j_pred(model_vars, batch['image'])
253 | logits /= spec['temperature']
254 | p_last_model = jax.device_get(jax.nn.softmax(logits))
255 |
256 | if np.isclose(ds_weight, 0.0):
257 | p_domain = 0.0
258 | else:
259 | p_domain = batch['prior']
260 | p_domain = np.power(p_domain, 1/spec['temperature'])
261 | p_domain = p_domain / np.sum(p_domain, axis=1, keepdims=True)
262 |
263 | probs = ds_weight*p_domain + (1-ds_weight)*p_last_model
264 | orig_labels = batch['orig_label']
265 |
266 | for i, prob in enumerate(probs):
267 | k, new_label = rr_with_prior(prob, spec['eps'], orig_labels[i], rng)
268 | soft_k += k
269 | dataset.label_mapping[batch['index'][i]][new_label] = 1
270 |
271 | soft_k /= len(dataset.subset_index)
272 | logging.info('effective soft_k ~= %.2f (averaged on %d examples).',
273 | soft_k, len(dataset.subset_index))
274 | return soft_k
275 |
276 |
277 | def rr_with_prior(prior, eps, y, rng):
278 | """Randomized response with prior.
279 |
280 | Args:
281 | prior: A K-length array where the k-th entry is the probability that the
282 | true label is k.
283 | eps: the epsilon value for which the randomized response is epsilon-DP.
284 | y: an integer indicating the true label.
285 | rng: a numpy random number generator for sampling.
286 |
287 | Returns:
288 | k, y_rr: k is the value used in rr-top-k; y_rr is the randomized label.
289 | """
290 | idx_sort = np.flipud(np.argsort(prior))
291 | prior_sorted = prior[idx_sort]
292 | tmp = np.exp(-eps)
293 | wks = [np.sum(prior_sorted[:(k+1)]) / (1 + (k-1)*tmp)
294 | for k in range(len(prior))]
295 | optim_k = np.argmax(wks) + 1
296 |
297 | adjusted_prior = np.zeros_like(prior) + tmp / (1 + (optim_k-1)*tmp)
298 | adjusted_prior[y] = 1 / (1 + (optim_k-1)*tmp)
299 | adjusted_prior[idx_sort[optim_k:]] = 0
300 | adjusted_prior /= np.sum(adjusted_prior) # renorm in case y not in topk
301 | rr_label = rng.choice(len(prior), 1, p=adjusted_prior)
302 | return optim_k, rr_label
303 |
304 |
305 | def filter_stage_data_by_prior(last_dataset, model, last_stage_state,
306 | k_for_prior, n_tr_total, batch_size):
307 | """Filtering out the egs from last stage according to prior."""
308 | assert k_for_prior is not None
309 | k_for_prior = int(k_for_prior)
310 |
311 | model_vars = {'params': last_stage_state.params,
312 | **last_stage_state.model_states}
313 | j_pred = jax.jit(functools.partial(model.apply, train=False))
314 | global_mask = np.ones(n_tr_total, dtype=np.bool)
315 |
316 | last_dataset.subset_mask[:] = True # enable all examples
317 | for batch in last_dataset.iterate('train', batch_size):
318 | logits = j_pred(model_vars, batch['image'])
319 | _, topk_idx = jax.device_get(jax.lax.top_k(logits, k=k_for_prior))
320 |
321 | for j in range(batch['image'].shape[0]):
322 | if np.isclose(np.sum(batch['label'][j, topk_idx[j]]), 0):
323 | # the randomized label from the last stage is not in the topk prior
324 | global_mask[batch['index'][j]] = False
325 |
326 | n_filtered = 0
327 | for i, idx in enumerate(last_dataset.subset_index):
328 | if not global_mask[idx]:
329 | last_dataset.subset_mask[i] = False
330 | n_filtered += 1
331 |
332 | logging.info('%d egs removed due to randomized labels not in top %d prior',
333 | n_filtered, k_for_prior)
334 |
335 |
336 | def merge_stage_data(dset_last_stage, dset):
337 | """Merges the (randomized) data from the last stage to reused in this stage."""
338 | assert isinstance(dset_last_stage, datasets.LabelRemappedTrainDataset)
339 | assert isinstance(dset, datasets.LabelRemappedTrainDataset)
340 | assert dset.label_mapping is not None
341 | assert dset_last_stage.label_mapping is not None
342 | # pylint: disable=g-explicit-length-test
343 | assert len(
344 | np.intersect1d(
345 | dset.subset_index, dset_last_stage.subset_index,
346 | assume_unique=True)) == 0
347 |
348 | dset.label_mapping[dset_last_stage.subset_index,
349 | ...] = dset_last_stage.label_mapping[
350 | dset_last_stage.subset_index, ...]
351 | dset.subset_index = np.concatenate(
352 | [dset_last_stage.subset_index, dset.subset_index], axis=0)
353 | dset.subset_mask = np.concatenate(
354 | [dset_last_stage.subset_mask, dset.subset_mask], axis=0)
355 |
--------------------------------------------------------------------------------
/label_dp/utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 Google LLC.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Training related utilities."""
16 |
17 | import functools
18 | from typing import Optional
19 |
20 | from absl import logging
21 |
22 | from flax import jax_utils
23 | from flax.training import common_utils
24 | import jax
25 | import jax.numpy as jnp
26 | import numpy as np
27 | import optax
28 | import tqdm
29 |
30 | from label_dp import datasets
31 | from label_dp import lr_schedules
32 |
33 |
34 | def derive_subset_dataset(dataset, n_stages=2, stage_splits=None, seed=None):
35 | """Specifies data splits for each stage of training.
36 |
37 | Args:
38 | dataset: The original dataset to derive from.
39 | n_stages: number of stages.
40 | stage_splits: Must be a list of numbers summing to 1, indicating
41 | the ratio of data for each stage.
42 | seed: if stage_split is not None, this should not be None, and should
43 | be the random seed for generating the splits.
44 |
45 | Returns:
46 | A LabelRemappedTrainDataset.
47 | """
48 | n_tr = dataset.get_num_examples('train')
49 | assert len(stage_splits) == n_stages
50 | assert seed is not None
51 | assert np.isclose(sum(stage_splits), 1)
52 | cum_splits = np.cumsum(stage_splits)
53 | cum_split_counts = [0] + [int(n_tr * x) for x in cum_splits]
54 | cum_split_counts[-1] = n_tr
55 | rng = np.random.RandomState(seed=seed)
56 | perm = rng.permutation(n_tr)
57 | stage_subsets = [perm[cum_split_counts[i]:cum_split_counts[i+1]]
58 | for i in range(n_stages)]
59 |
60 | return [datasets.LabelRemappedTrainDataset(dataset, subset_index)
61 | for subset_index in stage_subsets]
62 |
63 |
64 | def cross_entropy_loss(logits, onehot_labels):
65 | log_softmax_logits = jax.nn.log_softmax(logits)
66 | batch_size = onehot_labels.shape[0]
67 | return -jnp.sum(onehot_labels * log_softmax_logits) / batch_size
68 |
69 |
70 | def classification_metrics(logits, onehot_labels):
71 | loss = cross_entropy_loss(logits, onehot_labels)
72 | acc = jnp.mean(jnp.argmax(logits, -1) == jnp.argmax(onehot_labels, -1))
73 | metrics = {'loss': loss, 'accuracy': acc}
74 | metrics = jax.lax.pmean(metrics, 'batch')
75 | return metrics
76 |
77 |
78 | def l2_regularizer(coefficient, params):
79 | if coefficient <= 0:
80 | return 0
81 | params = jax.tree_leaves(params)
82 | weight_l2 = sum([jnp.sum(x ** 2) for x in params])
83 | weight_penalty = coefficient * 0.5 * weight_l2
84 | return weight_penalty
85 |
86 |
87 | def block_until_computation_finish():
88 | """Wait until computations are done."""
89 | logging.flush()
90 | jax.random.normal(jax.random.PRNGKey(0), ()).block_until_ready()
91 |
92 |
93 | def get_local_batch_size(batch_size: int) -> int:
94 | """Gets local batch size to a host."""
95 | # For example, if we have 2 hosts, each with 8 devices, batch_size=2048, then
96 | # - batch_size == 2048
97 | # - local_batch_size == 2048 / 2 == 1024
98 | # - jax.device_count() == 2*8 == 16
99 | # The dataset object will be sharded at host level, so each host will see
100 | # a different subset of the data.
101 | if batch_size % jax.device_count() > 0:
102 | raise ValueError('Batch size must be divisible by the number of devices')
103 | local_batch_size = batch_size // jax.process_count()
104 | logging.info(
105 | 'JAX process_index=%d, process_count=%d, device_count=%d, '
106 | 'local_device_count=%d', jax.process_index(), jax.process_count(),
107 | jax.device_count(), jax.local_device_count())
108 | return local_batch_size
109 |
110 |
111 | def get_dtype(half_precision: bool, platform: Optional[str] = None):
112 | """Gets concrete data type according to precision and platform.
113 |
114 | Args:
115 | half_precision: whether to use half precision (float16).
116 | platform: 'tpu' or 'gpu'.
117 |
118 | Returns:
119 | A data type to use according to the specification.
120 | """
121 | if platform is None:
122 | platform = jax.local_devices()[0].platform
123 |
124 | if half_precision:
125 | if platform == 'tpu':
126 | return jnp.bfloat16
127 | else:
128 | return jnp.float16
129 | else:
130 | return jnp.float32
131 |
132 |
133 | def split_state_params(variables):
134 | """Separate (BatchNorm) states and trainable params."""
135 | params = variables.pop('params')
136 | return variables, params
137 |
138 |
139 | def initialize_model(rng, input_shape, model):
140 | """Initializes parameters and states for a model."""
141 | input_shape = (1, *input_shape) # add a dummy batch dimension
142 | @jax.jit
143 | def init(*args):
144 | return model.init(*args)
145 | variables = init({'params': rng}, jnp.ones(input_shape, model.dtype))
146 | model_states, params = split_state_params(variables)
147 | return params, model_states
148 |
149 |
150 | def build_optimizer(name, learning_rate, kwargs):
151 | """Builds an optimizer."""
152 | ctor = getattr(optax, name)
153 | return ctor(learning_rate=learning_rate, **kwargs)
154 |
155 |
156 | def build_lr_fn(name, base_lr, num_train_steps, kwargs):
157 | """Builds learning rate scheduler."""
158 | return getattr(lr_schedules, name)(base_lr, num_train_steps, **kwargs)
159 |
160 |
161 | def train_step(apply_fn, state, batch, l2_regu,
162 | f_metrics=classification_metrics):
163 | """Performs a single training step."""
164 | def loss_fn(params):
165 | variables = {'params': params, **state.model_states}
166 | logits, new_model_states = apply_fn(
167 | variables, batch['image'], train=True, mutable=['batch_stats'])
168 | loss = cross_entropy_loss(logits, batch['label'])
169 | loss = loss + l2_regularizer(l2_regu, variables['params'])
170 | return loss, (new_model_states, logits)
171 |
172 | state, aux, metrics = optimizer_step(loss_fn, state)
173 | new_model_states, logits = aux[1]
174 | new_state = state.replace(model_states=new_model_states)
175 | metrics.update(f_metrics(logits, batch['label']))
176 |
177 | return new_state, metrics
178 |
179 |
180 | def optimizer_step(loss_fn, state):
181 | """Applies one optimizer step."""
182 | dynamic_scale = state.dynamic_scale
183 |
184 | if dynamic_scale:
185 | grad_fn = dynamic_scale.value_and_grad(
186 | loss_fn, has_aux=True, axis_name='batch')
187 | dynamic_scale, is_fin, aux, grad = grad_fn(state.params)
188 | # dynamic loss takes care of averaging gradients across replicas
189 | else:
190 | grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
191 | aux, grad = grad_fn(state.params)
192 | # Re-use same axis_name as in the call to `pmap(...train_step...)` below.
193 | grad = jax.lax.pmean(grad, axis_name='batch')
194 |
195 | metrics = {}
196 | new_model_states = aux[1][0]
197 | new_state = state.apply_gradients(grads=grad, model_states=new_model_states)
198 | if dynamic_scale:
199 | # if is_fin == False the gradients contain Inf/NaNs and optimizer state and
200 | # params should be restored (= skip this step).
201 | new_state = new_state.replace(
202 | opt_state=jax.tree_map(
203 | functools.partial(jnp.where, is_fin), new_state.opt_state,
204 | state.opt_state),
205 | params=jax.tree_map(
206 | functools.partial(jnp.where, is_fin), new_state.params,
207 | state.params))
208 | metrics['scale'] = dynamic_scale.scale
209 |
210 | return new_state, aux, metrics
211 |
212 |
213 | def eval_step(apply_fn, state, batch, f_metrics=classification_metrics):
214 | """Performs a single evaluation step."""
215 | variables = {'params': state.params, **state.model_states}
216 | logits = apply_fn(variables, batch['image'], train=False, mutable=False)
217 | return f_metrics(logits, batch['label'])
218 |
219 |
220 | def iterate_data(dataset, split_name, batch_size, augmentation=False,
221 | shuffle=False, desc='', mixup_sampler=None, **kwargs):
222 | """Iterates over data."""
223 | iterator = dataset.iterate(split_name, batch_size, shuffle=shuffle,
224 | augmentation=augmentation, **kwargs)
225 | if mixup_sampler is not None:
226 | def apply_mixup(batch):
227 | lm = mixup_sampler()
228 | batch['label'] = lm*batch['label'] + (1-lm)*np.flipud(batch['label'])
229 | batch['image'] = lm*batch['image'] + (1-lm)*np.flipud(batch['image'])
230 | return batch
231 | iterator = map(apply_mixup, iterator)
232 |
233 | iterator = map(common_utils.shard, iterator)
234 | iterator = jax_utils.prefetch_to_device(iterator, 2)
235 | iterator = tqdm.tqdm(iterator, desc=desc, disable=None,
236 | total=dataset.get_num_examples(split_name) // batch_size)
237 | return iterator
238 |
239 |
240 | def metrics_to_numpy(metrics):
241 | # We select the first element of x in order to get a single copy of a
242 | # device-replicated metric.
243 | metrics = jax.tree_map(lambda x: x[0], metrics)
244 | metrics_np = jax.device_get(metrics)
245 | return metrics_np
246 |
247 |
248 | def sync_batch_stats(state):
249 | """Sync the batch statistics across replicas."""
250 | if 'batch_stats' not in state.model_states:
251 | return state
252 |
253 | # An axis_name is passed to pmap which can then be used by pmean.
254 | # In this case each device has its own version of the batch statistics and
255 | # we average them.
256 | avg = jax.pmap(lambda x: jax.lax.pmean(x, 'x'), 'x')
257 |
258 | new_model_states = state.model_states | {
259 | 'batch_stats': avg(state.model_states['batch_stats'])}
260 | return state.replace(model_states=new_model_states)
261 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | absl-py>=1.0.0
2 | clu>=0.0.6
3 | tqdm>=4.63.0
4 |
--------------------------------------------------------------------------------