├── CONTRIBUTING.md ├── LICENSE ├── README.md └── distribution_shift_framework ├── __init__.py ├── classification ├── __init__.py ├── config.py ├── experiment.py ├── experiment_lib.py └── experiment_lib_test.py ├── configs ├── __init__.py └── disentanglement_config.py ├── core ├── __init__.py ├── adapt.py ├── adapt_train.py ├── algorithms │ ├── __init__.py │ ├── adversarial.py │ ├── base.py │ ├── erm.py │ ├── irm.py │ ├── losses.py │ └── sagnet.py ├── checkpointing.py ├── checkpointing_test.py ├── datasets │ ├── __init__.py │ ├── data_loaders.py │ ├── data_utils.py │ └── lowdata_wrapper.py ├── hyper.py ├── hyper_test.py ├── metrics │ ├── __init__.py │ └── metrics.py ├── model_zoo │ ├── __init__.py │ └── resnet.py └── pix │ ├── __init__.py │ ├── augment.py │ ├── color_conversion.py │ ├── corruptions.py │ └── postprocessing.py ├── requirements.txt └── run.sh /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to Contribute 2 | 3 | We'd love to accept your patches and contributions to this project. There are 4 | just a few small guidelines you need to follow. 5 | 6 | ## Contributor License Agreement 7 | 8 | Contributions to this project must be accompanied by a Contributor License 9 | Agreement (CLA). You (or your employer) retain the copyright to your 10 | contribution; this simply gives us permission to use and redistribute your 11 | contributions as part of the project. Head over to 12 | to see your current agreements on file or 13 | to sign a new one. 14 | 15 | You generally only need to submit a CLA once, so if you've already submitted one 16 | (even if it was for a different project), you probably don't need to do it 17 | again. 18 | 19 | ## Code Reviews 20 | 21 | All submissions, including submissions by project members, require review. We 22 | use GitHub pull requests for this purpose. Consult 23 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more 24 | information on using pull requests. 25 | 26 | ## Community Guidelines 27 | 28 | This project follows 29 | [Google's Open Source Community Guidelines](https://opensource.google/conduct/). -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Distribution Shift Framework 2 | 3 | This repository contains the code of the distribution shift framework presented 4 | in [A Fine-Grained Analysis on Distribution Shift](https://openreview.net/forum?id=Dl4LetuLdyK) 5 | (Wiles et al., 2022). 6 | 7 | ## Contents 8 | 9 | The framework allows to train models with different training methods on 10 | datasets undergoing specific kinds of distribution shift. 11 | 12 | ### Training Methods 13 | 14 | Currently the following training methods are supported (by setting the 15 | `algorithm` [config option](#config-options)): 16 | 17 | * Empirical Risk Minimization (**ERM**, [Vapnik, 1992](https://papers.nips.cc/paper/1991/hash/ff4d5fbbafdf976cfdc032e3bde78de5-Abstract.html)) 18 | * Invariant Risk Minimization (**IRM**, [Arjovsky et al., 2019](https://arxiv.org/abs/1907.02893)) 19 | * Deep Correlation Alignment (**Deep CORAL**, [Sun & Saenko, 2016](https://link.springer.com/chapter/10.1007/978-3-319-49409-8_35)) 20 | * Domain-Adversarial Training of Neural Networks (**DANN**, [Ganin et al., 2016](https://jmlr.org/papers/v17/15-239.html)) 21 | * Style-Agnostic Networks (**SagNet**, [Nam et al., 2021](https://openaccess.thecvf.com/content/CVPR2021/html/Nam_Reducing_Domain_Gap_by_Reducing_Style_Bias_CVPR_2021_paper.html)) 22 | * (Batch Normalization Adaption (**BN-Adapt**, [Schneider et al., 2020](https://proceedings.neurips.cc/paper/2020/hash/85690f81aadc1749175c187784afc9ee-Abstract.html)) 23 | * Just Train Twice (**JTT**, [Liu et al., 2021](http://proceedings.mlr.press/v139/liu21f.html)) 24 | * Inter-domain Mixup (**MixUp**, [Gulrajani & LopezPaz, 2021](https://openreview.net/forum?id=lQdXeXDoWtI)) 25 | 26 | ### Model Architectures 27 | 28 | The `model` [config option](#config-options) can be set to one of the following 29 | architectures 30 | 31 | * ResNet18, ResNet50, ResNet101 ([He et al., 2016](https://ieeexplore.ieee.org/document/7780459)) 32 | * MLP ([Vapnik, 1992](https://papers.nips.cc/paper/1991/hash/ff4d5fbbafdf976cfdc032e3bde78de5-Abstract.html)) 33 | 34 | ### Datasets 35 | 36 | You can train on the following datasets (by setting the `dataset_name` 37 | [config option.](#config-options)): 38 | 39 | * dSprites ([Matthey et al., 2017](https://github.com/deepmind/dsprites-dataset)) 40 | * SmallNorb ([LeCun et al., 2004](https://ieeexplore.ieee.org/document/1315150)) 41 | * Shapes3D ([Burgess & Kim, 2018](https://github.com/deepmind/3d-shapes)) 42 | 43 | Each dataset has a task (e.g. shape prediction on dSprites, set with the `label` 44 | [config option](#config-options)) and a set of properties (e.g. the colour of 45 | the shape in dSprites, set with the `property_label` 46 | [config option](#config-options)). 47 | 48 | ### Distribution Shift Scenarios 49 | 50 | You can evaluate your model on different conditions by varying the distribution 51 | of labels and properties in the configs. For each part of the distribution, 52 | you then assign a probability of sampling from that part of the distribution. 53 | 54 | * **Unseen data shift** (`ood`): Some parts of the distribution of the property 55 | are unseen at training time (e.g. certain colours may be unseen in 56 | dSprites). 57 | * **Spurious correlation** (`correlated`): Some property is correlated with the 58 | label at training time but not at test (e.g. all circles are red in 59 | training). 60 | * **Low data drift** (`lowdata`): Certain combinations of label and property are seen at a 61 | a lower rate during training while they are uniformly distributed during 62 | test. 63 | 64 | Additionally you can modify these scenarios with two conditions: 65 | 66 | * **Label noise** (`noise`): A certain percentage of the training labels are 67 | corrupted. 68 | * **Fixed dataset size** (`fixeddata`): We reduce the total training dataset 69 | size to a fixed amount. 70 | 71 | These scenarios can be set through the `test_case` 72 | [config option.](#config-options)) with the keywords in parenthesis and an 73 | optional modifier separated by a full stop, e.g. `lowdata.noise` for low data 74 | drift with added label noise. 75 | 76 | ### Future Additions 77 | 78 | We plan to add additional methods, models and datasets from the paper as well 79 | as the raw results from all the experiments. 80 | 81 | ## Usage Instructions 82 | 83 | ### Installing 84 | 85 | The following has been tested using Python 3.9.9. 86 | 87 | For GPU support with JAX, edit `requirements.txt` before running `run.sh` 88 | (e.g., use `jaxline==0.1.67+cuda111`). See JAX's installation 89 | [instructions](https://github.com/google/jax#installation) for more details. 90 | 91 | Execute `run.sh` to create and activate a virtualenv, install all necessary 92 | dependencies and run a test program to ensure that you can import all the 93 | modules. 94 | 95 | ``` 96 | # Run from the parent directory. 97 | sh distribution_shift_framework/run.sh 98 | ``` 99 | 100 | 101 | ### Running the Code 102 | 103 | To train a model, use this virtualenv: 104 | 105 | ``` 106 | source /tmp/distribution_shift_framework/bin/activate 107 | ``` 108 | 109 | and then run 110 | 111 | ``` 112 | python3 -m distribution_shift_framework.classification.experiment \ 113 | --jaxline_mode=train \ 114 | --config=distribution_shift_framework/classification/config.py 115 | ``` 116 | 117 | For evaluation run 118 | 119 | ``` 120 | python3 -m distribution_shift_framework.classification.experiment \ 121 | --jaxline_mode=eval \ 122 | --config=distribution_shift_framework/classification/config.py 123 | ``` 124 | 125 | ### Config Options {#config-options} 126 | 127 | Common changes can be done through an options string following the config file. 128 | The following options are available: 129 | 130 | * `algorithm`: What training method to use for training. 131 | * `model`:: The model architecture to evaluate. 132 | * `dataset_name`: The name of the dataset. 133 | * `test_case`: Which of the distribution shift scenarios to set up. 134 | * `label`: The label we're predicting. 135 | * `property_label`: Which property is treated as in or out of 136 | distribution (for the ood test_case), is correlated with the label 137 | (for the correlated setup) and is treated as having a low data region 138 | (for the low_data setup). 139 | * `number_of_seeds`: How many seeds to sweep over. 140 | * `batch_size`: Batch size used for training and evaluation. 141 | * `training_steps`: How many steps to train for. 142 | * `pretrained_checkpoint`: Path to a checkpoint for a pretrained model. 143 | * `overwrite_image_size`: Height and width to resize the images to. 0 means 144 | no resizing. 145 | * `eval_specific_ckpt`: Path to a checkpoint for a one time evaluation. 146 | * `wids`: Which wids of the checkpoint to look at. 147 | * `sweep_index`: Which experiment from the sweep to run. 148 | * `use_fake_data`: Whether to use fake data for testing. 149 | 150 | 151 | Multiple options need to be separated by commas. An example would be 152 | 153 | ``` 154 | python3 -m distribution_shift_framework.classification.experiment \ 155 | --jaxline_mode=train \ 156 | --config=distribution_shift_framework/classification/config.py:algorithm=SagNet,test_case=lowdata.noise,model=truncatedresnet18,property_label=label_object_hue,label=label_shape,dataset_name=shapes3d 157 | ``` 158 | 159 | Which would train a **truncated ResNet18** with the **SagNet** algorithm in the 160 | **low data** setting with added **label noise** on the **Shapes3D** dataset. 161 | **Shape** is used as the label for classification while **object hue** is used 162 | as the property that the distribution shifts over. 163 | 164 | ### Sweeps 165 | 166 | By default the program generates sweeps over multiple hyper-parameters depending 167 | on the chosen training method, dataset and distribution shift scenario. The 168 | `sweep_index` option lets you choose which of the configs in the sweep you want 169 | to run. 170 | 171 | ## Citing this work 172 | 173 | If you use this code (or any derived code) in your work, please cite the 174 | accompanying paper: 175 | 176 | ``` 177 | @inproceedings{wiles2022fine, 178 | title={A Fine-Grained Analysis on Distribution Shift}, 179 | author={Olivia Wiles and Sven Gowal and Florian Stimberg and Sylvestre-Alvise Rebuffi and Ira Ktena and Krishnamurthy Dj Dvijotham and Ali Taylan Cemgil}, 180 | booktitle={International Conference on Learning Representations}, 181 | year={2022}, 182 | url={https://openreview.net/forum?id=Dl4LetuLdyK} 183 | } 184 | ``` 185 | 186 | ## License and Disclaimer 187 | 188 | Copyright 2022 DeepMind Technologies Limited. 189 | 190 | All software is licensed under the Apache License, Version 2.0 (Apache 2.0); 191 | you may not use this file except in compliance with the License. You may obtain 192 | a copy of the Apache 2.0 license at 193 | 194 | [https://www.apache.org/licenses/LICENSE-2.0](https://www.apache.org/licenses/LICENSE-2.0) 195 | 196 | All non-code materials are licensed under the Creative Commons Attribution 4.0 197 | International License (CC-BY License). You may obtain a copy of the CC-BY 198 | License at: 199 | 200 | [https://creativecommons.org/licenses/by/4.0/legalcode](https://creativecommons.org/licenses/by/4.0/legalcode) 201 | 202 | You may not use the non-code portions of this file except in compliance with the 203 | CC-BY License. 204 | 205 | Unless required by applicable law or agreed to in writing, software distributed 206 | under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR 207 | CONDITIONS OF ANY KIND, either express or implied. See the License for the 208 | specific language governing permissions and limitations under the License. 209 | 210 | This is not an official Google product. 211 | -------------------------------------------------------------------------------- /distribution_shift_framework/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # 3 | # Copyright 2022 DeepMind Technologies Limited 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | -------------------------------------------------------------------------------- /distribution_shift_framework/classification/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # 3 | # Copyright 2022 DeepMind Technologies Limited 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | -------------------------------------------------------------------------------- /distribution_shift_framework/classification/config.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # 3 | # Copyright 2022 DeepMind Technologies Limited 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """Config for imagenet experiment.""" 18 | 19 | import functools 20 | from typing import Any, Callable, List, Mapping, Optional, Tuple 21 | 22 | import chex 23 | from distribution_shift_framework.configs import disentanglement_config 24 | from distribution_shift_framework.core import adapt 25 | from distribution_shift_framework.core import adapt_train 26 | from distribution_shift_framework.core import algorithms 27 | from distribution_shift_framework.core import hyper 28 | from distribution_shift_framework.core.datasets import data_utils 29 | from distribution_shift_framework.core.model_zoo import resnet 30 | from distribution_shift_framework.core.pix import postprocessing 31 | import haiku as hk 32 | from jaxline import base_config 33 | import ml_collections 34 | 35 | 36 | 37 | DATASETS = ('dsprites', 'small_norb', 'shapes3d') 38 | LEARNERS = algorithms.__all__ 39 | ADAPTERS = ('BNAdapt',) 40 | TRAIN_ADAPTERS = ('JTT',) 41 | POSTPROCESSORS = ('mixup',) 42 | ALGORITHMS = LEARNERS + ADAPTERS + TRAIN_ADAPTERS + POSTPROCESSORS 43 | 44 | 45 | ConfigAndSweeps = Tuple[ml_collections.ConfigDict, List[hyper.Sweep]] 46 | 47 | _EXP = 'config.experiment_kwargs.config' 48 | 49 | 50 | def parse_options( 51 | options: str, 52 | defaults: Mapping[str, Any], 53 | types: Optional[Mapping[str, Callable[[str], Any]]] = None 54 | ) -> Mapping[str, Any]: 55 | """Parse a "k1=v1,k2=v2" option string.""" 56 | if not options: 57 | return defaults 58 | if types is None: 59 | types = {} 60 | else: 61 | types = dict(**types) 62 | for k, v in defaults.items(): 63 | if k not in types: 64 | types[k] = type(v) 65 | kwargs = dict(t.split('=', 1) for t in options.split(',')) 66 | for k, v in kwargs.items(): 67 | if k in types: # Default type is `str`. 68 | kwargs[k] = ((v in ('True', 'true', 'yes')) if types[k] == bool 69 | else types[k](v)) 70 | # Only allow options where defaults are specified to avoid typos. 71 | for k in kwargs: 72 | if k not in defaults: 73 | raise ValueError('Unknown option `%s`.' % k) 74 | for k, v in defaults.items(): 75 | if k not in kwargs: 76 | kwargs[k] = v 77 | return kwargs 78 | 79 | 80 | def get_config(options: str = '') -> ml_collections.ConfigDict: 81 | """Return config object for training. 82 | 83 | Args: 84 | options: A list of options that are comma separated with: 85 | key1=value1,key2=value2. The actual key value pairs are the following: 86 | 87 | dataset_name -- The name of the dataset. 88 | model -- The model to evaluate. 89 | test_case -- Which of ood or correlated setups to run. 90 | label -- The label we're predicting. 91 | property_label -- Which property is treated as in or out of 92 | distribution (for the ood test_case), is correlated with the label 93 | (for the correlated setup) and is treated as having a low data region 94 | (for the low_data setup). 95 | algorithm -- What algorithm to use for training. 96 | number_of_seeds -- How many seeds to evaluate the models with. 97 | batch_size -- Batch size used for training and evaluation. 98 | training_steps -- How many steps to train for. 99 | pretrained_checkpoint -- Path to a checkpoint for a pretrained model. 100 | overwrite_image_size -- Height and width to resize the images to. 0 means 101 | no resizing. 102 | eval_specific_ckpt -- Path to a checkpoint for a one time evaluation. 103 | wids -- Which wids of the checkpoint to look at. 104 | sweep_index -- Which experiment from the sweep to run. 105 | use_fake_data -- Whether to use fake data for testing. 106 | Returns: 107 | ConfigDict: A dictionary of parameters. 108 | """ 109 | options = parse_options( 110 | options, 111 | defaults={ 112 | 'dataset_name': 'dsprites', 113 | 'model': 'resnet18', 114 | 'test_case': 'ood', 115 | 'label': 'label_shape', 116 | 'property_label': 'label_color', 117 | 'algorithm': 'ERM', 118 | 'number_of_seeds': 1, 119 | 'batch_size': 128, 120 | 'training_steps': 100_000, 121 | 'pretrained_checkpoint': '', 122 | 'overwrite_image_size': 0, # Zero means no resizing. 123 | 'eval_specific_ckpt': '', 124 | 'wids': '1-1', 125 | 'sweep_index': 0, 126 | 'use_fake_data': False, 127 | }) 128 | assert options['dataset_name'] in DATASETS 129 | assert options['algorithm'] in ALGORITHMS 130 | if options['algorithm'] in LEARNERS: 131 | learner = options['algorithm'] 132 | adapter = '' 133 | train_adapter = '' 134 | postprocessor = '' 135 | else: 136 | learner = 'ERM' 137 | if options['algorithm'] in ADAPTERS: 138 | adapter = options['algorithm'] 139 | elif options['algorithm'] in TRAIN_ADAPTERS: 140 | train_adapter = options['algorithm'] 141 | elif options['algorithm'] in POSTPROCESSORS: 142 | postprocessor = options['algorithm'] 143 | config = base_config.get_base_config() 144 | config.random_seed = 0 145 | config.checkpoint_dir = '/tmp' 146 | config.train_checkpoint_all_hosts = False 147 | 148 | training_steps = options['training_steps'] 149 | 150 | config.experiment_kwargs = ml_collections.ConfigDict() 151 | 152 | exp = config.experiment_kwargs.config = ml_collections.ConfigDict() 153 | exp.use_fake_data = options['use_fake_data'] 154 | exp.enable_double_transpose = False 155 | 156 | # Training. 157 | exp.training = ml_collections.ConfigDict() 158 | exp.training.use_gt_images = False 159 | exp.training.save_images = False 160 | exp.training.batch_size = options['batch_size'] 161 | exp.training.adversarial_weight = 1. 162 | exp.training.label_noise = 0.0 163 | 164 | # Evaluation. 165 | exp.evaluation = ml_collections.ConfigDict() 166 | exp.evaluation.batch_size = options['batch_size'] 167 | exp.evaluation.metrics = ['top1_accuracy'] 168 | 169 | # Optimizer. 170 | exp.optimizer = ml_collections.ConfigDict() 171 | exp.optimizer.name = 'adam' 172 | exp.optimizer.kwargs = dict(learning_rate=0.001) 173 | 174 | # Data. 175 | exp.data = ml_collections.ConfigDict() 176 | if data_utils.is_disentanglement_dataset(options['dataset_name']): 177 | exp.data = disentanglement_config.get_renderers( 178 | options['test_case'], dataset_name=options['dataset_name'], 179 | label=options['label'], 180 | property_label=options['property_label']) 181 | data_sweep = disentanglement_config.get_renderer_sweep( 182 | options['test_case']) 183 | else: 184 | dataset_name = options['dataset_name'] 185 | raise ValueError(f'Unsupported dataset {dataset_name}') 186 | 187 | if exp.use_fake_data: 188 | # Data loaders skip valid and test samples and default values are so high 189 | # that we would need to generate too many fake datapoints. 190 | batch_size = options['batch_size'] 191 | if options['dataset_name'] in ('dsprites', 'shapes3d'): 192 | exp.data.train_kwargs.load_kwargs.dataset_kwargs.valid_size = batch_size 193 | exp.data.train_kwargs.load_kwargs.dataset_kwargs.test_size = batch_size 194 | exp.data.test_kwargs.load_kwargs.valid_size = batch_size 195 | exp.data.test_kwargs.load_kwargs.test_size = batch_size 196 | elif options['dataset_name'] == 'small_norb': 197 | exp.data.train_kwargs.load_kwargs.dataset_kwargs.valid_size = batch_size 198 | exp.data.test_kwargs.load_kwargs.valid_size = batch_size 199 | 200 | # Model. 201 | model = options['model'] 202 | exp.model, model_sweep = globals()[f'get_{model}_config']( 203 | num_classes=exp.data.n_classes, resize_to=options['overwrite_image_size']) 204 | exp.pretrained_checkpoint = options['pretrained_checkpoint'] 205 | 206 | # Learning algorithm. 207 | exp.training.algorithm, learner_sweep = get_learner( 208 | learner, model, exp.data.n_classes) 209 | 210 | # Test time adaptation. 211 | if adapter: 212 | exp.adapter = get_adapter(adapter, exp.data.n_properties) 213 | else: 214 | exp.adapter = ml_collections.ConfigDict() 215 | 216 | # Adapt training parameters and state. 217 | if train_adapter: 218 | exp.training.learn_adapt = get_train_adapter( 219 | train_adapter, training_steps=training_steps) 220 | else: 221 | exp.training.learn_adapt = ml_collections.ConfigDict() 222 | 223 | # Postprocessing. 224 | if postprocessor: 225 | exp.postprocess = get_postprocessing_step(postprocessor) 226 | else: 227 | exp.postprocess = ml_collections.ConfigDict() 228 | 229 | if exp.data.train_kwargs.load_kwargs.get('shuffle_pre_sampling', False): 230 | exp_train_kwargs = 'config.experiment_kwargs.config.data.train_kwargs.' 231 | seeds = list(range(options['number_of_seeds'])) 232 | random_seedsweep = hyper.zipit([ 233 | hyper.sweep('config.random_seed', seeds), 234 | hyper.sweep(f'{exp_train_kwargs}load_kwargs.shuffle_pre_sample_seed', 235 | seeds)]) 236 | else: 237 | random_seedsweep = hyper.sweep('config.random_seed', 238 | list(range(options['number_of_seeds']))) 239 | 240 | all_sweeps = hyper.product( 241 | [random_seedsweep] + [data_sweep] + model_sweep + learner_sweep) 242 | 243 | dataset_name = options['dataset_name'] 244 | 245 | config.autoxprof_warmup_steps = 5 246 | config.autoxprof_measure_time_seconds = 50 247 | 248 | # Use so get consistency between different models with different speeds. 249 | config.interval_type = 'steps' 250 | 251 | config.training_steps = training_steps 252 | config.log_train_data_interval = 1_000 253 | config.log_tensors_interval = 1_000 254 | config.save_checkpoint_interval = 1_000 255 | config.eval_specific_checkpoint_dir = options['eval_specific_ckpt'] 256 | if options['eval_specific_ckpt']: 257 | min_wid, max_wid = [int(w) for w in options['wids'].split('-')] 258 | config.eval_only = True 259 | config.one_off_evaluate = True 260 | all_sweeps = hyper.product([hyper.zipit([ 261 | hyper.sweep('config.eval_specific_checkpoint_dir', 262 | [options['eval_specific_ckpt'].format(wid=w) 263 | for w in range(min_wid, max_wid+1)]), 264 | all_sweeps])]) 265 | 266 | else: 267 | config.eval_only = False 268 | config.best_model_eval_metric = 'top1_accuracy' 269 | 270 | config.update_from_flattened_dict(all_sweeps[options['sweep_index']], 271 | 'config.') 272 | 273 | # Prevents accidentally setting keys that aren't recognized (e.g. in tests). 274 | config.lock() 275 | 276 | return config 277 | 278 | 279 | def get_postprocessing_step(postprocessing_name: str 280 | ) -> ml_collections.ConfigDict: 281 | """Config for postprocessing steps.""" 282 | postprocess = ml_collections.ConfigDict() 283 | postprocess.fn = getattr(postprocessing, postprocessing_name) 284 | postprocess.kwargs = ml_collections.ConfigDict() 285 | 286 | if postprocessing_name == 'mixup': 287 | postprocess.kwargs.alpha = 0.2 288 | postprocess.kwargs.beta = 0.2 289 | 290 | return postprocess 291 | 292 | 293 | def get_train_adapter(adapter_name: str, training_steps: int 294 | ) -> ml_collections.ConfigDict: 295 | """Config for adapting the training parameters.""" 296 | adapter = ml_collections.ConfigDict() 297 | adapter.fn = getattr(adapt_train, adapter_name) 298 | adapter.kwargs = ml_collections.ConfigDict() 299 | 300 | if adapter_name == 'JTT': 301 | adapter.kwargs.lmbda = 20 302 | adapter.kwargs.num_steps_in_first_iter = training_steps // 2 303 | return adapter 304 | 305 | 306 | def get_adapter(adapt_name: str, num_properties: int 307 | ) -> ml_collections.ConfigDict: 308 | """Config for how to adapt the model at test time.""" 309 | adapter = ml_collections.ConfigDict() 310 | adapter.fn = getattr(adapt, adapt_name) 311 | adapter.kwargs = ml_collections.ConfigDict(dict(n_properties=num_properties)) 312 | adapter.num_adaptation_steps = 1_000 313 | return adapter 314 | 315 | 316 | def get_learner(learner_name: str, 317 | model_name: str, 318 | num_classes: int = 10) -> ConfigAndSweeps: 319 | """Config for which learning algorithm to use.""" 320 | learner = ml_collections.ConfigDict() 321 | learner.fn = getattr(algorithms, learner_name) 322 | learner.kwargs = ml_collections.ConfigDict() 323 | 324 | learner_sweep = [] 325 | exp_algthm = f'{_EXP}.training.algorithm' 326 | if learner_name == 'IRM': 327 | learner.kwargs.lambda_penalty = 1. 328 | learner_sweep = [ 329 | hyper.sweep(f'{exp_algthm}.kwargs.lambda_penalty', 330 | [0.01, 0.1, 1, 10]) 331 | ] 332 | elif learner_name == 'DANN': 333 | learner.kwargs.mlp_output_sizes = () 334 | exp = f'{_EXP}.training' 335 | learner_sweep = [ 336 | hyper.sweep(f'{exp}.adversarial_weight', 337 | [0.01, 0.1, 1, 10]), 338 | hyper.sweep(f'{exp_algthm}.kwargs.mlp_output_sizes', 339 | [(64, 64)]) 340 | ] 341 | elif learner_name == 'CORAL': 342 | learner.kwargs.coral_weight = 1. 343 | learner_sweep = [ 344 | hyper.sweep(f'{exp_algthm}.kwargs.coral_weight', 345 | [0.01, 0.1, 1, 10]) 346 | ] 347 | elif learner_name == 'SagNet': 348 | if model_name == 'truncatedresnet18': 349 | learner.kwargs.content_net_kwargs = ml_collections.ConfigDict(dict( 350 | output_sizes=(num_classes,))) 351 | learner.kwargs.style_net_kwargs = ml_collections.ConfigDict(dict( 352 | output_sizes=(num_classes,))) 353 | else: 354 | learner.kwargs.content_net_kwargs = ml_collections.ConfigDict(dict( 355 | output_sizes=(64, 64, num_classes))) 356 | learner.kwargs.style_net_kwargs = ml_collections.ConfigDict(dict( 357 | output_sizes=(64, 64, num_classes))) 358 | return learner, learner_sweep 359 | 360 | 361 | def _get_resizer(size: Optional[int]) -> Callable[[chex.Array], chex.Array]: 362 | if size is not None and size > 0: 363 | return functools.partial(data_utils.resize, size=(size, size)) 364 | return lambda x: x 365 | 366 | 367 | def get_mlp_config(n_layers: int = 4, n_hidden: int = 256, 368 | num_classes: int = 10, resize_to: Optional[int] = None 369 | ) -> ConfigAndSweeps: 370 | """Returns an MLP config and sweeps.""" 371 | resize = _get_resizer(resize_to) 372 | mlp = ml_collections.ConfigDict(dict( 373 | constructor=hk.nets.MLP, 374 | kwargs=dict(output_sizes=[n_hidden] * n_layers + [num_classes]), 375 | preprocess=lambda x: resize(x).reshape((x.shape[0], -1)))) 376 | sweep = hyper.sweep(f'{_EXP}.optimizer.kwargs.learning_rate', 377 | [0.01, 0.001, 1e-4]) 378 | return mlp, [sweep] 379 | 380 | 381 | def get_resnet18_config(num_classes: int = 10, 382 | resize_to: Optional[int] = None) -> ConfigAndSweeps: 383 | cnn = ml_collections.ConfigDict(dict( 384 | constructor=hk.nets.ResNet18, 385 | kwargs=dict(num_classes=num_classes), 386 | preprocess=_get_resizer(resize_to))) 387 | sweep = hyper.sweep(f'{_EXP}.optimizer.kwargs.learning_rate', 388 | [0.01, 0.001, 1e-4]) 389 | return cnn, [sweep] 390 | 391 | 392 | def get_resnet50_config(num_classes: int = 10, 393 | resize_to: Optional[int] = None) -> ConfigAndSweeps: 394 | cnn = ml_collections.ConfigDict(dict( 395 | constructor=hk.nets.ResNet50, 396 | kwargs=dict(num_classes=num_classes), 397 | preprocess=_get_resizer(resize_to))) 398 | sweep = hyper.sweep(f'{_EXP}.optimizer.kwargs.learning_rate', 399 | [0.01, 0.001, 1e-4]) 400 | return cnn, [sweep] 401 | 402 | 403 | def get_resnet101_config(num_classes: int = 10, 404 | resize_to: Optional[int] = None) -> ConfigAndSweeps: 405 | cnn = ml_collections.ConfigDict(dict( 406 | constructor=hk.nets.ResNet101, 407 | kwargs=dict(num_classes=num_classes), 408 | preprocess=_get_resizer(resize_to))) 409 | sweep = hyper.sweep(f'{_EXP}.optimizer.kwargs.learning_rate', 410 | [0.01, 0.001, 1e-4]) 411 | return cnn, [sweep] 412 | 413 | 414 | def get_truncatedresnet18_config( 415 | num_classes: int = 10, resize_to: Optional[int] = None) -> ConfigAndSweeps: 416 | """Config for a truncated ResNet.""" 417 | cnn = ml_collections.ConfigDict(dict( 418 | constructor=resnet.ResNet18, 419 | kwargs=dict(num_classes=num_classes), 420 | preprocess=_get_resizer(resize_to))) 421 | sweep = hyper.sweep(f'{_EXP}.optimizer.kwargs.learning_rate', 422 | [0.01, 0.001, 1e-4]) 423 | return cnn, [sweep] 424 | -------------------------------------------------------------------------------- /distribution_shift_framework/classification/experiment.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # 3 | # Copyright 2022 DeepMind Technologies Limited 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """Run the formalisation pipeline.""" 18 | 19 | import functools 20 | 21 | from absl import app 22 | from absl import flags 23 | from distribution_shift_framework.classification import experiment_lib 24 | from jaxline import platform 25 | 26 | 27 | if __name__ == '__main__': 28 | flags.mark_flag_as_required('config') 29 | app.run(functools.partial(platform.main, experiment_lib.Experiment)) 30 | 31 | -------------------------------------------------------------------------------- /distribution_shift_framework/classification/experiment_lib.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # 3 | # Copyright 2022 DeepMind Technologies Limited 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """Run the formalisation pipeline.""" 18 | 19 | import contextlib 20 | import functools 21 | import os 22 | from typing import Generator, Mapping, Optional, Tuple 23 | 24 | from absl import flags 25 | from absl import logging 26 | import chex 27 | from distribution_shift_framework.core import checkpointing 28 | from distribution_shift_framework.core.datasets import data_utils 29 | from distribution_shift_framework.core.metrics import metrics 30 | import haiku as hk 31 | import jax 32 | import jax.numpy as jnp 33 | from jaxline import experiment 34 | from jaxline import utils 35 | import ml_collections 36 | import numpy as np 37 | import optax 38 | from six.moves import cPickle as pickle 39 | import tensorflow as tf 40 | import tensorflow_datasets as tfds 41 | 42 | FLAGS = flags.FLAGS 43 | 44 | 45 | def get_per_device_batch_size(total_batch_size: int) -> int: 46 | num_devices = jax.device_count() 47 | per_device_batch_size, ragged = divmod(total_batch_size, num_devices) 48 | 49 | if ragged: 50 | raise ValueError( 51 | f'Global batch size {total_batch_size} must be divisible by the ' 52 | f'total number of devices {num_devices}') 53 | return per_device_batch_size 54 | 55 | 56 | class Experiment(experiment.AbstractExperiment): 57 | """Formalisation experiment.""" 58 | CHECKPOINT_ATTRS = { 59 | '_params': 'params', 60 | '_state': 'state', 61 | '_opt_state': 'opt_state', 62 | '_d_params': 'd_params', 63 | '_d_state': 'd_state', 64 | '_d_opt_state': 'd_opt_state', 65 | '_adapt_params': 'adapt_params', 66 | '_adapt_state': 'adapt_state' 67 | } 68 | 69 | def __init__(self, mode: str, init_rng: chex.PRNGKey, 70 | config: ml_collections.ConfigDict): 71 | """Initializes experiment.""" 72 | 73 | super(Experiment, self).__init__(mode=mode, init_rng=init_rng) 74 | 75 | self.mode = mode 76 | self.config = config 77 | self.init_rng = init_rng 78 | 79 | # Set up discriminator parameters. 80 | self._d_params = None 81 | self._d_state = None 82 | self._d_opt_state = None 83 | 84 | # Double transpose trick to improve performance on TPUs. 85 | self._should_transpose_images = ( 86 | config.enable_double_transpose and 87 | jax.local_devices()[0].platform == 'tpu') 88 | 89 | self._params = None # params 90 | self._state = None # network state for stats like batchnorm 91 | self._opt_state = None # optimizer state 92 | self._adapt_params = None 93 | self._adapt_state = None 94 | self._label = config.data.label 95 | 96 | with utils.log_activity('transform functions'): 97 | self.forward = hk.transform_with_state(self._forward_fn) 98 | self.eval_batch = jax.pmap(self._eval_batch, axis_name='i') 99 | self.learner_fn = hk.transform_with_state(self._learner_fn) 100 | self.adversarial_fn = hk.transform_with_state(self._adversarial_fn) 101 | self.adapt_fn = self._adapt_fn 102 | self.adaptor = None 103 | 104 | self._update_func = jax.pmap( 105 | self._update_func, axis_name='i', donate_argnums=(0, 1, 2)) 106 | 107 | if mode == 'train': 108 | with utils.log_activity('initialize training'): 109 | self._init_train(init_rng) 110 | 111 | if getattr(self.config.training.learn_adapt, 'fn', None): 112 | learner_adapt_fn = self.config.training.learn_adapt.fn 113 | learner_adapt_kwargs = self.config.training.learn_adapt.kwargs 114 | self._train_adapter = learner_adapt_fn(**learner_adapt_kwargs) 115 | if self._adapt_params is None: 116 | self._adapt_params = self._params 117 | self._adapt_state = self._state 118 | self._train_adapter.set(self._adapt_params, self._adapt_state) 119 | else: 120 | self._train_adapter = None 121 | 122 | def optimizer(self) -> optax.GradientTransformation: 123 | optimizer_fn = getattr(optax, self.config.optimizer.name) 124 | return optimizer_fn(**self.config.optimizer.kwargs) 125 | 126 | def _maybe_undo_transpose_images(self, images: chex.Array) -> chex.Array: 127 | if self._should_transpose_images: 128 | return jnp.transpose(images, (1, 2, 3, 0)) # NHWC -> HWCN. 129 | return images 130 | 131 | def _maybe_transpose_images(self, images: chex.Array) -> chex.Array: 132 | if self._should_transpose_images: 133 | # We use the double transpose trick to improve performance for TPUs. 134 | # Note that there is a matching NHWC->HWCN transpose in the data pipeline. 135 | # Here we reset back to NHWC like our model expects. The compiler cannot 136 | # make this optimization for us since our data pipeline and model are 137 | # compiled separately. 138 | images = jnp.transpose(images, (3, 0, 1, 2)) # HWCN -> NHWC. 139 | return images 140 | 141 | def _postprocess_fn( 142 | self, 143 | inputs: data_utils.Batch, 144 | rng: chex.PRNGKey 145 | ) -> data_utils.Batch: 146 | if not hasattr(self.config, 'postprocessing'): 147 | return inputs 148 | postprocessing = getattr(self.config.postprocessing, 'fn', None) 149 | if postprocessing is None: 150 | return inputs 151 | postprocess_fn = functools.partial(postprocessing, 152 | **self.config.postprocessing.kwargs) 153 | images = inputs['image'] 154 | labels = inputs['one_hot_label'] 155 | postprocessed_images, postprocessed_labels = postprocess_fn( 156 | images, labels, rng=rng) 157 | 158 | postprocessed_inputs = dict(**inputs) 159 | postprocessed_inputs['image'] = postprocessed_images 160 | postprocessed_inputs['one_hot_label'] = postprocessed_labels 161 | return postprocessed_inputs 162 | 163 | def _learner_fn(self, inputs: data_utils.Batch, 164 | reduction='mean') -> Tuple[data_utils.ScalarDict, chex.Array]: 165 | 166 | logits = self._forward_fn(inputs, is_training=True) 167 | 168 | if getattr(self.config.data, 'label_property', '') in inputs.keys(): 169 | property_vs = inputs[self.config.data.label_property] 170 | property_onehot = hk.one_hot(property_vs, self.config.data.n_properties) 171 | else: 172 | property_onehot = None 173 | 174 | algorithm_fn = self.config.training.algorithm.fn 175 | kwargs = self.config.training.algorithm.kwargs 176 | scalars, logits = algorithm_fn(**kwargs)( 177 | logits, inputs['one_hot_label'], property_vs=property_onehot, 178 | reduction=reduction) 179 | 180 | predicted_label = jnp.argmax(logits, axis=-1) 181 | top1_acc = jnp.equal(predicted_label, 182 | inputs[self._label]).astype(jnp.float32) 183 | scalars['top1_acc'] = top1_acc.mean() 184 | 185 | return scalars, logits 186 | 187 | def learner_adapt_weights_fn( 188 | self, params: optax.Params, state: optax.OptState, 189 | old_params: optax.Params, old_state: optax.OptState, 190 | inputs: data_utils.Batch, rng: chex.PRNGKey, 191 | global_step: chex.Array 192 | ) -> Tuple[Tuple[data_utils.ScalarDict, chex.Array], optax.OptState]: 193 | (scalars, logits), g_state = self._train_adapter( 194 | fn=functools.partial(self.learner_fn.apply, reduction=None), 195 | params=params, state=state, inputs=inputs, global_step=global_step, 196 | rng=rng, old_params=old_params, old_state=old_state) 197 | return (scalars, logits), g_state 198 | 199 | def _adversarial_fn(self, logits: chex.Array, 200 | inputs: data_utils.Batch) -> data_utils.ScalarDict: 201 | if getattr(self.config.data, 'label_property', '') in inputs.keys(): 202 | property_vs = inputs[self.config.data.label_property] 203 | property_onehot = hk.one_hot(property_vs, self.config.data.n_properties) 204 | else: 205 | property_onehot = None 206 | 207 | one_hot_labels = inputs['one_hot_label'] 208 | algorithm_fn = self.config.training.algorithm.fn 209 | kwargs = self.config.training.algorithm.kwargs 210 | return algorithm_fn(**kwargs).adversary( 211 | logits, property_vs=property_onehot, reduction='mean', 212 | targets=one_hot_labels) 213 | 214 | def _adapt_fn(self, params: optax.Params, state: optax.OptState, 215 | rng: chex.PRNGKey, is_final_eval: bool = False): 216 | adapt_fn = getattr(self.config.adapter, 'fn') 217 | adapt_kwargs = getattr(self.config.adapter, 'kwargs') 218 | 219 | forward_fn = functools.partial(self.forward.apply, is_training=True, 220 | test_local_stats=False) 221 | self.adaptor = adapt_fn(init_params=params, 222 | init_state=state, 223 | forward=jax.pmap(forward_fn, axis_name='i'), 224 | **adapt_kwargs) 225 | 226 | per_device_batch_size = get_per_device_batch_size( 227 | self.config.training.batch_size) 228 | 229 | ds = self._load_data(per_device_batch_size=per_device_batch_size, 230 | is_training=False, 231 | data_kwargs=self.config.data.test_kwargs) 232 | 233 | for step, batch in enumerate(ds, 1): 234 | logging.info('Updating using an adaptor function.') 235 | self.adaptor.update(batch, batch[self.config.data.label_property], rng) 236 | if (not is_final_eval and 237 | step > getattr(self.config.adapter, 'num_adaptation_steps')): 238 | break 239 | 240 | def _forward_fn(self, 241 | inputs: data_utils.Batch, 242 | is_training: bool, 243 | test_local_stats: bool = False) -> chex.Array: 244 | model_constructor = self.config.model.constructor 245 | model_instance = model_constructor(**self.config.model.kwargs.to_dict()) 246 | 247 | images = inputs['image'] 248 | images = self._maybe_transpose_images(images) 249 | images = self.config.model.preprocess(images) 250 | 251 | if isinstance(model_instance, hk.nets.MLP): 252 | return model_instance(images) 253 | return model_instance(images, is_training=is_training) 254 | 255 | def _d_loss_fn( 256 | self, d_params: optax.Params, d_state: optax.OptState, inputs: chex.Array, 257 | logits: chex.Array, 258 | rng: chex.PRNGKey 259 | ) -> Tuple[chex.Array, Tuple[data_utils.ScalarDict, optax.OptState]]: 260 | 261 | d_scalars, d_state = self.adversarial_fn.apply(d_params, d_state, rng, 262 | logits, inputs) 263 | if not d_scalars: 264 | # No adversary. 265 | return 0., (d_scalars, d_state) 266 | 267 | scaled_loss = d_scalars['loss'] / jax.device_count() 268 | d_scalars = {f'adv_{k}': v for k, v in d_scalars.items()} 269 | 270 | return scaled_loss, (d_scalars, d_state) 271 | 272 | def _run_postprocess_fn(self, 273 | rng: chex.PRNGKey, 274 | inputs: data_utils.Batch) -> data_utils.Batch: 275 | inputs = self._postprocess_fn(inputs, rng) 276 | return inputs 277 | 278 | def _loss_fn( 279 | self, g_params: optax.Params, 280 | g_state: optax.OptState, 281 | d_params: optax.Params, 282 | d_state: optax.OptState, 283 | inputs: chex.Array, 284 | rng: chex.PRNGKey, 285 | global_step: chex.Array, 286 | old_g_params: Optional[optax.Params] = None, 287 | old_g_state: Optional[optax.OptState] = None 288 | ) -> Tuple[chex.Array, Tuple[ 289 | data_utils.ScalarDict, chex.Array, data_utils.Batch, optax.OptState]]: 290 | # Find the loss according to the generator. 291 | if getattr(self.config.training.learn_adapt, 'fn', None): 292 | # Use generator loss computed by a training adaptation algorithm. 293 | (scalars, logits), g_state = self.learner_adapt_weights_fn( 294 | params=g_params, 295 | state=g_state, 296 | old_params=old_g_params, 297 | old_state=old_g_state, 298 | rng=rng, 299 | inputs=inputs, 300 | global_step=global_step) 301 | else: 302 | (scalars, logits), g_state = self.learner_fn.apply(g_params, g_state, rng, 303 | inputs) 304 | 305 | d_scalars, _ = self.adversarial_fn.apply(d_params, d_state, rng, logits, 306 | inputs) 307 | 308 | # If there is an adversary: 309 | if 'loss' in d_scalars.keys(): 310 | # Want to minimize the loss, so negate it. 311 | adv_weight = self.config.training.adversarial_weight 312 | scalars['loss'] = scalars['loss'] - d_scalars['loss'] * adv_weight 313 | scalars.update({f'gen_adv_{k}': v for k, v in d_scalars.items()}) 314 | 315 | scaled_loss = scalars['loss'] / jax.device_count() 316 | return scaled_loss, (scalars, logits, inputs, g_state) 317 | 318 | # _ _ 319 | # | |_ _ __ __ _(_)_ __ 320 | # | __| '__/ _` | | '_ \ 321 | # | |_| | | (_| | | | | | 322 | # \__|_| \__,_|_|_| |_| 323 | # 324 | 325 | def _prepare_train_batch(self, rng: chex.PRNGKey, 326 | batch: data_utils.Batch) -> data_utils.Batch: 327 | noise_threshold = self.config.training.label_noise 328 | if noise_threshold > 0: 329 | random_labels = jax.random.randint( 330 | rng[0], 331 | shape=batch[self._label].shape, 332 | dtype=batch[self._label].dtype, 333 | minval=0, 334 | maxval=self.config.data.n_classes) 335 | mask = jax.random.uniform(rng[0], 336 | batch[self._label].shape) < noise_threshold 337 | batch[self._label] = (random_labels * mask + 338 | batch[self._label] * (1 - mask)) 339 | batch['one_hot_label'] = hk.one_hot( 340 | batch[self._label], self.config.data.n_classes) 341 | return batch 342 | 343 | def _init_train(self, rng: chex.PRNGKey): 344 | self._train_input = utils.py_prefetch(self._build_train_input) 345 | 346 | if self._params is None: 347 | logging.info('Initializing parameters randomly rather than restoring' 348 | ' from checkpoint.') 349 | batch = next(self._train_input) 350 | batch['one_hot_label'] = hk.one_hot(batch[self._label], 351 | self.config.data.n_classes) 352 | 353 | # Initialize generator. 354 | self._params, self._state = self._init_params(rng, batch) 355 | opt_init, _ = self.optimizer() 356 | self._opt_state = jax.pmap(opt_init)(self._params) 357 | 358 | # Initialize discriminator. 359 | bcast_rng = utils.bcast_local_devices(rng) 360 | (_, dummy_logits), _ = jax.pmap(self.learner_fn.apply)(self._params, 361 | self._state, 362 | bcast_rng, batch) 363 | self._d_params, self._d_state = self._init_d_params( 364 | rng, dummy_logits, batch) 365 | opt_init, _ = self.optimizer() 366 | if self._d_params: 367 | self._d_opt_state = jax.pmap(opt_init)(self._d_params) 368 | else: 369 | # Is empty. 370 | self._d_opt_state = None 371 | 372 | def _init_params( 373 | self, rng: chex.PRNGKey, 374 | batch: data_utils.Batch) -> Tuple[optax.Params, optax.OptState]: 375 | init_net = jax.pmap(self.learner_fn.init) 376 | rng = utils.bcast_local_devices(rng) 377 | params, state = init_net(rng, batch) 378 | if not self.config.pretrained_checkpoint: 379 | return params, state 380 | ckpt_data = checkpointing.load_model( 381 | self.config.pretrained_checkpoint) 382 | ckpt_params, ckpt_state = ckpt_data['params'], ckpt_data['state'] 383 | 384 | ckpt_params = utils.bcast_local_devices(ckpt_params) 385 | ckpt_state = utils.bcast_local_devices(ckpt_state) 386 | 387 | def use_pretrained_if_shapes_match(params, ckpt_params): 388 | if params.shape == ckpt_params.shape: 389 | return ckpt_params 390 | logging.warning('Shape mismatch! Initialized parameter: %s, ' 391 | 'Pretrained parameter: %s.', 392 | params.shape, ckpt_params.shape) 393 | return params 394 | 395 | params = jax.tree_multimap( 396 | use_pretrained_if_shapes_match, params, ckpt_params) 397 | return params, ckpt_state 398 | 399 | def _init_d_params( 400 | self, rng: chex.PRNGKey, logits: chex.Array, 401 | batch: data_utils.Batch) -> Tuple[optax.Params, optax.OptState]: 402 | init_net = jax.pmap(self.adversarial_fn.init) 403 | rng = utils.bcast_local_devices(rng) 404 | return init_net(rng, logits, batch) 405 | 406 | def _write_images(self, writer, global_step: chex.Array, 407 | images: Mapping[str, chex.Array]): 408 | global_step = np.array(utils.get_first(global_step)) 409 | 410 | images_to_write = { 411 | k: self._maybe_transpose_images(utils.get_first(v)) 412 | for k, v in images.items()} 413 | 414 | writer.write_images(global_step, images_to_write) 415 | 416 | def _load_data(self, 417 | per_device_batch_size: int, 418 | is_training: bool, 419 | data_kwargs: ml_collections.ConfigDict 420 | ) -> Generator[data_utils.Batch, None, None]: 421 | 422 | with contextlib.ExitStack() as stack: 423 | if self.config.use_fake_data: 424 | stack.enter_context(tfds.testing.mock_data(num_examples=128)) 425 | ds = data_utils.load_dataset( 426 | is_training=is_training, 427 | batch_dims=[jax.local_device_count(), per_device_batch_size], 428 | transpose=self._should_transpose_images, 429 | data_kwargs=data_kwargs) 430 | return ds 431 | 432 | def _build_train_input(self) -> Generator[data_utils.Batch, None, None]: 433 | per_device_batch_size = get_per_device_batch_size( 434 | self.config.training.batch_size) 435 | return self._load_data(per_device_batch_size=per_device_batch_size, 436 | is_training=True, 437 | data_kwargs=self.config.data.train_kwargs) 438 | 439 | def _update_func( 440 | self, 441 | params: optax.Params, 442 | state: optax.OptState, 443 | opt_state: optax.OptState, 444 | global_step: chex.Array, 445 | batch: data_utils.Batch, 446 | rng: chex.PRNGKey, 447 | old_g_params: Optional[optax.Params] = None, 448 | old_g_state: Optional[optax.OptState] = None 449 | ) -> Tuple[Tuple[optax.Params, optax.Params], Tuple[ 450 | optax.OptState, optax.OptState], Tuple[optax.OptState, optax.OptState], 451 | data_utils.ScalarDict, data_utils.Batch]: 452 | """Updates parameters .""" 453 | # Obtain the parameters and discriminators. 454 | (g_params, d_params) = params 455 | (g_state, d_state) = state 456 | (g_opt_state, d_opt_state) = opt_state 457 | 458 | ################ 459 | # Generator. 460 | ################ 461 | # Compute the loss for the generator. 462 | inputs = self._run_postprocess_fn(rng, batch) 463 | grad_loss_fn = jax.grad(self._loss_fn, has_aux=True) 464 | scaled_grads, (g_scalars, logits, preprocessed_inputs, 465 | g_state) = grad_loss_fn(g_params, g_state, d_params, d_state, 466 | inputs, rng, global_step, 467 | old_g_params=old_g_params, 468 | old_g_state=old_g_state) 469 | 470 | # Update the generator. 471 | grads = jax.lax.psum(scaled_grads, axis_name='i') 472 | _, opt_apply = self.optimizer() 473 | 474 | updates, g_opt_state = opt_apply(grads, g_opt_state, g_params) 475 | g_params = optax.apply_updates(g_params, updates) 476 | 477 | ################ 478 | # Discriminator. 479 | ################ 480 | if not self._d_opt_state: 481 | # No discriminator. 482 | scalars = dict(global_step=global_step, **g_scalars) 483 | return ((g_params, d_params), (g_state, d_state), 484 | (g_opt_state, d_opt_state), scalars, preprocessed_inputs) 485 | 486 | # Compute the loss for the discriminator. 487 | grad_loss_fn = jax.grad(self._d_loss_fn, has_aux=True) 488 | scaled_grads, (d_scalars, d_state) = grad_loss_fn(d_params, d_state, batch, 489 | logits, rng) 490 | 491 | # Update the discriminator. 492 | grads = jax.lax.psum(scaled_grads, axis_name='i') 493 | _, opt_apply = self.optimizer() 494 | 495 | updates, d_opt_state = opt_apply(grads, d_opt_state, d_params) 496 | d_params = optax.apply_updates(d_params, updates) 497 | 498 | # For logging while training. 499 | scalars = dict( 500 | global_step=global_step, 501 | **g_scalars, 502 | **d_scalars) 503 | return ((g_params, d_params), (g_state, d_state), 504 | (g_opt_state, d_opt_state), scalars, preprocessed_inputs) 505 | 506 | def step(self, global_step: chex.Array, rng: chex.PRNGKey, writer, 507 | **unused_kwargs) -> chex.Array: 508 | """Perform one step of the model.""" 509 | 510 | batch = next(self._train_input) 511 | batch = self._prepare_train_batch(rng, batch) 512 | 513 | params, state, opt_state, scalars, preprocessed_batch = ( 514 | self._update_func( 515 | params=(self._params, self._d_params), 516 | state=(self._state, self._d_state), 517 | opt_state=(self._opt_state, self._d_opt_state), 518 | global_step=global_step, 519 | batch=batch, 520 | rng=rng, 521 | old_g_params=self._adapt_params, 522 | old_g_state=self._adapt_state)) 523 | (self._params, self._d_params) = params 524 | (self._state, self._d_state) = state 525 | (self._opt_state, self._d_opt_state) = opt_state 526 | 527 | if self._train_adapter: 528 | self._adapt_params, self._adapt_state = self._train_adapter.update( 529 | self._params, self._state, utils.get_first(global_step)) 530 | 531 | images = batch['image'] 532 | preprocessed_images = preprocessed_batch['image'] 533 | 534 | if self.config.training.save_images: 535 | self._write_images(writer, global_step, 536 | {'images': images, 537 | 'preprocessed_images': preprocessed_images}) 538 | 539 | # Just return the tracking metrics on the first device for logging. 540 | return utils.get_first(scalars) 541 | 542 | # _ 543 | # _____ ____ _| | 544 | # / _ \ \ / / _` | | 545 | # | __/\ V / (_| | | 546 | # \___| \_/ \__,_|_| 547 | # 548 | 549 | def _load_eval_data( 550 | self, 551 | per_device_batch_size: int) -> Generator[data_utils.Batch, None, None]: 552 | return self._load_data(per_device_batch_size=per_device_batch_size, 553 | is_training=False, 554 | data_kwargs=self.config.data.test_kwargs) 555 | 556 | def _full_eval(self, rng: chex.PRNGKey, scalars: data_utils.ScalarDict, 557 | checkpoint_path: Optional[str] = None 558 | ) -> data_utils.ScalarDict: 559 | if checkpoint_path: 560 | ckpt_data = checkpointing.load_model(checkpoint_path) 561 | params, state = ckpt_data['params'], ckpt_data['state'] 562 | params = utils.bcast_local_devices(params) 563 | state = utils.bcast_local_devices(state) 564 | else: 565 | params, state = self._params, self._state 566 | 567 | # Iterate over all the test sets. 568 | original_subset = self.config.data.test_kwargs.load_kwargs.subset 569 | for test_subset in getattr(self.config.data, 'test_sets', ('test',)): 570 | self.config.data.test_kwargs.load_kwargs.subset = test_subset 571 | test_scalars = jax.device_get( 572 | self._eval_top1_accuracy(params, state, rng, is_final=True)) 573 | scalars.update( 574 | {f'{test_subset}_{k}': v for k, v in test_scalars.items()}) 575 | self.config.data.test_kwargs.load_kwargs.subset = original_subset 576 | return scalars 577 | 578 | def evaluate(self, global_step: chex.Array, rng: chex.PRNGKey, writer, 579 | **unused_args) -> data_utils.ScalarDict: 580 | """See base class.""" 581 | # Need to set these so `on_new_best_model` can do a full eval. 582 | self._writer = writer 583 | self._rng = rng 584 | global_step = np.array(utils.get_first(global_step)) 585 | 586 | scalars = jax.device_get( 587 | self._eval_top1_accuracy(self._params, self._state, rng)) 588 | 589 | if FLAGS.config.eval_specific_checkpoint_dir: 590 | scalars = self._full_eval(rng, scalars, 591 | FLAGS.config.eval_specific_checkpoint_dir) 592 | 593 | logging.info('[Step %d] Eval scalars: %s', global_step, scalars) 594 | return scalars 595 | 596 | def on_new_best_model(self, best_state: ml_collections.ConfigDict): 597 | scalars = self._full_eval(self._rng, {}) 598 | if self._writer is not None: 599 | self._writer.write_scalars(best_state.global_step, scalars) 600 | ckpt_data = {} 601 | for self_key, ckpt_key in self.CHECKPOINT_ATTRS.items(): 602 | ckpt_data[ckpt_key] = getattr(self, self_key) 603 | checkpoint_path = checkpointing.get_checkpoint_dir(FLAGS.config) 604 | checkpointing.save_model(os.path.join(checkpoint_path, 'best.pkl'), 605 | ckpt_data) 606 | 607 | def _eval_top1_accuracy(self, params: optax.Params, state: optax.OptState, 608 | rng: chex.PRNGKey, is_final: bool = False 609 | ) -> data_utils.ScalarDict: 610 | """Evaluates an epoch.""" 611 | total_batch_size = self.config.evaluation.batch_size 612 | per_device_batch_size = total_batch_size 613 | eval_data = self._load_eval_data(per_device_batch_size) 614 | 615 | # If using an adaptive method. 616 | if getattr(self.config.adapter, 'fn', None): 617 | self.adapt_fn(params, state, rng, is_final_eval=is_final) 618 | self.adaptor.set_up_eval() 619 | 620 | # Accuracies for each set of corruptions. 621 | labels = [] 622 | predicted_labels = [] 623 | features = [] 624 | for batch in eval_data: 625 | if self.adaptor is not None: 626 | logging.info('Running adaptation algorithm for evaluation.') 627 | property_label = batch[self.config.data.label_property] 628 | predicted_label, _ = self.adaptor.run( 629 | self.eval_batch, property_label, inputs=batch, rng=rng) 630 | else: 631 | predicted_label, _ = self.eval_batch(params, state, batch, rng) 632 | label = batch[self._label] 633 | feature = batch[self.config.data.label_property] 634 | 635 | # Concatenate along the pmapped direction. 636 | labels.append(jnp.concatenate(label)) 637 | features.append(jnp.concatenate(feature)) 638 | predicted_labels.append(jnp.concatenate(predicted_label)) 639 | 640 | # And finally concatenate along the first dimension. 641 | labels = jnp.concatenate(labels) 642 | features = jnp.concatenate(features) 643 | predicted_labels = jnp.concatenate(predicted_labels) 644 | 645 | # Compute the metrics. 646 | results = {} 647 | for metric in self.config.evaluation.metrics: 648 | logging.info('Evaluating metric %s.', str(metric)) 649 | metric_fn = getattr(metrics, metric, None) 650 | results[metric] = metric_fn(labels, features, predicted_labels, None) 651 | 652 | # Dump all the results by saving pickled results to disk. 653 | out_dir = checkpointing.get_checkpoint_dir(FLAGS.config) 654 | dataset = self.config.data.test_kwargs.load_kwargs.subset 655 | results_path = os.path.join(out_dir, f'results_{dataset}') 656 | if not tf.io.gfile.exists(results_path): 657 | tf.io.gfile.makedirs(results_path) 658 | 659 | # Save numpy arrays. 660 | with tf.io.gfile.GFile( 661 | os.path.join(results_path, 'results.pkl'), 'wb') as f: 662 | # Using protocol 4 as it's the default from Python 3.8 on. 663 | pickle.dump({'all_labels': labels, 'all_features': features, 664 | 'all_predictions': predicted_labels}, f, protocol=4) 665 | 666 | return results 667 | 668 | def _eval_batch(self, params: optax.Params, state: optax.OptState, 669 | inputs: data_utils.Batch, 670 | rng: chex.PRNGKey 671 | ) -> Tuple[data_utils.ScalarDict, chex.Array]: 672 | """Evaluates a batch.""" 673 | 674 | logits, _ = self.forward.apply( 675 | params, state, rng, inputs, is_training=False) 676 | 677 | inputs['one_hot_label'] = hk.one_hot( 678 | inputs[self._label], self.config.data.n_classes) 679 | (_, logits), _ = self.learner_fn.apply(params, state, rng, inputs) 680 | 681 | softmax_predictions = jax.nn.softmax(logits, axis=-1) 682 | predicted_label = jnp.argmax(softmax_predictions, axis=-1) 683 | 684 | return predicted_label, logits 685 | -------------------------------------------------------------------------------- /distribution_shift_framework/classification/experiment_lib_test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # 3 | # Copyright 2022 DeepMind Technologies Limited 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """Tests for distribution_shift_framework.classification.experiment_lib.""" 18 | 19 | from absl.testing import absltest 20 | from absl.testing import flagsaver 21 | from absl.testing import parameterized 22 | from distribution_shift_framework.classification import config 23 | from distribution_shift_framework.classification import experiment_lib 24 | import jax 25 | from jaxline import platform 26 | 27 | 28 | _PREV_JAX_CONFIG = None 29 | 30 | 31 | def setUpModule(): 32 | global _PREV_JAX_CONFIG 33 | _PREV_JAX_CONFIG = jax.config.values.copy() 34 | # Disable jax optimizations to speed up test. 35 | jax.config.update('jax_disable_most_optimizations', True) 36 | 37 | 38 | def tearDownModule(): 39 | # Set config to previous values. 40 | jax.config.values.update(**_PREV_JAX_CONFIG) 41 | 42 | 43 | class ExperimentLibTest(parameterized.TestCase): 44 | 45 | @parameterized.parameters([ 46 | # Different algorithms. 47 | dict(algorithm='CORAL', test_case='ood', model='resnet18', 48 | dataset_name='dsprites', label='label_shape', 49 | property_label='label_color', number_of_seeds=1), 50 | dict(algorithm='DANN', test_case='ood', model='resnet18', 51 | dataset_name='dsprites', label='label_shape', 52 | property_label='label_color', number_of_seeds=1), 53 | dict(algorithm='ERM', test_case='ood', model='resnet18', 54 | dataset_name='dsprites', label='label_shape', 55 | property_label='label_color', number_of_seeds=1), 56 | dict(algorithm='IRM', test_case='ood', model='resnet18', 57 | dataset_name='dsprites', label='label_shape', 58 | property_label='label_color', number_of_seeds=1), 59 | dict(algorithm='SagNet', test_case='ood', model='resnet18', 60 | dataset_name='dsprites', label='label_shape', 61 | property_label='label_color', number_of_seeds=1), 62 | # Different datasets. 63 | dict(algorithm='ERM', test_case='ood', model='resnet18', 64 | dataset_name='small_norb', label='label_category', 65 | property_label='label_azimuth', number_of_seeds=1), 66 | dict(algorithm='ERM', test_case='ood', model='resnet18', 67 | dataset_name='shapes3d', label='label_shape', 68 | property_label='label_object_hue', number_of_seeds=1), 69 | # Different test cases. 70 | dict(algorithm='ERM', test_case='lowdata', model='resnet18', 71 | dataset_name='shapes3d', label='label_shape', 72 | property_label='label_object_hue', number_of_seeds=1), 73 | dict(algorithm='ERM', test_case='correlated.lowdata', model='resnet18', 74 | dataset_name='shapes3d', label='label_shape', 75 | property_label='label_object_hue', number_of_seeds=1), 76 | dict(algorithm='ERM', test_case='lowdata.noise', model='resnet18', 77 | dataset_name='shapes3d', label='label_shape', 78 | property_label='label_object_hue', number_of_seeds=1), 79 | dict(algorithm='ERM', test_case='lowdata.fixeddata', model='resnet18', 80 | dataset_name='shapes3d', label='label_shape', 81 | property_label='label_object_hue', number_of_seeds=1), 82 | ]) 83 | def test_train(self, **kwargs): 84 | kwargs['training_steps'] = 3 85 | kwargs['use_fake_data'] = True 86 | kwargs['batch_size'] = 8 87 | options = ','.join([f'{k}={v}' for k, v in kwargs.items()]) 88 | cfg = config.get_config(options) 89 | with flagsaver.flagsaver(config=cfg, jaxline_mode='train'): 90 | platform.main(experiment_lib.Experiment, []) 91 | 92 | 93 | if __name__ == '__main__': 94 | absltest.main() 95 | -------------------------------------------------------------------------------- /distribution_shift_framework/configs/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # 3 | # Copyright 2022 DeepMind Technologies Limited 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | -------------------------------------------------------------------------------- /distribution_shift_framework/configs/disentanglement_config.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # 3 | # Copyright 2022 DeepMind Technologies Limited 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """Configs for disentanglement datasets.""" 18 | import itertools 19 | from typing import Any, Callable, Optional, Sequence 20 | 21 | from distribution_shift_framework.core import hyper 22 | from distribution_shift_framework.core.datasets import data_loaders 23 | from distribution_shift_framework.core.datasets import data_utils 24 | from distribution_shift_framework.core.datasets import lowdata_wrapper 25 | import ml_collections 26 | import tensorflow.compat.v2 as tf 27 | 28 | 29 | _VALID_COLORS = (( 30 | 1, 31 | 0, 32 | 0, 33 | ), (0, 1, 0), (0, 0, 1)) 34 | _EXP = 'config.experiment_kwargs.config' 35 | _TRAIN_SPLIT = 'train' 36 | _TEST_SPLIT = 'valid' 37 | 38 | _ExampleFn = Callable[[tf.train.Example], tf.train.Example] 39 | 40 | 41 | def _color_preprocess(mode: str, 42 | preprocess: Optional[Callable[[str], _ExampleFn]] = None, 43 | label: str = 'label') -> _ExampleFn: 44 | """Preprocessing function to add colour to white pixels in a binary image.""" 45 | 46 | def _color_fn(example: tf.train.Example) -> tf.train.Example: 47 | if preprocess is not None: 48 | example = preprocess(mode)(example) 49 | 50 | example['image'] = tf.repeat(example['image'], 3, axis=2) 51 | example['image'] = tf.cast(example['image'], tf.float32) 52 | 53 | # Choose a random color. 54 | color_id = tf.random.uniform( 55 | shape=(), minval=0, maxval=len(_VALID_COLORS), dtype=tf.int64) 56 | example['label_color'] = color_id 57 | 58 | colors = tf.constant(_VALID_COLORS, dtype=tf.float32)[color_id] 59 | example['image'] = example['image'] * colors 60 | example['label'] = example[label] 61 | 62 | example['fairness_features'] = { 63 | k: v for k, v in example.items() if k.startswith('label_') 64 | } 65 | return example 66 | 67 | return _color_fn 68 | 69 | 70 | def _get_base_config(dataset_name: str, label: str, property_label: str 71 | ) -> ml_collections.ConfigDict: 72 | """Get base config.""" 73 | data = ml_collections.ConfigDict() 74 | data.name = dataset_name 75 | data.test = dataset_name 76 | 77 | dataset_constants = data_utils.get_dataset_constants(dataset_name, label) 78 | data.label_property = property_label 79 | data.label = label 80 | 81 | data.n_classes = dataset_constants['num_classes'] 82 | data.num_channels = dataset_constants['num_channels'] 83 | data.image_size = dataset_constants['image_size'] 84 | data.variance = dataset_constants['variance'] 85 | 86 | if dataset_name != data_utils.DatasetNames.DSPRITES.value or ( 87 | property_label != 'label_color'): 88 | data.prop_values = dataset_constants['properties'][property_label] 89 | data.n_properties = len(data.prop_values) 90 | 91 | if dataset_name == data_utils.DatasetNames.DSPRITES.value and ( 92 | label == 'label_color' or property_label == 'label_color'): 93 | 94 | data.num_channels = 3 95 | 96 | if label == 'label_color': 97 | data.n_classes = 3 98 | if property_label == 'label_color': 99 | data.prop_values = (0, 1, 2) 100 | data.n_properties = 3 101 | 102 | return data 103 | 104 | 105 | def _get_filter_fns(values: Sequence[Any], 106 | perc_property: float, 107 | property_name: str) -> str: 108 | cutoff = max(int((len(values) - 1) * perc_property), 0) 109 | cutoff = values[cutoff] 110 | filter_fns = (f'{property_name}:{cutoff}:less_equal,' 111 | f'{property_name}:{cutoff}:greater') 112 | return filter_fns 113 | 114 | 115 | def get_data_config(dataset_name: str, label: str, property_label: str 116 | ) -> ml_collections.ConfigDict: 117 | """Get config for a given setup.""" 118 | data = _get_base_config(dataset_name, label, property_label) 119 | 120 | dataset_loader = getattr(data_loaders, f'unbatched_load_{dataset_name}', '') 121 | preprocess_fn = getattr(data_loaders, f'{dataset_name}_preprocess', '') 122 | full_dataset_loader = getattr(data_loaders, f'load_{dataset_name}', '') 123 | 124 | data.train_kwargs = ml_collections.ConfigDict() 125 | data.train_kwargs.loader = lowdata_wrapper.load_data 126 | 127 | data.train_kwargs.load_kwargs = dict() 128 | data.train_kwargs.load_kwargs.dataset_loader = dataset_loader 129 | data.train_kwargs.load_kwargs.weights = [1.] 130 | 131 | data.train_kwargs.load_kwargs.dataset_kwargs = dict(subset=_TRAIN_SPLIT) 132 | data.train_kwargs.load_kwargs.preprocess_fn = preprocess_fn 133 | 134 | # Set up filters and number of samples. 135 | data.train_kwargs.load_kwargs.num_samples = '0' 136 | # A string to define how the dataset is filtered (not a boolean value). 137 | data.train_kwargs.load_kwargs.filter_fns = 'True' 138 | 139 | data.test_kwargs = ml_collections.ConfigDict() 140 | data.test_kwargs.loader = full_dataset_loader 141 | data.test_kwargs.load_kwargs = dict(subset=_TEST_SPLIT) 142 | 143 | if dataset_name == data_utils.DatasetNames.DSPRITES.value and ( 144 | label == 'label_color' or property_label == 'label_color'): 145 | # Make the images different colours, as opposed to block and white. 146 | preprocess = data.train_kwargs.load_kwargs.preprocess_fn 147 | data.train_kwargs.load_kwargs.preprocess_fn = ( 148 | lambda m: _color_preprocess(m, preprocess, label)) 149 | data.test_kwargs.load_kwargs.preprocess_fn = ( 150 | lambda m: _color_preprocess(m, None, label)) 151 | 152 | return data 153 | 154 | 155 | def get_alldata_config(dataset_name: str, label: str, property_label: str 156 | ) -> ml_collections.ConfigDict: 157 | """Config when using the full dataset.""" 158 | loader = getattr(data_loaders, f'load_{dataset_name}', '') 159 | data = _get_base_config(dataset_name, label, property_label) 160 | 161 | data.train_kwargs = ml_collections.ConfigDict() 162 | data.train_kwargs.loader = loader 163 | data.train_kwargs.load_kwargs = dict(subset=_TRAIN_SPLIT) 164 | 165 | data.test_kwargs = ml_collections.ConfigDict() 166 | data.test_kwargs.loader = loader 167 | data.test_kwargs.load_kwargs = dict(subset=_TEST_SPLIT) 168 | return data 169 | 170 | 171 | def get_renderers(datatype: str, 172 | dataset_name: str, 173 | label: str, 174 | property_label: str) -> ml_collections.ConfigDict: 175 | if len(datatype.split('.')) > 1: 176 | renderer, _ = datatype.split('.') 177 | else: 178 | renderer = datatype 179 | 180 | return globals()[f'get_{renderer}_renderers']( 181 | dataset_name, label=label, property_label=property_label) 182 | 183 | 184 | def get_renderer_sweep(datatype: str) -> hyper.Sweep: 185 | if len(datatype.split('.')) > 1: 186 | _, sweep = datatype.split('.') 187 | else: 188 | sweep = datatype 189 | return globals()[f'get_{sweep}_sweep']() 190 | 191 | 192 | def get_resample_sweep() -> hyper.Sweep: 193 | """Sweep over the resampling operation of the different datasets.""" 194 | ratios = [1e-3] 195 | n_samples = [1_000_000] 196 | ratio_samples = list(itertools.product(ratios, n_samples)) 197 | ratio_samples_sweep = hyper.sweep( 198 | f'{_EXP}.data.train_kwargs.load_kwargs.num_samples', 199 | [f'{n_s},{int(max(1, n_s * r))}' for r, n_s in ratio_samples]) 200 | resample_weights = hyper.sweep( 201 | f'{_EXP}.data.train_kwargs.load_kwargs.weights', 202 | [[1 - i, i] for i in [1e-4, 1e-3, 1e-2, 1e-1, 0.5]]) 203 | return hyper.product([ratio_samples_sweep, resample_weights]) 204 | 205 | 206 | def get_fixeddata_sweep() -> hyper.Sweep: 207 | """Sweep over the amount of data and noise present.""" 208 | ratios = [1e-3] 209 | n_samples = [1000, 10_000, 100_000, 1_000_000] 210 | ratio_samples = list(itertools.product(ratios, n_samples)) 211 | ratio_samples_sweep = hyper.sweep( 212 | f'{_EXP}.data.train_kwargs.load_kwargs.num_samples', 213 | [f'{n_s},{int(max(1, n_s * r))}' for r, n_s in ratio_samples]) 214 | return ratio_samples_sweep 215 | 216 | 217 | def get_noise_sweep() -> hyper.Sweep: 218 | return hyper.sweep(f'{_EXP}.training.label_noise', 219 | [i / float(10.) for i in list(range(7, 11))]) 220 | 221 | 222 | def get_lowdata_sweep() -> hyper.Sweep: 223 | return hyper.sweep( 224 | f'{_EXP}.data.train_kwargs.load_kwargs.num_samples', 225 | [f'0,{n_s}' for n_s in [1, 5, 10, 50, 100, 500, 1000, 5000, 10_000]]) 226 | 227 | 228 | def get_ood_sweep() -> hyper.Sweep: 229 | return hyper.sweep(f'{_EXP}.data.train_kwargs.load_kwargs.weights', 230 | [[1., 0.]]) 231 | 232 | 233 | def get_base_renderers(dataset_name: str, 234 | label: str = 'color', 235 | property_label: str = 'shape' 236 | ) -> ml_collections.ConfigDict: 237 | """Get base config for the given dataset, label and property value.""" 238 | data = get_data_config(dataset_name, label, property_label) 239 | data.train_kwargs.load_kwargs.filter_fns = 'True' 240 | data.train_kwargs.load_kwargs.num_samples = '0' 241 | data.train_kwargs.load_kwargs.weights = [1.] 242 | return data 243 | 244 | 245 | def get_ood_renderers(dataset_name: str, 246 | label: str = 'color', 247 | property_label: str = 'shape' 248 | ) -> ml_collections.ConfigDict: 249 | """Get OOD config for the given dataset, label and property value.""" 250 | data = get_data_config(dataset_name, label, property_label) 251 | 252 | perc_props_in_train = 0.7 if dataset_name in ('dsprites') else 0.2 253 | data.train_kwargs.load_kwargs.filter_fns = _get_filter_fns( 254 | data.prop_values, perc_props_in_train, property_label) 255 | data.train_kwargs.load_kwargs.weights = [1., 0.] 256 | data.train_kwargs.load_kwargs.num_samples = '0,1000' 257 | return data 258 | 259 | 260 | def get_correlated_renderers(dataset_name: str, 261 | label: str = 'color', 262 | property_label: str = 'shape' 263 | ) -> ml_collections.ConfigDict: 264 | """Get correlated config for the given dataset, label and property value.""" 265 | data = get_data_config(dataset_name, label, property_label) 266 | data.train_kwargs.load_kwargs.filter_fns = ( 267 | f'{label}:{property_label}:equal,True') 268 | data.train_kwargs.load_kwargs.weights = [0.5, 0.5] 269 | num_samples = '0,500' if dataset_name == 'dsprites' else '0,50' 270 | data.train_kwargs.load_kwargs.num_samples = num_samples 271 | data.train_kwargs.load_kwargs.shuffle_pre_sampling = True 272 | data.train_kwargs.load_kwargs.shuffle_pre_sample_seed = 0 273 | return data 274 | 275 | 276 | def get_lowdata_renderers(dataset_name: str, 277 | label: str = 'color', 278 | property_label: str = 'shape' 279 | ) -> ml_collections.ConfigDict: 280 | """Get lowdata config for the given dataset, label and property value.""" 281 | data = get_ood_renderers(dataset_name, label, property_label) 282 | data.train_kwargs.load_kwargs.weights = [0.5, 0.5] 283 | data.train_kwargs.load_kwargs.num_samples = '0,10' 284 | data.train_kwargs.load_kwargs.shuffle_pre_sampling = True 285 | data.train_kwargs.load_kwargs.shuffle_pre_sample_seed = 0 286 | 287 | return data 288 | -------------------------------------------------------------------------------- /distribution_shift_framework/core/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # 3 | # Copyright 2022 DeepMind Technologies Limited 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | -------------------------------------------------------------------------------- /distribution_shift_framework/core/adapt.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # 3 | # Copyright 2022 DeepMind Technologies Limited 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """Adaptation algorithms for modifying model parameters.""" 18 | import abc 19 | from typing import Callable, Sequence 20 | 21 | from absl import logging 22 | import chex 23 | from distribution_shift_framework.core.datasets import data_utils 24 | import haiku as hk 25 | import jax 26 | import jax.numpy as jnp 27 | import numpy as np 28 | import optax 29 | 30 | 31 | def _broadcast(tensor1, tensor2): 32 | num_ones = len(tensor1.shape) - len(tensor2.shape) 33 | return tensor2.reshape(tensor2.shape + (1,) * num_ones) 34 | 35 | 36 | def _bt_mult(tensor1, tensor2): 37 | tensor2 = _broadcast(tensor1, tensor2) 38 | return tensor1 * tensor2 39 | 40 | 41 | def _get_mean(tensor): 42 | if len(tensor.shape) == 1: 43 | return jnp.mean(tensor, keepdims=True) 44 | else: 45 | return jnp.mean(tensor, axis=(0, 1), keepdims=True) 46 | 47 | 48 | def _split_and_reshape(tree1, tree2): 49 | """Resize tree1 look like tree2 and return the resized tree and the modulo.""" 50 | tree1_reshaped = jax.tree_map( 51 | lambda a, b: a[:np.prod(b.shape[0:2])].reshape(b.shape), tree1, tree2) 52 | tree1_modulo = jax.tree_map(lambda a, b: a[np.prod(b.shape[0:2]):], tree1, 53 | tree2) 54 | return tree1_reshaped, tree1_modulo 55 | 56 | 57 | class Adapt(abc.ABC): 58 | """Class to encapsulate an adaptation framework.""" 59 | 60 | @abc.abstractmethod 61 | def __init__(self, init_params: optax.Params, init_state: optax.OptState, 62 | forward: Callable[..., chex.Array]): 63 | """Initializes the adaptation algorithm. 64 | 65 | This operates as follows. Given a number of examples, the model can update 66 | the parameters as it sees fit. Then, the updated parameters are run on an 67 | unseen test set. 68 | 69 | Args: 70 | init_params: The original parameters of the model. 71 | init_state: The original state of the model. 72 | forward: The forward call to the model. 73 | """ 74 | 75 | @abc.abstractmethod 76 | def update(self, inputs: data_utils.Batch, property_label: chex.Array, 77 | rng: chex.PRNGKey, **kwargs): 78 | """Updates the parameters of the adaptation algorithm. 79 | 80 | Args: 81 | inputs: The batch to be input to the model. 82 | property_label: The properties of the image. 83 | rng: The random key. 84 | **kwargs: Keyword arguments specific to the forward function. 85 | """ 86 | 87 | @abc.abstractmethod 88 | def run(self, fn: Callable[..., chex.Array], property_label: chex.Array, 89 | **fn_kwargs): 90 | """Runs the adaptation algorithm on a given function. 91 | 92 | Args: 93 | fn: The function we wish to apply the adapted parameters to. 94 | property_label: The property labels of the input values. 95 | **fn_kwargs: Additional kwargs to be input to the function fn. 96 | 97 | Returns: 98 | The result of fn using the adapted parameters according to the 99 | property_label value. 100 | """ 101 | 102 | 103 | class BNAdapt(Adapt): 104 | """Implements batch norm adaptation for a set of properties. 105 | 106 | Given a set of properties, and initial parameters/state, the batch 107 | normalization statistics are updated for each property value. 108 | """ 109 | 110 | def __init__(self, 111 | init_params: optax.Params, 112 | init_state: optax.OptState, 113 | forward: Callable[..., chex.Array], 114 | n_properties: int, 115 | n: int = 10, 116 | N: int = 100): 117 | """See parent.""" 118 | super().__init__( 119 | init_params=init_params, init_state=init_state, forward=forward) 120 | self.init_params = init_params 121 | self.init_state = init_state 122 | # Set the init state to 0. This will mean we always take the local stats. 123 | self.empty_state = self._reset_state(self.init_state) 124 | 125 | self.n_properties = n_properties 126 | self.forward_fn = forward 127 | self.adapted_state = {n: None for n in range(n_properties)} 128 | self.interpolated_states = None 129 | 130 | # Set up parameters that control the amount of adaptation. 131 | self.w_new = n 132 | self.w_old = N 133 | 134 | # Set up the cached dataset values. 135 | self._cached_dataset = [None] * self.n_properties 136 | 137 | def _reset_state(self, old_state, keys=('average', 'hidden', 'counter')): 138 | """Set the average of the BN parameters to 0.""" 139 | state = hk.data_structures.to_mutable_dict(old_state) 140 | for k in state.keys(): 141 | if 'batchnorm' in k and 'ema' in k: 142 | logging.info('Resetting %s in BNAdapt.', k) 143 | for state_key in keys: 144 | state[k][state_key] = jnp.zeros_like(state[k][state_key]) 145 | state = hk.data_structures.to_haiku_dict(state) 146 | return state 147 | 148 | def _update_state(self, old_state, new_state, sz): 149 | """Update the state using the old and new running state.""" 150 | if old_state is None: 151 | old_state = self._reset_state(self.init_state) 152 | 153 | new_state = hk.data_structures.to_mutable_dict(new_state) 154 | for k in new_state.keys(): 155 | if 'batchnorm' in k and 'ema' in k: 156 | new_state_k = new_state[k]['average'] 157 | old_counter = _broadcast(old_state[k]['average'], 158 | old_state[k]['counter']) 159 | new_state_k = new_state_k * sz 160 | old_state_k = old_state[k]['average'] * old_counter 161 | 162 | counter = jnp.maximum(old_counter + sz, 1) 163 | new_state[k]['average'] = (new_state_k + old_state_k) / counter 164 | new_state[k]['counter'] = counter.squeeze() 165 | new_state = hk.data_structures.to_haiku_dict(new_state) 166 | return new_state 167 | 168 | def _interpolate_state(self, old_state, new_state): 169 | """Update the state using the old and new running state.""" 170 | if new_state is None: 171 | return old_state 172 | 173 | new_state = hk.data_structures.to_mutable_dict(new_state) 174 | new_ratio = self.w_new / (self.w_new + self.w_old) 175 | old_ratio = self.w_old / (self.w_new + self.w_old) 176 | for k in new_state.keys(): 177 | if 'batchnorm' in k and 'ema' in k: 178 | new_state[k]['average'] = ( 179 | new_state[k]['average'] * new_ratio + 180 | old_state[k]['average'] * old_ratio) 181 | new_state = hk.data_structures.to_haiku_dict(new_state) 182 | return new_state 183 | 184 | def update(self, inputs: data_utils.Batch, property_label: chex.Array, 185 | rng: chex.PRNGKey, **kwargs): 186 | """See parent.""" 187 | # First, update cached data. 188 | for n in range(0, self.n_properties): 189 | mask = property_label == n 190 | masked_batch = jax.tree_map(lambda a: a[mask], inputs) # pylint: disable=cell-var-from-loop 191 | if self._cached_dataset[n] is None: 192 | self._cached_dataset[n] = masked_batch 193 | else: 194 | self._cached_dataset[n] = jax.tree_map(lambda *a: jnp.concatenate(a), 195 | self._cached_dataset[n], 196 | masked_batch) 197 | 198 | # Then, if there are enough samples of a property, update the BN stats. 199 | for n in range(0, self.n_properties): 200 | # Update the adapted states with the output of the property labels. 201 | if (self._cached_dataset[n]['image'].shape[0] < np.prod( 202 | inputs['image'].shape[0:2])): 203 | continue 204 | 205 | # There are enough samples to do a forward pass. 206 | batch, mod_batch = _split_and_reshape(self._cached_dataset[n], inputs) 207 | _, state = self.forward_fn(self.init_params, self.empty_state, rng, batch, 208 | **kwargs) 209 | 210 | # Take the average over the cross replicas. 211 | state = jax.tree_map(_get_mean, state) 212 | self._update_state( 213 | self.adapted_state[n], state, sz=np.prod(batch['image'].shape[:2])) 214 | self._cached_dataset[n] = mod_batch 215 | 216 | def set_up_eval(self): 217 | self.interpolated_states = [ 218 | self._interpolate_state( 219 | new_state=self.adapted_state[n], old_state=self.init_state) 220 | for n in range(self.n_properties) 221 | ] 222 | 223 | def run(self, fn: Callable[..., Sequence[chex.Array]], 224 | property_label: chex.Array, **fn_kwargs): 225 | """See parent.""" 226 | # Get the results for the initial parameters and state. 227 | result = fn(self.init_params, self.init_state, **fn_kwargs) 228 | 229 | # Compute the results for each set of properties. 230 | for n in range(0, self.n_properties): 231 | mask = property_label == n 232 | if mask.sum() == 0: 233 | continue 234 | 235 | # And update the result. 236 | result_prop = fn(self.init_params, self.interpolated_states[n], 237 | **fn_kwargs) 238 | result = [ 239 | _bt_mult(r, (1 - mask)) + _bt_mult(r_prop, mask) 240 | for r, r_prop in zip(result, result_prop) 241 | ] 242 | 243 | return result 244 | -------------------------------------------------------------------------------- /distribution_shift_framework/core/adapt_train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # 3 | # Copyright 2022 DeepMind Technologies Limited 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """Just train twice algorithm.""" 18 | import abc 19 | from typing import Callable, Tuple 20 | 21 | import chex 22 | from distribution_shift_framework.core.datasets import data_utils 23 | import jax.numpy as jnp 24 | import optax 25 | 26 | Learner = Tuple[Tuple[data_utils.ScalarDict, chex.Array], optax.OptState] 27 | LearnerFN = Callable[..., Learner] 28 | 29 | 30 | class Adapt(abc.ABC): 31 | """Encasuplates adapting parameters and state with auxiliary information. 32 | 33 | Given some initial set of parameters and the loss to be optimized, this 34 | set of classes is free to update the underlying parameters via adaptation 35 | based on difficulty of samples (e.g. JTT) or via EWA. 36 | """ 37 | 38 | @abc.abstractmethod 39 | def update(self, params: optax.Params, state: optax.OptState, 40 | global_step: chex.Array): 41 | """Updates and returns the new parameters and state. 42 | 43 | Args: 44 | params: The parameters returned at this step. 45 | state: The state returned at this step. 46 | global_step: The training step. 47 | 48 | Returns: 49 | The updated params and state. 50 | """ 51 | 52 | @abc.abstractmethod 53 | def __call__(self, fn: LearnerFN, params: optax.Params, state: optax.OptState, 54 | global_step: chex.Array, inputs: data_utils.Batch, 55 | rng: chex.PRNGKey) -> Tuple[data_utils.ScalarDict, chex.Array]: 56 | """Adapts the stored parameters according to the given information. 57 | 58 | Args: 59 | fn: The loss function. 60 | params: The parameters of the model at this step. 61 | state: The state of the model at this step. 62 | global_step: The step in the training pipeline. 63 | inputs: The inputs to the loss function. 64 | rng: The random key 65 | 66 | Returns: 67 | The scalars and logits which have been appropriately adapted. 68 | """ 69 | 70 | 71 | class JTT(Adapt): 72 | """Implementation of JTT algorithm.""" 73 | 74 | def __init__(self, lmbda: float, num_steps_in_first_iter: int): 75 | """Implementation of JTT. 76 | 77 | This algorithm first trains for some number of steps on the full training 78 | set. After this first stage, the parameters at the end of this stage are 79 | used to select the most difficult samples (those that are misclassified) 80 | and penalize the loss more heavily for these examples. 81 | 82 | Args: 83 | lmbda: How much to upsample the misclassified examples. 84 | num_steps_in_first_iter: How long to train on full dataset before 85 | computing the error set and reweighting misclassified samples. 86 | """ 87 | super().__init__() 88 | self.lmbda = lmbda 89 | self.num_steps_in_first_iter = num_steps_in_first_iter 90 | self.init_params = None 91 | self.init_state = None 92 | 93 | def update(self, params: optax.Params, state: optax.OptState, 94 | global_step: chex.Array): 95 | """See parent.""" 96 | if global_step < self.num_steps_in_first_iter: 97 | self.init_params = params 98 | self.init_state = state 99 | return params, state 100 | 101 | return self.init_params, self.init_state 102 | 103 | def set(self, params: optax.Params, state: optax.OptState): 104 | self.init_params = params 105 | self.init_state = state 106 | 107 | def __call__( 108 | self, fn: LearnerFN, params: optax.Params, state: optax.OptState, 109 | old_params: optax.Params, old_state: optax.OptState, 110 | global_step: chex.Array, inputs: data_utils.Batch, 111 | rng: chex.PRNGKey) -> Learner: 112 | """See parent.""" 113 | # Get the correct predictions with the params from the 1st training stage. 114 | (scalars, logits), g_state = fn(old_params, old_state, rng, inputs) 115 | predicted_label = jnp.argmax(logits, axis=-1) 116 | correct = jnp.equal(predicted_label, inputs['label']).astype(jnp.float32) 117 | 118 | # And now use this to reweight the current loss. 119 | (scalars, logits), g_state = fn(params, state, rng, inputs) 120 | new_loss = ((1 - correct) * scalars['loss'] * self.lmbda + 121 | correct * scalars['loss']) 122 | 123 | # And return the correct loss for the stage of training. 124 | in_first_stage = global_step < self.num_steps_in_first_iter 125 | scalars['1stiter_loss'] = scalars['loss'].mean() 126 | scalars['loss'] = (scalars['loss'] * in_first_stage + new_loss * 127 | (1 - in_first_stage)).mean() 128 | return (scalars, logits), g_state 129 | -------------------------------------------------------------------------------- /distribution_shift_framework/core/algorithms/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # 3 | # Copyright 2022 DeepMind Technologies Limited 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """Learning algorithms.""" 18 | from distribution_shift_framework.core.algorithms.adversarial import DANN 19 | from distribution_shift_framework.core.algorithms.erm import CORAL 20 | from distribution_shift_framework.core.algorithms.erm import ERM 21 | from distribution_shift_framework.core.algorithms.irm import IRM 22 | from distribution_shift_framework.core.algorithms.sagnet import SagNet 23 | 24 | # Learning algorithms. 25 | __all__ = ( 26 | 'CORAL', 27 | 'DANN', 28 | 'ERM', 29 | 'IRM', 30 | 'SagNet', 31 | ) 32 | -------------------------------------------------------------------------------- /distribution_shift_framework/core/algorithms/adversarial.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # 3 | # Copyright 2022 DeepMind Technologies Limited 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """Adversarial training of latent values.""" 18 | from typing import Optional, Sequence, Tuple 19 | 20 | import chex 21 | from distribution_shift_framework.core.algorithms import base 22 | from distribution_shift_framework.core.algorithms import losses 23 | from distribution_shift_framework.core.datasets import data_utils 24 | import haiku as hk 25 | import jax.numpy as jnp 26 | 27 | 28 | class DANN(base.LearningAlgorithm): 29 | """Uses adversarial training to train a property agnostic representation. 30 | 31 | Based on the work of Ganin et al. Domain-Adversarial Training of Neural 32 | Networks. https://jmlr.org/papers/volume17/15-239/15-239.pdf. 33 | 34 | This learnign setup takes a set of logits, property values, and targets. It 35 | then enforces that the logits contain *no* information about the set of 36 | properties. 37 | """ 38 | 39 | def __init__(self, 40 | loss_fn: base.LossFn = losses.softmax_cross_entropy, 41 | property_loss_fn: base.LossFn = losses.softmax_cross_entropy, 42 | mlp_output_sizes: Sequence[int] = (), 43 | name: str = 'DANN'): 44 | super().__init__(loss_fn=loss_fn, name=name) 45 | 46 | # Implicit assumptions in the code require classification. 47 | assert loss_fn == losses.softmax_cross_entropy 48 | assert property_loss_fn == losses.softmax_cross_entropy 49 | 50 | self.mlp_output_sizes = mlp_output_sizes 51 | self.property_loss_fn = property_loss_fn 52 | 53 | def __call__(self, 54 | logits: chex.Array, 55 | targets: chex.Array, 56 | property_vs: chex.Array, 57 | reduction: str = 'mean' 58 | ) -> Tuple[data_utils.ScalarDict, chex.Array]: 59 | ################### 60 | # Standard loss. 61 | ################### 62 | 63 | # Compute the regular loss function. 64 | erm = self.loss_fn(logits, targets, reduction=reduction) 65 | 66 | return {'loss': erm}, logits 67 | 68 | def adversary(self, 69 | logits: chex.Array, 70 | property_vs: chex.Array, 71 | reduction: str = 'mean', 72 | targets: Optional[chex.Array] = None) -> data_utils.ScalarDict: 73 | ################### 74 | # Adversarial loss. 75 | ################### 76 | adv_net = hk.nets.MLP( 77 | tuple(self.mlp_output_sizes) + (property_vs.shape[-1],)) 78 | 79 | # Get logits for estimating the property. 80 | adv_logits = adv_net(logits) 81 | # Enforce that the representation encodes nothing about the property values. 82 | adv_loss = self.property_loss_fn( 83 | adv_logits, property_vs, reduction=reduction) 84 | # How well are we estimating the property value? 85 | prop_top1_acc = (jnp.argmax(adv_logits, 86 | axis=-1) == jnp.argmax(property_vs, 87 | axis=-1)).mean() 88 | 89 | return {'loss': adv_loss, 'prop_top1_acc': prop_top1_acc} 90 | -------------------------------------------------------------------------------- /distribution_shift_framework/core/algorithms/base.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # 3 | # Copyright 2022 DeepMind Technologies Limited 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """Base class for learning algorithms.""" 18 | import abc 19 | from typing import Callable, Optional, Tuple 20 | 21 | import chex 22 | from distribution_shift_framework.core.datasets import data_utils 23 | import haiku as hk 24 | 25 | 26 | LossFn = Callable[..., chex.Array] 27 | 28 | 29 | class LearningAlgorithm(hk.Module): 30 | """Class to encapsulate a learning algorithm.""" 31 | 32 | def __init__(self, loss_fn: LossFn, name: str = 'DANN', **kwargs): 33 | """Initializes the algorithm with the given loss function.""" 34 | super().__init__(name=name) 35 | self.loss_fn = loss_fn 36 | 37 | @abc.abstractmethod 38 | def __call__( 39 | self, 40 | logits: chex.Array, 41 | targets: chex.Array, 42 | reduction: str = 'mean', 43 | property_vs: Optional[chex.Array] = None 44 | ) -> Tuple[data_utils.ScalarDict, chex.Array]: 45 | """The loss function of the learning algorithm. 46 | 47 | Args: 48 | logits: The predicted logits input to the training algorithm. 49 | targets: The ground truth value to estimate. 50 | reduction: How to combine the loss for different samples. 51 | property_vs: An optional set of properties of the input data. 52 | 53 | Returns: 54 | scalars: A dictionary of key and scalar estimates. The key `loss` 55 | is the loss that should be minimized. 56 | preds: The raw softmax predictions. 57 | """ 58 | pass 59 | 60 | def adversary(self, 61 | logits: chex.Array, 62 | property_vs: chex.Array, 63 | reduction: str = 'mean', 64 | targets: Optional[chex.Array] = None) -> data_utils.ScalarDict: 65 | """The adversarial loss function. 66 | 67 | If la = LearningAlgorithm(), this function is applied in a min-max game 68 | with la(). The model is trained to minimize the loss arising from la(), 69 | while maximizing the loss from the adversary (la.adversary()). The 70 | adversarial part of the model tries to minimize this loss. 71 | 72 | Args: 73 | logits: The predicted value input to the training algorithm. 74 | property_vs: An set of properties of the input data. 75 | reduction: How to combine the loss for different samples. 76 | targets: The ground truth value to estimate (optional). 77 | 78 | Returns: 79 | scalars: A dictionary of key and scalar estimates. The key `adv_loss` is 80 | the value that should be minimized (for the adversary) and maximized ( 81 | for the model). If empty, this learning algorithm has no adversary. 82 | """ 83 | # Do nothing. 84 | return {} 85 | -------------------------------------------------------------------------------- /distribution_shift_framework/core/algorithms/erm.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # 3 | # Copyright 2022 DeepMind Technologies Limited 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """Empirical risk minimization for minimizing loss.""" 18 | import abc 19 | from typing import Tuple 20 | 21 | import chex 22 | from distribution_shift_framework.core.algorithms import base 23 | from distribution_shift_framework.core.algorithms import losses 24 | from distribution_shift_framework.core.datasets import data_utils 25 | import jax 26 | import jax.numpy as jnp 27 | 28 | 29 | class ERM(base.LearningAlgorithm): 30 | """Computes the empirical risk.""" 31 | 32 | def __init__(self, 33 | loss_fn: base.LossFn = losses.softmax_cross_entropy, 34 | name: str = 'empirical_risk'): 35 | super().__init__(loss_fn=loss_fn, name=name) 36 | 37 | def __call__(self, 38 | logits: chex.Array, 39 | targets: chex.Array, 40 | reduction: str = 'mean', 41 | **unused_kwargs) -> Tuple[data_utils.ScalarDict, chex.Array]: 42 | loss = self.loss_fn(logits, targets, reduction=reduction) 43 | return {'loss': loss}, logits 44 | 45 | 46 | class AbstractMMD(base.LearningAlgorithm): 47 | """Base class for the CORAL and MMD algorithms.""" 48 | 49 | def __init__(self, 50 | mmd_weight: float = 1., 51 | loss_fn: base.LossFn = losses.softmax_cross_entropy, 52 | name: str = 'coral'): 53 | super().__init__(loss_fn=loss_fn, name=name) 54 | self.mmd_weight = mmd_weight 55 | 56 | @abc.abstractmethod 57 | def _mmd(self, x: chex.Array, x_mask: chex.Array, y: chex.Array, 58 | y_mask: chex.Array) -> chex.Array: 59 | """Computes the MMD between two sets of masked features. 60 | 61 | Args: 62 | x: The first set of features. 63 | x_mask: Which of the x features should be considered. 64 | y: The second set of features. 65 | y_mask: Which of the y features should be considered. 66 | 67 | Returns: 68 | A tuple of the mean and covariance. 69 | """ 70 | pass 71 | 72 | def __call__(self, 73 | logits: chex.Array, 74 | targets: chex.Array, 75 | property_vs: chex.Array, 76 | reduction: str = 'mean' 77 | ) -> Tuple[data_utils.ScalarDict, chex.Array]: 78 | """Compute the MMD loss where the domains are given by the properties.""" 79 | pnum = property_vs.shape[-1] 80 | if len(property_vs.shape) != 2: 81 | raise ValueError( 82 | f'Properties have an unexpected shape: {property_vs.shape}.') 83 | 84 | # For each label, compute the difference in domain shift against all the 85 | # others. 86 | mmd_loss = {'loss': 0} 87 | property_pairs = [] 88 | for i, property_v1 in enumerate(range(pnum)): 89 | for property_v2 in range(i + 1, pnum): 90 | property_pairs += [(property_v1, property_v2)] 91 | 92 | def compute_pair_loss(mmd_loss, pair_vs): 93 | property_v1, property_v2 = pair_vs 94 | 95 | # One hot encoding. 96 | mask1 = jnp.argmax(property_vs, axis=-1)[..., None] == property_v1 97 | mask2 = jnp.argmax(targets, axis=-1)[..., None] == property_v2 98 | 99 | loss = jax.lax.cond( 100 | jnp.minimum(mask1.sum(), mask2.sum()) > 1, 101 | lambda a: self._mmd(*a), 102 | lambda _: jnp.zeros(()), 103 | operand=(logits, mask1, logits, mask2)) 104 | 105 | t_mmd_loss = {'loss': loss} 106 | mmd_loss = jax.tree_map(jnp.add, mmd_loss, t_mmd_loss) 107 | return (mmd_loss, 0) 108 | 109 | mmd_loss, _ = jax.lax.scan(compute_pair_loss, mmd_loss, 110 | jnp.array(property_pairs)) 111 | 112 | erm = self.loss_fn(logits, targets, reduction=reduction) 113 | # How well are we estimating the labels? 114 | top1_acc = (jnp.argmax(logits, axis=-1) == jnp.argmax(targets, 115 | axis=-1)).mean() 116 | 117 | loss = mmd_loss['loss'] / (pnum * (pnum - 1)) * self.mmd_weight + erm 118 | mmd_loss['loss'] = loss 119 | mmd_loss['erm'] = erm 120 | mmd_loss['top1_acc'] = top1_acc 121 | return mmd_loss, logits 122 | 123 | 124 | class CORAL(AbstractMMD): 125 | """The CORAL algorithm. 126 | 127 | Computes the empirical risk and enforces that feature distributions match 128 | across distributions (by minimizing the maximum mean discrepancy). 129 | """ 130 | 131 | def __init__(self, 132 | coral_weight: float = 1., 133 | loss_fn: base.LossFn = losses.softmax_cross_entropy, 134 | name: str = 'coral'): 135 | super().__init__(loss_fn=loss_fn, name=name, mmd_weight=coral_weight) 136 | 137 | def _mmd(self, x: chex.Array, x_mask: chex.Array, y: chex.Array, 138 | y_mask: chex.Array) -> chex.Array: 139 | """Computes the MMD between two sets of masked features. 140 | 141 | Args: 142 | x: The first set of features. 143 | x_mask: Which of the x features should be considered. 144 | y: The second set of features. 145 | y_mask: Which of the y features should be considered. 146 | 147 | Returns: 148 | A tuple of the mean and covariance. 149 | """ 150 | mean_x = (x * x_mask).sum(0, keepdims=True) / x_mask.sum() 151 | mean_y = (y * y_mask).sum(0, keepdims=True) / y_mask.sum() 152 | cent_x = (x - mean_x) * x_mask 153 | cent_y = (y - mean_y) * y_mask 154 | 155 | # Compute the covariances of the inputs. 156 | cova_x = cent_x.T.dot(cent_x) / (x_mask.sum() - 1) 157 | cova_y = cent_y.T.dot(cent_y) / (y_mask.sum() - 1) 158 | 159 | d_x = x_mask.sum() 160 | d_y = y_mask.sum() 161 | 162 | mean_mse = ((mean_x - mean_y)**2).mean() 163 | cov_mse = ((cova_x - cova_y)**2 / (4 * d_x * d_y)).mean() 164 | return mean_mse + cov_mse 165 | -------------------------------------------------------------------------------- /distribution_shift_framework/core/algorithms/irm.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # 3 | # Copyright 2022 DeepMind Technologies Limited 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """Invariant risk minimization for minimizing loss.""" 18 | from typing import Tuple 19 | 20 | import chex 21 | from distribution_shift_framework.core.algorithms import base 22 | from distribution_shift_framework.core.algorithms import losses 23 | from distribution_shift_framework.core.datasets import data_utils 24 | import haiku as hk 25 | import jax.numpy as jnp 26 | 27 | 28 | class IRM(base.LearningAlgorithm): 29 | """Computes the invariant risk. 30 | 31 | This learning algorithm is based on that of Arjovosky et al. Invariant Risk 32 | Minimization. https://arxiv.org/abs/1907.02893. 33 | 34 | It enforces that the optimal classifiers for representations with different 35 | properties are the same. 36 | """ 37 | 38 | def __init__(self, 39 | lambda_penalty: float = 1., 40 | loss_fn: base.LossFn = losses.softmax_cross_entropy, 41 | name: str = 'invariant_risk'): 42 | super().__init__(loss_fn=loss_fn, name=name) 43 | self.penalty_weight = lambda_penalty 44 | 45 | def _apply_loss(self, weights, logits, targets): 46 | return self.loss_fn(logits * weights, targets, reduction='mean') 47 | 48 | def __call__(self, 49 | logits: chex.Array, 50 | targets: chex.Array, 51 | property_vs: chex.Array, 52 | reduction: str = 'mean' 53 | ) -> Tuple[data_utils.ScalarDict, chex.Array]: 54 | assert len(targets.shape) == 2 55 | erm = 0 56 | penalty = 0 57 | 58 | # For each property, estimate the weights of an optimal classifier. 59 | for property_v in range(property_vs.shape[-1]): 60 | if len(property_vs.shape) == 2: 61 | # One hot encoding. 62 | mask = jnp.argmax(property_vs, axis=-1)[..., None] == property_v 63 | masked_logits = mask * logits 64 | masked_targets = mask * targets 65 | else: 66 | raise ValueError( 67 | f'Properties have an unexpected shape: {property_vs.shape}.') 68 | 69 | weights = jnp.ones((1,)) 70 | 71 | # Compute empirical risk. 72 | erm += self._apply_loss(weights, masked_logits, masked_targets) 73 | 74 | # Compute penalty. 75 | grad_fn = hk.grad(self._apply_loss, argnums=0) 76 | grad_1 = grad_fn(weights, masked_logits[::2], masked_targets[::2]) 77 | grad_2 = grad_fn(weights, masked_logits[1::2], masked_targets[1::2]) 78 | penalty += (grad_1 * grad_2).sum() 79 | 80 | # How well are we estimating the labels? 81 | top1_acc = (jnp.argmax(logits, axis=-1) == jnp.argmax(targets, 82 | axis=-1)).mean() 83 | 84 | return { 85 | 'loss': erm + self.penalty_weight * penalty, 86 | 'erm': erm, 87 | 'penalty': penalty, 88 | 'top1_acc': top1_acc 89 | }, logits 90 | -------------------------------------------------------------------------------- /distribution_shift_framework/core/algorithms/losses.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # 3 | # Copyright 2022 DeepMind Technologies Limited 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """Common Losses to be used to train a model.""" 18 | 19 | import chex 20 | import jax.numpy as jnp 21 | import optax 22 | 23 | 24 | def softmax_cross_entropy(logits: chex.Array, 25 | labels: chex.Array, 26 | reduction: str = 'sum') -> chex.Array: 27 | """Computes softmax cross entropy given logits and one-hot class labels. 28 | 29 | Args: 30 | logits: Logit output values. 31 | labels: Ground truth one-hot-encoded labels. 32 | reduction: Type of reduction to apply to loss. 33 | 34 | Returns: 35 | Loss value. If `reduction` is `none`, this has the same shape as `labels`; 36 | otherwise, it is scalar. 37 | 38 | Raises: 39 | ValueError: If the type of `reduction` is unsupported. 40 | """ 41 | x = optax.softmax_cross_entropy(logits, labels) 42 | if reduction == 'none' or reduction is None: 43 | return jnp.asarray(x) 44 | elif reduction == 'sum': 45 | return jnp.asarray(x).sum() 46 | elif reduction == 'mean': 47 | return jnp.mean(jnp.asarray(x)) 48 | else: 49 | raise ValueError('Unsupported reduction option.') 50 | -------------------------------------------------------------------------------- /distribution_shift_framework/core/algorithms/sagnet.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # 3 | # Copyright 2022 DeepMind Technologies Limited 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """Training representations to be style agnostic..""" 18 | from typing import Any, Mapping, Optional, Tuple 19 | 20 | import chex 21 | from distribution_shift_framework.core.algorithms import base 22 | from distribution_shift_framework.core.algorithms import losses 23 | from distribution_shift_framework.core.datasets import data_utils 24 | import haiku as hk 25 | import jax 26 | import jax.numpy as jnp 27 | import ml_collections 28 | 29 | 30 | class SagNet(base.LearningAlgorithm): 31 | """Implemenets a SagNet https://arxiv.org/pdf/1910.11645.pdf. 32 | 33 | This is a method for training networks to be invariant to style for 34 | improved domain generalization. 35 | """ 36 | 37 | def __init__(self, 38 | loss_fn: base.LossFn = losses.softmax_cross_entropy, 39 | content_net_fn=hk.nets.MLP, 40 | content_net_kwargs: Mapping[str, 41 | Any] = (ml_collections.ConfigDict( 42 | dict(output_sizes=(64, 64, 43 | 64)))), 44 | style_net_fn=hk.nets.MLP, 45 | style_net_kwargs: Mapping[str, Any] = ml_collections.ConfigDict( 46 | dict(output_size=(64, 64, 64))), 47 | name: str = 'SagNet', 48 | **kwargs): 49 | super().__init__(loss_fn=loss_fn, name=name) 50 | self._content_net_fn = content_net_fn 51 | self._content_net_kwargs = content_net_kwargs 52 | 53 | self._style_net_fn = style_net_fn 54 | self._style_net_kwargs = style_net_kwargs 55 | 56 | def _randomize(self, features, interpolate=False, eps=1e-5): 57 | """Apply the ADAIN style operator (https://arxiv.org/abs/1703.06868).""" 58 | b = features.shape[0] 59 | alpha = jax.random.uniform(hk.next_rng_key(), 60 | (b,) + (1,) * len(features.shape[1:])) 61 | 62 | is_image_shape = len(features.shape) == 4 63 | if is_image_shape: 64 | # Features is an image of with shape BHWC. 65 | b, h, w, c = features.shape 66 | features = jnp.transpose(features, axes=(0, 3, 1, 2)).view(b, c, -1) 67 | 68 | mean = jnp.mean(features, axis=(-1,), keepdims=True) 69 | variance = jnp.var(features, axis=(-1,), keepdims=True) 70 | features = (features - mean) / jnp.sqrt(variance + eps) 71 | 72 | idx_swap = jax.random.permutation(hk.next_rng_key(), jnp.arange(b)) 73 | if interpolate: 74 | mean = alpha * mean + (1 - alpha) * mean[idx_swap, ...] 75 | variance = alpha * variance + (1 - alpha) * variance[idx_swap, ...] 76 | else: 77 | features = jax.lax.stop_gradient(features[idx_swap, ...]) 78 | 79 | features = features * jnp.sqrt(variance + eps) + mean 80 | if is_image_shape: 81 | features = jnp.transpose(features, axes=(0, 2, 1)).view(b, h, w, c) 82 | return features 83 | 84 | def _content_pred(self, features): 85 | features = self._randomize(features, True) 86 | return self._content_net_fn(**self._content_net_kwargs)(features) 87 | 88 | def _style_pred(self, features): 89 | features = self._randomize(features, False) 90 | return self._style_net_fn(**self._style_net_kwargs)(features) 91 | 92 | def __call__(self, 93 | logits: chex.Array, 94 | targets: chex.Array, 95 | property_vs: chex.Array, 96 | reduction: str = 'mean' 97 | ) -> Tuple[data_utils.ScalarDict, chex.Array]: 98 | """Train the content network.""" 99 | if len(logits.shape) == 4: 100 | logits = jnp.mean(logits, axis=(1, 2)) 101 | preds = self._content_pred(logits) 102 | loss_content = self.loss_fn(preds, targets) 103 | 104 | # How well are we estimating the content? 105 | top1_acc = (jnp.argmax(preds, axis=-1) == jnp.argmax(targets, 106 | axis=-1)).mean() 107 | return {'loss': loss_content, 'top1_acc': top1_acc}, preds 108 | 109 | def adversary(self, 110 | logits: chex.Array, 111 | property_vs: chex.Array, 112 | reduction: str = 'mean', 113 | targets: Optional[chex.Array] = None) -> data_utils.ScalarDict: 114 | """Train the adversary which aims to predict style.""" 115 | if len(logits.shape) == 4: 116 | logits = jnp.mean(logits, axis=(1, 2)) 117 | preds = self._style_pred(logits) 118 | loss_style = self.loss_fn(preds, targets) 119 | # How well are we estimating the style? 120 | top1_acc = (jnp.argmax(preds, axis=-1) == jnp.argmax(targets, 121 | axis=-1)).mean() 122 | return {'loss': loss_style, 'style_top1acc': top1_acc} 123 | -------------------------------------------------------------------------------- /distribution_shift_framework/core/checkpointing.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # 3 | # Copyright 2022 DeepMind Technologies Limited 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """Checkpointing code. 18 | """ 19 | import os 20 | import pickle 21 | from typing import Mapping, Optional 22 | 23 | import jax 24 | import ml_collections 25 | import optax 26 | import tensorflow as tf 27 | 28 | 29 | def load_model(checkpoint_path: str) -> Mapping[str, optax.Params]: 30 | with tf.io.gfile.GFile(checkpoint_path, 'rb') as f: 31 | return pickle.load(f) 32 | 33 | 34 | def save_model(checkpoint_path: str, 35 | ckpt_dict: Mapping[str, optax.Params]): 36 | with tf.io.gfile.GFile(checkpoint_path, 'wb') as f: 37 | # Using protocol 4 as it's the default from Python 3.8 on. 38 | pickle.dump(ckpt_dict, f, protocol=4) 39 | 40 | 41 | def get_checkpoint_dir(config: ml_collections.ConfigDict) -> Optional[str]: 42 | """Constructs the checkpoint directory from the config.""" 43 | 44 | if config.checkpoint_dir is None: 45 | return None 46 | path = os.path.join(config.checkpoint_dir, 47 | config.host_subdir.format(host_id=jax.process_index())) 48 | if not tf.io.gfile.exists(path): 49 | tf.io.gfile.makedirs(path) 50 | return path 51 | -------------------------------------------------------------------------------- /distribution_shift_framework/core/checkpointing_test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # 3 | # Copyright 2022 DeepMind Technologies Limited 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """Tests for distribution_shift_framework.core.checkpointing.""" 18 | 19 | import os 20 | 21 | from absl.testing import absltest 22 | from absl.testing import parameterized 23 | from distribution_shift_framework.core import checkpointing 24 | import jax 25 | import ml_collections 26 | import numpy as np 27 | import numpy.testing as npt 28 | import tensorflow as tf 29 | 30 | 31 | class CheckpointingTest(parameterized.TestCase): 32 | 33 | @parameterized.parameters([ 34 | dict(data={}), 35 | dict(data={'params': []}), 36 | dict(data={'state': None}), 37 | dict(data={'params': 3, 'stuff': 5.3, 'something': 'anything'}), 38 | dict(data={'params': {'stuff': 5.3, 'something': 'anything'}}), 39 | dict(data={'params': {'stuff': {'something': 'anything'}}}), 40 | dict(data={'params': {'stuff': {'something': np.random.rand(4, 3, 2)}}}), 41 | ]) 42 | def test_load_and_save_model(self, data): 43 | ckpt_file = os.path.join(self.create_tempdir(), 'ckpt.pkl') 44 | checkpointing.save_model(ckpt_file, data) 45 | loaded_data = checkpointing.load_model(ckpt_file) 46 | loaded_leaves, loaded_treedef = jax.tree_flatten(loaded_data) 47 | leaves, treedef = jax.tree_flatten(data) 48 | for leaf, loaded_leaf in zip(leaves, loaded_leaves): 49 | npt.assert_array_equal(leaf, loaded_leaf) 50 | self.assertEqual(treedef, loaded_treedef) 51 | 52 | def test_empty_checkpoint_dir(self): 53 | config = ml_collections.ConfigDict() 54 | config.checkpoint_dir = None 55 | self.assertIsNone(checkpointing.get_checkpoint_dir(config)) 56 | 57 | def test_get_checkpoint_dir(self): 58 | config = ml_collections.ConfigDict() 59 | temp_dir = self.create_tempdir() 60 | config.checkpoint_dir = os.path.join(temp_dir, 'my_exp') 61 | self.assertFalse(tf.io.gfile.exists(config.checkpoint_dir)) 62 | config.host_subdir = 'prefix_{host_id}_postfix' 63 | path = checkpointing.get_checkpoint_dir(config) 64 | self.assertEqual(os.path.join(temp_dir, 'my_exp', 'prefix_0_postfix'), path) 65 | self.assertTrue(tf.io.gfile.exists(path)) 66 | 67 | 68 | if __name__ == '__main__': 69 | absltest.main() 70 | -------------------------------------------------------------------------------- /distribution_shift_framework/core/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # 3 | # Copyright 2022 DeepMind Technologies Limited 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | -------------------------------------------------------------------------------- /distribution_shift_framework/core/datasets/data_loaders.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # 3 | # Copyright 2022 DeepMind Technologies Limited 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """Loader and preprocessing functions for the datasets.""" 18 | 19 | from typing import Optional, Sequence 20 | 21 | import chex 22 | from distribution_shift_framework.core.datasets import data_utils 23 | import jax 24 | import numpy as np 25 | import tensorflow.compat.v2 as tf 26 | import tensorflow_datasets as tfds 27 | 28 | 29 | def shapes3d_normalize(image: chex.Array) -> chex.Array: 30 | return (image - .5) * 2 31 | 32 | 33 | def shapes3d_preprocess( 34 | mode: str = 'train' 35 | ) -> data_utils.TFPreprocessFn: 36 | del mode 37 | def _preprocess_fn(example): 38 | example['image'] = tf.image.convert_image_dtype( 39 | example['image'], dtype=tf.float32) 40 | example['label'] = example['label_shape'] 41 | return example 42 | return _preprocess_fn 43 | 44 | 45 | def unbatched_load_shapes3d(subset: str = 'train', 46 | valid_size: int = 10000, 47 | test_size: int = 10000) -> data_utils.Dataset: 48 | """Loads the 3D Shapes dataset without batching.""" 49 | if subset == 'train': 50 | ds = tfds.load(name='shapes3d', split=tfds.Split.TRAIN 51 | ).skip(valid_size + test_size) 52 | elif subset == 'valid': 53 | ds = tfds.load(name='shapes3d', split=tfds.Split.TRAIN 54 | ).skip(test_size).take(valid_size) 55 | elif subset == 'train_and_valid': 56 | ds = tfds.load(name='shapes3d', split=tfds.Split.TRAIN).skip(test_size) 57 | elif subset == 'test': 58 | ds = tfds.load(name='shapes3d', split=tfds.Split.TRAIN).take(test_size) 59 | else: 60 | raise ValueError('Unknown subset: "{}"'.format(subset)) 61 | return ds 62 | 63 | 64 | def load_shapes3d(batch_sizes: Sequence[int], 65 | subset: str = 'train', 66 | is_training: bool = True, 67 | num_samples: Optional[int] = None, 68 | preprocess_fn: Optional[data_utils.PreprocessFnGen] = None, 69 | transpose: bool = False, 70 | valid_size: int = 10000, 71 | test_size: int = 10000, 72 | drop_remainder: bool = True, 73 | local_cache: bool = True) -> data_utils.Dataset: 74 | """Loads the 3D Shapes dataset. 75 | 76 | The 3D shapes dataset is available at https://github.com/deepmind/3d-shapes. 77 | It consists of 4 different shapes which vary along 5 different axes: 78 | - Floor hue: 10 colors with varying red, orange, yellow, green, blue 79 | - Wall hue: 10 colors with varying red, orange, yellow, green, blue 80 | - Object hue: 10 colors with varying red, orange, yellow, green, blue 81 | - Scale: How large the object is. 82 | - Shape: 4 values -- (cube, sphere, cylinder, and oblong). 83 | - Orientation: Rotates the object around the vertical axis. 84 | 85 | Args: 86 | batch_sizes: Specifies how to batch examples. I.e., if batch_sizes = [8, 4] 87 | then output images will have shapes (8, 4, height, width, 3). 88 | subset: Specifies which subset (train, valid or train_and_valid) to use. 89 | is_training: Whether to infinitely repeat and shuffle examples (`True`) or 90 | not (`False`). 91 | num_samples: The number of samples to crop each individual dataset variant 92 | from the start, or `None` to use the full dataset. 93 | preprocess_fn: Function mapped onto each example for pre-processing. 94 | transpose: Whether to permute image dimensions NHWC -> HWCN to speed up 95 | performance on TPUs. 96 | valid_size: Size of the validation set to take from the training set. 97 | test_size: Size of the validation set to take from the training set. 98 | drop_remainder: Whether to drop the last batch(es) if they would not match 99 | the shapes specified by `batch_sizes`. 100 | local_cache: Whether to locally cache the dataset. 101 | 102 | Returns: 103 | ds: Fully configured dataset ready for training/evaluation. 104 | """ 105 | if preprocess_fn is None: 106 | preprocess_fn = shapes3d_preprocess 107 | ds = unbatched_load_shapes3d(subset=subset, valid_size=valid_size, 108 | test_size=test_size) 109 | total_batch_size = np.prod(batch_sizes) 110 | if subset == 'valid' and valid_size < total_batch_size: 111 | ds = ds.repeat().take(total_batch_size) 112 | ds = batch_and_shuffle(ds, batch_sizes, 113 | is_training=is_training, 114 | transpose=transpose, 115 | num_samples=num_samples, 116 | preprocess_fn=preprocess_fn, 117 | drop_remainder=drop_remainder, 118 | local_cache=local_cache) 119 | return ds 120 | 121 | 122 | def small_norb_normalize(image: chex.Array) -> chex.Array: 123 | return (image - .5) * 2 124 | 125 | 126 | def small_norb_preprocess( 127 | mode: str = 'train' 128 | ) -> data_utils.TFPreprocessFn: 129 | del mode 130 | def _preprocess_fn(example): 131 | example['image'] = tf.image.convert_image_dtype( 132 | example['image'], dtype=tf.float32) 133 | example['label'] = example['label_category'] 134 | return example 135 | return _preprocess_fn 136 | 137 | 138 | def unbatched_load_small_norb(subset: str = 'train', 139 | valid_size: int = 10000) -> data_utils.Dataset: 140 | """Load the small norb dataset.""" 141 | if subset == 'train': 142 | ds = tfds.load(name='smallnorb', split=tfds.Split.TRAIN).skip(valid_size) 143 | elif subset == 'valid': 144 | ds = tfds.load(name='smallnorb', split=tfds.Split.TRAIN).take(valid_size) 145 | elif subset == 'train_and_valid': 146 | ds = tfds.load(name='smallnorb', split=tfds.Split.TRAIN) 147 | elif subset == 'test': 148 | ds = tfds.load(name='smallnorb', split=tfds.Split.TEST) 149 | else: 150 | raise ValueError('Unknown subset: "{}"'.format(subset)) 151 | return ds 152 | 153 | 154 | def load_small_norb(batch_sizes: Sequence[int], 155 | subset: str = 'train', 156 | is_training: bool = True, 157 | num_samples: Optional[int] = None, 158 | preprocess_fn: Optional[data_utils.PreprocessFnGen] = None, 159 | transpose: bool = False, 160 | valid_size: int = 1000, 161 | drop_remainder: bool = True, 162 | local_cache: bool = True) -> data_utils.Dataset: 163 | """Loads the small norb dataset. 164 | 165 | The norb dataset is available at: 166 | https://cs.nyu.edu/~ylclab/data/norb-v1.0-small/. 167 | 168 | It consists of 5 categories (Animals, People, Airplanes, Trucks, and Cars). 169 | These categories have 5 instances (different animals, airplanes, or types of 170 | cars). 171 | 172 | They vary by (which are consistent across categories and instances): 173 | 1. Elevation 174 | 2. Azimuth 175 | 3. Lighting 176 | 177 | Args: 178 | batch_sizes: Specifies how to batch examples. I.e., if batch_sizes = [8, 4] 179 | then output images will have shapes (8, 4, height, width, 3). 180 | subset: Specifies the subset (train, valid, test or train_and_valid) to use. 181 | is_training: Whether to infinitely repeat and shuffle examples (`True`) or 182 | not (`False`). 183 | num_samples: The number of samples to crop each individual dataset variant 184 | from the start, or `None` to use the full dataset. 185 | preprocess_fn: Function mapped onto each example for pre-processing. 186 | transpose: Whether to permute image dimensions NHWC -> HWCN to speed up 187 | performance on TPUs. 188 | valid_size: The size of the validation set. 189 | drop_remainder: Whether to drop the last batch(es) if they would not match 190 | the shapes specified by `batch_sizes`. 191 | local_cache: Whether to locally cache the dataset. 192 | 193 | Returns: 194 | ds: Fully configured dataset ready for training/evaluation. 195 | """ 196 | if preprocess_fn is None: 197 | preprocess_fn = small_norb_preprocess 198 | 199 | ds = unbatched_load_small_norb(subset=subset, valid_size=valid_size) 200 | total_batch_size = np.prod(batch_sizes) 201 | if subset == 'valid' and valid_size < total_batch_size: 202 | ds = ds.repeat().take(total_batch_size) 203 | ds = batch_and_shuffle(ds, batch_sizes, 204 | is_training=is_training, 205 | transpose=transpose, 206 | num_samples=num_samples, 207 | preprocess_fn=preprocess_fn, 208 | drop_remainder=drop_remainder, 209 | local_cache=local_cache) 210 | return ds 211 | 212 | 213 | def dsprites_normalize(image: chex.Array) -> chex.Array: 214 | return (image - .5) * 2 215 | 216 | 217 | def dsprites_preprocess( 218 | mode: str = 'train' 219 | ) -> data_utils.TFPreprocessFn: 220 | del mode 221 | def _preprocess_fn(example): 222 | example['image'] = tf.image.convert_image_dtype( 223 | example['image'], dtype=tf.float32) * 255. 224 | example['label'] = example['label_shape'] 225 | return example 226 | return _preprocess_fn 227 | 228 | 229 | def unbatched_load_dsprites(subset: str = 'train', 230 | valid_size: int = 10000, 231 | test_size: int = 10000) -> data_utils.Dataset: 232 | """Loads the dsprites dataset without batching and prefetching.""" 233 | if subset == 'train': 234 | ds = tfds.load(name='dsprites', 235 | split=tfds.Split.TRAIN).skip(valid_size + test_size) 236 | elif subset == 'valid': 237 | ds = tfds.load(name='dsprites', 238 | split=tfds.Split.TRAIN).skip(test_size).take(valid_size) 239 | elif subset == 'train_and_valid': 240 | ds = tfds.load(name='dsprites', split=tfds.Split.TRAIN).skip(test_size) 241 | elif subset == 'test': 242 | ds = tfds.load(name='dsprites', split=tfds.Split.TRAIN).take(test_size) 243 | else: 244 | raise ValueError('Unknown subset: "{}"'.format(subset)) 245 | return ds 246 | 247 | 248 | def load_dsprites(batch_sizes: Sequence[int], 249 | subset: str = 'train', 250 | is_training: bool = True, 251 | num_samples: Optional[int] = None, 252 | preprocess_fn: Optional[data_utils.PreprocessFnGen] = None, 253 | transpose: bool = False, 254 | valid_size: int = 10000, 255 | test_size: int = 10000, 256 | drop_remainder: bool = True, 257 | local_cache: bool = True) -> data_utils.Dataset: 258 | """Loads the dsprites dataset. 259 | 260 | The dsprites dataset is available at: 261 | https://github.com/deepmind/dsprites-dataset. 262 | 263 | It consists of 3 shapes (heart, ellipse and square). 264 | 265 | They vary by (which are consistent across categories and instances): 266 | 1. Scale (6 values) 267 | 2. Orientation: 40 values (rotates around the center of the object) 268 | 3. Position (X): 32 values 269 | 4. Position (Y): 32 values 270 | 271 | Args: 272 | batch_sizes: Specifies how to batch examples. I.e., if batch_sizes = [8, 4] 273 | then output images will have shapes (8, 4, height, width, 3). 274 | subset: Specifies the subset (train, valid, test or train_and_valid) to use. 275 | is_training: Whether to infinitely repeat and shuffle examples (`True`) or 276 | not (`False`). 277 | num_samples: The number of samples to crop each individual dataset variant 278 | from the start, or `None` to use the full dataset. 279 | preprocess_fn: Function mapped onto each example for pre-processing. 280 | transpose: Whether to permute image dimensions NHWC -> HWCN to speed up 281 | performance on TPUs. 282 | valid_size: The size of the validation set. 283 | test_size: The size of the test set. 284 | drop_remainder: Whether to drop the last batch(es) if they would not match 285 | the shapes specified by `batch_sizes`. 286 | local_cache: Whether to locally cache the dataset. 287 | 288 | Returns: 289 | ds: Fully configured dataset ready for training/evaluation. 290 | """ 291 | if preprocess_fn is None: 292 | preprocess_fn = dsprites_preprocess 293 | 294 | ds = unbatched_load_dsprites(subset=subset, valid_size=valid_size, 295 | test_size=test_size) 296 | total_batch_size = np.prod(batch_sizes) 297 | if subset == 'valid' and valid_size < total_batch_size: 298 | ds = ds.repeat().take(total_batch_size) 299 | ds = batch_and_shuffle(ds, batch_sizes, 300 | is_training=is_training, 301 | transpose=transpose, 302 | num_samples=num_samples, 303 | preprocess_fn=preprocess_fn, 304 | drop_remainder=drop_remainder, 305 | local_cache=local_cache) 306 | return ds 307 | 308 | 309 | def batch_and_shuffle( 310 | ds: data_utils.Dataset, 311 | batch_sizes: Sequence[int], 312 | preprocess_fn: Optional[data_utils.PreprocessFnGen] = None, 313 | is_training: bool = True, 314 | num_samples: Optional[int] = None, 315 | transpose: bool = False, 316 | drop_remainder: bool = True, 317 | local_cache: bool = False) -> data_utils.Dataset: 318 | """Performs post-processing on datasets (i.e., batching, transposing). 319 | 320 | Args: 321 | ds: The dataset. 322 | batch_sizes: Specifies how to batch examples. I.e., if batch_sizes = [8, 4] 323 | then output images will have shapes (8, 4, height, width, 3). 324 | preprocess_fn: Function mapped onto each example for pre-processing. 325 | is_training: Whether to infinitely repeat and shuffle examples (`True`) or 326 | not (`False`). 327 | num_samples: The number of samples to crop each individual dataset variant 328 | from the start, or `None` to use the full dataset. 329 | transpose: Whether to permute image dimensions NHWC -> HWCN to speed up 330 | performance on TPUs. 331 | drop_remainder: Whether to drop the last batch(es) if they would not match 332 | the shapes specified by `batch_sizes`. 333 | local_cache: Whether to locally cache the dataset. 334 | Returns: 335 | ds: Dataset with all the post-processing applied. 336 | """ 337 | if num_samples: 338 | ds = ds.take(num_samples) 339 | if local_cache: 340 | ds = ds.cache() 341 | if is_training: 342 | ds = ds.repeat() 343 | total_batch_size = np.prod(batch_sizes) 344 | shuffle_buffer = 10 * total_batch_size 345 | ds = ds.shuffle(buffer_size=shuffle_buffer, seed=jax.process_index()) 346 | if preprocess_fn is not None: 347 | ds = ds.map(preprocess_fn('train' if is_training else 'test'), 348 | num_parallel_calls=tf.data.experimental.AUTOTUNE) 349 | for i, batch_size in enumerate(reversed(batch_sizes)): 350 | ds = ds.batch(batch_size, drop_remainder=drop_remainder) 351 | if i == 0 and transpose: 352 | ds = ds.map(data_utils.transpose_fn) # NHWC -> HWCN. 353 | ds = ds.prefetch(tf.data.experimental.AUTOTUNE) 354 | return ds 355 | -------------------------------------------------------------------------------- /distribution_shift_framework/core/datasets/data_utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # 3 | # Copyright 2022 DeepMind Technologies Limited 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """Data utility functions.""" 18 | import enum 19 | from typing import Any, Callable, Dict, Generator, Iterable, Mapping, Optional, Sequence, Tuple, Union 20 | 21 | import chex 22 | import jax 23 | import ml_collections 24 | import numpy as np 25 | import tensorflow.compat.v2 as tf 26 | import tensorflow_datasets as tfds 27 | 28 | # Type Aliases 29 | Batch = Dict[str, np.ndarray] 30 | ScalarDict = Dict[str, chex.Array] 31 | # Objects that can be treated like tensors in TF2. 32 | TFTensorLike = Union[np.ndarray, tf.Tensor, tf.Variable] 33 | # pytype: disable=not-supported-yet 34 | TFTensorNest = Union[TFTensorLike, Iterable['TFTensorNest'], 35 | Mapping[str, 'TFTensorNest']] 36 | # pytype: enable=not-supported-yet 37 | PreprocessFnGen = Callable[[str], Callable[[chex.ArrayTree], chex.ArrayTree]] 38 | TFPreprocessFn = Callable[[TFTensorNest], TFTensorNest] 39 | 40 | Dataset = tf.data.Dataset 41 | 42 | # Disentanglement datasets. 43 | 44 | SHAPES3D_PROPERTIES = { 45 | 'label_scale': tuple(range(8)), 46 | 'label_orientation': tuple(range(15)), 47 | 'label_floor_hue': tuple(range(10)), 48 | 'label_object_hue': tuple(range(10)), 49 | 'label_wall_hue': tuple(range(10)), 50 | 'label_shape': tuple(range(4)), 51 | 'label_color': 52 | tuple(range(3)) # Only added through preprocessing. 53 | } 54 | 55 | SMALL_NORB_PROPERTIES = { 56 | 'label_azimuth': tuple(range(18)), 57 | 'label_elevation': tuple(range(9)), 58 | 'label_lighting': tuple(range(6)), 59 | 'label_category': tuple(range(5)), 60 | } 61 | 62 | DSPRITES_PROPERTIES = { 63 | 'label_scale': tuple(range(6)), 64 | 'label_orientation': tuple(range(40)), 65 | 'label_x_position': tuple(range(32)), 66 | 'label_y_position': tuple(range(32)), 67 | 'label_shape': tuple(range(3)), 68 | } 69 | 70 | 71 | class DatasetNames(enum.Enum): 72 | """Names of the datasets.""" 73 | SHAPES3D = 'shapes3d' 74 | SMALL_NORB = 'small_norb' 75 | DSPRITES = 'dsprites' 76 | 77 | 78 | class NumChannels(enum.Enum): 79 | """Number of channels of the images.""" 80 | SHAPES3D = 3 81 | SMALL_NORB = 1 82 | DSPRITES = 1 83 | 84 | 85 | class Variance(enum.Enum): 86 | """Variance of the pixels in the images.""" 87 | SHAPES3D = 0.155252 88 | SMALL_NORB = 0.031452 89 | DSPRITES = 0.04068864749147259 90 | 91 | 92 | class ImageSize(enum.Enum): 93 | """Size of the images.""" 94 | SHAPES3D = 64 95 | SMALL_NORB = 96 96 | DSPRITES = 64 97 | 98 | 99 | def is_disentanglement_dataset(dataset_name: str) -> bool: 100 | return dataset_name in (DatasetNames.SHAPES3D.value, 101 | DatasetNames.SMALL_NORB.value, 102 | DatasetNames.DSPRITES.value) 103 | 104 | 105 | def get_dataset_constants(dataset_name: str, 106 | label: str = 'label', 107 | variant: Optional[str] = None) -> Mapping[str, Any]: 108 | """Returns a dictionary with several constants for the dataset.""" 109 | if variant: 110 | properties_name = f'{dataset_name.upper()}_{variant.upper()}_PROPERTIES' 111 | else: 112 | properties_name = f'{dataset_name.upper()}_PROPERTIES' 113 | properties = globals()[properties_name] 114 | num_channels = NumChannels[dataset_name.upper()].value 115 | 116 | if dataset_name == DatasetNames.DSPRITES.value and label == 'label_color': 117 | num_classes = 3 118 | else: 119 | num_classes = len(properties[label]) 120 | 121 | return { 122 | 'properties': properties, 123 | 'num_channels': num_channels, 124 | 'num_classes': num_classes, 125 | 'variance': Variance[dataset_name.upper()].value, 126 | 'image_size': ImageSize[dataset_name.upper()].value 127 | } 128 | 129 | 130 | def transpose_fn(batch: Batch) -> Batch: 131 | # Transpose for performance on TPU. 132 | batch = dict(**batch) 133 | batch['image'] = tf.transpose(batch['image'], (1, 2, 3, 0)) 134 | return batch 135 | 136 | 137 | def load_dataset(is_training: bool, 138 | batch_dims: Sequence[int], 139 | transpose: bool, 140 | data_kwargs: Optional[ml_collections.ConfigDict] = None 141 | ) -> Generator[Batch, None, None]: 142 | """Wrapper to load a dataset.""" 143 | 144 | data_loader = data_kwargs['loader'] 145 | batch_kwd = getattr(data_kwargs, 'batch_kwd', 'batch_sizes') 146 | batch_kwargs = {batch_kwd: batch_dims} 147 | 148 | dataset = data_loader( 149 | is_training=is_training, 150 | transpose=transpose, 151 | **batch_kwargs, 152 | **data_kwargs['load_kwargs']) 153 | 154 | is_numpy = getattr(data_kwargs, 'is_numpy', False) 155 | if not is_numpy: 156 | dataset = iter(tfds.as_numpy(dataset)) 157 | 158 | return dataset 159 | 160 | 161 | def resize(image: chex.Array, size: Tuple[int, int]) -> chex.Array: 162 | """Resizes a batch of images using bilinear interpolation.""" 163 | return jax.image.resize(image, 164 | (image.shape[0], size[0], size[1], image.shape[3]), 165 | method='bilinear', antialias=False) 166 | -------------------------------------------------------------------------------- /distribution_shift_framework/core/datasets/lowdata_wrapper.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # 3 | # Copyright 2022 DeepMind Technologies Limited 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """Generates the low data versions for the disentanglement datasets.""" 18 | from typing import Callable, Optional, Sequence 19 | 20 | from distribution_shift_framework.core.datasets import data_utils 21 | import jax 22 | import ml_collections 23 | import tensorflow.compat.v2 as tf 24 | 25 | 26 | def _create_filter_fn(filter_string: str) -> Callable[..., bool]: 27 | """Creates a filter function based on the string. 28 | 29 | Given a string of 30 | "key11:val11:comp11^key12:val12:comp1b|...^keyNk:valNk:compNk" 31 | the string is parsed as the OR of several AND statements. The ORs are at the 32 | top level (denoted by |), and divide into a set of AND statements. 33 | The AND values are denoted by ^ and operate at the bottom level. 34 | Fof each "keyij:valij:compij" pairing, keyij is the key in the dataset, 35 | valij is the value the key is compared against and compij is the tensorflow 36 | comparison function: e.g. less, less_equal, equal, greater_equal, greater. 37 | 38 | Note that parentheses and infinite depth are *not* supported yet. 39 | 40 | Example 1: for dSprites: "label_scale:3:equal". 41 | This will select all samples from dSprites where the label_scale parameter is 42 | equal to 3. 43 | 44 | Example 2: for Shapes3D: 45 | "wall_hue_value:0.3:less_equal^floor_hue_value:0.3:less_equal". 46 | This will select all samples from Shapes3D where the wall hue and floor hue 47 | are less than or equal to 0.3. 48 | 49 | Example 3: for smallNORB: 50 | ('label_azimuth:7:less^label_category:0:equal|' 51 | 'label_azimuth:7:greater_equal^label_category:0:not_equal'). 52 | This will select all samples from smallNORB which either have azimuth of less 53 | than 7 and category 0 or azimuth of greater or equal 7 and a category other 54 | than 0. 55 | 56 | Args: 57 | filter_string: The filter string that is used to make the filter function. 58 | 59 | Returns: 60 | filter_fn: A function that takes a batch and returns True or False if it 61 | matches the filter string. 62 | """ 63 | all_comparisons = filter_string.split('|') 64 | 65 | def filter_fn(x): 66 | or_filter = False 67 | 68 | # Iterate over all the OR comparisons. 69 | for or_comparison in all_comparisons: 70 | and_comparisons = or_comparison.split('^') 71 | 72 | and_filter = True 73 | # Iterate over all the AND comparisons. 74 | for and_comparison in and_comparisons: 75 | key, value, comp = and_comparison.split(':') 76 | if value in x.keys(): 77 | value = x[value] 78 | else: 79 | value = tf.cast(float(value), x[key].dtype) 80 | bool_fn = getattr(tf, comp) 81 | # Accumulate the and comparisons. 82 | and_filter = tf.logical_and(and_filter, bool_fn(x[key], value)) 83 | # Accumulate the or comparisons. 84 | or_filter = tf.logical_or(or_filter, and_filter) 85 | 86 | return or_filter 87 | 88 | return filter_fn 89 | 90 | 91 | def load_data(batch_sizes: Sequence[int], 92 | dataset_loader: Callable[..., data_utils.Dataset], 93 | num_samples: str, 94 | filter_fns: str, 95 | dataset_kwargs: ml_collections.ConfigDict, 96 | shuffle_pre_sampling: bool = False, 97 | shuffle_pre_sample_seed: int = 0, 98 | local_cache: bool = True, 99 | is_training: bool = True, 100 | transpose: bool = True, 101 | drop_remainder: bool = True, 102 | prefilter: Optional[Callable[..., bool]] = None, 103 | preprocess_fn: Optional[data_utils.PreprocessFnGen] = None, 104 | shuffle_buffer: Optional[int] = 100_000, 105 | weights: Optional[Sequence[float]] = None) -> data_utils.Dataset: 106 | """A low data wrapper around a tfds dataset. 107 | 108 | This wrapper creates a set of datasets according to the parameters. For each 109 | filtering function and number of samples, the dataset defined by the 110 | dataset_loader and **dataset_kwargs is filtered and the first N samples are 111 | taken. All datasets are concatenated together and a sample is drawn with 112 | equal probability from each dataset. 113 | 114 | Args: 115 | batch_sizes: Specifies how to batch examples. I.e., if batch_sizes = [8, 4] 116 | then output images will have shapes (8, 4, height, width, 3). 117 | dataset_loader: The tfds dataset loader. 118 | num_samples: An string of the number of samples each returned dataset will 119 | contain. I.e., if num_samples = '1,2,3' then the first filtering 120 | operation will create a dataset with 1 sample, the second a dataset of 2 121 | samples, and so on. 122 | filter_fns: An iterable of the filtering functions for each part of the 123 | dataset. 124 | dataset_kwargs: A dict of the kwargs to pass to dataset_loader. 125 | shuffle_pre_sampling: Whether to shuffle presampling and thereby get a 126 | different set of samples. 127 | shuffle_pre_sample_seed: What seed to use for presampling. 128 | local_cache: Whether to cache the concatenated dataset. Good to do if the 129 | dataset fits in memory. 130 | is_training: Whether this is train or test. 131 | transpose: Whether to permute image dimensions NHWC -> HWCN to speed up 132 | performance on TPUs. 133 | drop_remainder: Whether to drop the last batch(es) if they would not match 134 | the shapes specified by `batch_sizes`. 135 | prefilter: Filter to apply to the dataset. 136 | preprocess_fn: Function mapped onto each example for pre-processing. 137 | shuffle_buffer: How big the buffer for shuffling the images is. 138 | weights: The probabilities to select samples from each dataset. 139 | 140 | Returns: 141 | A tf.Dataset instance. 142 | """ 143 | ds = dataset_loader(**dataset_kwargs) 144 | 145 | if preprocess_fn: 146 | ds = ds.map( 147 | preprocess_fn('train' if is_training else 'test'), 148 | num_parallel_calls=tf.data.AUTOTUNE) 149 | 150 | if prefilter: 151 | ds.filter(prefilter) 152 | 153 | filter_fns = filter_fns.split(',') 154 | num_samples = [int(n) for n in num_samples.split(',')] 155 | 156 | assert len(filter_fns) == len(num_samples) 157 | 158 | all_ds = [] 159 | for filter_fn, n_sample in zip(filter_fns, num_samples): 160 | if filter_fn != 'True': 161 | filter_fn = _create_filter_fn(filter_fn) 162 | filtered_ds = ds.filter(filter_fn) 163 | else: 164 | filtered_ds = ds 165 | 166 | if shuffle_pre_sampling: 167 | filtered_ds = filtered_ds.shuffle( 168 | buffer_size=shuffle_buffer, seed=shuffle_pre_sample_seed) 169 | 170 | if n_sample: 171 | filtered_ds = filtered_ds.take(n_sample) 172 | if local_cache or n_sample: 173 | filtered_ds = filtered_ds.cache() 174 | 175 | if is_training: 176 | filtered_ds = filtered_ds.repeat() 177 | 178 | shuffle_buffer = ( 179 | min(n_sample, shuffle_buffer) if n_sample > 0 else shuffle_buffer) 180 | filtered_ds = filtered_ds.shuffle( 181 | buffer_size=shuffle_buffer, seed=jax.process_index()) 182 | all_ds.append(filtered_ds) 183 | 184 | ds = tf.data.Dataset.sample_from_datasets( 185 | all_ds, weights=weights, seed=None) 186 | 187 | for i, batch_size in enumerate(reversed(batch_sizes)): 188 | ds = ds.batch(batch_size, drop_remainder=drop_remainder) 189 | 190 | if i == 0 and transpose: 191 | ds = ds.map(data_utils.transpose_fn) # NHWC -> HWCN. 192 | 193 | return ds 194 | -------------------------------------------------------------------------------- /distribution_shift_framework/core/hyper.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # 3 | # Copyright 2022 DeepMind Technologies Limited 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """Functions to create and combine hyper parameter sweeps.""" 18 | 19 | import functools 20 | import itertools 21 | from typing import Any, Dict, Iterable, List, Sequence 22 | 23 | # A sweep is a list of parameter mappings that defines a set of experiments. 24 | Sweep = List[Dict[str, Any]] 25 | 26 | 27 | def sweep(parameter_name: str, values: Iterable[Any]) -> Sweep: 28 | """Creates a sweep from a list of values for a parameter.""" 29 | return [{parameter_name: value} for value in values] 30 | 31 | 32 | def product(sweeps: Sequence[Sweep]) -> Sweep: 33 | """Builds a sweep from the cartesian product of a list of sweeps.""" 34 | return [functools.reduce(_combine_parameter_dicts, param_dicts, {}) 35 | for param_dicts in itertools.product(*sweeps)] 36 | 37 | 38 | def zipit(sweeps: Sequence[Sweep]) -> Sweep: 39 | """Builds a sweep from zipping a list of sweeps.""" 40 | return [functools.reduce(_combine_parameter_dicts, param_dicts, {}) 41 | for param_dicts in zip(*sweeps)] 42 | 43 | 44 | def _combine_parameter_dicts(x: Dict[str, Any], y: Dict[str, Any] 45 | ) -> Dict[str, Any]: 46 | if x.keys() & y.keys(): 47 | raise ValueError('Cannot combine sweeps that set the same parameters. ' 48 | f'Keys in x: {x.keys()}, keys in y: {y.keys}, ' 49 | f'overlap: {x.keys() & y.keys()}') 50 | return {**x, **y} 51 | 52 | -------------------------------------------------------------------------------- /distribution_shift_framework/core/hyper_test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # 3 | # Copyright 2022 DeepMind Technologies Limited 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """Tests for distribution_shift_framework.core.hyper.""" 18 | 19 | 20 | from absl.testing import absltest 21 | from absl.testing import parameterized 22 | from distribution_shift_framework.core import hyper 23 | 24 | 25 | class HyperTest(parameterized.TestCase): 26 | 27 | @parameterized.parameters([ 28 | dict(parameter_name='a', values=[1, 2, 3], 29 | expected_sweep=[{'a': 1}, {'a': 2}, {'a': 3}]), 30 | dict(parameter_name='b', values=[.1, .2, .3], 31 | expected_sweep=[{'b': .1}, {'b': .2}, {'b': .3}]), 32 | dict(parameter_name='c', values=[True, False], 33 | expected_sweep=[{'c': True}, {'c': False}]), 34 | dict(parameter_name='d', values=['one', 'two', 'three'], 35 | expected_sweep=[{'d': 'one'}, {'d': 'two'}, {'d': 'three'}]), 36 | dict(parameter_name='e', values=[1, 0.5, True, 'string'], 37 | expected_sweep=[{'e': 1}, {'e': 0.5}, {'e': True}, {'e': 'string'}]), 38 | dict(parameter_name='f', values=[], 39 | expected_sweep=[]), 40 | ]) 41 | def test_sweep(self, parameter_name, values, expected_sweep): 42 | self.assertEqual(expected_sweep, hyper.sweep(parameter_name, values)) 43 | 44 | @parameterized.parameters([ 45 | dict(sweeps=[], 46 | expected_sweep=[{}]), 47 | dict(sweeps=[hyper.sweep('param1', [1, 2, 3, 4, 5, 6])], 48 | expected_sweep=[ 49 | {'param1': 1}, {'param1': 2}, {'param1': 3}, 50 | {'param1': 4}, {'param1': 5}, {'param1': 6}, 51 | ]), 52 | dict(sweeps=[hyper.sweep('param1', [1, 2, 3]), 53 | hyper.sweep('param2', [4, 5, 6])], 54 | expected_sweep=[ 55 | {'param1': 1, 'param2': 4}, 56 | {'param1': 1, 'param2': 5}, 57 | {'param1': 1, 'param2': 6}, 58 | {'param1': 2, 'param2': 4}, 59 | {'param1': 2, 'param2': 5}, 60 | {'param1': 2, 'param2': 6}, 61 | {'param1': 3, 'param2': 4}, 62 | {'param1': 3, 'param2': 5}, 63 | {'param1': 3, 'param2': 6}, 64 | ]), 65 | dict(sweeps=[hyper.sweep('param1', [1, 2]), 66 | hyper.sweep('param2', [3, 4]), 67 | hyper.sweep('param3', [5, 6])], 68 | expected_sweep=[ 69 | {'param1': 1, 'param2': 3, 'param3': 5}, 70 | {'param1': 1, 'param2': 3, 'param3': 6}, 71 | {'param1': 1, 'param2': 4, 'param3': 5}, 72 | {'param1': 1, 'param2': 4, 'param3': 6}, 73 | {'param1': 2, 'param2': 3, 'param3': 5}, 74 | {'param1': 2, 'param2': 3, 'param3': 6}, 75 | {'param1': 2, 'param2': 4, 'param3': 5}, 76 | {'param1': 2, 'param2': 4, 'param3': 6}, 77 | ]), 78 | dict(sweeps=[hyper.sweep('param1', [1, 2., 'Three']), 79 | hyper.sweep('param2', [True, 'Two', 3.0])], 80 | expected_sweep=[ 81 | {'param1': 1, 'param2': True}, 82 | {'param1': 1, 'param2': 'Two'}, 83 | {'param1': 1, 'param2': 3.0}, 84 | {'param1': 2., 'param2': True}, 85 | {'param1': 2., 'param2': 'Two'}, 86 | {'param1': 2., 'param2': 3.0}, 87 | {'param1': 'Three', 'param2': True}, 88 | {'param1': 'Three', 'param2': 'Two'}, 89 | {'param1': 'Three', 'param2': 3.0}, 90 | ]), 91 | ]) 92 | def test_product(self, sweeps, expected_sweep): 93 | self.assertEqual(expected_sweep, hyper.product(sweeps)) 94 | 95 | def test_product_raises_valueerror_for_same_name(self): 96 | sweep1 = hyper.sweep('param1', [1, 2, 3]) 97 | sweep2 = hyper.sweep('param2', [4, 5, 6]) 98 | sweep3 = hyper.sweep('param1', [7, 8, 9]) 99 | with self.assertRaises(ValueError): 100 | hyper.product([sweep1, sweep2, sweep3]) 101 | 102 | @parameterized.parameters([ 103 | dict(sweeps=[], 104 | expected_sweep=[]), 105 | dict(sweeps=[hyper.sweep('param1', [1, 2, 3, 4, 5, 6])], 106 | expected_sweep=[ 107 | {'param1': 1}, {'param1': 2}, {'param1': 3}, 108 | {'param1': 4}, {'param1': 5}, {'param1': 6}, 109 | ]), 110 | dict(sweeps=[hyper.sweep('param1', [1, 2, 3]), 111 | hyper.sweep('param2', [4, 5, 6])], 112 | expected_sweep=[ 113 | {'param1': 1, 'param2': 4}, 114 | {'param1': 2, 'param2': 5}, 115 | {'param1': 3, 'param2': 6}, 116 | ]), 117 | dict(sweeps=[hyper.sweep('param1', [1, 2, 3]), 118 | hyper.sweep('param2', [4, 5, 6]), 119 | hyper.sweep('param3', [7, 8, 9])], 120 | expected_sweep=[ 121 | {'param1': 1, 'param2': 4, 'param3': 7}, 122 | {'param1': 2, 'param2': 5, 'param3': 8}, 123 | {'param1': 3, 'param2': 6, 'param3': 9}, 124 | ]), 125 | dict(sweeps=[hyper.sweep('param1', [1, 2., 'Three']), 126 | hyper.sweep('param2', [True, 'Two', 3.0])], 127 | expected_sweep=[ 128 | {'param1': 1, 'param2': True}, 129 | {'param1': 2., 'param2': 'Two'}, 130 | {'param1': 'Three', 'param2': 3.0}, 131 | ]), 132 | dict(sweeps=[hyper.sweep('param1', [1, 2, 3]), 133 | hyper.sweep('param2', [4, 5, 6, 7])], 134 | expected_sweep=[ 135 | {'param1': 1, 'param2': 4}, 136 | {'param1': 2, 'param2': 5}, 137 | {'param1': 3, 'param2': 6}, 138 | ]), 139 | ]) 140 | def test_zipit(self, sweeps, expected_sweep): 141 | self.assertEqual(expected_sweep, hyper.zipit(sweeps)) 142 | 143 | def test_zipit_raises_valueerror_for_same_name(self): 144 | sweep1 = hyper.sweep('param1', [1, 2, 3]) 145 | sweep2 = hyper.sweep('param2', [4, 5, 6]) 146 | sweep3 = hyper.sweep('param1', [7, 8, 9]) 147 | with self.assertRaises(ValueError): 148 | hyper.zipit([sweep1, sweep2, sweep3]) 149 | 150 | 151 | if __name__ == '__main__': 152 | absltest.main() 153 | -------------------------------------------------------------------------------- /distribution_shift_framework/core/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # 3 | # Copyright 2022 DeepMind Technologies Limited 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | -------------------------------------------------------------------------------- /distribution_shift_framework/core/metrics/metrics.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # 3 | # Copyright 2022 DeepMind Technologies Limited 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """List of standard evaluation metrics.""" 18 | from typing import Dict, Sequence 19 | 20 | import chex 21 | import jax 22 | import jax.numpy as jnp 23 | import numpy as np 24 | import scipy.stats as stats 25 | import sklearn 26 | import sklearn.metrics 27 | 28 | 29 | def pearson(predictions: chex.Array, labels: chex.Array) -> chex.Scalar: 30 | """Computes the Pearson correlation coefficient. 31 | 32 | Assumes all inputs are numpy arrays. 33 | 34 | Args: 35 | predictions: The predicted class labels. 36 | labels: The true class labels. 37 | 38 | Returns: 39 | cc: The predicted Pearson correlation coefficient. 40 | """ 41 | cc = stats.pearsonr(predictions, labels)[0] 42 | return cc 43 | 44 | 45 | def f1_score(average: chex.Array, 46 | predictions: chex.Array, 47 | labels: chex.Array) -> chex.Scalar: 48 | """Computes the F1 score. 49 | 50 | Assumes all inputs are numpy arrays. 51 | 52 | Args: 53 | average: How to accumulate the f1 score (macro or weighted). 54 | predictions: The predicted class labels. 55 | labels: The true class labels. 56 | 57 | Returns: 58 | f1: The predicted f1 score. 59 | """ 60 | f1 = sklearn.metrics.f1_score( 61 | predictions, labels, average=average, labels=np.unique(labels)) 62 | 63 | return f1 64 | 65 | 66 | def recall_score(average: chex.Array, 67 | predictions: chex.Array, 68 | labels: chex.Array) -> chex.Scalar: 69 | """Computes the recall score. 70 | 71 | Assumes all inputs are numpy arrays. 72 | 73 | Args: 74 | average: How to accumulate the recall score (macro or weighted). 75 | predictions: The predicted class labels. 76 | labels: The true class labels. 77 | 78 | Returns: 79 | recall: The predicted recall. 80 | """ 81 | recall = sklearn.metrics.recall_score( 82 | predictions, labels, average=average, labels=np.unique(labels)) 83 | 84 | return recall 85 | 86 | 87 | def top_k_accuracy(logits: chex.Array, 88 | labels: chex.Array, 89 | k: int) -> chex.Scalar: 90 | """Compute top_k_accuracy. 91 | 92 | Args: 93 | logits: The network predictions. 94 | labels: The true class labels. 95 | k: Accuracy at what k. 96 | 97 | Returns: 98 | top_k_accuracy: The top k accuracy. 99 | """ 100 | chex.assert_equal_shape_prefix([logits, labels], 1) 101 | chex.assert_rank(logits, 2) # [bs, k] 102 | chex.assert_rank(labels, 1) # [bs] 103 | 104 | _, top_ks = jax.vmap(lambda x: jax.lax.top_k(x, k=k))(logits) 105 | 106 | return jnp.mean(jnp.sum(top_ks == labels[:, None], axis=-1)) 107 | 108 | 109 | def compute_all_metrics(predictions: chex.Array, labels: chex.Array, 110 | metrics: Sequence[str]) -> Dict[str, chex.Scalar]: 111 | """Computes a set of metrics given the predictions and labels. 112 | 113 | Args: 114 | predictions: A tensor of shape (N, *): the predicted values. 115 | labels: A tensor of shape (N, *): the ground truth values. 116 | metrics: A sequence of strings describing the metrics to be evaluated. 117 | This can be one of 'pearson' (to compute the pearson correlation 118 | coefficient), 'f1_{average}', 'recall_{average}'. For f1 and 119 | recall the value {average} is defined in the numpy api: 120 | https://scikit-learn.org/stable/modules/generated/sklearn.metrics.f1_score.html 121 | 122 | Returns: 123 | scalars: A dict containing (metric name, score) items with the metric name 124 | and associated score as a float value. 125 | """ 126 | scalars = {} 127 | for metric in metrics: 128 | if metric == 'pearson': 129 | scalars['pearson'] = pearson(predictions, labels) 130 | elif 'f1' in metric: 131 | scalars[metric] = f1_score(metric.split('_')[1], predictions, labels) 132 | elif 'recall' in metric: 133 | scalars[metric] = recall_score(metric.split('_')[1], predictions, labels) 134 | 135 | return scalars 136 | 137 | 138 | def top1_accuracy(labels, features, predictions, latents): 139 | del features 140 | del latents 141 | return np.equal(predictions, labels).mean() 142 | -------------------------------------------------------------------------------- /distribution_shift_framework/core/model_zoo/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # 3 | # Copyright 2022 DeepMind Technologies Limited 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | -------------------------------------------------------------------------------- /distribution_shift_framework/core/model_zoo/resnet.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # 3 | # Copyright 2022 DeepMind Technologies Limited 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """A reimplementation of resnet that exposes intermediate values.""" 18 | 19 | from typing import Mapping, Optional, Sequence, Union 20 | 21 | import chex 22 | import haiku as hk 23 | import jax 24 | import jax.numpy as jnp 25 | 26 | 27 | class BlockV1(hk.Module): 28 | """ResNet V1 block with optional bottleneck.""" 29 | 30 | def __init__( 31 | self, 32 | channels: int, 33 | stride: Union[int, Sequence[int]], 34 | use_projection: bool, 35 | bn_config: Mapping[str, float], 36 | bottleneck: bool, 37 | name: Optional[str] = None, 38 | ): 39 | super().__init__(name=name) 40 | self.use_projection = use_projection 41 | 42 | bn_config = dict(bn_config) 43 | bn_config.setdefault('create_scale', True) 44 | bn_config.setdefault('create_offset', True) 45 | bn_config.setdefault('decay_rate', 0.999) 46 | 47 | if self.use_projection: 48 | self.proj_conv = hk.Conv2D( 49 | output_channels=channels, 50 | kernel_shape=1, 51 | stride=stride, 52 | with_bias=False, 53 | padding='SAME', 54 | name='shortcut_conv') 55 | 56 | self.proj_batchnorm = hk.BatchNorm(name='shortcut_batchnorm', **bn_config) 57 | 58 | channel_div = 4 if bottleneck else 1 59 | conv_0 = hk.Conv2D( 60 | output_channels=channels // channel_div, 61 | kernel_shape=1 if bottleneck else 3, 62 | stride=1, 63 | with_bias=False, 64 | padding='SAME', 65 | name='conv_0') 66 | bn_0 = hk.BatchNorm(name='batchnorm_0', **bn_config) 67 | 68 | conv_1 = hk.Conv2D( 69 | output_channels=channels // channel_div, 70 | kernel_shape=3, 71 | stride=stride, 72 | with_bias=False, 73 | padding='SAME', 74 | name='conv_1') 75 | 76 | bn_1 = hk.BatchNorm(name='batchnorm_1', **bn_config) 77 | layers = ((conv_0, bn_0), (conv_1, bn_1)) 78 | 79 | if bottleneck: 80 | conv_2 = hk.Conv2D( 81 | output_channels=channels, 82 | kernel_shape=1, 83 | stride=1, 84 | with_bias=False, 85 | padding='SAME', 86 | name='conv_2') 87 | 88 | bn_2 = hk.BatchNorm(name='batchnorm_2', scale_init=jnp.zeros, **bn_config) 89 | layers = layers + ((conv_2, bn_2),) 90 | 91 | self.layers = layers 92 | 93 | def __call__(self, 94 | inputs: chex.Array, 95 | is_training: bool, 96 | test_local_stats: bool) -> chex.Array: 97 | out = shortcut = inputs 98 | 99 | if self.use_projection: 100 | shortcut = self.proj_conv(shortcut) 101 | shortcut = self.proj_batchnorm(shortcut, is_training, test_local_stats) 102 | 103 | for i, (conv_i, bn_i) in enumerate(self.layers): 104 | out = conv_i(out) 105 | out = bn_i(out, is_training, test_local_stats) 106 | if i < len(self.layers) - 1: # Don't apply relu on last layer 107 | out = jax.nn.relu(out) 108 | 109 | return jax.nn.relu(out + shortcut) 110 | 111 | 112 | class BlockV2(hk.Module): 113 | """ResNet V2 block with optional bottleneck.""" 114 | 115 | def __init__( 116 | self, 117 | channels: int, 118 | stride: Union[int, Sequence[int]], 119 | use_projection: bool, 120 | bn_config: Mapping[str, float], 121 | bottleneck: bool, 122 | name: Optional[str] = None, 123 | ): 124 | super().__init__(name=name) 125 | self.use_projection = use_projection 126 | 127 | bn_config = dict(bn_config) 128 | bn_config.setdefault('create_scale', True) 129 | bn_config.setdefault('create_offset', True) 130 | 131 | if self.use_projection: 132 | self.proj_conv = hk.Conv2D( 133 | output_channels=channels, 134 | kernel_shape=1, 135 | stride=stride, 136 | with_bias=False, 137 | padding='SAME', 138 | name='shortcut_conv') 139 | 140 | channel_div = 4 if bottleneck else 1 141 | conv_0 = hk.Conv2D( 142 | output_channels=channels // channel_div, 143 | kernel_shape=1 if bottleneck else 3, 144 | stride=1, 145 | with_bias=False, 146 | padding='SAME', 147 | name='conv_0') 148 | 149 | bn_0 = hk.BatchNorm(name='batchnorm_0', **bn_config) 150 | 151 | conv_1 = hk.Conv2D( 152 | output_channels=channels // channel_div, 153 | kernel_shape=3, 154 | stride=stride, 155 | with_bias=False, 156 | padding='SAME', 157 | name='conv_1') 158 | 159 | bn_1 = hk.BatchNorm(name='batchnorm_1', **bn_config) 160 | layers = ((conv_0, bn_0), (conv_1, bn_1)) 161 | 162 | if bottleneck: 163 | conv_2 = hk.Conv2D( 164 | output_channels=channels, 165 | kernel_shape=1, 166 | stride=1, 167 | with_bias=False, 168 | padding='SAME', 169 | name='conv_2') 170 | 171 | # NOTE: Some implementations of ResNet50 v2 suggest initializing 172 | # gamma/scale here to zeros. 173 | bn_2 = hk.BatchNorm(name='batchnorm_2', **bn_config) 174 | layers = layers + ((conv_2, bn_2),) 175 | 176 | self.layers = layers 177 | 178 | def __call__(self, 179 | inputs: chex.Array, 180 | is_training: bool, 181 | test_local_stats: bool) -> chex.Array: 182 | x = shortcut = inputs 183 | 184 | for i, (conv_i, bn_i) in enumerate(self.layers): 185 | x = bn_i(x, is_training, test_local_stats) 186 | x = jax.nn.relu(x) 187 | if i == 0 and self.use_projection: 188 | shortcut = self.proj_conv(x) 189 | x = conv_i(x) 190 | 191 | return x + shortcut 192 | 193 | 194 | class BlockGroup(hk.Module): 195 | """Higher level block for ResNet implementation.""" 196 | 197 | def __init__( 198 | self, 199 | channels: int, 200 | num_blocks: int, 201 | stride: Union[int, Sequence[int]], 202 | bn_config: Mapping[str, float], 203 | resnet_v2: bool, 204 | bottleneck: bool, 205 | use_projection: bool, 206 | name: Optional[str] = None, 207 | ): 208 | super().__init__(name=name) 209 | 210 | block_cls = BlockV2 if resnet_v2 else BlockV1 211 | self.blocks = [] 212 | for i in range(num_blocks): 213 | self.blocks.append( 214 | block_cls( 215 | channels=channels, 216 | stride=(1 if i else stride), 217 | use_projection=(i == 0 and use_projection), 218 | bottleneck=bottleneck, 219 | bn_config=bn_config, 220 | name=f'block_{i}')) 221 | 222 | def __call__(self, 223 | inputs: chex.Array, 224 | is_training: bool, 225 | test_local_stats: bool) -> chex.Array: 226 | out = inputs 227 | for block in self.blocks: 228 | out = block(out, is_training, test_local_stats) 229 | return out 230 | 231 | 232 | def check_length(length: int, value: Sequence[int], name: str): 233 | if len(value) != length: 234 | raise ValueError(f'`{name}` must be of length 4 not {len(value)}') 235 | 236 | 237 | class ResNet(hk.Module): 238 | """ResNet model.""" 239 | 240 | BlockGroup = BlockGroup # pylint: disable=invalid-name 241 | BlockV1 = BlockV1 # pylint: disable=invalid-name 242 | BlockV2 = BlockV2 # pylint: disable=invalid-name 243 | 244 | def __init__( 245 | self, 246 | blocks_per_group: Sequence[int], 247 | num_classes: int, 248 | bn_config: Optional[Mapping[str, float]] = None, 249 | resnet_v2: bool = False, 250 | bottleneck: bool = True, 251 | channels_per_group: Sequence[int] = (256, 512, 1024, 2048, 2048), 252 | use_projection: Sequence[bool] = (True, True, True, True), 253 | name: Optional[str] = None, 254 | ): 255 | """Constructs a ResNet model. 256 | 257 | Args: 258 | blocks_per_group: A sequence of length 4 that indicates the number of 259 | blocks created in each group. 260 | num_classes: The number of classes to classify the inputs into. 261 | bn_config: A dictionary of two elements, ``decay_rate`` and ``eps`` to be 262 | passed on to the :class:`~haiku.BatchNorm` layers. By default the 263 | ``decay_rate`` is ``0.9`` and ``eps`` is ``1e-5``. 264 | resnet_v2: Whether to use the v1 or v2 ResNet implementation. Defaults to 265 | ``False``. 266 | bottleneck: Whether the block should bottleneck or not. Defaults to 267 | ``True``. 268 | channels_per_group: A sequence of length 4 that indicates the number of 269 | channels used for each block in each group. 270 | use_projection: A sequence of length 4 that indicates whether each 271 | residual block should use projection. 272 | name: Name of the module. 273 | """ 274 | super().__init__(name=name) 275 | self.resnet_v2 = resnet_v2 276 | 277 | bn_config = dict(bn_config or {}) 278 | bn_config.setdefault('decay_rate', 0.9) 279 | bn_config.setdefault('eps', 1e-5) 280 | bn_config.setdefault('create_scale', True) 281 | bn_config.setdefault('create_offset', True) 282 | 283 | logits_config = dict({}) 284 | logits_config.setdefault('w_init', jnp.zeros) 285 | logits_config.setdefault('name', 'logits') 286 | 287 | # Number of blocks in each group for ResNet. 288 | check_length(4, blocks_per_group, 'blocks_per_group') 289 | check_length(4, channels_per_group, 'channels_per_group') 290 | 291 | self.initial_conv = hk.Conv2D( 292 | output_channels=64, 293 | kernel_shape=7, 294 | stride=2, 295 | with_bias=False, 296 | padding='SAME', 297 | name='initial_conv') 298 | 299 | if not self.resnet_v2: 300 | self.initial_batchnorm = hk.BatchNorm( 301 | name='initial_batchnorm', **bn_config) 302 | 303 | self.block_groups = [] 304 | strides = (1, 2, 2, 1) 305 | for i in range(4): 306 | self.block_groups.append( 307 | BlockGroup( 308 | channels=channels_per_group[i], 309 | num_blocks=blocks_per_group[i], 310 | stride=strides[i], 311 | bn_config=bn_config, 312 | resnet_v2=resnet_v2, 313 | bottleneck=bottleneck, 314 | use_projection=use_projection[i], 315 | name=f'block_group_{i}')) 316 | 317 | def __call__(self, 318 | inputs: chex.Array, 319 | is_training: bool, 320 | test_local_stats: bool = False) -> chex.Array: 321 | out = inputs 322 | out = self.initial_conv(out) 323 | if not self.resnet_v2: 324 | out = self.initial_batchnorm(out, is_training, test_local_stats) 325 | out = jax.nn.relu(out) 326 | 327 | out = hk.max_pool( 328 | out, window_shape=(1, 3, 3, 1), strides=(1, 2, 2, 1), padding='SAME') 329 | 330 | for block_group in self.block_groups: 331 | out = block_group(out, is_training, test_local_stats) 332 | 333 | return out 334 | 335 | 336 | class ResNet18(ResNet): 337 | """ResNet18.""" 338 | 339 | def __init__(self, 340 | num_classes: int, 341 | bn_config: Optional[Mapping[str, float]] = None, 342 | resnet_v2: bool = False, 343 | name: Optional[str] = None): 344 | """Constructs a ResNet model. 345 | 346 | Args: 347 | num_classes: The number of classes to classify the inputs into. 348 | bn_config: A dictionary of two elements, ``decay_rate`` and ``eps`` to be 349 | passed on to the :class:`~haiku.BatchNorm` layers. 350 | resnet_v2: Whether to use the v1 or v2 ResNet implementation. Defaults to 351 | ``False``. 352 | name: Name of the module. 353 | """ 354 | super().__init__( 355 | blocks_per_group=(2, 2, 2, 2), 356 | num_classes=num_classes, 357 | bn_config=bn_config, 358 | resnet_v2=resnet_v2, 359 | bottleneck=False, 360 | channels_per_group=(64, 128, 256, 2048), 361 | use_projection=(False, True, True, True), 362 | name=name) 363 | 364 | 365 | class ResNet34(ResNet): 366 | """ResNet34.""" 367 | 368 | def __init__(self, 369 | num_classes: int, 370 | bn_config: Optional[Mapping[str, float]] = None, 371 | resnet_v2: bool = False, 372 | name: Optional[str] = None): 373 | """Constructs a ResNet model. 374 | 375 | Args: 376 | num_classes: The number of classes to classify the inputs into. 377 | bn_config: A dictionary of two elements, ``decay_rate`` and ``eps`` to be 378 | passed on to the :class:`~haiku.BatchNorm` layers. 379 | resnet_v2: Whether to use the v1 or v2 ResNet implementation. Defaults to 380 | ``False``. 381 | name: Name of the module. 382 | """ 383 | super().__init__( 384 | blocks_per_group=(3, 4, 6, 3), 385 | num_classes=num_classes, 386 | bn_config=bn_config, 387 | resnet_v2=resnet_v2, 388 | bottleneck=False, 389 | channels_per_group=(64, 128, 256, 512), 390 | use_projection=(False, True, True, True), 391 | name=name) 392 | 393 | 394 | class ResNet50(ResNet): 395 | """ResNet50.""" 396 | 397 | def __init__(self, 398 | num_classes: int, 399 | bn_config: Optional[Mapping[str, float]] = None, 400 | resnet_v2: bool = False, 401 | name: Optional[str] = None): 402 | """Constructs a ResNet model. 403 | 404 | Args: 405 | num_classes: The number of classes to classify the inputs into. 406 | bn_config: A dictionary of two elements, ``decay_rate`` and ``eps`` to be 407 | passed on to the :class:`~haiku.BatchNorm` layers. 408 | resnet_v2: Whether to use the v1 or v2 ResNet implementation. Defaults to 409 | ``False``. 410 | name: Name of the module. 411 | """ 412 | super().__init__( 413 | blocks_per_group=(3, 4, 6, 3), 414 | num_classes=num_classes, 415 | bn_config=bn_config, 416 | resnet_v2=resnet_v2, 417 | bottleneck=True, 418 | name=name) 419 | 420 | 421 | class ResNet101(ResNet): 422 | """ResNet101.""" 423 | 424 | def __init__(self, 425 | num_classes: int, 426 | bn_config: Optional[Mapping[str, float]] = None, 427 | resnet_v2: bool = False, 428 | name: Optional[str] = None): 429 | """Constructs a ResNet model. 430 | 431 | Args: 432 | num_classes: The number of classes to classify the inputs into. 433 | bn_config: A dictionary of two elements, ``decay_rate`` and ``eps`` to be 434 | passed on to the :class:`~haiku.BatchNorm` layers. 435 | resnet_v2: Whether to use the v1 or v2 ResNet implementation. Defaults to 436 | ``False``. 437 | name: Name of the module. 438 | """ 439 | super().__init__( 440 | blocks_per_group=(3, 4, 23, 3), 441 | num_classes=num_classes, 442 | bn_config=bn_config, 443 | resnet_v2=resnet_v2, 444 | bottleneck=True, 445 | name=name) 446 | 447 | 448 | class ResNet152(ResNet): 449 | """ResNet152.""" 450 | 451 | def __init__(self, 452 | num_classes: int, 453 | bn_config: Optional[Mapping[str, float]] = None, 454 | resnet_v2: bool = False, 455 | name: Optional[str] = None): 456 | """Constructs a ResNet model. 457 | 458 | Args: 459 | num_classes: The number of classes to classify the inputs into. 460 | bn_config: A dictionary of two elements, ``decay_rate`` and ``eps`` to be 461 | passed on to the :class:`~haiku.BatchNorm` layers. 462 | resnet_v2: Whether to use the v1 or v2 ResNet implementation. Defaults to 463 | ``False``. 464 | name: Name of the module. 465 | """ 466 | super().__init__( 467 | blocks_per_group=(3, 8, 36, 3), 468 | num_classes=num_classes, 469 | bn_config=bn_config, 470 | resnet_v2=resnet_v2, 471 | bottleneck=True, 472 | name=name) 473 | 474 | 475 | class ResNet200(ResNet): 476 | """ResNet200.""" 477 | 478 | def __init__(self, 479 | num_classes: int, 480 | bn_config: Optional[Mapping[str, float]] = None, 481 | resnet_v2: bool = False, 482 | name: Optional[str] = None): 483 | """Constructs a ResNet model. 484 | 485 | Args: 486 | num_classes: The number of classes to classify the inputs into. 487 | bn_config: A dictionary of two elements, ``decay_rate`` and ``eps`` to be 488 | passed on to the :class:`~haiku.BatchNorm` layers. 489 | resnet_v2: Whether to use the v1 or v2 ResNet implementation. Defaults to 490 | ``False``. 491 | name: Name of the module. 492 | """ 493 | super().__init__( 494 | blocks_per_group=(3, 24, 36, 3), 495 | num_classes=num_classes, 496 | bn_config=bn_config, 497 | resnet_v2=resnet_v2, 498 | bottleneck=True, 499 | name=name) 500 | -------------------------------------------------------------------------------- /distribution_shift_framework/core/pix/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # 3 | # Copyright 2022 DeepMind Technologies Limited 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | -------------------------------------------------------------------------------- /distribution_shift_framework/core/pix/augment.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # 3 | # Copyright 2022 DeepMind Technologies Limited 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """This module provides image augmentation functions. 18 | 19 | All functions expect float-encoded images, with values between 0 and 1, but 20 | do not clip their outputs. 21 | """ 22 | 23 | import chex 24 | from distribution_shift_framework.core.pix import color_conversion 25 | import jax 26 | import jax.numpy as jnp 27 | 28 | 29 | def _auto_contrast(image: chex.Array, cutoff: int = 0) -> chex.Array: 30 | """The auto contrast transform: remove top/bottom % and rescale histogram. 31 | 32 | Args: 33 | image: an RGB image given as a float tensor in [0, 1]. 34 | cutoff: what % of higher/lower pixels to remove 35 | 36 | Returns: 37 | The new image with auto contrast applied. 38 | """ 39 | im_rgbs = [] 40 | indices = jnp.arange(0, 256, 1) 41 | for rgb in range(0, image.shape[2]): 42 | im_rgb = image[:, :, rgb:rgb + 1] 43 | hist = jnp.histogram(im_rgb, bins=256, range=(0, 1))[0] 44 | 45 | hist_cumsum = hist.cumsum() 46 | # Determine % samples 47 | cut_lower = hist_cumsum[-1] * cutoff // 100 48 | cut_higher = hist_cumsum[-1] * (100 - cutoff) // 100 49 | 50 | # The lower offset 51 | offset_lo = (hist_cumsum < cut_lower) * indices 52 | offset_lo = offset_lo.max() / 256. 53 | 54 | # The higher offset 55 | offset_hi = (hist_cumsum <= cut_higher) * indices 56 | offset_hi = offset_hi.max() / 256. 57 | 58 | # Remove cutoff% samples from low/hi end 59 | im_rgb = (im_rgb - offset_lo).clip(0, 1) + offset_lo 60 | im_rgb = (im_rgb + 1 - offset_hi).clip(0, 1) - (1 - offset_hi) 61 | 62 | # And renormalize 63 | offset = (offset_hi - offset_lo) < 1 / 256. 64 | im_rgb = (im_rgb - offset_lo) / (offset_hi - offset_lo + offset) 65 | 66 | # And return 67 | im_rgbs.append(im_rgb) 68 | 69 | return jnp.concatenate(im_rgbs, axis=2) 70 | 71 | 72 | def auto_contrast(image: chex.Array, cutoff: chex.Array) -> chex.Array: 73 | if len(image.shape) < 4: 74 | return _auto_contrast(image, cutoff) 75 | 76 | else: 77 | return jax.vmap(_auto_contrast)(image, cutoff.astype(jnp.int32)) 78 | 79 | 80 | def _equalize(image: chex.Array) -> chex.Array: 81 | """The equalize transform: make histogram cover full scale. 82 | 83 | Args: 84 | image: an RGB image given as a float tensor in [0, 1]. 85 | 86 | Returns: 87 | The equalized image. 88 | """ 89 | im_rgbs = [] 90 | 91 | im = (image * 255).astype(jnp.int32).clip(0, 255) 92 | for rgb in range(0, im.shape[2]): 93 | im_rgb = im[:, :, rgb:rgb + 1] 94 | 95 | hist = jnp.histogram(im_rgb, bins=256, range=(0, 256))[0] 96 | 97 | last_nonzero_value = hist.sum() - hist.cumsum() 98 | last_nonzero_value = last_nonzero_value + last_nonzero_value.max() * ( 99 | last_nonzero_value == 0) 100 | step = (hist.sum() - last_nonzero_value.min()) // 255 101 | n = step // 2 102 | 103 | im_rgb_new = jnp.zeros((im_rgb.shape), dtype=im_rgb.dtype) 104 | 105 | def for_loop(i, values): 106 | (im, n, hist, step, im_rgb) = values 107 | im = im + (n // step) * (im_rgb == i) 108 | 109 | return (im, n + hist[i], hist, step, im_rgb) 110 | 111 | result, _, _, _, _ = jax.lax.fori_loop(0, 256, for_loop, 112 | (im_rgb_new, n, hist, step, im_rgb)) 113 | 114 | im_rgbs.append(result.astype(jnp.float32) / 255.) 115 | return jnp.concatenate(im_rgbs, 2) 116 | 117 | 118 | def equalize(image: chex.Array, unused_cutoff: chex.Array) -> chex.Array: 119 | if len(image.shape) < 4: 120 | return _equalize(image) 121 | else: 122 | return jax.vmap(_equalize)(image) 123 | 124 | 125 | def _posterize(image: chex.Array, bits: chex.Array) -> chex.Array: 126 | """The posterize transform: remove least significant bits. 127 | 128 | Args: 129 | image: an RGB image given as a float tensor in [0, 1]. 130 | bits: how many bits to ignore. 131 | 132 | Returns: 133 | The posterized image. 134 | """ 135 | mask = ~(2**(8 - bits) - 1) 136 | image = (image * 255).astype(jnp.int32).clip(0, 255) 137 | 138 | image = jnp.bitwise_and(image, mask) 139 | return image.astype(jnp.float32) / 255. 140 | 141 | 142 | def posterize(image: chex.Array, bits: chex.Array) -> chex.Array: 143 | if len(image.shape) < 4: 144 | return _posterize(image, bits) 145 | else: 146 | return jax.vmap(_posterize)(image, bits.astype(jnp.uint8)) 147 | 148 | 149 | def _solarize(image: chex.Array, threshold: chex.Array) -> chex.Array: 150 | """The solarization transformation: pixels > threshold are inverted. 151 | 152 | Args: 153 | image: an RGB image given as a float tensor in [0, 1]. 154 | threshold: the threshold in [0, 1] above which to invert the image. 155 | 156 | Returns: 157 | The solarized image. 158 | """ 159 | image = (1 - image) * (image >= threshold) + image * (image < threshold) 160 | return image 161 | 162 | 163 | def solarize(image: chex.Array, threshold: chex.Array) -> chex.Array: 164 | if len(image.shape) < 4: 165 | return _solarize(image, threshold) 166 | else: 167 | return jax.vmap(_solarize)(image, threshold) 168 | 169 | 170 | def adjust_color(image: chex.Array, 171 | factor: chex.Numeric, 172 | channel: int = 0, 173 | channel_axis: int = -1) -> chex.Array: 174 | """Shifts the color of an RGB by a given multiplicative amount. 175 | 176 | Args: 177 | image: an RGB image, given as a float tensor in [0, 1]. 178 | factor: the (additive) amount to shift the RGB by. 179 | channel: the RGB channel to manipulate 180 | channel_axis: the index of the channel axis. 181 | 182 | Returns: 183 | The color adjusted image. 184 | """ 185 | red, green, blue = color_conversion.split_channels(image, channel_axis) 186 | 187 | if channel == 0: 188 | red = jnp.clip(red + factor, 0., 1.) 189 | elif channel == 1: 190 | green = jnp.clip(green + factor, 0., 1.) 191 | else: 192 | blue = jnp.clip(blue + factor, 0., 1.) 193 | 194 | return jnp.stack((red, green, blue), axis=channel_axis) 195 | -------------------------------------------------------------------------------- /distribution_shift_framework/core/pix/color_conversion.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # 3 | # Copyright 2022 DeepMind Technologies Limited 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """Color conversion utilities. 18 | 19 | These used to be in the dm_pix library but have been removed. I've added them 20 | back here for the time being. 21 | """ 22 | 23 | from typing import Tuple 24 | 25 | import chex 26 | import jax.numpy as jnp 27 | 28 | 29 | def split_channels( 30 | image: chex.Array, 31 | channel_axis: int, 32 | ) -> Tuple[chex.Array, chex.Array, chex.Array]: 33 | chex.assert_axis_dimension(image, axis=channel_axis, expected=3) 34 | split_axes = jnp.split(image, 3, axis=channel_axis) 35 | return tuple(map(lambda x: jnp.squeeze(x, axis=channel_axis), split_axes)) 36 | 37 | 38 | def rgb_to_hsv( 39 | image_rgb: chex.Array, 40 | *, 41 | channel_axis: int = -1, 42 | ) -> chex.Array: 43 | """Converts an image from RGB to HSV. 44 | 45 | Args: 46 | image_rgb: an RGB image, with float values in range [0, 1]. Behavior outside 47 | of these bounds is not guaranteed. 48 | channel_axis: the channel axis. image_rgb should have 3 layers along this 49 | axis. 50 | 51 | Returns: 52 | An HSV image, with float values in range [0, 1], stacked along channel_axis. 53 | """ 54 | red, green, blue = split_channels(image_rgb, channel_axis) 55 | return jnp.stack( 56 | rgb_planes_to_hsv_planes(red, green, blue), axis=channel_axis) 57 | 58 | 59 | def rgb_planes_to_hsv_planes( 60 | red: chex.Array, 61 | green: chex.Array, 62 | blue: chex.Array, 63 | ) -> Tuple[chex.Array, chex.Array, chex.Array]: 64 | """Converts red, green, blue color planes to hue, saturation, value planes. 65 | 66 | All planes should have the same shape, with float values in range [0, 1]. 67 | Behavior outside of these bounds is not guaranteed. 68 | 69 | Reference implementation: 70 | https://github.com/tensorflow/tensorflow/blob/262f4ad303c78a99e0974c4b17892db2255738a0/tensorflow/compiler/tf2xla/kernels/image_ops.cc#L36-L68 71 | 72 | Args: 73 | red: the red color plane. 74 | green: the red color plane. 75 | blue: the red color plane. 76 | 77 | Returns: 78 | A tuple of (hue, saturation, value) planes, as float values in range [0, 1]. 79 | """ 80 | value = jnp.maximum(jnp.maximum(red, green), blue) 81 | minimum = jnp.minimum(jnp.minimum(red, green), blue) 82 | range_ = value - minimum 83 | 84 | saturation = jnp.where(value > 0, range_ / value, 0.) 85 | norm = 1. / (6. * range_) 86 | 87 | hue = jnp.where(value == green, 88 | norm * (blue - red) + 2. / 6., 89 | norm * (red - green) + 4. / 6.) 90 | hue = jnp.where(value == red, norm * (green - blue), hue) 91 | hue = jnp.where(range_ > 0, hue, 0.) + (hue < 0.) 92 | 93 | return hue, saturation, value 94 | 95 | 96 | def hsv_planes_to_rgb_planes( 97 | hue: chex.Array, 98 | saturation: chex.Array, 99 | value: chex.Array, 100 | ) -> Tuple[chex.Array, chex.Array, chex.Array]: 101 | """Converts hue, saturation, value planes to red, green, blue color planes. 102 | 103 | All planes should have the same shape, with float values in range [0, 1]. 104 | Behavior outside of these bounds is not guaranteed. 105 | 106 | Reference implementation: 107 | https://github.com/tensorflow/tensorflow/blob/262f4ad303c78a99e0974c4b17892db2255738a0/tensorflow/compiler/tf2xla/kernels/image_ops.cc#L71-L94 108 | 109 | Args: 110 | hue: the hue plane (wrapping). 111 | saturation: the saturation plane. 112 | value: the value plane. 113 | 114 | Returns: 115 | A tuple of (red, green, blue) planes, as float values in range [0, 1]. 116 | """ 117 | dh = (hue % 1.0) * 6. # Wrap when hue >= 360°. 118 | dr = jnp.clip(jnp.abs(dh - 3.) - 1., 0., 1.) 119 | dg = jnp.clip(2. - jnp.abs(dh - 2.), 0., 1.) 120 | db = jnp.clip(2. - jnp.abs(dh - 4.), 0., 1.) 121 | one_minus_s = 1. - saturation 122 | 123 | red = value * (one_minus_s + saturation * dr) 124 | green = value * (one_minus_s + saturation * dg) 125 | blue = value * (one_minus_s + saturation * db) 126 | 127 | return red, green, blue 128 | -------------------------------------------------------------------------------- /distribution_shift_framework/core/pix/corruptions.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # 3 | # Copyright 2022 DeepMind Technologies Limited 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """Implementation of the ImageNet-C corruptions for sanity checks and eval. 18 | 19 | 20 | All severity values are taken from ImageNet-C at 21 | https://github.com/hendrycks/robustness/blob/master/ImageNet-C/create_c/make_imagenet_c.py 22 | """ 23 | 24 | import chex 25 | from distribution_shift_framework.core.pix import color_conversion 26 | import dm_pix 27 | import jax 28 | import jax.numpy as jnp 29 | import numpy as np 30 | 31 | 32 | def scale_image(image: chex.Array, z_factor: chex.Numeric) -> chex.Array: 33 | """Resizes an image.""" 34 | 35 | # And then resize 36 | b, h, w, c = image.shape 37 | resize_x = jax.image.scale_and_translate( 38 | image, 39 | shape=(b, int(h * z_factor), int(w * z_factor), c), 40 | method='bilinear', 41 | antialias=False, 42 | scale=jnp.ones((2,)) * z_factor, 43 | translation=jnp.zeros((2,)), 44 | spatial_dims=(1, 2)) 45 | 46 | return resize_x 47 | 48 | 49 | def zoom_blur(image: chex.Array, severity: int = 1, rng: chex.PRNGKey = None 50 | ) -> chex.Array: 51 | """The zoom blur corruption from ImageNet-C.""" 52 | del rng 53 | 54 | c = [ 55 | np.arange(1, 1.11, 0.01), 56 | np.arange(1, 1.16, 0.01), 57 | np.arange(1, 1.21, 0.02), 58 | np.arange(1, 1.26, 0.02), 59 | np.arange(1, 1.31, 0.03) 60 | ][severity - 1] 61 | 62 | _, h, w, _ = image.shape 63 | image_zoomed = jnp.zeros_like(image) 64 | for zoom_factor in c: 65 | t_image_zoomed = scale_image(image, zoom_factor) 66 | 67 | b = int(h * zoom_factor - h) // 2 68 | t_image_zoomed = t_image_zoomed[:, b:b + h, b:b + w, :] 69 | image_zoomed += t_image_zoomed 70 | 71 | image_zoomed = (image_zoomed + image) / (c.shape[0] + 1) 72 | return image_zoomed 73 | 74 | 75 | def gaussian_blur(image: chex.Array, 76 | severity: int = 1, 77 | rng: chex.PRNGKey = None) -> chex.Array: 78 | """Gaussian blur corruption for ImageNet-C.""" 79 | del rng 80 | c = [1, 2, 3, 4, 6][severity - 1] 81 | return dm_pix.gaussian_blur(image, sigma=c, kernel_size=image.shape[1]) 82 | 83 | 84 | def speckle_noise(image: chex.Array, 85 | severity: int = 1, 86 | rng: chex.PRNGKey = None) -> chex.Array: 87 | """Speckle noise corruption in ImageNet-C.""" 88 | c = [.15, .2, 0.35, 0.45, 0.6][severity - 1] 89 | 90 | image = image + image * jax.random.normal(rng, shape=image.shape) * c 91 | return jnp.clip(image, a_min=0, a_max=1) 92 | 93 | 94 | def impulse_noise(image: chex.Array, 95 | severity: int = 1, 96 | rng: chex.PRNGKey = None) -> chex.Array: 97 | """Impulse noise corruption in ImageNet-C.""" 98 | c = [.03, .06, .09, 0.17, 0.27][severity - 1] 99 | x = jnp.clip(image, 0, 1) 100 | p = c 101 | q = 0.5 102 | out = x 103 | 104 | flipped = jax.random.choice( 105 | rng, 2, shape=x.shape, p=jax.numpy.array([1 - p, p])) 106 | salted = jax.random.choice( 107 | rng, 2, shape=x.shape, p=jax.numpy.array([1 - q, q])) 108 | peppered = 1 - salted 109 | 110 | mask = flipped * salted 111 | out = out * (1 - mask) + mask 112 | 113 | mask = flipped * peppered 114 | out = out * (1 - mask) 115 | return jnp.clip(out, a_min=0, a_max=1) 116 | 117 | 118 | def shot_noise(image: chex.Array, severity: int = 1, rng: chex.PRNGKey = None 119 | ) -> chex.Array: 120 | """Shot noise in ImageNet-C corruptions.""" 121 | c = [60, 25, 12, 5, 3][severity - 1] 122 | 123 | x = jnp.clip(image, 0, 1) 124 | x = jax.random.poisson(rng, lam=x * c, shape=x.shape) / c 125 | return jnp.clip(x, a_min=0, a_max=1) 126 | 127 | 128 | def gaussian_noise(image: chex.Array, 129 | severity: int = 1, 130 | rng: chex.PRNGKey = None) -> chex.Array: 131 | """Gaussian noise in ImageNet-C corruptions.""" 132 | c = [.08, .12, 0.18, 0.26, 0.38][severity - 1] 133 | 134 | x = image + jax.random.normal(rng, shape=image.shape) * c 135 | return jnp.clip(x, a_min=0, a_max=1) 136 | 137 | 138 | def brightness(image: chex.Array, severity: int = 1, rng: chex.PRNGKey = None 139 | ) -> chex.Array: 140 | """The brightness corruption from ImageNet-C.""" 141 | del rng 142 | c = [.1, .2, .3, .4, .5][severity - 1] 143 | 144 | x = jnp.clip(image, 0, 1) 145 | hsv = color_conversion.rgb_to_hsv(x) 146 | h, s, v = color_conversion.split_channels(hsv, -1) 147 | v = jnp.clip(v + c, 0, 1) 148 | rgb_adjusted = color_conversion.hsv_planes_to_rgb_planes(h, s, v) 149 | rgb = jnp.stack(rgb_adjusted, axis=-1) 150 | 151 | return rgb 152 | 153 | 154 | def saturate(image: chex.Array, severity: int = 1, rng: chex.PRNGKey = None 155 | ) -> chex.Array: 156 | """The saturation corruption from ImageNet-C.""" 157 | del rng 158 | c = [(0.3, 0), (0.1, 0), (2, 0), (5, 0.1), (20, 0.2)][severity - 1] 159 | 160 | x = jnp.clip(image, 0, 1) 161 | hsv = color_conversion.rgb_to_hsv(x) 162 | h, s, v = color_conversion.split_channels(hsv, -1) 163 | s = jnp.clip(s * c[0] + c[1], 0, 1) 164 | rgb_adjusted = color_conversion.hsv_planes_to_rgb_planes(h, s, v) 165 | rgb = jnp.stack(rgb_adjusted, axis=-1) 166 | 167 | return rgb 168 | 169 | 170 | def contrast(image: chex.Array, severity: int = 1, rng: chex.PRNGKey = None 171 | ) -> chex.Array: 172 | """The contrast corruption from ImageNet-C.""" 173 | del rng 174 | c = [0.4, .3, .2, .1, .05][severity - 1] 175 | 176 | return dm_pix.adjust_contrast(image, factor=c) 177 | -------------------------------------------------------------------------------- /distribution_shift_framework/core/pix/postprocessing.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # 3 | # Copyright 2022 DeepMind Technologies Limited 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """Implementation of post(-augmentation) processing steps.""" 18 | 19 | from typing import Tuple 20 | 21 | import chex 22 | import jax 23 | 24 | 25 | def mixup(images: chex.Array, 26 | labels: chex.Array, 27 | alpha: float = 1., 28 | beta: float = 1., 29 | rng: chex.PRNGKey = None) -> Tuple[chex.Array, chex.Array]: 30 | """Interpolating two images to create a new image. 31 | 32 | Source: https://arxiv.org/abs/1710.09412 33 | 34 | Args: 35 | images: Minibatch of images. 36 | labels: One-hot encoded labels for minibatch. 37 | alpha: Alpha parameter for the beta law which samples the interpolation 38 | weight. 39 | beta: Beta parameter for the beta law which samples the interpolation 40 | weight. 41 | rng: Random number generator state. 42 | 43 | Returns: 44 | Images resulting from the interpolation of pairs of images 45 | and their corresponding weighted labels. 46 | """ 47 | assert labels.shape == 2, 'Labels need to represent one-hot encodings.' 48 | batch_size = images.shape[0] 49 | lmbda_rng, rng = jax.random.split(rng) 50 | lmbda = jax.random.beta(lmbda_rng, a=alpha, b=beta, shape=()) 51 | idx = jax.random.permutation(rng, batch_size) 52 | 53 | images_a = images 54 | images_b = images[idx, :, :, :] 55 | images = lmbda * images_a + (1. - lmbda) * images_b[idx, :] 56 | labels = lmbda * labels + (1. - lmbda) * labels[idx, :] 57 | return images, labels 58 | -------------------------------------------------------------------------------- /distribution_shift_framework/requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==1.0.0 2 | astunparse==1.6.3 3 | cachetools==5.0.0 4 | certifi==2021.10.8 5 | charset-normalizer==2.0.12 6 | chex==0.1.0 7 | contextlib2==21.6.0 8 | dill==0.3.4 9 | dm-haiku==0.0.6 10 | dm-tree==0.1.6 11 | flatbuffers==2.0 12 | gast==0.5.3 13 | google-auth==2.6.0 14 | google-auth-oauthlib==0.4.6 15 | google-pasta==0.2.0 16 | googleapis-common-protos==1.55.0 17 | grpcio==1.44.0 18 | h5py==3.6.0 19 | idna==3.3 20 | importlib-metadata==4.11.1 21 | jax==0.3.1 22 | jaxlib==0.3.0 23 | jaxline==0.0.5 24 | jmp==0.0.2 25 | joblib==1.1.0 26 | keras==2.8.0 27 | Keras-Preprocessing==1.1.2 28 | libclang==13.0.0 29 | Markdown==3.3.6 30 | ml-collections==0.1.1 31 | numpy==1.22.2 32 | oauthlib==3.2.0 33 | opt-einsum==3.3.0 34 | optax==0.1.1 35 | pkg_resources==0.0.0 36 | promise==2.3 37 | protobuf==3.19.4 38 | pyasn1==0.4.8 39 | pyasn1-modules==0.2.8 40 | PyYAML==6.0 41 | requests==2.27.1 42 | requests-oauthlib==1.3.1 43 | rsa==4.8 44 | scikit-learn==1.0.2 45 | scipy==1.8.0 46 | six==1.16.0 47 | sklearn==0.0 48 | tabulate==0.8.9 49 | tensorboard==2.8.0 50 | tensorboard-data-server==0.6.1 51 | tensorboard-plugin-wit==1.8.1 52 | tensorflow==2.8.0 53 | tensorflow-datasets==4.5.2 54 | tensorflow-io-gcs-filesystem==0.24.0 55 | tensorflow-metadata==1.6.0 56 | termcolor==1.1.0 57 | tf-estimator-nightly==2.8.0.dev2021122109 58 | threadpoolctl==3.1.0 59 | toolz==0.11.2 60 | tqdm==4.62.3 61 | typing_extensions==4.1.1 62 | urllib3==1.26.8 63 | Werkzeug==2.0.3 64 | wrapt==1.13.3 65 | zipp==3.7.0 66 | -------------------------------------------------------------------------------- /distribution_shift_framework/run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | # Copyright 2022 DeepMind Technologies Limited. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | set -euf -o pipefail # Stop at failure. 17 | 18 | python3 -m venv /tmp/distribution_shift_framework 19 | source /tmp/distribution_shift_framework/bin/activate 20 | pip install -U pip 21 | pip install -r distribution_shift_framework/requirements.txt 22 | 23 | python3 -m distribution_shift_framework.classification.experiment_lib_test 24 | --------------------------------------------------------------------------------