├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── label_dp ├── datasets.py ├── lr_schedules.py ├── main.py ├── models.py ├── profiles │ ├── __init__.py │ ├── __main__.py │ ├── p100_cifar10.py │ └── registry.py ├── train.py └── utils.py └── requirements.txt /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to Contribute 2 | 3 | We'd love to accept your patches and contributions to this project. There are 4 | just a few small guidelines you need to follow. 5 | 6 | ## Contributor License Agreement 7 | 8 | Contributions to this project must be accompanied by a Contributor License 9 | Agreement (CLA). You (or your employer) retain the copyright to your 10 | contribution; this simply gives us permission to use and redistribute your 11 | contributions as part of the project. Head over to 12 | to see your current agreements on file or 13 | to sign a new one. 14 | 15 | You generally only need to submit a CLA once, so if you've already submitted one 16 | (even if it was for a different project), you probably don't need to do it 17 | again. 18 | 19 | ## Code Reviews 20 | 21 | All submissions, including submissions by project members, require review. We 22 | use GitHub pull requests for this purpose. Consult 23 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more 24 | information on using pull requests. 25 | 26 | ## Community Guidelines 27 | 28 | This project follows 29 | [Google's Open Source Community Guidelines](https://opensource.google/conduct/). 30 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | This repository contains code for the multi-stage label differential privacy 2 | training code for the paper 3 | 4 | > Badih Ghazi, Noah Golowich, Ravi Kumar, Pasin Manurangsi, Chiyuan Zhang. *Deep 5 | > Learning with Label Differential Privacy*. Advances in Neural Information 6 | > Processing Systems (**NeurIPS**), 2021. 7 | > [arxiv:2102.06062](https://arxiv.org/abs/2102.06062) 8 | 9 | **Note**: This is not an officially supported Google product. 10 | 11 | ## Abstract 12 | 13 | The Randomized Response (`RR`) algorithm is a classical technique to improve 14 | robustness in survey aggregation, and has been widely adopted in applications 15 | with differential privacy guarantees. We propose a novel algorithm, *Randomized 16 | Response with Prior* (`RRWithPrior`), which can provide more accurate results 17 | while maintaining the same level of privacy guaranteed by `RR`. We then apply 18 | `RRWithPrior` to learn neural networks with *label* differential privacy 19 | (`LabelDP`), and show that when only the label needs to be protected, the model 20 | performance can be significantly improved over the previous state-of-the-art 21 | private baselines. Moreover, we study different ways to obtain priors, which 22 | when used with `RRWithPrior` can additionally improve the model performance, 23 | further reducing the accuracy gap between private and non-private models. We 24 | complement the empirical results with theoretical analysis showing that 25 | `LabelDP` is provably easier than protecting both the inputs and labels. 26 | 27 | ## Getting Started 28 | 29 | ### Requirements 30 | 31 | This codebase is implemented with [Jax](https://github.com/google/jax), 32 | [Flax](https://github.com/google/flax) and 33 | [Optax](https://github.com/deepmind/optax). To install the dependencies, run the 34 | following commands (see also the [Jax](https://github.com/google/jax) homepage 35 | for the latest installation guides for GPU/TPU support). 36 | 37 | ``` 38 | pip install --upgrade pip 39 | pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_releases.html 40 | 41 | pip install -r requirements.txt 42 | ``` 43 | 44 | ### 2-stage LabelDP Training for CIFAR-10 45 | 46 | Use the following command to run 47 | 48 | ``` 49 | python3 -m label_dp.main --base_workdir --profile_key 50 | ``` 51 | 52 | An example `` is `cifar10/e2/lp-2st/run0`. Check out 53 | `profiles/p100_cifar10.py` for a list of all pre-defined profiles. One can also 54 | run the following command to query a list of all registered profile keys 55 | matching a given regex: 56 | 57 | ``` 58 | python3 -m label_dp.profiles 59 | ``` 60 | 61 | Note the original experiments in the paper were run with Tensorflow. This 62 | codebase is a reimplementation with Jax. The results we obtained with this 63 | codebase are slightly different from what was reported in the paper. We include 64 | the numbers here for your reference (CIFAR-10 with 2-stage training): 65 | 66 | Epsilon | Test Accuracy ± std | Accuracy from Table 1 67 | ------- | ------------------- | --------------------- 68 | 1.0 | 62.89 ± 2.07 | 63.67 69 | 2.0 | 88.11 ± 0.38 | 86.05 70 | 4.0 | 94.18 ± 0.13 | 93.37 71 | 8.0 | 95.18 ± 0.07 | 94.52 72 | 73 | ### Structure of the Code 74 | 75 | To use the training code in the current setup, you just need to define new 76 | profiles that specify hyperparameters such as what dataset to load, what 77 | optimizer to use, etc. Create a new file `profiles/pXXX_xxx.py` and import it 78 | from `profiles/__init__.py`. Functions starting with `register_` will be called 79 | automatically to register hyperparameter profiles. 80 | 81 | The code loads dataset using [tfds](https://www.tensorflow.org/datasets), so 82 | theoretically any dataset available in tfds runs. But note for convenience we 83 | load the entire dataset into memory when doing data splitting for multi-stage 84 | training. This should work for CIFAR scale dataset. 85 | 86 | The key algorithm `RRWithPrior` (Algorithm 2 from the paper) is implemented in 87 | `rr_with_prior` (`train.py`) with numpy, and can be used independently in other 88 | scenarios. 89 | 90 | ## Citation 91 | 92 | ``` 93 | @article{ghazi2021deep, 94 | title={Deep Learning with Label Differential Privacy}, 95 | author={Ghazi, Badih and Golowich, Noah and Kumar, Ravi and Manurangsi, Pasin and Zhang, Chiyuan}, 96 | journal={Advances in Neural Information Processing Systems}, 97 | volume={34}, 98 | year={2021} 99 | } 100 | ``` 101 | -------------------------------------------------------------------------------- /label_dp/datasets.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Dataset.""" 16 | 17 | import numpy as np 18 | import tensorflow_datasets as tfds 19 | 20 | 21 | def batch_random_crop(batch_image_np, pad=4): 22 | """Randomly cropping images for data augmentation.""" 23 | n, h, w, c = batch_image_np.shape 24 | # pad 25 | padded_image = np.zeros((n, h+2*pad, w+2*pad, c), 26 | dtype=batch_image_np.dtype) 27 | padded_image[:, pad:-pad, pad:-pad, :] = batch_image_np 28 | # crop 29 | idxs = np.random.randint(2*pad, size=(n, 2)) 30 | cropped_image = np.array([ 31 | padded_image[i, y:y+h, x:x+w, :] 32 | for i, (y, x) in enumerate(idxs)]) 33 | return cropped_image 34 | 35 | 36 | def batch_random_fliplr(batch_image_np): 37 | """Randomly do left-right flip on images.""" 38 | n = batch_image_np.shape[0] 39 | coins = np.random.choice([-1, 1], size=n) 40 | flipped_image = np.array([ 41 | batch_image_np[i, :, ::coins[i], :] 42 | for i in range(n)]) 43 | return flipped_image 44 | 45 | 46 | def batch_random_cutout(batch_image_np, size=8): 47 | """Random cutout. 48 | 49 | Note we are using the same cutout region for all the images in the batch. 50 | 51 | Args: 52 | batch_image_np: images. 53 | size: cutout size. 54 | 55 | Returns: 56 | Images with random cutout. 57 | """ 58 | _, h, w, _ = batch_image_np.shape 59 | h0 = np.random.randint(0, h-size) 60 | w0 = np.random.randint(0, w-size) 61 | batch_image_np[:, h0:h0+size, w0:w0+size, :] = 0 62 | return batch_image_np 63 | 64 | 65 | class TFDSNumpyDataset: 66 | """Images dataset loaded into memory as numpy array. 67 | 68 | The full data array in numpy format can be easily accessed. Suitable for 69 | smaller scale image datasets like MNIST (and variants), CIFAR-10 / CIFAR-100, 70 | SVHN, etc. 71 | """ 72 | 73 | def __init__(self, name, random_crop=True, random_fliplr=True, 74 | random_cutout=0): 75 | """Constructs a dataset from tfds. 76 | 77 | Args: 78 | name: name of the dataset. 79 | random_crop: whether to perform random crop in data augmentation. 80 | random_fliplr: whether to perform random left-right flip in data aug. 81 | random_cutout: if non-zero, denote the cutout length. 82 | """ 83 | self.name = name 84 | 85 | self._random_crop = random_crop 86 | self._random_fliplr = random_fliplr 87 | self._random_cutout = random_cutout 88 | self.ds, self.info = tfds.load(name, batch_size=-1, 89 | as_dataset_kwargs={'shuffle_files': False}, 90 | with_info=True) 91 | self.ds_np = tfds.as_numpy(self.ds) 92 | 93 | self._add_index_feature() 94 | 95 | def _add_index_feature(self): 96 | """Adds 'index' feature if not present.""" 97 | for split in self.ds_np: 98 | if 'index' in self.ds_np[split]: 99 | continue 100 | n_sample = len(self.ds_np[split]['label']) 101 | index = np.arange(n_sample) 102 | if 'id' in self.ds_np[split]: 103 | # remove the 'id' feature, b/c jax cannot handle string type 104 | self.ds_np[split].pop('id') 105 | self.ds_np[split]['index'] = index 106 | 107 | @property 108 | def num_classes(self): 109 | return self.info.features['label'].num_classes 110 | 111 | @property 112 | def use_onehot_label(self): 113 | return False 114 | 115 | @property 116 | def data_scale(self): 117 | return 255.0 118 | 119 | def get_num_examples(self, split_name): 120 | return self.ds_np[split_name]['image'].shape[0] 121 | 122 | def get_input_shape(self, input_name): 123 | if input_name == 'image': 124 | return self.ds_np['train']['image'].shape[1:] 125 | raise KeyError(f'getting input shape for {input_name}') 126 | 127 | def normalize_images(self, batch_image_np): 128 | images = batch_image_np.astype(np.float32) / self.data_scale 129 | return images 130 | 131 | def iterate(self, split_name, batch_size, shuffle=False, augmentation=False, 132 | subset_index=None): 133 | """Iterates over the dataset.""" 134 | n_sample = self.get_num_examples(split_name) 135 | # make a shallow copy 136 | dset = dict(self.ds_np[split_name]) 137 | 138 | if subset_index is not None: 139 | n_sample = len(subset_index) 140 | for key in dset: 141 | dset[key] = dset[key][subset_index] 142 | 143 | if shuffle: 144 | rp = np.random.permutation(n_sample) 145 | for key in dset: 146 | dset[key] = dset[key][rp] 147 | 148 | for i in range(0, n_sample, batch_size): 149 | batch = {key: val[i:i+batch_size] 150 | for key, val in dset.items()} 151 | batch['image'] = self.normalize_images(batch['image']) 152 | if augmentation: 153 | if self._random_crop: 154 | batch['image'] = batch_random_crop(batch['image']) 155 | if self._random_fliplr: 156 | batch['image'] = batch_random_fliplr(batch['image']) 157 | if self._random_cutout > 0: 158 | batch['image'] = batch_random_cutout( 159 | batch['image'], self._random_cutout) 160 | 161 | yield batch 162 | 163 | 164 | class LabelRemappedTrainDataset: 165 | """A derived dataset where the labels are remapped according to the index.""" 166 | 167 | def __init__(self, dataset, subset_index): 168 | self.dataset = dataset 169 | self.subset_index = subset_index 170 | # if not None, could be a (n, k) array that defines the k-dimensional 171 | # one-hot vector label for each vector. Note n here is the total number 172 | # of examples in the original dataset as the index in the original data 173 | # is used to address this array. However, this array does not need to have 174 | # meaningful values outside of the examples specified by subset_index. 175 | self.label_mapping = None 176 | # subset mask can be used to further filter out some 177 | # of the examples in the training set 178 | self.subset_mask = np.ones(len(subset_index), dtype=np.bool) 179 | 180 | @property 181 | def num_classes(self): 182 | return self.dataset.num_classes 183 | 184 | @property 185 | def use_onehot_label(self): 186 | return True 187 | 188 | def get_num_examples(self, split_name): 189 | if split_name == 'train': 190 | return len(self.subset_index[self.subset_mask]) 191 | else: 192 | return self.dataset.get_num_examples(split_name) 193 | 194 | def get_input_shape(self, input_name): 195 | return self.dataset.get_input_shape(input_name) 196 | 197 | def iterate(self, split_name, batch_size, shuffle=False, augmentation=False, 198 | subset_index=None): 199 | """Iterate over the dataset.""" 200 | assert subset_index is None, ('LabelRemappedTrainDataset does not support ' 201 | 'further subset indexing.') 202 | 203 | if split_name == 'train': 204 | for batch in self.dataset.iterate( 205 | 'train', batch_size, shuffle=shuffle, augmentation=augmentation, 206 | subset_index=self.subset_index[self.subset_mask]): 207 | yield self.remap_batch_label(batch) 208 | else: 209 | for batch in self.dataset.iterate( 210 | split_name, batch_size, shuffle=shuffle, augmentation=augmentation): 211 | batch['orig_label'] = batch['label'] 212 | batch['label'] = self._make_onehot(batch['label']) 213 | yield batch 214 | 215 | def remap_batch_label(self, batch): 216 | batch['orig_label'] = batch['label'] 217 | if self.label_mapping is not None: 218 | batch['label'] = self.label_mapping[batch['index'], :] 219 | else: 220 | batch['label'] = self._make_onehot(batch['label']) 221 | return batch 222 | 223 | def _make_onehot(self, labels): 224 | return np.eye(self.num_classes, dtype=np.float32)[labels, :] 225 | -------------------------------------------------------------------------------- /label_dp/lr_schedules.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Learning rate schedules. 16 | 17 | A thin wrapper over optax schedulers with a unifying constructing interface. 18 | """ 19 | 20 | from typing import Callable 21 | import optax 22 | 23 | 24 | Schedule = Callable[[int], float] 25 | 26 | 27 | def constant(base_lr, num_train_steps) -> Schedule: 28 | del num_train_steps 29 | return lambda _: base_lr 30 | 31 | 32 | def piecewise_constant(base_lr, num_train_steps, *, 33 | rampup_thresh=0.15, 34 | stages=((0.3, 0.1), (0.6, 0.1), (0.9, 0.1))) -> Schedule: 35 | """Piecewise constant learning rate with optional linear rampup. 36 | 37 | Args: 38 | base_lr: base learning rate. 39 | num_train_steps: total number of training steps. 40 | rampup_thresh: if not None, can specify a linear rampup. 41 | stages: a sequence of (step_ratio, scaling_factor). The step_ratio times 42 | the num_train_steps is the decaying boundary. The scaling factor for all 43 | the stages whose decaying boundary is less than the current step is 44 | multiplied to the base learning rate. 45 | 46 | Returns: 47 | A learning rate schedule. 48 | """ 49 | lr_fn = optax.piecewise_constant_schedule( 50 | init_value=base_lr, 51 | boundaries_and_scales={int(r*num_train_steps): s for r, s in stages}) 52 | if rampup_thresh is not None and rampup_thresh > 0: 53 | rampup_steps = int(rampup_thresh * num_train_steps) 54 | rampup_fn = optax.linear_schedule( 55 | init_value=0, end_value=base_lr, transition_steps=rampup_steps) 56 | lr_fn = optax.join_schedules( 57 | schedules=[rampup_fn, lr_fn], boundaries=[rampup_steps]) 58 | return lr_fn 59 | -------------------------------------------------------------------------------- /label_dp/main.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Sets up environment and calls training.""" 16 | 17 | import copy 18 | import datetime 19 | import logging as native_logging 20 | import os 21 | from typing import Sequence 22 | 23 | from absl import app 24 | from absl import flags 25 | from absl import logging 26 | from clu import platform 27 | 28 | import jax 29 | import ml_collections 30 | import tensorflow as tf 31 | 32 | from label_dp import profiles 33 | from label_dp import train 34 | 35 | 36 | FLAGS = flags.FLAGS 37 | 38 | 39 | _BASE_WORKDIR = flags.DEFINE_string( 40 | 'base_workdir', '', 41 | 'Base directory for logs, checkpoints, and other outputs. ' 42 | 'When a profile key is given, logs will go into subfolders ' 43 | 'specified by the key.') 44 | _PROFILE_KEY = flags.DEFINE_string('profile_key', None, 45 | 'Key to a pre-defined profile.') 46 | 47 | 48 | def main(argv: Sequence[str]) -> None: 49 | if len(argv) > 1: 50 | raise app.UsageError('Too many command-line arguments.') 51 | if jax.process_count() > 1: 52 | raise NotImplementedError() 53 | 54 | # Hide any GPUs from TensorFlow. Otherwise TF might reserve memory and make 55 | # it unavailable to JAX. 56 | tf.config.experimental.set_visible_devices([], 'GPU') 57 | 58 | configs = profiles.Registry.get_profile(_PROFILE_KEY.value) 59 | configs = ml_collections.ConfigDict(copy.deepcopy(configs)) 60 | workdir = os.path.join(_BASE_WORKDIR.value, _PROFILE_KEY.value) 61 | 62 | # logging 63 | logdir = os.path.join(workdir, 'logs') 64 | tf.io.gfile.makedirs(logdir) 65 | log_file = os.path.join( 66 | logdir, datetime.datetime.now().strftime('%Y%m%d-%H%M%S') + '.txt') 67 | log_format = '%(asctime)s %(name)-12s %(levelname)-8s %(message)s' 68 | formatter = native_logging.Formatter(log_format) 69 | file_stream = tf.io.gfile.GFile(log_file, 'w') 70 | handler = native_logging.StreamHandler(file_stream) 71 | handler.setLevel(native_logging.INFO) 72 | handler.setFormatter(formatter) 73 | logging.get_absl_logger().addHandler(handler) 74 | 75 | if jax.process_index() == 0: 76 | work_unit = platform.work_unit() 77 | work_unit.create_artifact( 78 | artifact_type=platform.ArtifactType.DIRECTORY, 79 | artifact=workdir, description='Working directory') 80 | work_unit.create_artifact( 81 | artifact_type=platform.ArtifactType.FILE, 82 | artifact=log_file, description='Log file') 83 | 84 | train.multi_stage_train(configs.train, workdir) 85 | logging.flush() 86 | 87 | 88 | if __name__ == '__main__': 89 | app.run(main) 90 | -------------------------------------------------------------------------------- /label_dp/models.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """ResNet.""" 16 | 17 | 18 | import functools 19 | from typing import Any, Callable, Optional, Sequence, Tuple 20 | 21 | from flax import linen as nn 22 | import jax.numpy as jnp 23 | 24 | ModuleDef = Any 25 | 26 | 27 | ################################################################################ 28 | # ResNet V2 29 | ################################################################################ 30 | class BasicBlockV2(nn.Module): 31 | """Basic Block for a ResNet V2.""" 32 | 33 | channels: int 34 | conv: ModuleDef 35 | norm: ModuleDef 36 | act: Callable 37 | strides: Tuple[int, int] = (1, 1) 38 | 39 | @nn.compact 40 | def __call__(self, x): 41 | preact = self.act(self.norm()(x)) 42 | y = self.conv(self.channels, (3, 3), self.strides)(preact) 43 | y = self.act(self.norm()(y)) 44 | y = self.conv(self.channels, (3, 3))(y) 45 | 46 | if y.shape != x.shape: 47 | shortcut = self.conv(self.channels, (1, 1), self.strides)(preact) 48 | else: 49 | shortcut = x 50 | return shortcut + y 51 | 52 | 53 | class BottleneckBlockV2(nn.Module): 54 | """Bottleneck Block for a ResNet V2.""" 55 | 56 | channels: int 57 | conv: ModuleDef 58 | norm: ModuleDef 59 | act: Callable 60 | strides: Tuple[int, int] = (1, 1) 61 | 62 | @nn.compact 63 | def __call__(self, x): 64 | preact = self.act(self.norm()(x)) 65 | y = self.conv(self.channels, (1, 1))(preact) 66 | y = self.act(self.norm()(y)) 67 | y = self.conv(self.channels, (3, 3), self.strides)(y) 68 | y = self.act(self.norm()(y)) 69 | y = self.conv(self.channels * 4, (1, 1))(y) 70 | 71 | if y.shape != x.shape: 72 | shortcut = self.conv(self.channels * 4, (1, 1), self.strides)(preact) 73 | else: 74 | shortcut = x 75 | 76 | return shortcut + y 77 | 78 | 79 | class ResNetV2(nn.Module): 80 | """ResNet v2. 81 | 82 | K. He, X. Zhang, S. Ren, and J. Sun. Identity mappings in deep residual 83 | networks. In ECCV, pages 630–645, 2016. 84 | """ 85 | 86 | stage_sizes: Sequence[int] 87 | block_class: ModuleDef 88 | num_classes: Optional[int] = None 89 | base_channels: int = 64 90 | act: Callable = nn.relu 91 | dtype: Any = jnp.float32 92 | small_image: bool = False 93 | # if not None, batch statistics are sync-ed across replica according to 94 | # this axis_name used in pmap 95 | bn_cross_replica_axis_name: Optional[str] = None 96 | 97 | @nn.compact 98 | def __call__(self, x, train: bool = True): 99 | conv = functools.partial(nn.Conv, use_bias=False, dtype=self.dtype) 100 | norm = functools.partial( 101 | nn.BatchNorm, use_running_average=not train, 102 | momentum=0.9, epsilon=1e-5, dtype=self.dtype, 103 | axis_name=self.bn_cross_replica_axis_name) 104 | 105 | if self.small_image: # suitable for Cifar 106 | x = conv(self.base_channels, (3, 3), padding='SAME')(x) 107 | else: 108 | x = conv(self.base_channels, (7, 7), (2, 2), padding=[(3, 3), (3, 3)])(x) 109 | x = nn.max_pool(x, (3, 3), strides=(2, 2), padding='SAME') 110 | 111 | for i, n_blocks in enumerate(self.stage_sizes): 112 | for j in range(n_blocks): 113 | strides = (2, 2) if i > 0 and j == 0 else (1, 1) 114 | x = self.block_class(self.base_channels * 2 ** i, strides=strides, 115 | conv=conv, norm=norm, act=self.act)(x) 116 | 117 | x = self.act(norm(name='bn_final')(x)) 118 | x = jnp.mean(x, axis=(1, 2)) 119 | if self.num_classes is not None: 120 | x = nn.Dense(self.num_classes, dtype=self.dtype, name='classifier')(x) 121 | return x 122 | 123 | 124 | CifarResNet18V2 = functools.partial( 125 | ResNetV2, stage_sizes=[2, 2, 2, 2], block_class=BasicBlockV2, 126 | small_image=True) 127 | 128 | CifarResNet50V2 = functools.partial( 129 | ResNetV2, stage_sizes=[3, 4, 6, 3], block_class=BottleneckBlockV2, 130 | small_image=True) 131 | 132 | ResNet50V2 = functools.partial( 133 | ResNetV2, stage_sizes=[3, 4, 6, 3], block_class=BottleneckBlockV2) 134 | 135 | 136 | ################################################################################ 137 | # ResNet V1 138 | ################################################################################ 139 | class BasicBlockV1(nn.Module): 140 | """Basic block for a ResNet V1.""" 141 | 142 | channels: int 143 | conv: ModuleDef 144 | norm: ModuleDef 145 | act: Callable 146 | strides: Tuple[int, int] = (1, 1) 147 | 148 | @nn.compact 149 | def __call__(self, x,): 150 | residual = x 151 | y = self.conv(self.channels, (3, 3), self.strides)(x) 152 | y = self.norm()(y) 153 | y = self.act(y) 154 | y = self.conv(self.channels, (3, 3))(y) 155 | y = self.norm(scale_init=nn.initializers.zeros)(y) 156 | 157 | if residual.shape != y.shape: 158 | residual = self.conv(self.channels, (1, 1), 159 | self.strides, name='conv_proj')(residual) 160 | residual = self.norm(name='norm_proj')(residual) 161 | 162 | return self.act(residual + y) 163 | 164 | 165 | class BottleneckBlockV1(nn.Module): 166 | """Bottleneck block for ResNet V1.""" 167 | 168 | channels: int 169 | conv: ModuleDef 170 | norm: ModuleDef 171 | act: Callable 172 | strides: Tuple[int, int] = (1, 1) 173 | 174 | @nn.compact 175 | def __call__(self, x): 176 | residual = x 177 | y = self.conv(self.channels, (1, 1))(x) 178 | y = self.norm()(y) 179 | y = self.act(y) 180 | y = self.conv(self.channels, (3, 3), self.strides)(y) 181 | y = self.norm()(y) 182 | y = self.act(y) 183 | y = self.conv(self.channels * 4, (1, 1))(y) 184 | y = self.norm(scale_init=nn.initializers.zeros)(y) 185 | 186 | if residual.shape != y.shape: 187 | residual = self.conv(self.channels * 4, (1, 1), 188 | self.strides, name='conv_proj')(residual) 189 | residual = self.norm(name='norm_proj')(residual) 190 | 191 | return self.act(residual + y) 192 | 193 | 194 | class ResNetV1(nn.Module): 195 | """ResNetV1. 196 | 197 | K. He, X. Zhang, S. Ren, and J. Sun. Deep residual learning for image 198 | recognition. In CVPR, pages 770–778, 2016. 199 | """ 200 | 201 | stage_sizes: Sequence[int] 202 | block_class: ModuleDef 203 | num_classes: Optional[int] = None 204 | base_channels: int = 64 205 | act: Callable = nn.relu 206 | dtype: Any = jnp.float32 207 | small_image: bool = False 208 | # if not None, batch statistics are sync-ed across replica according to 209 | # this axis_name used in pmap 210 | bn_cross_replica_axis_name: Optional[str] = None 211 | 212 | @nn.compact 213 | def __call__(self, x, train: bool = True): 214 | conv = functools.partial(nn.Conv, use_bias=False, dtype=self.dtype) 215 | norm = functools.partial( 216 | nn.BatchNorm, use_running_average=not train, momentum=0.9, 217 | epsilon=1e-5, dtype=self.dtype, 218 | axis_name=self.bn_cross_replica_axis_name) 219 | 220 | if self.small_image: # suitable for Cifar 221 | x = conv(self.base_channels, (3, 3), padding='SAME', name='conv_init')(x) 222 | x = norm(name='bn_init')(x) 223 | x = self.act(x) 224 | else: 225 | x = conv(self.base_channels, (7, 7), (2, 2), padding=[(3, 3), (3, 3)], 226 | name='conv_init')(x) 227 | x = norm(name='bn_init')(x) 228 | x = nn.relu(x) 229 | x = nn.max_pool(x, (3, 3), strides=(2, 2), padding='SAME') 230 | 231 | for i, block_size in enumerate(self.stage_sizes): 232 | for j in range(block_size): 233 | strides = (2, 2) if i > 0 and j == 0 else (1, 1) 234 | x = self.block_class(self.base_channels * 2 ** i, strides=strides, 235 | conv=conv, norm=norm, act=self.act)(x) 236 | 237 | x = jnp.mean(x, axis=(1, 2)) 238 | if self.num_classes is not None: 239 | x = nn.Dense(self.num_classes, dtype=self.dtype, name='classifier')(x) 240 | return x 241 | 242 | 243 | CifarResNet18V1 = functools.partial( 244 | ResNetV1, stage_sizes=[2, 2, 2, 2], block_class=BasicBlockV1, 245 | small_image=True) 246 | 247 | CifarResNet50V1 = functools.partial( 248 | ResNetV1, stage_sizes=[3, 4, 6, 3], block_class=BottleneckBlockV1, 249 | small_image=True) 250 | 251 | ResNet50V1 = functools.partial( 252 | ResNetV1, stage_sizes=[3, 4, 6, 3], block_class=BottleneckBlockV1) 253 | -------------------------------------------------------------------------------- /label_dp/profiles/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Profiles.""" 16 | 17 | from . import p100_cifar10 18 | 19 | from . import registry 20 | 21 | 22 | Registry = registry.Registry 23 | -------------------------------------------------------------------------------- /label_dp/profiles/__main__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # pylint: disable=invalid-name 16 | """Useful to list / search in model zoo. 17 | 18 | python3 -m profiles 19 | """ 20 | import sys 21 | 22 | from . import registry 23 | 24 | 25 | if __name__ == '__main__': 26 | key = '.*' 27 | if len(sys.argv) > 1: 28 | key = sys.argv[1] 29 | 30 | registry.Registry.print_profiles(key) 31 | -------------------------------------------------------------------------------- /label_dp/profiles/p100_cifar10.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """CIFAR-10 experiments.""" 16 | 17 | 18 | def register_cifar10_experiment(registry): 19 | """Register experiment configs.""" 20 | batch_size = 512 21 | learning_rate = 0.4 22 | 23 | num_epochs = 200 24 | cutout = 8 25 | lr_fn = 'piecewise_constant' 26 | optimizer = 'sgd' 27 | for epsilon, hparam in [ 28 | (1, {'mixup': (16, 8), 'data_split': 80, 'temperature': 0.6}), 29 | (2, {'mixup': (8, 8), 'data_split': 70, 'temperature': 0.5}), 30 | (4, {'mixup': (8, 4), 'data_split': 60, 'temperature': 0.5}), 31 | (8, {'mixup': (4, 4), 'data_split': 60, 'temperature': 0.5}), 32 | ]: 33 | for rep in range(5): 34 | mixup1, mixup2 = hparam['mixup'] 35 | first_stage_data_ratio = hparam['data_split'] 36 | temperature = hparam['temperature'] 37 | key = f'cifar10/e{epsilon}/lp-2st/run{rep}' 38 | meta = {'target': 'main', 'platform': 'v100'} 39 | train_configs = { 40 | 'run_seed': 1234 + rep, 41 | 'batch_size': batch_size, 42 | 'half_precision': False, 43 | 'l2_regu': 1e-4, 44 | 'num_epochs': num_epochs, 45 | 'eval_splits': ['test'], 46 | 'reuse_last_stage_data': True, 47 | 'mask_last_stage_label_by_prior': True, 48 | 'data': {'name': 'cifar10', 'kwargs': {'random_cutout': cutout}}, 49 | 'base_lr': learning_rate, 50 | 'lr_fn': {'name': lr_fn, 'kwargs': {}}, 51 | 'optimizer': get_optimizer(optimizer), 52 | 'model': {'arch': 'CifarResNet18V2', 'kwargs': {}}, 53 | 'stage_specs': [ 54 | dict(type='rr', seed=2019, eps=epsilon, 55 | data_split=first_stage_data_ratio/100, mixup=mixup1), 56 | dict(type='rr-with-prior', seed=2020, eps=epsilon, 57 | data_split=1 - first_stage_data_ratio/100, 58 | temperature=temperature, mixup=mixup2) 59 | ] 60 | } 61 | 62 | spec = {'key': key, 'meta': meta, 'train': train_configs} 63 | registry.register(spec) 64 | 65 | 66 | def get_optimizer(name): 67 | cfg = {'name': name} 68 | if name == 'sgd': 69 | cfg['kwargs'] = {'momentum': 0.9, 'nesterov': True} 70 | else: 71 | cfg['kwargs'] = {} 72 | return cfg 73 | -------------------------------------------------------------------------------- /label_dp/profiles/registry.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Profile registry.""" 16 | import re 17 | import sys 18 | 19 | 20 | class Registry(object): 21 | """Registry holding all model specs in the model zoo.""" 22 | registry = None 23 | 24 | @classmethod 25 | def build_registry(cls): 26 | """Builds registry, called upon first zoo query.""" 27 | if cls.registry is not None: 28 | return 29 | 30 | cls.registry = dict() 31 | 32 | mod_profiles = sys.modules[sys.modules[__name__].__package__] 33 | profile_mods = [getattr(mod_profiles, x) for x in dir(mod_profiles) 34 | if re.match(r'^p[0-9]+_', x)] 35 | profile_mods = [x for x in profile_mods if isinstance(x, type(sys))] 36 | profile_mods.sort(key=str) 37 | for p_mod in profile_mods: 38 | register_funcs = [getattr(p_mod, x) for x in dir(p_mod) 39 | if x.startswith('register_')] 40 | register_funcs = filter(callable, register_funcs) 41 | 42 | for func in register_funcs: 43 | func(cls) 44 | 45 | @classmethod 46 | def register(cls, profile): 47 | key = profile['key'] 48 | if key in cls.registry: # pylint: disable=unsupported-membership-test 49 | raise KeyError('duplicated profile key: {}'.format(key)) 50 | cls.registry[key] = profile # pylint: disable=unsupported-assignment-operation 51 | 52 | @classmethod 53 | def list_profiles(cls, regex): 54 | cls.build_registry() 55 | profiles = [cls.registry[key] for key in cls.registry.keys() 56 | if re.search(regex, key)] 57 | return profiles 58 | 59 | @classmethod 60 | def print_profiles(cls, regex): 61 | profiles = cls.list_profiles(regex) 62 | print('{} profiles found ====== with regex: {}'.format( 63 | len(profiles), regex)) 64 | for i, profile in enumerate(profiles): 65 | print(' {:>3d}) {}'.format(i, profile['key'])) 66 | 67 | @classmethod 68 | def get_profile(cls, key): 69 | cls.build_registry() 70 | assert key is not None 71 | return cls.registry[key] 72 | -------------------------------------------------------------------------------- /label_dp/train.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Training script.""" 16 | 17 | import functools 18 | import pathlib 19 | from typing import Any 20 | 21 | from absl import logging 22 | 23 | from clu import metric_writers 24 | 25 | from flax import jax_utils 26 | from flax.optim import dynamic_scale as dynamic_scale_lib 27 | from flax.training import common_utils 28 | from flax.training import train_state 29 | import jax 30 | import ml_collections 31 | import numpy as np 32 | import tensorflow as tf 33 | 34 | from label_dp import datasets 35 | from label_dp import models 36 | from label_dp import utils 37 | 38 | 39 | class TrainState(train_state.TrainState): 40 | epoch: int 41 | model_states: Any 42 | dynamic_scale: dynamic_scale_lib.DynamicScale 43 | 44 | 45 | def create_train_state(rng, input_shape, half_precision, model, optimizer_cfgs): 46 | """Creates initial training state.""" 47 | dynamic_scale = None 48 | platform = jax.local_devices()[0].platform 49 | if half_precision and platform == 'gpu': 50 | dynamic_scale = dynamic_scale_lib.DynamicScale() 51 | else: 52 | dynamic_scale = None 53 | 54 | params, model_states = utils.initialize_model( 55 | rng, input_shape, model) 56 | tx = utils.build_optimizer(**optimizer_cfgs) 57 | state = TrainState.create( 58 | apply_fn=model.apply, params=params, tx=tx, model_states=model_states, 59 | dynamic_scale=dynamic_scale, epoch=0) 60 | return state 61 | 62 | 63 | def build_model(model_configs, num_classes, dtype): 64 | kwargs = dict(model_configs.kwargs, num_classes=num_classes, dtype=dtype) 65 | return getattr(models, model_configs.arch)(**kwargs) 66 | 67 | 68 | def report_metrics(i_stage, report_name, epoch, metrics, writer): 69 | metrics = common_utils.stack_forest(metrics) 70 | summary = jax.tree_map(lambda x: float(x.mean()), metrics) 71 | summary = {f'stage{i_stage}/{report_name}/{k}': v for k, v in summary.items()} 72 | writer.write_scalars(epoch, summary) 73 | 74 | 75 | def multi_stage_train(configs: ml_collections.ConfigDict, workdir: str): 76 | """Multi-stage training.""" 77 | orig_dataset = datasets.TFDSNumpyDataset( 78 | name=configs.data.name, **configs.data.kwargs) 79 | stage_datasets = utils.derive_subset_dataset( 80 | orig_dataset, n_stages=len(configs.stage_specs), 81 | stage_splits=[spec['data_split'] for spec in configs.stage_specs], 82 | seed=sum([spec['seed'] for spec in configs.stage_specs])) 83 | 84 | n_tr_total = orig_dataset.get_num_examples('train') 85 | input_dtype = utils.get_dtype(configs.half_precision) 86 | model = build_model(configs.model, orig_dataset.num_classes, input_dtype) 87 | last_stage_state = None 88 | for i_stage in range(len(stage_datasets)): 89 | if configs.stage_specs[i_stage]['type'] == 'rr': 90 | k_for_prior = compute_randomized_labels( 91 | stage_datasets[i_stage], configs.batch_size, 92 | configs.stage_specs[i_stage], n_tr_total) 93 | elif configs.stage_specs[i_stage]['type'] == 'rr-with-prior': 94 | k_for_prior = compute_randomized_labels_with_priors( 95 | stage_datasets[i_stage], model, last_stage_state, configs.batch_size, 96 | configs.stage_specs[i_stage], n_tr_total) 97 | else: 98 | raise KeyError('Unknown label randomization type') 99 | 100 | if i_stage != 0 and configs.reuse_last_stage_data: 101 | logging.info('Reusing randomized data from stage %d in stage %d', 102 | i_stage - 1, i_stage) 103 | if configs.mask_last_stage_label_by_prior: 104 | logging.info('Filtering out egs from stage %d using learned prior.', 105 | i_stage - 1) 106 | # Note we are passing in the dataset from the last stage and the k 107 | # from this stage 108 | filter_stage_data_by_prior( 109 | stage_datasets[i_stage-1], model, last_stage_state, k_for_prior, 110 | n_tr_total, configs.batch_size) 111 | merge_stage_data(stage_datasets[i_stage-1], stage_datasets[i_stage]) 112 | 113 | logging.info('#' * 60) 114 | logging.info('# Stage %d training', i_stage) 115 | logging.info('#' * 60) 116 | last_stage_state = single_stage_train( 117 | configs, workdir, i_stage, stage_datasets, model, last_stage_state) 118 | 119 | 120 | def single_stage_train( 121 | configs: ml_collections.ConfigDict, workdir: str, i_stage: int, 122 | stage_datasets, model, last_stage_state=None): 123 | """Training for a single stage.""" 124 | workdir = pathlib.Path(workdir) / f'stage{i_stage}' 125 | tf.io.gfile.makedirs(str(workdir)) 126 | 127 | writer = metric_writers.create_default_writer( 128 | str(workdir / 'tensorboard'), just_logging=jax.process_index() > 0) 129 | 130 | rng = jax.random.PRNGKey(configs.run_seed + i_stage) 131 | np_rng = np.random.RandomState(seed=configs.run_seed + i_stage + 123) 132 | mixup_alpha = configs.stage_specs[i_stage]['mixup'] 133 | mixup_sampler = functools.partial( 134 | np_rng.beta, mixup_alpha, mixup_alpha) if mixup_alpha > 0 else None 135 | local_batch_size = utils.get_local_batch_size(configs.batch_size) 136 | dataset = stage_datasets[i_stage] 137 | 138 | n_train = dataset.get_num_examples('train') 139 | n_train_steps = int(configs.num_epochs * n_train / local_batch_size) 140 | lr_fn = utils.build_lr_fn( 141 | configs.lr_fn.name, configs.base_lr, n_train_steps, configs.lr_fn.kwargs) 142 | optimizer_cfgs = dict(configs.optimizer, learning_rate=lr_fn) 143 | state = create_train_state( 144 | rng, dataset.get_input_shape('image'), configs.half_precision, model, 145 | optimizer_cfgs) 146 | if last_stage_state is not None: 147 | state = state.replace( 148 | params=last_stage_state.params, 149 | model_states=last_stage_state.model_states) 150 | 151 | state = jax_utils.replicate(state) 152 | 153 | # pmap the train and eval functions 154 | p_train_step = jax.pmap( 155 | functools.partial(utils.train_step, model.apply, l2_regu=configs.l2_regu), 156 | axis_name='batch') 157 | p_eval_step = jax.pmap(functools.partial(utils.eval_step, model.apply), 158 | axis_name='batch') 159 | 160 | def run_eval(epoch=0, split_name='test'): 161 | eval_metrics = [] 162 | for batch in utils.iterate_data(dataset, split_name, local_batch_size, 163 | desc=f'E{epoch:03d} eval-{split_name}'): 164 | metrics = p_eval_step(state, batch) 165 | eval_metrics.append(utils.metrics_to_numpy(metrics)) 166 | 167 | report_metrics(i_stage, f'eval-{split_name}', epoch, eval_metrics, writer) 168 | 169 | start_epoch = jax_utils.unreplicate(state.epoch) 170 | logging.info('Start training from epoch %d...', start_epoch) 171 | 172 | with metric_writers.ensure_flushes(writer): 173 | while True: 174 | epoch = int(jax_utils.unreplicate(state.epoch)) 175 | if epoch >= configs.num_epochs: 176 | break 177 | 178 | train_metrics = [] 179 | for batch in utils.iterate_data(dataset, 'train', local_batch_size, 180 | augmentation=True, shuffle=True, 181 | desc=f'E{epoch+1:03d} train', 182 | mixup_sampler=mixup_sampler): 183 | state, metrics = p_train_step(state, batch) 184 | train_metrics.append(utils.metrics_to_numpy(metrics)) 185 | 186 | state = state.replace(epoch=state.epoch + 1) 187 | epoch += 1 188 | report_metrics(i_stage, 'train', epoch, train_metrics, writer) 189 | 190 | # sync batch statistics across replicas 191 | state = utils.sync_batch_stats(state) 192 | 193 | for split in configs.eval_splits: 194 | run_eval(epoch, split) 195 | 196 | utils.block_until_computation_finish() 197 | return jax_utils.unreplicate(state) 198 | 199 | 200 | def compute_randomized_labels(dataset, batch_size, spec, n_tr_total): 201 | """Computes randomized labels.""" 202 | assert isinstance(dataset, datasets.LabelRemappedTrainDataset) 203 | n_classes = dataset.num_classes 204 | orig_labels = np.zeros(n_tr_total, dtype=np.int64) 205 | 206 | for batch in dataset.iterate('train', batch_size): 207 | orig_labels[batch['index']] = batch['orig_label'] 208 | 209 | # assign new labels 210 | rng = np.random.RandomState(seed=spec['seed']) 211 | dataset.label_mapping = np.zeros((n_tr_total, n_classes), dtype=np.float32) 212 | dataset.subset_mask = np.ones(len(dataset.subset_index), dtype=np.bool) 213 | for idx in dataset.subset_index: 214 | if spec['type'] == 'rr': 215 | label = orig_labels[idx] 216 | rate = 1 / (np.exp(spec['eps']) + n_classes - 1) 217 | prob = np.zeros(n_classes) + rate 218 | prob[label] = 1 - rate * (n_classes - 1) 219 | new_label = rng.choice(n_classes, 1, p=prob) 220 | dataset.label_mapping[idx][new_label] = 1 221 | else: 222 | raise KeyError(f'Unknown type: {spec["type"]}') 223 | 224 | if spec['type'] == 'rr': 225 | return n_classes 226 | 227 | 228 | def compute_randomized_labels_with_priors( 229 | dataset, model, last_stage_state, batch_size, spec, n_tr_total): 230 | """Computes randomized labels based on lp.""" 231 | assert isinstance(dataset, datasets.LabelRemappedTrainDataset) 232 | assert spec['type'] == 'rr-with-prior' 233 | ds_weight = spec.get('domain_specific_prior_weight', 0.0) 234 | assert ds_weight >= 0.0 and ds_weight <= 1.0 235 | 236 | n_classes = dataset.num_classes 237 | dataset.label_mapping = np.zeros((n_tr_total, n_classes), dtype=np.float32) 238 | dataset.subset_mask = np.ones(len(dataset.subset_index), dtype=np.bool) 239 | rng = np.random.RandomState(seed=spec['seed']) 240 | logging.info('RRWithPrior labeling (T=%f, ds_weight=%f)', 241 | spec['temperature'], ds_weight) 242 | 243 | model_vars = {'params': last_stage_state.params, 244 | **last_stage_state.model_states} 245 | j_pred = jax.jit(functools.partial(model.apply, train=False)) 246 | 247 | soft_k = 0.0 248 | for batch in dataset.iterate('train', batch_size): 249 | if np.isclose(ds_weight, 1.0): 250 | p_last_model = 0.0 251 | else: 252 | logits = j_pred(model_vars, batch['image']) 253 | logits /= spec['temperature'] 254 | p_last_model = jax.device_get(jax.nn.softmax(logits)) 255 | 256 | if np.isclose(ds_weight, 0.0): 257 | p_domain = 0.0 258 | else: 259 | p_domain = batch['prior'] 260 | p_domain = np.power(p_domain, 1/spec['temperature']) 261 | p_domain = p_domain / np.sum(p_domain, axis=1, keepdims=True) 262 | 263 | probs = ds_weight*p_domain + (1-ds_weight)*p_last_model 264 | orig_labels = batch['orig_label'] 265 | 266 | for i, prob in enumerate(probs): 267 | k, new_label = rr_with_prior(prob, spec['eps'], orig_labels[i], rng) 268 | soft_k += k 269 | dataset.label_mapping[batch['index'][i]][new_label] = 1 270 | 271 | soft_k /= len(dataset.subset_index) 272 | logging.info('effective soft_k ~= %.2f (averaged on %d examples).', 273 | soft_k, len(dataset.subset_index)) 274 | return soft_k 275 | 276 | 277 | def rr_with_prior(prior, eps, y, rng): 278 | """Randomized response with prior. 279 | 280 | Args: 281 | prior: A K-length array where the k-th entry is the probability that the 282 | true label is k. 283 | eps: the epsilon value for which the randomized response is epsilon-DP. 284 | y: an integer indicating the true label. 285 | rng: a numpy random number generator for sampling. 286 | 287 | Returns: 288 | k, y_rr: k is the value used in rr-top-k; y_rr is the randomized label. 289 | """ 290 | idx_sort = np.flipud(np.argsort(prior)) 291 | prior_sorted = prior[idx_sort] 292 | tmp = np.exp(-eps) 293 | wks = [np.sum(prior_sorted[:(k+1)]) / (1 + (k-1)*tmp) 294 | for k in range(len(prior))] 295 | optim_k = np.argmax(wks) + 1 296 | 297 | adjusted_prior = np.zeros_like(prior) + tmp / (1 + (optim_k-1)*tmp) 298 | adjusted_prior[y] = 1 / (1 + (optim_k-1)*tmp) 299 | adjusted_prior[idx_sort[optim_k:]] = 0 300 | adjusted_prior /= np.sum(adjusted_prior) # renorm in case y not in topk 301 | rr_label = rng.choice(len(prior), 1, p=adjusted_prior) 302 | return optim_k, rr_label 303 | 304 | 305 | def filter_stage_data_by_prior(last_dataset, model, last_stage_state, 306 | k_for_prior, n_tr_total, batch_size): 307 | """Filtering out the egs from last stage according to prior.""" 308 | assert k_for_prior is not None 309 | k_for_prior = int(k_for_prior) 310 | 311 | model_vars = {'params': last_stage_state.params, 312 | **last_stage_state.model_states} 313 | j_pred = jax.jit(functools.partial(model.apply, train=False)) 314 | global_mask = np.ones(n_tr_total, dtype=np.bool) 315 | 316 | last_dataset.subset_mask[:] = True # enable all examples 317 | for batch in last_dataset.iterate('train', batch_size): 318 | logits = j_pred(model_vars, batch['image']) 319 | _, topk_idx = jax.device_get(jax.lax.top_k(logits, k=k_for_prior)) 320 | 321 | for j in range(batch['image'].shape[0]): 322 | if np.isclose(np.sum(batch['label'][j, topk_idx[j]]), 0): 323 | # the randomized label from the last stage is not in the topk prior 324 | global_mask[batch['index'][j]] = False 325 | 326 | n_filtered = 0 327 | for i, idx in enumerate(last_dataset.subset_index): 328 | if not global_mask[idx]: 329 | last_dataset.subset_mask[i] = False 330 | n_filtered += 1 331 | 332 | logging.info('%d egs removed due to randomized labels not in top %d prior', 333 | n_filtered, k_for_prior) 334 | 335 | 336 | def merge_stage_data(dset_last_stage, dset): 337 | """Merges the (randomized) data from the last stage to reused in this stage.""" 338 | assert isinstance(dset_last_stage, datasets.LabelRemappedTrainDataset) 339 | assert isinstance(dset, datasets.LabelRemappedTrainDataset) 340 | assert dset.label_mapping is not None 341 | assert dset_last_stage.label_mapping is not None 342 | # pylint: disable=g-explicit-length-test 343 | assert len( 344 | np.intersect1d( 345 | dset.subset_index, dset_last_stage.subset_index, 346 | assume_unique=True)) == 0 347 | 348 | dset.label_mapping[dset_last_stage.subset_index, 349 | ...] = dset_last_stage.label_mapping[ 350 | dset_last_stage.subset_index, ...] 351 | dset.subset_index = np.concatenate( 352 | [dset_last_stage.subset_index, dset.subset_index], axis=0) 353 | dset.subset_mask = np.concatenate( 354 | [dset_last_stage.subset_mask, dset.subset_mask], axis=0) 355 | -------------------------------------------------------------------------------- /label_dp/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Training related utilities.""" 16 | 17 | import functools 18 | from typing import Optional 19 | 20 | from absl import logging 21 | 22 | from flax import jax_utils 23 | from flax.training import common_utils 24 | import jax 25 | import jax.numpy as jnp 26 | import numpy as np 27 | import optax 28 | import tqdm 29 | 30 | from label_dp import datasets 31 | from label_dp import lr_schedules 32 | 33 | 34 | def derive_subset_dataset(dataset, n_stages=2, stage_splits=None, seed=None): 35 | """Specifies data splits for each stage of training. 36 | 37 | Args: 38 | dataset: The original dataset to derive from. 39 | n_stages: number of stages. 40 | stage_splits: Must be a list of numbers summing to 1, indicating 41 | the ratio of data for each stage. 42 | seed: if stage_split is not None, this should not be None, and should 43 | be the random seed for generating the splits. 44 | 45 | Returns: 46 | A LabelRemappedTrainDataset. 47 | """ 48 | n_tr = dataset.get_num_examples('train') 49 | assert len(stage_splits) == n_stages 50 | assert seed is not None 51 | assert np.isclose(sum(stage_splits), 1) 52 | cum_splits = np.cumsum(stage_splits) 53 | cum_split_counts = [0] + [int(n_tr * x) for x in cum_splits] 54 | cum_split_counts[-1] = n_tr 55 | rng = np.random.RandomState(seed=seed) 56 | perm = rng.permutation(n_tr) 57 | stage_subsets = [perm[cum_split_counts[i]:cum_split_counts[i+1]] 58 | for i in range(n_stages)] 59 | 60 | return [datasets.LabelRemappedTrainDataset(dataset, subset_index) 61 | for subset_index in stage_subsets] 62 | 63 | 64 | def cross_entropy_loss(logits, onehot_labels): 65 | log_softmax_logits = jax.nn.log_softmax(logits) 66 | batch_size = onehot_labels.shape[0] 67 | return -jnp.sum(onehot_labels * log_softmax_logits) / batch_size 68 | 69 | 70 | def classification_metrics(logits, onehot_labels): 71 | loss = cross_entropy_loss(logits, onehot_labels) 72 | acc = jnp.mean(jnp.argmax(logits, -1) == jnp.argmax(onehot_labels, -1)) 73 | metrics = {'loss': loss, 'accuracy': acc} 74 | metrics = jax.lax.pmean(metrics, 'batch') 75 | return metrics 76 | 77 | 78 | def l2_regularizer(coefficient, params): 79 | if coefficient <= 0: 80 | return 0 81 | params = jax.tree_leaves(params) 82 | weight_l2 = sum([jnp.sum(x ** 2) for x in params]) 83 | weight_penalty = coefficient * 0.5 * weight_l2 84 | return weight_penalty 85 | 86 | 87 | def block_until_computation_finish(): 88 | """Wait until computations are done.""" 89 | logging.flush() 90 | jax.random.normal(jax.random.PRNGKey(0), ()).block_until_ready() 91 | 92 | 93 | def get_local_batch_size(batch_size: int) -> int: 94 | """Gets local batch size to a host.""" 95 | # For example, if we have 2 hosts, each with 8 devices, batch_size=2048, then 96 | # - batch_size == 2048 97 | # - local_batch_size == 2048 / 2 == 1024 98 | # - jax.device_count() == 2*8 == 16 99 | # The dataset object will be sharded at host level, so each host will see 100 | # a different subset of the data. 101 | if batch_size % jax.device_count() > 0: 102 | raise ValueError('Batch size must be divisible by the number of devices') 103 | local_batch_size = batch_size // jax.process_count() 104 | logging.info( 105 | 'JAX process_index=%d, process_count=%d, device_count=%d, ' 106 | 'local_device_count=%d', jax.process_index(), jax.process_count(), 107 | jax.device_count(), jax.local_device_count()) 108 | return local_batch_size 109 | 110 | 111 | def get_dtype(half_precision: bool, platform: Optional[str] = None): 112 | """Gets concrete data type according to precision and platform. 113 | 114 | Args: 115 | half_precision: whether to use half precision (float16). 116 | platform: 'tpu' or 'gpu'. 117 | 118 | Returns: 119 | A data type to use according to the specification. 120 | """ 121 | if platform is None: 122 | platform = jax.local_devices()[0].platform 123 | 124 | if half_precision: 125 | if platform == 'tpu': 126 | return jnp.bfloat16 127 | else: 128 | return jnp.float16 129 | else: 130 | return jnp.float32 131 | 132 | 133 | def split_state_params(variables): 134 | """Separate (BatchNorm) states and trainable params.""" 135 | params = variables.pop('params') 136 | return variables, params 137 | 138 | 139 | def initialize_model(rng, input_shape, model): 140 | """Initializes parameters and states for a model.""" 141 | input_shape = (1, *input_shape) # add a dummy batch dimension 142 | @jax.jit 143 | def init(*args): 144 | return model.init(*args) 145 | variables = init({'params': rng}, jnp.ones(input_shape, model.dtype)) 146 | model_states, params = split_state_params(variables) 147 | return params, model_states 148 | 149 | 150 | def build_optimizer(name, learning_rate, kwargs): 151 | """Builds an optimizer.""" 152 | ctor = getattr(optax, name) 153 | return ctor(learning_rate=learning_rate, **kwargs) 154 | 155 | 156 | def build_lr_fn(name, base_lr, num_train_steps, kwargs): 157 | """Builds learning rate scheduler.""" 158 | return getattr(lr_schedules, name)(base_lr, num_train_steps, **kwargs) 159 | 160 | 161 | def train_step(apply_fn, state, batch, l2_regu, 162 | f_metrics=classification_metrics): 163 | """Performs a single training step.""" 164 | def loss_fn(params): 165 | variables = {'params': params, **state.model_states} 166 | logits, new_model_states = apply_fn( 167 | variables, batch['image'], train=True, mutable=['batch_stats']) 168 | loss = cross_entropy_loss(logits, batch['label']) 169 | loss = loss + l2_regularizer(l2_regu, variables['params']) 170 | return loss, (new_model_states, logits) 171 | 172 | state, aux, metrics = optimizer_step(loss_fn, state) 173 | new_model_states, logits = aux[1] 174 | new_state = state.replace(model_states=new_model_states) 175 | metrics.update(f_metrics(logits, batch['label'])) 176 | 177 | return new_state, metrics 178 | 179 | 180 | def optimizer_step(loss_fn, state): 181 | """Applies one optimizer step.""" 182 | dynamic_scale = state.dynamic_scale 183 | 184 | if dynamic_scale: 185 | grad_fn = dynamic_scale.value_and_grad( 186 | loss_fn, has_aux=True, axis_name='batch') 187 | dynamic_scale, is_fin, aux, grad = grad_fn(state.params) 188 | # dynamic loss takes care of averaging gradients across replicas 189 | else: 190 | grad_fn = jax.value_and_grad(loss_fn, has_aux=True) 191 | aux, grad = grad_fn(state.params) 192 | # Re-use same axis_name as in the call to `pmap(...train_step...)` below. 193 | grad = jax.lax.pmean(grad, axis_name='batch') 194 | 195 | metrics = {} 196 | new_model_states = aux[1][0] 197 | new_state = state.apply_gradients(grads=grad, model_states=new_model_states) 198 | if dynamic_scale: 199 | # if is_fin == False the gradients contain Inf/NaNs and optimizer state and 200 | # params should be restored (= skip this step). 201 | new_state = new_state.replace( 202 | opt_state=jax.tree_map( 203 | functools.partial(jnp.where, is_fin), new_state.opt_state, 204 | state.opt_state), 205 | params=jax.tree_map( 206 | functools.partial(jnp.where, is_fin), new_state.params, 207 | state.params)) 208 | metrics['scale'] = dynamic_scale.scale 209 | 210 | return new_state, aux, metrics 211 | 212 | 213 | def eval_step(apply_fn, state, batch, f_metrics=classification_metrics): 214 | """Performs a single evaluation step.""" 215 | variables = {'params': state.params, **state.model_states} 216 | logits = apply_fn(variables, batch['image'], train=False, mutable=False) 217 | return f_metrics(logits, batch['label']) 218 | 219 | 220 | def iterate_data(dataset, split_name, batch_size, augmentation=False, 221 | shuffle=False, desc='', mixup_sampler=None, **kwargs): 222 | """Iterates over data.""" 223 | iterator = dataset.iterate(split_name, batch_size, shuffle=shuffle, 224 | augmentation=augmentation, **kwargs) 225 | if mixup_sampler is not None: 226 | def apply_mixup(batch): 227 | lm = mixup_sampler() 228 | batch['label'] = lm*batch['label'] + (1-lm)*np.flipud(batch['label']) 229 | batch['image'] = lm*batch['image'] + (1-lm)*np.flipud(batch['image']) 230 | return batch 231 | iterator = map(apply_mixup, iterator) 232 | 233 | iterator = map(common_utils.shard, iterator) 234 | iterator = jax_utils.prefetch_to_device(iterator, 2) 235 | iterator = tqdm.tqdm(iterator, desc=desc, disable=None, 236 | total=dataset.get_num_examples(split_name) // batch_size) 237 | return iterator 238 | 239 | 240 | def metrics_to_numpy(metrics): 241 | # We select the first element of x in order to get a single copy of a 242 | # device-replicated metric. 243 | metrics = jax.tree_map(lambda x: x[0], metrics) 244 | metrics_np = jax.device_get(metrics) 245 | return metrics_np 246 | 247 | 248 | def sync_batch_stats(state): 249 | """Sync the batch statistics across replicas.""" 250 | if 'batch_stats' not in state.model_states: 251 | return state 252 | 253 | # An axis_name is passed to pmap which can then be used by pmean. 254 | # In this case each device has its own version of the batch statistics and 255 | # we average them. 256 | avg = jax.pmap(lambda x: jax.lax.pmean(x, 'x'), 'x') 257 | 258 | new_model_states = state.model_states | { 259 | 'batch_stats': avg(state.model_states['batch_stats'])} 260 | return state.replace(model_states=new_model_states) 261 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py>=1.0.0 2 | clu>=0.0.6 3 | tqdm>=4.63.0 4 | --------------------------------------------------------------------------------