├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── config ├── cmd │ ├── curriculum.txt │ ├── es.txt │ └── fixed_sharing.txt ├── exps │ ├── curriculum_baselines.py │ ├── dec.py │ ├── dist.py │ ├── distillation_sampling.py │ ├── es_debug.py │ ├── es_dist.py │ ├── es_hyperparams.py │ ├── finetune.py │ ├── fixed_sharing_baselines.py │ ├── single_task_baselines.py │ └── validate.py └── opts │ ├── meta.py │ ├── parser.py │ └── train.py ├── main.py ├── meta ├── optim │ ├── es.py │ ├── grid.py │ ├── optimizer.py │ ├── random_search.py │ └── runner.py └── param │ ├── cmd.py │ ├── manager.py │ └── partition.py ├── models ├── distill.py ├── masked_resnet.py └── resnet.py ├── third_party └── cutout │ ├── LICENSE.md │ └── cutout.py ├── train ├── distill.py └── multiclass.py └── util ├── calc.py ├── datasets └── decathlon.py ├── prepare_submission.py ├── session_manager.py └── tensorboard_manager.py /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to Contribute 2 | 3 | We'd love to accept your patches and contributions to this project. There are 4 | just a few small guidelines you need to follow. 5 | 6 | ## Contributor License Agreement 7 | 8 | Contributions to this project must be accompanied by a Contributor License 9 | Agreement. You (or your employer) retain the copyright to your contribution; 10 | this simply gives us permission to use and redistribute your contributions as 11 | part of the project. Head over to to see 12 | your current agreements on file or to sign a new one. 13 | 14 | You generally only need to submit a CLA once, so if you've already submitted one 15 | (even if it was for a different project), you probably don't need to do it 16 | again. 17 | 18 | ## Code reviews 19 | 20 | All submissions, including submissions by project members, require review. We 21 | use GitHub pull requests for this purpose. Consult 22 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more 23 | information on using pull requests. 24 | 25 | ## Community Guidelines 26 | 27 | This project follows 28 | [Google's Open Source Community Guidelines](https://opensource.google.com/conduct/). 29 | 30 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright 2019 Google LLC 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | https://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | 204 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Feature Partitioning for Multi-Task Architectures 2 | 3 | This framework can be used to reproduce all experiments performed in _“Feature 4 | Partitioning for Efficient Multi-Task Architectures”_. It provides a variety of 5 | functionality for performing and managing deep learning experiments. In 6 | particular, it helps manage meta-optimization which is useful 7 | when doing hyper-parameter tuning and architecture search. 8 | 9 | _Please note, this is not an official Google product._ 10 | 11 | ## Cloud Instance Setup 12 | 13 | Start a new Cloud Instance from _“Deep Learning Image: PyTorch 14 | 1.0.0”_. Almost everything needed to get the code up and running is 15 | included automatically with the Deep Learning Image. 16 | 17 | Following instructions assume the git repository has been pulled and placed in the home directory. 18 | 19 | #### Data setup 20 | 21 | Download and set up Visual Decathlon data and annotations: 22 | 23 | ``` 24 | wget http://www.robots.ox.ac.uk/~vgg/share/decathlon-1.0-devkit.tar.gz 25 | wget http://www.robots.ox.ac.uk/~vgg/share/decathlon-1.0-data.tar.gz 26 | tar zxf decathlon-1.0-devkit.tar.gz 27 | mv decathlon-1.0 ~/mtl/data/decathlon 28 | tar zxf decathlon-1.0-data.tar.gz -C ~/mtl/data/decathlon/data 29 | cd ~/mtl/data/decathlon/data 30 | for f in *.tar; do tar xf "$f"; done 31 | ``` 32 | 33 | ImageNet data must be set up separately, please check out [http://image-net.org/download-images](http://image-net.org/download-images). 34 | 35 | #### Code setup 36 | 37 | Add the following lines to your ~/.bashrc file: 38 | ``` 39 | export PATH=~/.local/bin:$PATH 40 | export PYTHONPATH=~/mtl:$PYTHONPATH 41 | ulimit -n 2048 42 | ``` 43 | 44 | Then run the following: 45 | ``` 46 | source ~/.bashrc 47 | pip install --upgrade torch torchvision tensorflow 48 | ``` 49 | 50 | All code was tested with Python 3.7 and PyTorch 1.0. 51 | 52 | 53 | ## Network Training 54 | 55 | A variety of configuration files are available to run different training procedures tested in the paper, some examples include: 56 | 57 | ``` 58 | python main.py -e single_task_network --config exp.dec --model resnet --task_choice 0 59 | python main.py -e partitioned_mtl_network --config exp.dec --task_choice 1-2-3-4 60 | python main.py -e distillation_test --config exp.dist 61 | python main.py -e es_optimization_test --config exp.es_dist 62 | ``` 63 | 64 | The argument ```-e``` indicates the experiment name, and ```--config``` specifies the appropriate configuration file. Further details about network training can be found [here](). 65 | -------------------------------------------------------------------------------- /config/cmd/curriculum.txt: -------------------------------------------------------------------------------- 1 | --temperature .6 .05 2 log 0 2 | --curriculum_bias .05 .01 .5 log 0 3 | -------------------------------------------------------------------------------- /config/cmd/es.txt: -------------------------------------------------------------------------------- 1 | --learning_rate 1e-1 .01 1 log 0 2 | --momentum .9 .5 .99 linear 0 3 | --delta_size .1 .01 .15 log 0 4 | --num_deltas 8 8 12 linear 1 5 | --num_to_use 6 8 12 linear 1 6 | -------------------------------------------------------------------------------- /config/cmd/fixed_sharing.txt: -------------------------------------------------------------------------------- 1 | --share_amt .75 0 1 linear 0 2 | -------------------------------------------------------------------------------- /config/exps/curriculum_baselines.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | is_meta = True 17 | 18 | option_defaults = { 19 | 'metaoptimizer': 'grid', 20 | 'worker_mode': 'cmd', 21 | 'gpu_choice': '0,1,2,3', 22 | 'num_procs': 4, 23 | 'param': 'cmd', 24 | 'search': 'arg', 25 | 'cmd_config': 'curriculum', 26 | 'distribute': 1, 27 | 'n_steps': '4', 28 | } 29 | 30 | base_cmd = '--config exps.dec --suppress_output 1'.split(' ') 31 | -------------------------------------------------------------------------------- /config/exps/dec.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | option_defaults = { 17 | 'dataset': 'decathlon', 18 | 'model': 'masked_resnet', 19 | 'task_choice': '1-2-3-4-5-6-7-8-9', 20 | 'param_init': 'random', 21 | 'weight_decay': 1e-4, 22 | 'drop_rate': .3, 23 | 'num_rounds': 40, 24 | 'train_iters': 5000, 25 | 'drop_lr_iters': '150000', 26 | 'num_unique': 1, 27 | 'fixed_seed': 0, 28 | 'last_third_only': 1, 29 | 'imagenet_pretrained': 1, 30 | } 31 | -------------------------------------------------------------------------------- /config/exps/dist.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | option_defaults = { 17 | 'dataset': 'decathlon', 18 | 'model': 'distill', 19 | 'task': 'distill', 20 | 'curriculum': 'uniform', 21 | 'optimizer': 'SGD', 22 | 'param_init': 'random', 23 | 'rand_type': 'restrict', 24 | 'learning_rate': 1, 25 | 'batchsize': 4, 26 | 'valid_batchsize': 64, 27 | 'task_choice': '1-2-3-4-5-6-7-8-9', 28 | 'weight_decay': 1e-5, 29 | 'num_rounds': 1, 30 | 'train_iters': 3000, 31 | 'drop_lr_iters': '2000', 32 | 'last_layer_idx': 0, 33 | 'num_distill': 4, 34 | 'num_unique': 1, 35 | 'fixed_seed': 0, 36 | 'subsample_validation': 1, 37 | 'last_third_only': 1, 38 | } 39 | -------------------------------------------------------------------------------- /config/exps/distillation_sampling.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | is_meta = True 17 | 18 | option_defaults = { 19 | 'metaoptimizer': 'runner', 20 | 'worker_mode': 'cmd', 21 | 'num_procs': 8, 22 | 'gpu_choice': '0,1,2,3', 23 | 'param': '', 24 | 'num_samples': 3, 25 | 'distribute': 1, 26 | 'cleanup_experiment': 1, 27 | } 28 | 29 | base_cmd = '--config exps.dist --suppress_output 1'.split(' ') 30 | exp_list = [] 31 | exp_names = [] 32 | 33 | for i in range(3900): 34 | exp_list += ['--metaparam_load all_metaparams/%d/%d' % (i // 100, i % 100)] 35 | exp_names += ['m%d/%d' % (i // 100, i % 100)] 36 | exp_list = [e.split(' ') for e in exp_list] 37 | -------------------------------------------------------------------------------- /config/exps/es_debug.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | is_meta = True 17 | 18 | option_defaults = { 19 | 'metaoptimizer': 'es', 20 | 'worker_mode': 'debug', 21 | 'num_procs': 1, 22 | 'param': 'partition', 23 | 'search': 'partition', 24 | 'num_samples': 10000, 25 | 'task_choice': '0-1-2-3-4-5-6-7-8-9', 26 | 'learning_rate': 1, 27 | 'momentum': .8, 28 | 'delta_size': .005, 29 | 'num_deltas': 32, 30 | 'num_to_use': 24, 31 | } 32 | 33 | base_cmd = None 34 | -------------------------------------------------------------------------------- /config/exps/es_dist.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | is_meta = True 17 | 18 | option_defaults = { 19 | 'metaoptimizer': 'es', 20 | 'worker_mode': 'cmd', 21 | 'num_procs': 8, 22 | 'gpu_choice': '0,1,2,3', 23 | 'distribute': 1, 24 | 'param': 'partition', 25 | 'search': 'partition', 26 | 'num_samples': 10000, 27 | 'task_choice': '1-2-3-4-5-6-7-8-9', 28 | 'init_random': 32, 29 | 'learning_rate': .25, 30 | 'momentum': .8, 31 | 'delta_size': .04, 32 | 'num_deltas': 8, 33 | 'num_to_use': 7, 34 | 'num_unique': 1, 35 | 'cleanup_experiment': 1, 36 | 'do_weight_reg': 2, 37 | 'diag_weight_decay': .001, 38 | 'multiobjective': 0, 39 | } 40 | 41 | base_cmd = '--config exps.dist --suppress_output 1'.split(' ') 42 | -------------------------------------------------------------------------------- /config/exps/es_hyperparams.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | is_meta = True 17 | 18 | option_defaults = { 19 | 'metaoptimizer': 'random_search', 20 | 'worker_mode': 'cmd', 21 | 'num_procs': 1, 22 | 'param': 'cmd', 23 | 'search': 'arg', 24 | 'cmd_config': 'es', 25 | 'num_samples': 50, 26 | 'distribute': 0, 27 | } 28 | 29 | base_cmd = '--config exps.es_dist --num_samples 5000'.split(' ') 30 | -------------------------------------------------------------------------------- /config/exps/finetune.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | option_defaults = { 17 | 'dataset': 'decathlon', 18 | 'model': 'resnet', 19 | 'task_choice': '1', 20 | 'weight_decay': 1e-4, 21 | 'drop_rate': .3, 22 | 'num_rounds': 10, 23 | 'train_iters': 5000, 24 | 'drop_lr_iters': '30000', 25 | 'imagenet_pretrained': 1, 26 | 'last_third_only': 1, 27 | 'num_data_threads': 8, 28 | 'fixed_seed': 0, 29 | } 30 | -------------------------------------------------------------------------------- /config/exps/fixed_sharing_baselines.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | is_meta = True 17 | 18 | option_defaults = { 19 | 'metaoptimizer': 'grid', 20 | 'worker_mode': 'cmd', 21 | 'gpu_choice': '0,1,2,3', 22 | 'num_procs': 4, 23 | 'param': 'cmd', 24 | 'search': 'arg', 25 | 'cmd_config': 'partial_share', 26 | 'distribute': 1, 27 | 'n_steps': '5', 28 | } 29 | 30 | base_cmd = '--config exps.dec --suppress_output 1 --param_init share_fixed'.split(' ') 31 | -------------------------------------------------------------------------------- /config/exps/single_task_baselines.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | is_meta = True 17 | 18 | option_defaults = { 19 | 'metaoptimizer': 'runner', 20 | 'worker_mode': 'cmd', 21 | 'num_procs': 4, 22 | 'gpu_choice': '0,1,2,3', 23 | 'param': '', 24 | 'num_samples': 3, 25 | 'distribute': 1, 26 | } 27 | 28 | base_cmd = '--config exps.finetune --suppress_output 1'.split(' ') 29 | exp_list = [] 30 | exp_names = [] 31 | 32 | tmp_exps = ['', '--last_third_only 0', 33 | '--imagenet_pretrained 0 --last_third_only 0 --num_rounds 30 --drop_lr_iters 100000'] 34 | tmp_names = ['', 'full', 'scratch'] 35 | for task_idx in range(1,10): 36 | for exp_idx, exp_type in enumerate(tmp_exps): 37 | exp_list += ['--task_choice %d %s' % (task_idx, exp_type)] 38 | exp_names += ['t%d%s' % (task_idx, tmp_names[exp_idx])] 39 | exp_list = [e.split(' ') for e in exp_list] 40 | -------------------------------------------------------------------------------- /config/exps/validate.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | option_defaults = { 17 | 'dataset': 'decathlon', 18 | 'model': 'resnet', 19 | 'task_choice': '1', 20 | 'num_rounds': 1, 21 | 'train_iters': 0, 22 | 'last_third_only': 0, 23 | 'imagenet_pretrained': 0, 24 | } 25 | -------------------------------------------------------------------------------- /config/opts/meta.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Command line options for meta optimization.""" 17 | 18 | 19 | def setup_args(parser): 20 | parser.add_argument('--metaoptimizer', type=str, default='es') 21 | parser.add_argument('-n', '--num_samples', type=int, default=10000) 22 | parser.add_argument('--worker_mode', type=str, default='cmd') 23 | parser.add_argument('--num_procs', type=int, default=4) 24 | parser.add_argument('--distribute', type=int, default=1) 25 | parser.add_argument('--cleanup_experiment', type=int, default=0) 26 | 27 | # Maximize a reward or minimize a loss? 28 | parser.add_argument('--maximize', type=int, default=1) 29 | parser.add_argument('--worst_score', type=float, default=0) 30 | 31 | # Parameter options 32 | parser.add_argument('-p', '--param', type=str, default='partition') 33 | parser.add_argument('-s', '--search', type=str, default='partition') 34 | parser.add_argument('--cmd_config', type=str, default='') 35 | 36 | # Meta optimization debugging options 37 | parser.add_argument('--meta_eval_noise', type=float, default=0.) 38 | -------------------------------------------------------------------------------- /config/opts/parser.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Management of command line options.""" 17 | 18 | import argparse 19 | import importlib 20 | import os 21 | import pickle 22 | import sys 23 | 24 | from mtl.config.opts import meta as meta_opts 25 | from mtl.config.opts import train as train_opts 26 | 27 | 28 | def get_exp_root_dir(): 29 | """Defines path of experiment root directory.""" 30 | 31 | # Change this to save experiments somewhere else 32 | tmp_dir = os.path.dirname(__file__) 33 | tmp_dir = tmp_dir.split('/')[:-3] + ['exp'] 34 | return '/'.join(tmp_dir) 35 | 36 | 37 | def mkdir_p(dirname): 38 | """Make directory and any necessary parent directories.""" 39 | try: 40 | os.makedirs(dirname) 41 | except FileExistsError: 42 | pass 43 | 44 | 45 | def suppress_output(): 46 | """Suppress all further output to the console.""" 47 | f = open(os.devnull, 'w') 48 | sys.stdout = f 49 | sys.stderr = f 50 | 51 | 52 | def get_disp_opts(is_meta): 53 | """Defines which flags are displayed when starting an experiment.""" 54 | if is_meta: 55 | return [ 56 | 'exp_id', 57 | 'metaoptimizer', 58 | 'worker_mode', 59 | 'num_samples', 60 | 'param', 61 | 'search', 62 | ] 63 | else: 64 | return [ 65 | 'exp_id', 'task', 'model', 'dataset', 'optimizer', 'batchsize', 66 | 'learning_rate' 67 | ] 68 | 69 | 70 | def setup_init_args(parser): 71 | """Initial arguments that need to be parsed first.""" 72 | parser.add_argument('-e', '--exp_id', type=str, default='default') 73 | parser.add_argument('-x', '--is_meta', action='store_true') 74 | parser.add_argument('--config', type=str, default='exps.dec') 75 | parser.add_argument('--gpu_choice', type=str, default='0') 76 | parser.add_argument('--suppress_output', type=int, default=0) 77 | parser.add_argument('--task_choice', type=str, default='1-2-3-4-5-6-7-8-9') 78 | parser.add_argument('--fixed_seed', type=int, default=0) 79 | 80 | # Options to continue/branch off existing experiment 81 | g = parser.add_mutually_exclusive_group() 82 | g.add_argument('-c', '--continue_exp', type=int, default=0) 83 | g.add_argument('--branch', default='') 84 | 85 | 86 | def restore_flags(parser, init_flags): 87 | """Updates options to match cached values. 88 | 89 | When restoring options from a previously run experiment, we want to load all 90 | values and update any terms that have been manually changed on the command 91 | line (for example, when continuing with a new learning rate). 92 | 93 | Args: 94 | parser: Command line parser used to collect init_flags. 95 | init_flags: Current flags that will tell us whether or not to load previous 96 | values from another experiment. 97 | """ 98 | 99 | init_flags.restore_session = None 100 | last_round = 0 101 | 102 | # Check if we need to load up a previous set of options 103 | if init_flags.continue_exp or init_flags.branch: 104 | if init_flags.continue_exp: 105 | tmp_exp_dir = init_flags.exp_id 106 | else: 107 | tmp_exp_dir = init_flags.branch 108 | 109 | tmp_exp_dir = '%s/%s' % (get_exp_root_dir(), tmp_exp_dir) 110 | 111 | # Load previous options 112 | with open(tmp_exp_dir + '/opts.p', 'rb') as f: 113 | flags = pickle.load(f) 114 | 115 | # Parse newly set options 116 | setup_parser(parser, init_flags, flags) 117 | new_flags, _ = parser.parse_known_args() 118 | 119 | # Check which flags have been manually set and update them 120 | opts = {} 121 | for val in sys.argv: 122 | if val == '--': 123 | break 124 | elif val and val[0] == '-': 125 | if val in flags.short_ref: 126 | tmp_arg = flags.short_ref[val] 127 | else: 128 | tmp_arg = val[2:] 129 | if tmp_arg in new_flags: 130 | opts[tmp_arg] = new_flags.__dict__[tmp_arg] 131 | 132 | if '--' in sys.argv: 133 | opts['unparsed'] = sys.argv[sys.argv.index('--'):] 134 | 135 | for opt in opts: 136 | flags.__dict__[opt] = opts[opt] 137 | 138 | flags.restore_session = tmp_exp_dir 139 | 140 | try: 141 | with open(tmp_exp_dir + '/last_round', 'r') as f: 142 | last_round = int(f.readline()) 143 | except: 144 | pass 145 | 146 | else: 147 | flags = init_flags 148 | 149 | if 'num_rounds' in flags: 150 | flags.last_round = last_round + flags.num_rounds 151 | 152 | return flags 153 | 154 | 155 | def add_extra_args(parser, files): 156 | """Add additional arguments defined in other files.""" 157 | 158 | for f in files: 159 | m = importlib.import_module(f) 160 | if 'setup_extra_args' in m.__dict__: 161 | m.setup_extra_args(parser) 162 | 163 | 164 | def setup_parser(parser, init_flags, ref_flags=None): 165 | """Setup appropriate arguments and defaults for command line parser.""" 166 | 167 | # Load config file 168 | if ref_flags is None: 169 | cfg = importlib.import_module('mtl.config.' + init_flags.config) 170 | is_meta = cfg.is_meta if 'is_meta' in cfg.__dict__ else init_flags.is_meta 171 | else: 172 | cfg = importlib.import_module('mtl.config.' + ref_flags.config) 173 | is_meta = ref_flags.is_meta 174 | 175 | parser.set_defaults(is_meta=is_meta) 176 | 177 | if is_meta: 178 | # Set up meta optimization options 179 | meta_opts.setup_args(parser) 180 | extra_arg_files = [['mtl.meta.optim', 'metaoptimizer'], 181 | ['mtl.meta.param', 'param']] 182 | else: 183 | # Set up network training options 184 | train_opts.setup_args(parser) 185 | extra_arg_files = [['mtl.models', 'model'], 186 | ['mtl.util.datasets', 'dataset'], ['mtl.train', 'task']] 187 | 188 | if ref_flags is None: 189 | for _, k in extra_arg_files: 190 | if k in cfg.option_defaults: 191 | parser.set_defaults(**{k: cfg.option_defaults[k]}) 192 | ref_flags, _ = parser.parse_known_args() 193 | 194 | # Add additional arguments 195 | add_extra_args(parser, [ 196 | '%s.%s' % (d, ref_flags.__dict__[k]) 197 | for d, k in extra_arg_files 198 | if ref_flags.__dict__[k] 199 | ]) 200 | 201 | # Update options/defaults according to config file 202 | if 'option_defaults' in cfg.__dict__: 203 | parser.set_defaults(**cfg.option_defaults) 204 | 205 | 206 | def parse_command_line(): 207 | """Parse command line and set up experiment options. 208 | 209 | Returns: 210 | An object with all options stored as attributes. 211 | """ 212 | 213 | parser = argparse.ArgumentParser() 214 | setup_init_args(parser) 215 | init_flags, _ = parser.parse_known_args() 216 | 217 | # Check whether to restore previous options 218 | flags = restore_flags(parser, init_flags) 219 | 220 | if flags.restore_session is None: 221 | setup_parser(parser, init_flags) 222 | flags, unparsed = parser.parse_known_args() 223 | flags.unparsed = unparsed 224 | flags.restore_session = None 225 | if not flags.is_meta: 226 | flags.last_round = flags.num_rounds 227 | flags.short_ref = { 228 | a.option_strings[0]: a.option_strings[1][2:] 229 | for a in parser._actions 230 | if len(a.option_strings) == 2 231 | } 232 | 233 | # Save options 234 | flags.exp_root_dir = get_exp_root_dir() 235 | flags.data_dir = '' 236 | flags.exp_dir = '%s/%s' % (flags.exp_root_dir, flags.exp_id) 237 | mkdir_p(flags.exp_dir) 238 | with open('%s/opts.p' % flags.exp_dir, 'wb') as f: 239 | pickle.dump(flags, f) 240 | 241 | if not flags.is_meta: 242 | flags.iters = {'train': flags.train_iters} 243 | flags.drop_learning_rate = [] 244 | if flags.drop_lr_iters: 245 | flags.drop_learning_rate = list(map(int, flags.drop_lr_iters.split('-'))) 246 | 247 | if flags.suppress_output: 248 | suppress_output() 249 | print('---------------------------------------------') 250 | for tmp_opt in get_disp_opts(flags.is_meta): 251 | print('{:15s}: {}'.format(tmp_opt, flags.__dict__[tmp_opt])) 252 | print('---------------------------------------------') 253 | 254 | return flags 255 | 256 | 257 | if __name__ == '__main__': 258 | print(parse_command_line()) 259 | -------------------------------------------------------------------------------- /config/opts/train.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Command line options for network training.""" 17 | 18 | 19 | def setup_args(parser): 20 | # Dataset options 21 | parser.add_argument('-d', '--dataset', type=str, default='decathlon') 22 | parser.add_argument('--num_data_threads', type=int, default=1) 23 | parser.add_argument('--train_on_valid', type=int, default=0) 24 | parser.add_argument('--validate_on_train', type=int, default=0) 25 | parser.add_argument('--use_test', type=int, default=0) 26 | 27 | # Training length options 28 | parser.add_argument('--num_rounds', type=int, default=100) 29 | parser.add_argument('--train_iters', type=int, default=4000) 30 | parser.add_argument('--early_stop_thr', type=float, default=.0) 31 | parser.add_argument('--curriculum', type=str, default='train_accuracy') 32 | 33 | # Training hyperparameters 34 | parser.add_argument('--optimizer', type=str, default='SGD') 35 | parser.add_argument('-l', '--learning_rate', type=float, default=.05) 36 | parser.add_argument('--batchsize', type=int, default=64) 37 | parser.add_argument('--valid_batchsize', type=int, default=0) 38 | parser.add_argument('--momentum', type=float, default=.9) 39 | parser.add_argument('--weight_decay', type=float, default=1e-4) 40 | parser.add_argument('--clip_grad', type=float, default=0.) 41 | parser.add_argument('--dropout', type=float, default=0.3) 42 | parser.add_argument('--temperature', type=float, default=0.3) 43 | parser.add_argument('--curriculum_bias', type=float, default=0.15) 44 | 45 | parser.add_argument('--drop_lr_iters', type=str, default='') 46 | parser.add_argument('--drop_lr_factor', type=int, default=10) 47 | 48 | # Task options 49 | parser.add_argument('-t', '--task', type=str, default='multiclass') 50 | parser.add_argument('--metaparam', type=str, default='') 51 | parser.add_argument('--metaparam_load', type=str, default='') 52 | 53 | # Model 54 | parser.add_argument('-m', '--model', type=str, default='masked_resnet') 55 | parser.add_argument('--pretrained', type=str, default='') 56 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Main file for running experiments.. 17 | """ 18 | import importlib 19 | import os 20 | import subprocess 21 | 22 | from mtl.config.opts import parser 23 | import numpy as np 24 | import torch 25 | import torch.multiprocessing as mp 26 | from tqdm import tqdm 27 | 28 | 29 | def train(opt): 30 | """Run standard network training loop. 31 | 32 | Args: 33 | opt: All experiment and training options (see mtl/config/opts). 34 | """ 35 | 36 | if opt.fixed_seed: 37 | print('Fixing random seed') 38 | np.random.seed(9999) 39 | torch.manual_seed(9999) 40 | torch.cuda.manual_seed(9999) 41 | torch.backends.cudnn.deterministic = True 42 | 43 | ds = importlib.import_module('mtl.util.datasets.' + opt.dataset) 44 | ds, dataloaders = ds.initialize(opt) 45 | 46 | task = importlib.import_module('mtl.train.' + opt.task) 47 | sess = task.Task(opt, ds, dataloaders) 48 | sess.cuda() 49 | 50 | splits = [s for s in ['train', 'valid'] if opt.iters[s] > 0] 51 | start_round = opt.last_round - opt.num_rounds 52 | 53 | # Main training loop 54 | for round_idx in range(start_round, opt.last_round): 55 | 56 | sess.valid_accuracy_track = [[] for _ in range(sess.num_tasks)] 57 | for split in splits: 58 | 59 | print('Round %d: %s' % (round_idx, split)) 60 | train_flag = split == 'train' 61 | sess.set_train_mode(train_flag) 62 | 63 | if split == 'valid': 64 | sess.prediction_ref = [{} for _ in range(sess.num_tasks)] 65 | 66 | for step in tqdm(range(opt.iters[split]), ascii=True): 67 | global_step = step + round_idx * opt.iters[split] 68 | sess.run(split, global_step) 69 | if train_flag: sess.update_weights() 70 | 71 | if (split == 'train' and opt.drop_learning_rate 72 | and global_step in opt.drop_learning_rate): 73 | opt.learning_rate /= opt.drop_lr_factor 74 | print('Dropping learning rate to %.2f' % opt.learning_rate) 75 | for opt_key in sess.checkpoint_ref['optim']: 76 | for p in sess.__dict__[opt_key].param_groups: 77 | p['lr'] = opt.learning_rate 78 | 79 | # Update Tensorboard 80 | if global_step % 500 == 0 or (split == 'valid' and global_step % 50): 81 | for i in range(len(ds[split])): 82 | sess.tb.update(ds[split][i].task_name, split, global_step, 83 | sess.get_log_vals(split, i)) 84 | 85 | torch.save({'preds': sess.prediction_ref}, 86 | '%s/final_predictions' % opt.exp_dir) 87 | 88 | # Update accuracy history 89 | sess.score = np.array([np.array(a).mean() 90 | for a in sess.valid_accuracy_track]).mean() 91 | print('Score:', sess.score) 92 | 93 | for i in range(sess.num_tasks): 94 | for s in splits: 95 | if s == 'valid': 96 | tmp_acc = np.array(sess.valid_accuracy_track[i]).mean() 97 | sess.log['accuracy'][i][s] = tmp_acc 98 | sess.log['accuracy_history'][i][s] += [sess.log['accuracy'][i][s]] 99 | 100 | sess.save(opt.exp_dir + '/snapshot') 101 | with open(opt.exp_dir + '/last_round', 'w') as f: 102 | f.write('%d\n' % (round_idx + 1)) 103 | 104 | if (opt.iters['valid'] > 0 and sess.score < opt.early_stop_thr): 105 | break 106 | 107 | 108 | def worker(opt, p_idx, cmd_queue, result_queue, debug_param): 109 | """Worker thread for managing parallel experiment runs. 110 | 111 | Args: 112 | opt: Experiment options 113 | p_idx: Process index 114 | cmd_queue: Queue holding experiment commands to run 115 | result_queue: Queue to submit experiment results 116 | debug_param: Shared target for debugging meta-optimization 117 | """ 118 | gpus = list(map(int, opt.gpu_choice.split(','))) 119 | gpu_choice = gpus[p_idx % len(gpus)] 120 | np.random.seed() 121 | 122 | try: 123 | while True: 124 | msg = cmd_queue.get() 125 | if msg == 'DONE': break 126 | exp_count, mode, cmd, extra_args = msg 127 | 128 | if mode == 'debug': 129 | # Basic operation for debugging/sanity checking optimizers 130 | exp_id, param = cmd 131 | pred_param = param['partition'][0] 132 | 133 | triu_ = torch.Tensor(np.triu(np.ones(debug_param.shape))) 134 | score = -np.linalg.norm((debug_param - pred_param)*triu_) 135 | score += np.random.randn() * opt.meta_eval_noise 136 | 137 | tmp_acc = {'accuracy': [{'valid': score} for i in range(10)]} 138 | result = {'score': score, 'log': tmp_acc} 139 | 140 | elif mode == 'cmd': 141 | # Run a specified command 142 | tmp_cmd = cmd 143 | if opt.distribute: 144 | tmp_cmd += ['--gpu_choice', str(gpu_choice)] 145 | tmp_cmd += extra_args 146 | exp_id = tmp_cmd[tmp_cmd.index('-e') + 1] 147 | 148 | print('%d:' % p_idx, ' '.join(tmp_cmd)) 149 | subprocess.call(tmp_cmd) 150 | 151 | # Collect result 152 | log_path = '%s/%s/snapshot_extra' % (opt.exp_root_dir, exp_id) 153 | try: 154 | result = torch.load(log_path) 155 | except Exception as e: 156 | print('Error loading result:', repr(e)) 157 | result = None 158 | 159 | if opt.cleanup_experiment: 160 | # Remove extraneous files that take up disk space 161 | exp_dir = '%s/%s' % (opt.exp_root_dir, exp_id) 162 | cleanup_paths = [exp_dir + '/snapshot_optim', 163 | exp_dir + '/snapshot_model'] 164 | dir_files = os.listdir(exp_dir) 165 | tfevent_files = ['%s/%s' % (exp_dir, fn) 166 | for fn in dir_files if 'events' in fn] 167 | cleanup_paths += tfevent_files 168 | 169 | for cleanup in cleanup_paths: 170 | subprocess.call(['rm', cleanup]) 171 | 172 | result_queue.put([exp_id, result, exp_count]) 173 | 174 | except KeyboardInterrupt: 175 | print('Keyboard interrupt in process %d' % p_idx) 176 | finally: 177 | print('Exiting process %d' % p_idx) 178 | 179 | 180 | def main(): 181 | # Parse command line options 182 | opt = parser.parse_command_line() 183 | # Set GPU 184 | os.environ['CUDA_VISIBLE_DEVICES'] = opt.gpu_choice 185 | 186 | if opt.is_meta: 187 | # Initialize queues 188 | cmd_queue = mp.Queue() 189 | result_queue = mp.Queue() 190 | 191 | # Set up target debug params 192 | debug_param = torch.rand(2, 10, 10) 193 | 194 | # Start workers 195 | workers = [] 196 | for i in range(opt.num_procs): 197 | worker_args = (opt, i, cmd_queue, result_queue, debug_param) 198 | worker_p = mp.Process(target=worker, args=worker_args) 199 | worker_p.daemon = True 200 | worker_p.start() 201 | workers += [worker_p] 202 | 203 | # Initialize and run meta optimizer 204 | metaoptim = importlib.import_module('mtl.meta.optim.' + opt.metaoptimizer) 205 | metaoptim = metaoptim.MetaOptimizer(opt) 206 | metaoptim.run(cmd_queue, result_queue) 207 | 208 | # Clean up workers 209 | for i in range(opt.num_procs): 210 | cmd_queue.put('DONE') 211 | for worker_p in workers: 212 | worker_p.join() 213 | 214 | else: 215 | # Run standard network training 216 | train(opt) 217 | 218 | if __name__ == '__main__': 219 | main() 220 | -------------------------------------------------------------------------------- /meta/optim/es.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | from mtl.meta.optim import optimizer 17 | from mtl.util import tensorboard_manager 18 | import numpy as np 19 | import torch 20 | 21 | 22 | def setup_extra_args(parser): 23 | parser.add_argument('--optimizer', type=str, default='SGD') 24 | parser.add_argument('-l', '--learning_rate', type=float, default=1) 25 | parser.add_argument('--momentum', type=float, default=.5) 26 | parser.add_argument('--init_random', type=int, default=32) 27 | parser.add_argument('--num_deltas', type=int, default=16) 28 | parser.add_argument('--num_to_use', type=int, default=8) 29 | parser.add_argument('--delta_size', type=float, default=.05) 30 | parser.add_argument('--do_weight_reg', type=int, default=0) 31 | parser.add_argument('--diag_weight_decay', type=float, default=.01) 32 | parser.add_argument('--multiobjective', type=int, default=0) 33 | 34 | 35 | class MetaOptimizer(optimizer.MetaOptimizer): 36 | 37 | def __init__(self, opt): 38 | self.num_tasks = len(opt.task_choice.split('-')) 39 | super().__init__(opt) 40 | 41 | self.ref_params = self.curr_params._parameters[opt.search[0]] 42 | self.task_grad_masks = np.array([ 43 | self.ref_params.get_task_parameter_mask(i) 44 | for i in range(self.num_tasks) 45 | ]) 46 | 47 | def tensorboard_setup(self): 48 | self.task_names = ['total'] + ['task_%d' % i for i in range(self.num_tasks)] 49 | self.to_track = ['score', 'param_use'] 50 | self.tb = tensorboard_manager.TBManager(self.exp_dir, self.task_names, 51 | self.to_track, ['train']) 52 | 53 | def run(self, cmd_queue, result_queue): 54 | opt = self.opt 55 | 56 | self.base_cmd = self.cfg.base_cmd 57 | 58 | if '--' in opt.unparsed: 59 | # Additional arguments to pass on to experiments 60 | extra_args = opt.unparsed[opt.unparsed.index('--') + 1:] 61 | self.extra_child_args = extra_args 62 | 63 | if self.best_exp is None: 64 | # Sample a random set of parameters and keep best 65 | print('Testing initial random samples...') 66 | 67 | samples = [] 68 | for i in range(opt.init_random): 69 | samples += [self.init_sample()] 70 | self.submit_cmd(cmd_queue, str(i), param=samples[i]) 71 | 72 | _, scores = self.collect_batch_results(opt.init_random, result_queue) 73 | self.best_score = scores.max() 74 | self.best_exp = scores.argmax() 75 | self.best_params = samples[self.best_exp] 76 | 77 | # Set up optimizer 78 | self.curr_params = self.best_params 79 | curr_params = self.curr_params.get_params(opt.search) 80 | curr_params.requires_grad = True 81 | self.setup_optimizer([curr_params]) 82 | 83 | # Save checkpoint 84 | self.score = self.best_score 85 | self.save(self.exp_dir + '/snapshot') 86 | 87 | # Main loop 88 | while self.exp_count < opt.num_samples: 89 | # Sample deltas 90 | deltas = [] 91 | p1 = curr_params.data.clone() 92 | for delta_idx in range(opt.num_deltas): 93 | p2 = self.ref_params.mutate(p1, delta=opt.delta_size) 94 | deltas += [np.array(p2 - p1)] 95 | deltas = np.stack(deltas) 96 | 97 | to_test = np.concatenate([deltas, -deltas], 0) 98 | samples, exp_ref = [], [] 99 | n_samples = to_test.shape[0] 100 | 101 | for delta_idx in range(n_samples): 102 | exp_id = str(delta_idx) 103 | sample = self.init_sample() 104 | params = curr_params.clone() 105 | params += torch.Tensor(to_test[delta_idx]).view(params.shape) 106 | params.data = params.data.clamp(*self.ref_params.valid_range) 107 | sample.update_params(opt.search, params) 108 | samples += [sample] 109 | exp_ref += [self.exp_count] 110 | 111 | self.submit_cmd(cmd_queue, exp_id, param=sample) 112 | 113 | # Collect results 114 | result_idx_ref, scores = self.collect_batch_results( 115 | n_samples, result_queue) 116 | self.score = scores.mean() 117 | print(self.exp_count, '%.3f' % self.score) 118 | 119 | # Calculate per-task scores 120 | task_scores = np.zeros((scores.shape[0], self.num_tasks)) 121 | for sample_idx, result_idx in enumerate(result_idx_ref): 122 | acc = self.results[result_idx][1]['log']['accuracy'] 123 | task_scores[sample_idx] = [acc_['valid'] for acc_ in acc] 124 | 125 | # Update tensorboard 126 | param_use = np.array( 127 | curr_params.view(self.ref_params.shape).data[0][0].diag()) 128 | param_use = [param_use.mean()] + list(param_use) 129 | tmp_scores = [self.score] + list(task_scores.mean(0)) 130 | for task_idx, task_name in enumerate(self.task_names): 131 | self.tb.update(task_name, 'train', self.exp_count, { 132 | 'score': tmp_scores[task_idx], 133 | 'param_use': param_use[task_idx] 134 | }) 135 | 136 | if scores.max() > self.best_score: 137 | self.best_score = scores.max() 138 | self.best_params = samples[scores.argmax()] 139 | self.best_exp = exp_ref[scores.argmax()] 140 | 141 | if not opt.multiobjective: 142 | # Single objective optimization 143 | 144 | # Normalize scores 145 | scores /= scores.std() + 1e-4 146 | scores = np.stack([scores[:opt.num_deltas], scores[opt.num_deltas:]], 1) 147 | 148 | # Get best deltas 149 | max_rewards = np.max(scores, axis=1) 150 | best_idxs = np.argsort(-max_rewards)[:opt.num_to_use] 151 | rewards = scores[best_idxs] 152 | tmp_deltas = np.array(deltas)[best_idxs] 153 | 154 | # Calculate weighted sum 155 | reward_diff = rewards[:, 0] - rewards[:, 1] 156 | result = -np.dot(reward_diff, tmp_deltas) / reward_diff.size 157 | 158 | else: 159 | # Multi-objective optimization 160 | 161 | # Normalize scores 162 | task_scores /= task_scores.std(0, keepdims=True) + 1e-4 163 | task_scores = np.stack( 164 | [task_scores[:opt.num_deltas], task_scores[opt.num_deltas:]], 165 | 2).transpose(1, 0, 2) 166 | 167 | # Per-task optimization 168 | task_results = [] 169 | 170 | for task_idx in range(self.num_tasks): 171 | tmp_scores = task_scores[task_idx] 172 | 173 | # Get best deltas 174 | max_rewards = np.max(tmp_scores, axis=1) 175 | best_idxs = np.argsort(-max_rewards)[:opt.num_to_use] 176 | rewards = tmp_scores[best_idxs] 177 | tmp_deltas = np.array(deltas)[best_idxs] 178 | 179 | # Calculate weighted sum 180 | reward_diff = rewards[:, 0] - rewards[:, 1] 181 | result = -np.dot(reward_diff, tmp_deltas) / reward_diff.size 182 | result *= self.task_grad_masks[task_idx] 183 | task_results += [result] 184 | 185 | result = np.stack(task_results, 0).sum(0) 186 | result /= np.maximum(self.task_grad_masks.sum(0), 1) 187 | 188 | result = torch.Tensor(result) 189 | 190 | # Calculate regularization (only for diagonal terms in forward matrix) 191 | lmda_mat = torch.unsqueeze( 192 | torch.stack([ 193 | torch.eye(self.num_tasks), 194 | torch.zeros(self.num_tasks, self.num_tasks) 195 | ]), 0) 196 | lmda_mat = lmda_mat.repeat(opt.num_unique, 1, 1, 1).view(-1) 197 | lmda_mat *= opt.diag_weight_decay 198 | if opt.do_weight_reg == 1: 199 | # Do L1 regularization 200 | result += lmda_mat * torch.sign(curr_params.data.view(lmda_mat.shape)) 201 | elif opt.do_weight_reg == 2: 202 | # Do L2 regularization 203 | result += lmda_mat * curr_params.data.view(lmda_mat.shape) 204 | 205 | curr_params.grad = result.view(curr_params.shape) 206 | 207 | # Update parameters 208 | self.optim.step() 209 | curr_params.data = curr_params.data.clamp(*self.ref_params.valid_range) 210 | self.curr_params.update_params(opt.search, curr_params) 211 | 212 | # Save checkpoint 213 | self.save(self.exp_dir + '/snapshot') 214 | -------------------------------------------------------------------------------- /meta/optim/grid.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | from mtl.meta.optim import random_search 17 | import numpy as np 18 | import torch 19 | 20 | 21 | def setup_extra_args(parser): 22 | parser.add_argument('--n_steps', type=str, default='4') 23 | 24 | 25 | class MetaOptimizer(random_search.MetaOptimizer): 26 | 27 | def setup_samples(self): 28 | opt = self.opt 29 | 30 | if len(opt.n_steps) == 1: 31 | n_steps = [int(opt.n_steps)] * self.n_params 32 | else: 33 | n_steps = list(map(int, opt.n_steps.split('-'))) 34 | 35 | lspaces = [np.linspace(0, 1, n) for n in n_steps] 36 | all_params = np.meshgrid(*lspaces) 37 | all_params = [p.flatten() for p in all_params] 38 | all_params = np.stack(all_params, 1) 39 | 40 | all_samples = [] 41 | for p in all_params: 42 | sample = self.init_sample() 43 | sample.update_params(opt.search, torch.Tensor(p)) 44 | all_samples += [sample] 45 | 46 | return all_samples 47 | -------------------------------------------------------------------------------- /meta/optim/optimizer.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """ Base metaoptimizer class """ 17 | import importlib 18 | import os 19 | import subprocess 20 | 21 | from mtl.config.opts.parser import mkdir_p 22 | from mtl.util import session_manager 23 | import numpy as np 24 | import torch 25 | 26 | 27 | class MetaOptimizer(session_manager.SessionManager): 28 | 29 | def __init__(self, opt, ds=None, dataloaders=None): 30 | super().__init__(opt) 31 | 32 | self.cfg = importlib.import_module('mtl.config.' + opt.config) 33 | self.optim = None 34 | self.curr_params = None 35 | self.results = [] 36 | self.exp_count = 0 37 | self.worst_score = opt.worst_score 38 | self.best_score = self.worst_score 39 | self.best_exp = None 40 | self.best_params = None 41 | self.extra_child_args = [] 42 | 43 | self.parameter_setup() 44 | self.restore(opt.restore_session) 45 | 46 | def checkpoint_ref_setup(self): 47 | super().checkpoint_ref_setup() 48 | self.checkpoint_ref['best'] = ['best_exp', 'best_score', 'best_params'] 49 | self.checkpoint_ref['results'] = ['results', 'exp_count'] 50 | 51 | def parameter_setup(self): 52 | opt = self.opt 53 | if opt.param: 54 | param = importlib.import_module('mtl.meta.param.' + opt.param) 55 | self.param = param.Metaparam 56 | opt.search = opt.search.split('-') 57 | 58 | test_sample = self.init_sample() 59 | self.curr_params = test_sample 60 | self.best_params = test_sample 61 | self.n_params = test_sample.get_params(opt.search).nelement() 62 | 63 | else: 64 | # Not optimizing any parameters 65 | self.param = None 66 | self.n_params = 0 67 | 68 | def init_sample(self): 69 | return self.param(self.opt) 70 | 71 | def copy_sample(self, p, k=None): 72 | p_new = self.init_sample() 73 | p_new.update_params(k, p.get_params(k)) 74 | return p_new 75 | 76 | def submit_cmd(self, 77 | cmd_queue, 78 | exp_id, 79 | param=None, 80 | extra_args=None, 81 | extra_child_args=[], 82 | worker_mode=None): 83 | if worker_mode is None: 84 | worker_mode = self.opt.worker_mode 85 | sub_exp_id = '%s/%s' % (self.opt.exp_id, exp_id) 86 | extra_child_args = self.extra_child_args + extra_child_args 87 | 88 | if worker_mode == 'cmd': 89 | tmp_cmd = ['python', 'main.py', '-e', sub_exp_id] + self.base_cmd 90 | if extra_args is not None: 91 | tmp_cmd += extra_args 92 | 93 | # Initialize sub experiment directory 94 | tmp_dir = self.exp_root_dir + '/' + sub_exp_id 95 | mkdir_p(tmp_dir) 96 | 97 | if param is not None: 98 | if param.is_command: 99 | # Add extra arguments to new experiment 100 | param_file = '%s/params.txt' % self.exp_dir 101 | param_cmd = param.arg.get_cmd() 102 | with open(param_file, 'a') as f: 103 | f.write('%d %s\n' % (self.exp_count, ' '.join(param_cmd))) 104 | tmp_cmd += param_cmd 105 | 106 | else: 107 | # Point new experiment to metaparameters to load 108 | param_dir = '%s/params/%d/%d' % (self.opt.exp_id, 109 | self.exp_count // 100, 110 | self.exp_count % 100) 111 | param_path = '%s/%s' % (self.exp_root_dir, param_dir) 112 | mkdir_p(param_path) 113 | torch.save({'metaparams': param.state_dict()}, 114 | '%s/snapshot_meta' % param_path) 115 | tmp_cmd += ['--metaparam_load', param_dir] 116 | 117 | elif worker_mode == 'debug': 118 | tmp_cmd = [sub_exp_id, param.state_dict()] 119 | 120 | else: 121 | raise ValueError('Undefined worker mode: %s' % worker_mode) 122 | 123 | cmd_queue.put((self.exp_count, worker_mode, tmp_cmd, extra_child_args)) 124 | self.exp_count += 1 125 | 126 | def collect_batch_results(self, n, result_queue): 127 | opt = self.opt 128 | scores = np.zeros(n) 129 | result_idx_ref = np.zeros(n, int) 130 | 131 | for i in range(n): 132 | result = result_queue.get() 133 | self.results += [result] 134 | sample_id = int(result[0].split('/')[-1]) 135 | result_idx_ref[sample_id] = len(self.results) - 1 136 | if result[1] is not None: 137 | scores[sample_id] = result[1]['score'] 138 | else: 139 | scores[sample_id] = self.worst_score 140 | 141 | if not opt.maximize: 142 | scores = -scores 143 | return result_idx_ref, scores 144 | 145 | def run(self, cmd_queue, result_queue): 146 | return 147 | -------------------------------------------------------------------------------- /meta/optim/random_search.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | import numpy as np 17 | import subprocess 18 | import torch 19 | import argparse 20 | 21 | from mtl.meta.optim import optimizer 22 | 23 | 24 | class MetaOptimizer(optimizer.MetaOptimizer): 25 | 26 | def __init__(self, opt): 27 | self.exps_done = None 28 | self.exp_scores = None 29 | super().__init__(opt) 30 | 31 | def checkpoint_ref_setup(self): 32 | super().checkpoint_ref_setup() 33 | self.checkpoint_ref['extra'] = ['score', 'exps_done', 'exp_scores'] 34 | 35 | def setup_samples(self): 36 | all_samples = [] 37 | for i in range(self.opt.num_samples): 38 | all_samples += [self.init_sample()] 39 | return all_samples 40 | 41 | def run(self, cmd_queue, result_queue): 42 | opt = self.opt 43 | self.base_cmd = self.cfg.base_cmd 44 | 45 | samples = self.setup_samples() 46 | n_exps = len(samples) 47 | print(n_exps, 'parameterizations to test') 48 | 49 | for i, s in enumerate(samples): 50 | self.submit_cmd(cmd_queue, str(i), param=s) 51 | 52 | self.exp_scores = np.zeros(n_exps) 53 | self.exps_done = np.zeros(n_exps) 54 | 55 | for i in range(n_exps): 56 | result = result_queue.get() 57 | self.results += [result] 58 | sample_id = int(result[0].split('/')[-1]) 59 | self.exps_done[sample_id] = 1 60 | if result[1] is not None: 61 | self.exp_scores[sample_id] = result[1]['score'] 62 | else: 63 | self.exp_scores[sample_id] = self.worst_score 64 | 65 | self.best_score = self.exp_scores.max() 66 | self.best_exp = self.exp_scores.argmax() 67 | 68 | if n_exps < 100 or i % 50 == 0: 69 | self.save(self.exp_dir + '/snapshot') 70 | 71 | self.save(self.exp_dir + '/snapshot') 72 | -------------------------------------------------------------------------------- /meta/optim/runner.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | import importlib 17 | 18 | from mtl.meta.optim import optimizer 19 | import numpy as np 20 | 21 | 22 | class MetaOptimizer(optimizer.MetaOptimizer): 23 | 24 | def __init__(self, opt): 25 | self.cfg = importlib.import_module('mtl.config.' + opt.config) 26 | self.num_exps = len(self.cfg.exp_list) 27 | self.exps_done = np.zeros((opt.num_samples, self.num_exps)) 28 | self.exp_scores = np.zeros_like(self.exps_done) 29 | super().__init__(opt) 30 | 31 | def checkpoint_ref_setup(self): 32 | super().checkpoint_ref_setup() 33 | self.checkpoint_ref['extra'] = ['score', 'exps_done', 'exp_scores'] 34 | 35 | def run(self, cmd_queue, result_queue): 36 | opt = self.opt 37 | self.base_cmd = self.cfg.base_cmd 38 | 39 | # Set up any additional arguments to tack on to experiments 40 | extra_args = [] 41 | unparsed = opt.unparsed 42 | if '--' in unparsed: 43 | extra_args = unparsed[unparsed.index('--') + 1:] 44 | self.extra_child_args = extra_args 45 | 46 | print('Number of experiments:', self.num_exps) 47 | print('Number of trials:', opt.num_samples) 48 | print('Extra args:', extra_args) 49 | 50 | # Submit all experiments that haven't been run 51 | for trial_idx in range(opt.num_samples): 52 | for exp_idx, e in enumerate(self.cfg.exp_list): 53 | if not self.exps_done[trial_idx, exp_idx]: 54 | exp_id = 'trial_%d/%s' % (trial_idx, self.cfg.exp_names[exp_idx]) 55 | if '--' in e: 56 | extra_child_args = e[e.index('--'):] 57 | e = e[:e.index('--')] 58 | else: 59 | extra_child_args = [] 60 | 61 | self.submit_cmd( 62 | cmd_queue, 63 | exp_id, 64 | extra_args=e, 65 | extra_child_args=extra_child_args) 66 | 67 | # Collect results 68 | while self.exps_done.sum() != opt.num_samples * self.num_exps: 69 | result = result_queue.get() 70 | self.results += [result] 71 | 72 | exp_id = result[0].split('/') 73 | for tmp_idx, val in enumerate(exp_id): 74 | if 'trial_' in val: 75 | trial_idx = int(val.split('_')[-1]) 76 | ref_idx = tmp_idx + 1 77 | exp_idx = self.cfg.exp_names.index('/'.join(exp_id[ref_idx:])) 78 | 79 | score = result[1]['score'] 80 | self.exp_scores[trial_idx, exp_idx] = score 81 | self.exps_done[trial_idx, exp_idx] = 1 82 | 83 | if score > self.best_score: 84 | self.best_score = score 85 | self.best_exp = result[0] 86 | 87 | self.exp_count += 1 88 | 89 | print('Collected %s with score %.2f' % (result[0], score)) 90 | 91 | if opt.num_samples < 100 or self.exp_count % 20 == 0: 92 | self.save(self.exp_dir + '/snapshot') 93 | 94 | self.save(self.exp_dir + '/snapshot') 95 | -------------------------------------------------------------------------------- /meta/param/cmd.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Parameter management of command line options (for hyperparam tuning).""" 17 | import os 18 | 19 | from mtl.meta.param import manager 20 | from mtl.util import calc 21 | import torch 22 | 23 | 24 | curr_dir = os.path.dirname(__file__) 25 | config_dir = os.path.join(curr_dir, '../../config/cmd') 26 | 27 | 28 | def setup_extra_args(parser): 29 | parser.add_argument('--param_init', type=str, default='random') 30 | 31 | 32 | class Metaparam(manager.Metaparam): 33 | 34 | def __init__(self, opt, config_fn=None): 35 | super().__init__() 36 | 37 | if config_fn is None: 38 | config_fn = opt.cmd_config 39 | config_file = '%s/%s.txt' % (config_dir, config_fn) 40 | args = [] 41 | with open(config_file) as f: 42 | for line in f: 43 | # arg name, default, min, max, log/linear 44 | vals = line[:-1].split(' ') 45 | for i in range(1, 4): 46 | vals[i] = float(vals[i]) 47 | vals[-1] = int(vals[-1]) 48 | args += [vals] 49 | 50 | self.arg = ArgManager(args, opt.param_init) 51 | self.is_command = True 52 | 53 | 54 | class ArgManager(manager.ParamManager): 55 | 56 | def __new__(cls, args, param_init='default'): 57 | return super().__new__(cls, shape=[len(args)]) 58 | 59 | def __init__(self, args, param_init='default'): 60 | super().__init__() 61 | self.arg_ref = args 62 | self.valid_range = [0, 1] 63 | self.data = self.get_default() 64 | if param_init == 'random': 65 | self.set_to_random() 66 | 67 | def get_default(self): 68 | vals = torch.zeros(self.shape) 69 | for i, ref in enumerate(self.arg_ref): 70 | vals[i] = calc.map_val(*ref[1:-1], invert=True) 71 | return vals 72 | 73 | def get_cmd(self): 74 | tmp_cmd = [] 75 | for i, ref in enumerate(self.arg_ref): 76 | tmp_cmd += [ref[0]] 77 | val = calc.map_val(self.data[i], *ref[2:-1]) 78 | if ref[-1]: 79 | tmp_cmd += ['%d' % int(val)] 80 | else: 81 | tmp_cmd += ['%.3g' % val] 82 | return tmp_cmd 83 | -------------------------------------------------------------------------------- /meta/param/manager.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Class to manage metaparameters.""" 17 | 18 | from copy import deepcopy 19 | 20 | from mtl.util import calc 21 | import numpy as np 22 | import torch 23 | import torch.nn 24 | 25 | 26 | class ParamManager(torch.nn.Parameter): 27 | 28 | def __new__(cls, shape=None): 29 | data = None if shape is None else torch.Tensor(*shape) 30 | return super().__new__(cls, data=data, requires_grad=False) 31 | 32 | def __init__(self): 33 | super().__init__() 34 | self.is_continuous = True 35 | self.valid_range = None 36 | 37 | def get_default(self): 38 | return None 39 | 40 | def random_sample(self): 41 | # Sample within defined range with appropriate scaling 42 | vals = torch.rand(self.shape) 43 | vals = calc.map_val(vals, *self.valid_range) 44 | # Discretize if necessesary 45 | if not self.is_continuous: 46 | vals = vals.astype(int) 47 | 48 | return vals 49 | 50 | def set_to_default(self): 51 | self.data[:] = self.get_default() 52 | 53 | def set_to_random(self): 54 | self.data[:] = self.random_sample() 55 | 56 | def mutate(self, delta=.1): 57 | return None 58 | 59 | def copy(self): 60 | return deepcopy(self) 61 | 62 | 63 | class Metaparam(torch.nn.Module): 64 | 65 | def __init__(self): 66 | super().__init__() 67 | self.is_command = False 68 | self.is_parameterized = None 69 | self.model = None 70 | 71 | def parameters(self, keys=None): 72 | for name, param in self.named_parameters(): 73 | tmp_name = name.split('.')[0] 74 | if keys is None or tmp_name in keys: 75 | yield param 76 | 77 | def get_params(self, keys): 78 | return torch.nn.utils.parameters_to_vector(self.parameters(keys)).detach() 79 | 80 | def update_params(self, keys, data): 81 | torch.nn.utils.vector_to_parameters(data, self.parameters(keys)) 82 | 83 | def copy_from(self, src, keys): 84 | self.update_params(keys, src.get_params(keys)) 85 | 86 | def parameterize(self, model, search, inp_size): 87 | # Loop through all keys, get data shape size 88 | data_ref = [self._parameters[k].data for k in search] 89 | dim_ref = [d.shape for d in data_ref] 90 | # Determine output size (flattened/concatted data) 91 | out_size = int(sum([np.prod(d) for d in dim_ref])) 92 | # Initialize model 93 | self.model = model(inp_size, out_size) 94 | self.is_parameterized = search 95 | self.is_command = False 96 | 97 | def reparameterize(self, x): 98 | # Update metaparameters after forward call of model 99 | new_params = self.model(x) 100 | self.update_params(self.is_parameterized, new_params) 101 | -------------------------------------------------------------------------------- /meta/param/partition.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Feature partitioning parameter management.""" 17 | from mtl.meta.param import manager 18 | from mtl.util import calc 19 | import numpy as np 20 | import torch 21 | 22 | 23 | def setup_extra_args(parser): 24 | parser.add_argument('--param_init', type=str, default='random') 25 | parser.add_argument('--rand_type', type=str, default='restrict') 26 | parser.add_argument('--restrict_amt', type=float, default=.5) 27 | parser.add_argument('--num_unique', type=int, default=1) 28 | parser.add_argument('--mask_gradient', type=int, default=1) 29 | parser.add_argument('--share_amt', type=float, default=.75) 30 | 31 | 32 | def prepare_masks(p, num_feats): 33 | # Prepare masks 34 | num_tasks = p.num_tasks 35 | num_unique = p.num_unique 36 | mask_ref = [[{} for _ in range(num_unique)] for _ in range(num_tasks)] 37 | grad_mask_ref = [[{} for _ in range(num_unique)] for _ in range(num_tasks)] 38 | 39 | for t in range(num_tasks): 40 | for u in range(num_unique): 41 | for f in num_feats: 42 | # Forward mask 43 | tmp_mask = p.get_task_mask(t, f, m_idx=0, p_idx=u) 44 | mask_ref[t][u][f] = torch.Tensor(tmp_mask).view(1, -1, 1, 1) 45 | 46 | # Backward mask 47 | tmp_grad_mask = p.get_task_mask(t, f, m_idx=1, p_idx=u) 48 | grad_mask_ref[t][u][f] = torch.Tensor(tmp_grad_mask).view(-1, 1, 1, 1) 49 | 50 | return mask_ref, grad_mask_ref 51 | 52 | 53 | def masked_cnv(ref, cnv, unq_idx=0): 54 | num_f = cnv.out_channels 55 | return lambda x, y: cnv(x) * ref.mask_ref[y][unq_idx][num_f].cuda() 56 | 57 | 58 | class Metaparam(manager.Metaparam): 59 | 60 | def __init__(self, opt): 61 | super().__init__() 62 | self.partition = PartitionManager(opt) 63 | 64 | 65 | class PartitionManager(manager.ParamManager): 66 | 67 | def __new__(cls, opt): 68 | return super().__new__(cls, shape=None) 69 | 70 | def __init__(self, opt): 71 | super().__init__() 72 | self.param_init = opt.param_init 73 | self.rand_type = opt.rand_type 74 | self.restrict_amt = opt.restrict_amt 75 | self.share_amt = opt.share_amt 76 | self.valid_range = [0, 1] 77 | self.num_tasks = len(opt.task_choice.split('-')) 78 | self.num_unique = opt.num_unique 79 | shape = [self.num_unique, 2, self.num_tasks, self.num_tasks] 80 | self.data = torch.Tensor(*shape) 81 | self.data = self.get_default() 82 | self.mask = {} 83 | self.set_partition(opt.param_init) 84 | self.ref_triu_idxs = np.triu(np.ones(shape[1:])) != 0 85 | 86 | def set_partition(self, share_type=None): 87 | if share_type is None: 88 | share_type = self.param_init 89 | tmp_mat = None 90 | n = self.num_tasks 91 | ones = torch.ones((2, n, n)) 92 | eye = torch.eye(n) 93 | 94 | if share_type == 'share_all': 95 | tmp_mat = ones 96 | 97 | elif share_type == 'share_fixed': 98 | share_pct = self.share_amt 99 | indiv_pct = share_pct + (1 - share_pct) / n 100 | share_val = (n - 2) / (n - 1) 101 | tmp_mat = ones * share_val 102 | tmp_mat[0] += eye * (indiv_pct - share_val) 103 | tmp_mat[1] = 1 104 | 105 | elif share_type == 'share_fwd_only': 106 | tmp_mat = ones 107 | tmp_mat[1] = eye * (1. / n) 108 | 109 | elif share_type == 'independent': 110 | tmp_mat = ones 111 | tmp_mat[0] = eye * (1. / n) 112 | tmp_mat[1] = eye 113 | 114 | elif share_type == 'random': 115 | self.set_to_random() 116 | 117 | else: 118 | raise ValueError('Undefined partition setting: %s' % share_type) 119 | 120 | if tmp_mat is not None: 121 | for i in range(self.num_unique): 122 | self.data[i] = tmp_mat 123 | 124 | def ignore_backward_mask(self): 125 | for d in self.data: 126 | d[1].fill_(1) 127 | 128 | def get_default(self): 129 | return torch.ones(self.shape) 130 | 131 | def set_to_random(self): 132 | super().set_to_random() 133 | self.reset_masks() 134 | 135 | def random_sample(self, rand_type=None): 136 | if rand_type is None: 137 | rand_type = self.rand_type 138 | tmp_mat = torch.rand(self.shape) 139 | n = self.num_tasks 140 | eye = torch.eye(n) 141 | 142 | if rand_type == 'restrict': 143 | mutate = torch.randn(n, n) * .2 144 | for m in tmp_mat: 145 | m[0] = eye * (self.restrict_amt + mutate) + (1 - eye) * m[0] 146 | m[1] = eye * (1 + torch.randn(n, n) * .15) + (1 - eye) * m[1] 147 | tmp_mat = tmp_mat.clamp(.05, 1) 148 | 149 | return tmp_mat 150 | 151 | def preprocess_mat(self, mat, diag=None): 152 | # Preserve diagonal and make matrix symmetric 153 | if diag is None: 154 | diag = mat.diag() 155 | mat = torch.triu(mat, 1) 156 | mat = diag.diag() + mat + mat.t() 157 | 158 | return mat.clamp(0, 1) 159 | 160 | def find_masks(self, p=None, n_feats=100, n_iters=100, p_idx=0): 161 | # Convert raw parameterization 162 | if p is None: 163 | p = self.data[p_idx] 164 | 165 | if self.param_init == 'share_fixed': 166 | amt_per_task = int( 167 | np.floor(n_feats * (1 - self.share_amt) / self.num_tasks)) 168 | amt_shared = n_feats - (amt_per_task * self.num_tasks) 169 | masks = [] 170 | for i in range(self.num_tasks): 171 | tmp_mask = np.zeros(n_feats) 172 | tmp_mask[i * amt_per_task:(i + 1) * amt_per_task] = 1 173 | tmp_mask[-amt_shared:] = 1 174 | masks += [tmp_mask] 175 | mask_f = np.stack(masks) 176 | mask_b = mask_f 177 | return np.stack([mask_f, mask_b], 0) 178 | 179 | else: 180 | return calc.find_masks(p, n_feats, n_iters) 181 | 182 | def init_mask(self, n_feats=100, n_iters=100): 183 | if n_feats not in self.mask: 184 | self.mask[n_feats] = [ 185 | self.find_masks(n_feats=n_feats, n_iters=n_iters, p_idx=p_idx) 186 | for p_idx in range(self.num_unique) 187 | ] 188 | self.mask[n_feats] = [m.copy() for m in self.mask[n_feats]] 189 | 190 | def get_task_mask(self, task_idx, n_feats, p_idx=0, m_idx=0): 191 | self.init_mask(n_feats) 192 | return self.mask[n_feats][p_idx][m_idx][task_idx] 193 | 194 | def reset_masks(self): 195 | self.mask = {} 196 | 197 | def mutate(self, data=None, delta=.1): 198 | v_min, v_max = self.valid_range 199 | update_self = data is None 200 | if update_self: 201 | data = self.data 202 | 203 | if isinstance(data, torch.Tensor): 204 | data = data.clone() 205 | data += torch.randn(data.shape) * delta 206 | data = data.clamp(v_min, v_max) 207 | else: 208 | data = data.copy() 209 | data += np.random.randn(*data.shape) * delta 210 | data = data.clip(v_min, v_max) 211 | 212 | if update_self: 213 | self.data = data 214 | return data 215 | 216 | def get_task_parameter_mask(self, task_idx): 217 | # Return a binary mask indicating which parameters are directly 218 | # associated with a particular task. 219 | v = np.zeros_like(self.ref_triu_idxs) 220 | v[:, :, task_idx] = True 221 | v[:, task_idx, :] = True 222 | v = (v * self.ref_triu_idxs).astype(int) 223 | tmp_p = np.expand_dims(v, 0).repeat(self.num_unique, 0) 224 | 225 | return tmp_p.flatten() 226 | -------------------------------------------------------------------------------- /models/distill.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | import copy 17 | 18 | from mtl.models import masked_resnet 19 | from mtl.models import resnet 20 | import torch 21 | 22 | 23 | def setup_extra_args(parser): 24 | masked_resnet.setup_extra_args(parser) 25 | parser.add_argument('--last_layer_idx', type=int, default=1) 26 | parser.add_argument('--num_distill', type=int, default=2) 27 | parser.add_argument('--finetune_fc', type=int, default=0) 28 | parser.set_defaults(last_third_only=1) 29 | 30 | 31 | class RefNet(resnet.Net): 32 | 33 | def __init__(self, opt, ds, sub_task_choice=None): 34 | super().__init__(opt, ds, sub_task_choice=None) 35 | self.last_layer_idx = opt.last_layer_idx 36 | 37 | def forward(self, x): 38 | x = self.resnet.pre_layers_conv(x) 39 | x = self.resnet.layer1(x) 40 | x = self.resnet.layer2(x) 41 | for i in range(self.last_layer_idx): 42 | x = self.resnet.layer3[i](x) 43 | 44 | return x 45 | 46 | 47 | class Net(masked_resnet.Net): 48 | 49 | def __init__(self, opt, ds, metaparams=None, masks=None): 50 | super().__init__(opt, ds, metaparams, masks) 51 | 52 | # Layer distillation options 53 | self.last_layer_idx = opt.last_layer_idx 54 | self.num_distill = opt.num_distill 55 | self.finetune_fc = opt.finetune_fc 56 | 57 | # Reference performance for each task 58 | task_idxs = list(map(int, opt.task_choice.split('-'))) 59 | self.task_low = [0 for i in task_idxs] 60 | self.task_high = [ 61 | [.63, .55, .80, 1, .51, 1, .85, .89, .96, .85][i] for i in task_idxs 62 | ] 63 | 64 | # Load pretrained models 65 | self.ref_models = [] 66 | tmp_opt = copy.deepcopy(opt) 67 | tmp_opt.imagenet_pretrained = 0 68 | 69 | for i, task_idx in enumerate(task_idxs): 70 | r = RefNet(tmp_opt, ds, sub_task_choice=[i]) 71 | pretrained = torch.load( 72 | '%s/finetuned/%d/snapshot_model' % (opt.exp_root_dir, task_idx)) 73 | r.load_state_dict(pretrained['model'], strict=False) 74 | r.cuda() 75 | r.eval() 76 | for p in r.parameters(): 77 | p.requires_grad = False 78 | self.ref_models += [r] 79 | 80 | # If finetuning fc layers, copy fc weights from reference 81 | if opt.finetune_fc: 82 | for i in range(self.num_tasks): 83 | fc_name = 'out_%s' % self.task_names[i] 84 | fc_ref = self.ref_models[i]._modules[fc_name] 85 | self._modules[fc_name].load_state_dict(fc_ref.state_dict()) 86 | 87 | def forward(self, x, task_idx, split, global_step): 88 | begin_idx = self.last_layer_idx 89 | end_idx = begin_idx + self.num_distill 90 | 91 | # Initial pass through task-specific resnet 92 | x = self.ref_models[task_idx](x) 93 | ref_feats = [x] 94 | matched_feats = [x] 95 | 96 | # Pass through teacher and shared layers 97 | ref_layer = self.ref_models[task_idx].resnet.layer3 98 | ref_end_bn = self.ref_models[task_idx].resnet.end_bns[0] 99 | shared_layer = self.resnet.layer3 100 | shared_end_bn = self.resnet.end_bns[task_idx] 101 | 102 | for l_idx in range(begin_idx, end_idx): 103 | ref_feats += [ref_layer[l_idx](ref_feats[-1], 0)] 104 | matched_feats += [shared_layer[l_idx](matched_feats[-1], task_idx)] 105 | if l_idx == 3: 106 | ref_feats[-1] = ref_end_bn(ref_feats[-1]) 107 | matched_feats[-1] = shared_end_bn(matched_feats[-1]) 108 | 109 | # Run through rest of pre-trained network 110 | x = matched_feats[-1] 111 | for l_idx in range(end_idx, 4): 112 | x = ref_layer[l_idx](x, 0) 113 | if end_idx < 4: 114 | x = ref_end_bn(x) 115 | 116 | x = self.resnet.avgpool(x) 117 | x = x.view(x.size(0), -1) 118 | 119 | # Final fully connected layer 120 | fc_name = 'out_%s' % self.task_names[task_idx] 121 | if not self.finetune_fc: 122 | final_fc = self.ref_models[task_idx]._modules[fc_name] 123 | else: 124 | final_fc = self._modules[fc_name] 125 | 126 | x = final_fc(x) 127 | 128 | ref_feats = torch.cat(ref_feats[1:], 1) 129 | matched_feats = torch.cat(matched_feats[1:], 1) 130 | 131 | return ref_feats, matched_feats, x 132 | 133 | 134 | def initialize(*args, **kargs): 135 | return Net(*args, **kargs) 136 | -------------------------------------------------------------------------------- /models/masked_resnet.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | from mtl.meta.param import partition 17 | from mtl.models import resnet 18 | import numpy as np 19 | 20 | 21 | def setup_extra_args(parser): 22 | resnet.setup_extra_args(parser) 23 | partition.setup_extra_args(parser) 24 | parser.set_defaults(metaparam='partition', last_third_only=1) 25 | 26 | 27 | class Net(resnet.Net): 28 | """ResNet that supports masking for feature partitioning.""" 29 | 30 | def __init__(self, opt, ds, metaparams=None, masks=None): 31 | super(Net, self).__init__(opt, ds) 32 | 33 | self.meta = [metaparams] 34 | if self.meta[0] is None: 35 | self.meta = [partition.Metaparam(opt)] 36 | share = self.meta[0].partition 37 | self.num_unique = share.num_unique 38 | self.bw_ref = [] 39 | 40 | layer_ref = self.resnet.get_layer_ref() 41 | num_feats = self.resnet.num_feat_ref 42 | if opt.bottleneck_ratio != 1: 43 | num_feats += [int(f * opt.bottleneck_ratio) for f in num_feats] 44 | 45 | if masks is None: 46 | # Prepare all masks 47 | self.resnet.mask_ref, self.grad_mask_ref = partition.prepare_masks( 48 | share, num_feats) 49 | else: 50 | self.resnet.mask_ref, self.grad_mask_ref = masks 51 | 52 | repeat_rate = int(np.ceil(len(layer_ref) / self.num_unique)) 53 | unq_idx_ref = [i // repeat_rate for i in range(len(layer_ref))] 54 | 55 | # Convert all layers 56 | for l_idx, l in enumerate(layer_ref): 57 | unq_idx = unq_idx_ref[l_idx] 58 | tmp_m = self.resnet.get_module(l) 59 | l_name = l[-1] 60 | cnv = tmp_m._modules[l_name] 61 | 62 | if 'conv1' in l and (not opt.last_third_only or 'layer3' in l): 63 | # Apply mask to first convolution in ResBlock 64 | tmp_m.conv = partition.masked_cnv(self.resnet, cnv, unq_idx) 65 | 66 | # Save a reference for doing gradient masking 67 | bw_masks = [ 68 | self.grad_mask_ref[i][unq_idx][cnv.out_channels] 69 | for i in range(self.num_tasks) 70 | ] 71 | self.bw_ref += [[cnv, bw_masks]] 72 | 73 | 74 | def initialize(*args, **kargs): 75 | return Net(*args, **kargs) 76 | -------------------------------------------------------------------------------- /models/resnet.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | import math 17 | 18 | import torch 19 | import torch.nn as nn 20 | import torch.nn.functional as F 21 | 22 | 23 | def setup_extra_args(parser): 24 | parser.add_argument('--imagenet_pretrained', type=int, default=1) 25 | parser.add_argument('--separate_bn', type=int, default=1) 26 | parser.add_argument('--num_blocks', type=int, default=4) 27 | parser.add_argument('--bottleneck_ratio', type=float, default=1.) 28 | parser.add_argument('--network_size_factor', type=float, default=1.) 29 | parser.add_argument('--last_third_only', type=int, default=0) 30 | 31 | 32 | class ConvBN(nn.Module): 33 | """Combined convolution/batchnorm, with support for separate normalization. 34 | 35 | Using a wrapper for the convolution operation makes it easier to 36 | add task-conditioned auxiliary functions to augment intermediate 37 | activations as we will be doing with the masks. This will make it 38 | easier to load and exchange weights across models whether or not 39 | they were trained with masks. 40 | """ 41 | 42 | def __init__(self, in_channels, out_channels, stride=1, num_tasks=1): 43 | super(ConvBN, self).__init__() 44 | 45 | self.conv_aux = nn.Conv2d( 46 | in_channels, out_channels, 3, stride, padding=1, bias=False) 47 | self.conv = self.identity_fn(self.conv_aux) 48 | self.bns = nn.ModuleList( 49 | [nn.BatchNorm2d(out_channels) for i in range(num_tasks)]) 50 | 51 | def identity_fn(self, cnv): 52 | return lambda x, y: cnv(x) 53 | 54 | def forward(self, x, task_idx=0): 55 | return self.bns[task_idx](self.conv(x, task_idx)) 56 | 57 | 58 | class ResBlock(nn.Module): 59 | 60 | def __init__(self, 61 | in_channels, 62 | out_channels, 63 | stride=1, 64 | shortcut=0, 65 | num_tasks=1, 66 | bottleneck=1): 67 | super(ResBlock, self).__init__() 68 | 69 | f = int(out_channels * bottleneck) 70 | self.conv1 = ConvBN(in_channels, f, stride, num_tasks) 71 | self.conv2 = ConvBN(f, out_channels, 1, num_tasks) 72 | 73 | self.shortcut = shortcut 74 | if shortcut: 75 | self.avgpool = nn.AvgPool2d(2) 76 | 77 | def forward(self, x, task_idx=0): 78 | y = self.conv1(x, task_idx) 79 | y = self.conv2(F.relu(y, inplace=True), task_idx) 80 | 81 | if self.shortcut: 82 | x = self.avgpool(x) 83 | x = torch.cat((x, x * 0), 1) 84 | 85 | return F.relu(x + y) 86 | 87 | 88 | class ResNet(nn.Module): 89 | 90 | def __init__(self, size_factor=1, num_blocks=4, num_tasks=1, bottleneck=1): 91 | super(ResNet, self).__init__() 92 | f = [int(size_factor * 2**(i + 5)) for i in range(4)] 93 | 94 | self.num_feat_ref = f 95 | self.num_tasks = num_tasks 96 | self.num_blocks = num_blocks 97 | self.pre_layers_conv = ConvBN(3, f[0], 1, num_tasks) 98 | 99 | for i in range(1, 4): 100 | tmp_bottleneck = 1 if i < 3 else bottleneck 101 | layers = [ResBlock(f[i - 1], f[i], 2, 1, num_tasks, tmp_bottleneck)] 102 | for j in range(1, num_blocks): 103 | layers += [ResBlock(f[i], f[i], 1, 0, num_tasks, tmp_bottleneck)] 104 | 105 | self.add_module('layer%d' % i, nn.Sequential(*layers)) 106 | 107 | self.end_bns = nn.ModuleList([ 108 | nn.Sequential(nn.BatchNorm2d(f[-1]), nn.ReLU(True)) 109 | for i in range(num_tasks) 110 | ]) 111 | self.avgpool = nn.AdaptiveAvgPool2d(1) 112 | 113 | # Weight initialization 114 | for m in self.modules(): 115 | if isinstance(m, nn.Conv2d): 116 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 117 | m.weight.data.normal_(0, math.sqrt(2. / n)) 118 | elif isinstance(m, nn.BatchNorm2d): 119 | m.weight.data.fill_(1) 120 | m.bias.data.zero_() 121 | 122 | def get_layer_ref(self): 123 | layer_ref = [] 124 | for k in self.state_dict().keys(): 125 | if 'conv' in k and 'weight' in k and 'bns' not in k: 126 | layer_ref += [k.split('.')[:-1]] 127 | 128 | return layer_ref 129 | 130 | def get_module(self, l): 131 | tmp_m = self 132 | if len(l) > 1: 133 | tmp_m = tmp_m._modules[l[0]] 134 | if len(l) > 2: 135 | tmp_m = tmp_m[int(l[1])] 136 | if len(l) > 3: 137 | tmp_m = tmp_m._modules[l[2]] 138 | if len(l) > 4: 139 | tmp_m = tmp_m[int(l[3])] 140 | 141 | return tmp_m 142 | 143 | def forward(self, x, task_idx): 144 | x = self.pre_layers_conv(x, task_idx) 145 | for l in [self.layer1, self.layer2, self.layer3]: 146 | for m in l: 147 | x = m(x, task_idx) 148 | 149 | x = self.end_bns[task_idx](x) 150 | x = self.avgpool(x) 151 | x = x.view(x.size(0), -1) 152 | 153 | return x 154 | 155 | 156 | class Net(nn.Module): 157 | 158 | def __init__(self, opt, ds, metaparams=None, masks=None, 159 | sub_task_choice=None): 160 | super(Net, self).__init__() 161 | self.dropout = opt.dropout 162 | self.separate_bn = opt.separate_bn 163 | 164 | if sub_task_choice is None: 165 | sub_task_choice = [i for i in range(len(ds['train']))] 166 | self.num_tasks = len(sub_task_choice) 167 | self.num_out = [ds['train'][i].num_out for i in sub_task_choice] 168 | self.task_names = [ds['train'][i].task_name for i in sub_task_choice] 169 | 170 | # Initialize ResNet 171 | self.resnet = ResNet(opt.network_size_factor, opt.num_blocks, 172 | self.num_tasks if opt.separate_bn else 1, 173 | opt.bottleneck_ratio) 174 | 175 | # Final fully connected layers 176 | f = int(256 * opt.network_size_factor) 177 | self.final_n_feat = f 178 | for t in range(self.num_tasks): 179 | self.add_module('out_%s' % self.task_names[t], 180 | nn.Linear(f, self.num_out[t])) 181 | 182 | if opt.imagenet_pretrained: 183 | print('Loading pretrained Imagenet model.') 184 | imgnet_path = '%s/finetuned/0/snapshot_model' % opt.exp_root_dir 185 | pretrained = torch.load(imgnet_path) 186 | if opt.bottleneck_ratio != 1: 187 | pretrained['model'] = { 188 | k: v for k, v in pretrained['model'].items() if 'layer3' not in k 189 | } 190 | self.load_state_dict(pretrained['model'], strict=False) 191 | 192 | # Copy batchnorm weights in all layers 193 | for l in self.resnet.get_layer_ref(): 194 | tmp_m = self.resnet.get_module(l) 195 | if len(tmp_m.bns) > 1: 196 | for bn_idx in range(1, len(tmp_m.bns)): 197 | tmp_m.bns[bn_idx].load_state_dict(tmp_m.bns[0].state_dict()) 198 | 199 | if opt.last_third_only: 200 | # Do not update conv weights of first two-thirds of model (still update BN) 201 | layer_names = ['pre_layers_conv', 'layer1', 'layer2'] 202 | 203 | resnet_params = [] 204 | for l_name in layer_names: 205 | l = self.resnet._modules[l_name] 206 | for k in l.state_dict(): 207 | if 'conv_aux' in k: 208 | tmp_m = self.resnet.get_module([l_name] + k.split('.')[:-1]) 209 | tmp_m.conv_aux.weight.requires_grad = False 210 | resnet_params += [p for p in l.parameters() if len(p.shape) == 1] 211 | 212 | resnet_params += [p for p in self.resnet.layer3.parameters()] 213 | 214 | else: 215 | # Train the full model 216 | resnet_params = [p for p in self.resnet.parameters()] 217 | 218 | self.net_parameters = [] 219 | for t in range(self.num_tasks): 220 | task_out = self._modules['out_%s' % self.task_names[t]] 221 | task_params = [p for p in task_out.parameters()] 222 | self.net_parameters += [resnet_params + task_params] 223 | 224 | def forward(self, x, task_idx, split, global_step): 225 | x = self.resnet(x, task_idx if self.separate_bn else 0) 226 | if split == 'train' and self.dropout: 227 | x = F.dropout(x, self.dropout) 228 | x = self._modules['out_%s' % self.task_names[task_idx]](x) 229 | 230 | return x 231 | 232 | 233 | def initialize(*args, **kargs): 234 | return Net(*args, **kargs) 235 | -------------------------------------------------------------------------------- /third_party/cutout/LICENSE.md: -------------------------------------------------------------------------------- 1 | Educational Community License, Version 2.0 (ECL-2.0) 2 | 3 | Version 2.0, April 2007 4 | 5 | http://www.osedu.org/licenses/ 6 | 7 | The Educational Community License version 2.0 ("ECL") consists of the Apache 2.0 license, modified to change the scope of the patent grant in section 3 to be specific to the needs of the education communities using this license. The original Apache 2.0 license can be found at: http://www.apache.org/licenses /LICENSE-2.0 8 | 9 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 10 | 11 | 1. Definitions. 12 | 13 | "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. 14 | 15 | "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. 16 | 17 | "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. 18 | 19 | "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. 20 | 21 | "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. 22 | 23 | "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. 24 | 25 | "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). 26 | 27 | "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. 28 | 29 | "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." 30 | 31 | "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 32 | 33 | 2. Grant of Copyright License. 34 | 35 | Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 36 | 37 | 3. Grant of Patent License. 38 | 39 | Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. Any patent license granted hereby with respect to contributions by an individual employed by an institution or organization is limited to patent claims where the individual that is the author of the Work is also the inventor of the patent claims licensed, and where the organization or institution has the right to grant such license under applicable grant and research funding agreements. No other express or implied licenses are granted. 40 | 41 | 4. Redistribution. 42 | 43 | You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: 44 | 45 | You must give any other recipients of the Work or Derivative Works a copy of this License; and You must cause any modified files to carry prominent notices stating that You changed the files; and You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 46 | 47 | 5. Submission of Contributions. 48 | 49 | Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 50 | 51 | 6. Trademarks. 52 | 53 | This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 54 | 55 | 7. Disclaimer of Warranty. 56 | 57 | Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 58 | 59 | 8. Limitation of Liability. 60 | 61 | In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 62 | 63 | 9. Accepting Warranty or Additional Liability. 64 | 65 | While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. 66 | 67 | END OF TERMS AND CONDITIONS 68 | 69 | APPENDIX: How to apply the Educational Community License to your work 70 | 71 | To apply the Educational Community License to your work, attach 72 | the following boilerplate notice, with the fields enclosed by 73 | brackets "[]" replaced with your own identifying information. 74 | (Don't include the brackets!) The text should be enclosed in the 75 | appropriate comment syntax for the file format. We also recommend 76 | that a file or class name and description of purpose be included on 77 | the same "printed page" as the copyright notice for easier 78 | identification within third-party archives. 79 | 80 | Copyright [yyyy] [name of copyright owner] Licensed under the 81 | Educational Community License, Version 2.0 (the "License"); you may 82 | not use this file except in compliance with the License. You may 83 | obtain a copy of the License at 84 | 85 | http://www.osedu.org/licenses /ECL-2.0 86 | 87 | Unless required by applicable law or agreed to in writing, 88 | software distributed under the License is distributed on an "AS IS" 89 | BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express 90 | or implied. See the License for the specific language governing 91 | permissions and limitations under the License. 92 | -------------------------------------------------------------------------------- /third_party/cutout/cutout.py: -------------------------------------------------------------------------------- 1 | """Cutout data augmentation. 2 | 3 | From: https://github.com/uoguelph-mlrg/Cutout 4 | """ 5 | 6 | import torch 7 | import numpy as np 8 | 9 | 10 | class Cutout(object): 11 | """Randomly mask out one or more patches from an image. 12 | 13 | Args: 14 | n_holes (int): Number of patches to cut out of each image. 15 | length (int): The length (in pixels) of each square patch. 16 | """ 17 | def __init__(self, n_holes, length): 18 | self.n_holes = n_holes 19 | self.length = length 20 | 21 | def __call__(self, img): 22 | """ 23 | Args: 24 | img (Tensor): Tensor image of size (C, H, W). 25 | Returns: 26 | Tensor: Image with n_holes of dimension length x length cut out of it. 27 | """ 28 | h = img.size(1) 29 | w = img.size(2) 30 | 31 | mask = np.ones((h, w), np.float32) 32 | 33 | for n in range(self.n_holes): 34 | y = np.random.randint(h) 35 | x = np.random.randint(w) 36 | 37 | y1 = int(np.clip(y - self.length / 2, 0, h)) 38 | y2 = int(np.clip(y + self.length / 2, 0, h)) 39 | x1 = int(np.clip(x - self.length / 2, 0, w)) 40 | x2 = int(np.clip(x + self.length / 2, 0, w)) 41 | 42 | mask[y1: y2, x1: x2] = 0. 43 | 44 | mask = torch.from_numpy(mask) 45 | mask = mask.expand_as(img) 46 | img = img * mask 47 | 48 | return img 49 | -------------------------------------------------------------------------------- /train/distill.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | from mtl.train import multiclass 17 | import numpy as np 18 | import torch 19 | 20 | 21 | def setup_extra_args(parser): 22 | parser.add_argument('--dist_loss_wt', type=float, default=1) 23 | parser.add_argument('--class_loss_wt', type=float, default=0) 24 | 25 | 26 | class Task(multiclass.Task): 27 | """Distillation training manager.""" 28 | 29 | def __init__(self, opt, ds, dataloaders): 30 | super().__init__(opt, ds, dataloaders) 31 | self.ce_loss = torch.nn.CrossEntropyLoss() 32 | self.mse_loss = torch.nn.MSELoss() 33 | 34 | def run(self, split, step): 35 | opt, ds = self.opt, self.ds 36 | self.step = step 37 | self.split = split 38 | 39 | # Sample task 40 | task_idx = self.sample_task() 41 | self.task_idx = task_idx 42 | self.curr_task = ds['train'][task_idx].task_name 43 | 44 | # Get samples + model output 45 | inp, label, _ = self.get_next_sample(split, task_idx) 46 | ref_feats, pred_feats, pred = self.model(inp, task_idx, split, step) 47 | 48 | # Calculate loss 49 | _, class_preds = torch.max(pred, 1) 50 | t_min, t_max = self.model.task_low[task_idx], self.model.task_high[task_idx] 51 | accuracy = class_preds.eq(label).float().mean() 52 | accuracy = (accuracy - t_min) / (t_max - t_min) 53 | 54 | class_loss = self.ce_loss(pred, label) 55 | distill_loss = self.mse_loss(pred_feats, ref_feats.detach()) 56 | 57 | self.net_loss = 0 58 | if opt.dist_loss_wt: 59 | self.net_loss += opt.dist_loss_wt * distill_loss 60 | if opt.class_loss_wt: 61 | self.net_loss += opt.class_loss_wt * class_loss 62 | 63 | if split == 'valid': 64 | self.valid_accuracy_track[task_idx] += [accuracy.data.item()] 65 | self.update_log('accuracy', accuracy.data.item()) 66 | self.update_log('network_loss', self.net_loss.data.item()) 67 | self.score = np.array([d['valid'] for d in self.log['accuracy']]).mean() 68 | 69 | self.global_trained_steps += 1 70 | self.task_trained_steps[task_idx] += 1 71 | -------------------------------------------------------------------------------- /train/multiclass.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | import importlib 17 | 18 | from mtl.meta.param import partition 19 | from mtl.util import calc 20 | from mtl.util import session_manager 21 | from mtl.util import tensorboard_manager 22 | import numpy as np 23 | import torch 24 | 25 | 26 | class Task(session_manager.SessionManager): 27 | """Basic network training manager.""" 28 | 29 | def __init__(self, opt, ds, dataloaders): 30 | super().__init__(opt, ds, dataloaders) 31 | self.opt = opt 32 | self.ds = ds 33 | self.dataloaders = dataloaders 34 | self.is_training = True 35 | splits = ['train', 'valid'] 36 | 37 | # Task reference information 38 | self.num_tasks = len(ds['train']) 39 | self.task_idx_ref = list(map(int, opt.task_choice.split('-'))) 40 | 41 | self.targets = ds['train'][0].targets 42 | task_ds_size = [len(d) for d in ds['train']] 43 | self.task_ds_size = task_ds_size 44 | self.task_proportion = [t / sum(task_ds_size) for t in task_ds_size] 45 | 46 | self.task_num_out = [d.num_out for d in ds['train']] 47 | self.task_trained_steps = [0 for _ in range(self.num_tasks)] 48 | self.global_trained_steps = 0 49 | 50 | # Initialize running averages for logging 51 | self.log = {} 52 | for k in self.to_track: 53 | self.log[k] = [ 54 | {s: 0 for s in splits} for _ in range(self.num_tasks) 55 | ] 56 | self.log['%s_history' % k] = [ 57 | {s: [] for s in splits} for _ in range(self.num_tasks) 58 | ] 59 | 60 | self.score = 0 61 | 62 | # Set up dataset iterators 63 | self.iters = {s: [] for s in splits} 64 | for split in self.iters: 65 | for task_idx in range(self.num_tasks): 66 | self.iters[split] += [iter(self.dataloaders[split][task_idx])] 67 | 68 | # Set up metaparameters 69 | self.setup_metaparams(opt, ds) 70 | if opt.metaparam_load: 71 | del self.to_load[self.to_load.index('meta')] 72 | metaparam_path = '%s/%s/snapshot' % (opt.exp_root_dir, opt.metaparam_load) 73 | print('Loading parameters from... (%s)' % metaparam_path) 74 | self.load(metaparam_path, groups=['meta']) 75 | 76 | # Set up model and optimization and loss 77 | self.setup_model(opt, ds) 78 | self.setup_optimizer(self.model.net_parameters, True) 79 | self.loss_fn = torch.nn.CrossEntropyLoss() 80 | 81 | # Check for a fixed curriculum 82 | if opt.curriculum == 'fixed': 83 | print('Loading fixed curriculum.') 84 | self.fixed_curriculum = ds['train'][0].load_fixed_curriculum() 85 | 86 | # Load pretrained model weights, restore previous checkpoint 87 | if opt.pretrained: 88 | tmp_path = '%s/%s/snapshot_model' % (opt.exp_root_dir, opt.pretrained) 89 | print('Loading pretrained model weights from:', tmp_path) 90 | pretrained = torch.load(tmp_path) 91 | self.model.load_state_dict(pretrained['model'], strict=False) 92 | 93 | self.restore(opt.restore_session, self.to_load) 94 | 95 | def tensorboard_setup(self): 96 | self.to_track = ['network_loss', 'accuracy', 'decathlon'] 97 | self.task_names = [d.task_name for d in self.ds['train']] 98 | self.tb = tensorboard_manager.TBManager(self.exp_dir, self.task_names, 99 | self.to_track) 100 | 101 | def checkpoint_ref_setup(self): 102 | # Define key/value pairs for checkpoint management 103 | self.checkpoint_ref = { 104 | 'model': ['model'], 105 | 'meta': ['metaparams'], 106 | 'extra': ['log', 'score', 'task_trained_steps', 'global_trained_steps'] 107 | } 108 | self.to_load = list(self.checkpoint_ref.keys()) 109 | 110 | def setup_metaparams(self, opt, ds): 111 | if opt.metaparam == 'partition': 112 | self.metaparams = partition.Metaparam(opt) 113 | self.masks = None 114 | self.checkpoint_ref['model'] += ['masks'] 115 | 116 | else: 117 | self.metaparams = None 118 | 119 | def setup_model(self, opt, ds): 120 | model = importlib.import_module('mtl.models.' + opt.model) 121 | 122 | if opt.metaparam == 'partition': 123 | if opt.restore_session is not None: 124 | # Restore session snapshot 125 | print('Restoring previous masks... (%s/snapshot)' % opt.restore_session) 126 | checkpoint = torch.load('%s/snapshot_model' % opt.restore_session) 127 | self.masks = checkpoint['masks'] 128 | 129 | self.model = model.initialize( 130 | opt, ds, metaparams=self.metaparams, masks=self.masks) 131 | self.masks = [self.model.resnet.mask_ref, self.model.grad_mask_ref] 132 | 133 | else: 134 | self.model = model.initialize(opt, ds) 135 | 136 | def cuda(self): 137 | self.model.cuda() 138 | 139 | def set_train_mode(self, train_flag): 140 | self.is_training = train_flag 141 | if train_flag: 142 | self.model.train() 143 | else: 144 | self.model.eval() 145 | 146 | def get_log_vals(self, split, task_idx): 147 | return {k: self.log[k][task_idx][split] for k in self.to_track} 148 | 149 | def update_log(self, k, v): 150 | tmp_log = self.log[k][self.task_idx] 151 | if self.step == 0: 152 | tmp_v = v 153 | else: 154 | tmp_v = calc.running_avg(tmp_log[self.split], v) 155 | tmp_log[self.split] = tmp_v 156 | 157 | def sample_task(self): 158 | task_idxs = np.arange(self.num_tasks) 159 | curr = self.opt.curriculum 160 | 161 | if self.split == 'train': 162 | if curr == 'fixed': 163 | # Fixed training curriculum 164 | return self.fixed_curriculum[self.step % len(self.fixed_curriculum)] 165 | 166 | elif curr == 'uniform': 167 | # Uniformly iterate through tasks 168 | return self.step % self.num_tasks 169 | 170 | elif curr == 'proportional': 171 | # Sample tasks proportional to dataset size 172 | return np.random.choice(task_idxs, p=self.task_proportion) 173 | 174 | elif curr == 'train_accuracy': 175 | # Sample tasks based on relative training accuracies 176 | train_acc = np.array([d['train'] for d in self.log['accuracy']]) 177 | train_acc = np.log(1 - train_acc + self.opt.curriculum_bias) 178 | task_dist = calc.softmax(train_acc / self.opt.temperature) 179 | return np.random.choice(task_idxs, p=task_dist) 180 | 181 | else: 182 | # Undefined, raise error 183 | raise ValueError('Undefined task curriculum: %s' % curr) 184 | 185 | else: 186 | # Validation (order doesn't matter, just hit all validation samples) 187 | return self.opt.valid_iter_ref[self.step % self.opt.iters['valid']] 188 | 189 | def get_next_sample(self, split, task_idx): 190 | reset_iter = False 191 | try: 192 | sample = self.iters[split][task_idx].next() 193 | except StopIteration: 194 | reset_iter = True 195 | 196 | if reset_iter: 197 | self.iters[split][task_idx] = iter(self.dataloaders[split][task_idx]) 198 | sample = self.iters[split][task_idx].next() 199 | 200 | inp = sample['img'].cuda() 201 | label = sample['label'].view(-1).cuda() 202 | idxs = sample['index'] 203 | 204 | return inp, label, idxs 205 | 206 | def run(self, split, step): 207 | self.step = step 208 | self.split = split 209 | 210 | # Sample task 211 | task_idx = self.sample_task() 212 | self.task_idx = task_idx 213 | self.curr_task = self.ds['train'][task_idx].task_name 214 | 215 | # Get samples + model output 216 | inp, label, idxs = self.get_next_sample(split, task_idx) 217 | pred = self.model(inp, task_idx, split, step) 218 | 219 | # Calculate loss 220 | _, class_preds = torch.max(pred, 1) 221 | accuracy = class_preds.eq(label).float().mean() 222 | dec_score, _ = calc.decathlon_score( 223 | accuracy.data, task_idxs=[self.task_idx_ref[task_idx]]) 224 | self.net_loss = self.loss_fn(pred, label) 225 | 226 | # Track accuracy and cache predictions 227 | if split == 'valid': 228 | self.valid_accuracy_track[task_idx] += [accuracy.data.item()] 229 | for batch_idx, tmp_idx in enumerate(idxs): 230 | self.prediction_ref[task_idx][ 231 | tmp_idx.item()] = class_preds.data[batch_idx].item() 232 | 233 | self.update_log('accuracy', accuracy.data.item()) 234 | self.update_log('network_loss', self.net_loss.data.item()) 235 | self.update_log('decathlon', dec_score) 236 | self.score = np.array([d['valid'] for d in self.log['accuracy']]).mean() 237 | 238 | self.global_trained_steps += 1 239 | self.task_trained_steps[task_idx] += 1 240 | 241 | def update_weights(self): 242 | opt = self.opt 243 | t_idx = self.task_idx 244 | 245 | # Set up optimizer and learning rate 246 | net_optimizer = self.__dict__['optim_%d' % t_idx] 247 | lr = opt.learning_rate 248 | for p in net_optimizer.param_groups: 249 | p['lr'] = lr 250 | 251 | # Calculate gradients 252 | net_optimizer.zero_grad() 253 | self.net_loss.backward() 254 | 255 | if opt.clip_grad: 256 | p = self.model.net_parameters[t_idx] 257 | torch.nn.utils.clip_grad_norm(p, opt.clip_grad) 258 | 259 | if opt.metaparam == 'partition' and opt.mask_gradient: 260 | # Loop through parameters and multiply by appropriate mask 261 | for m, bw_masks in self.model.bw_ref: 262 | bw_mask = bw_masks[t_idx].cuda() 263 | m.weight.grad *= bw_mask 264 | if 'bias' in m._parameters and m.bias is not None: 265 | m.bias.grad *= bw_mask[:, 0] 266 | 267 | # Do weight update 268 | net_optimizer.step() 269 | -------------------------------------------------------------------------------- /util/calc.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | import numpy as np 17 | import torch 18 | import torch.nn.functional as F 19 | 20 | 21 | def running_avg(x, y, k=.99): 22 | return k * x + (1 - k) * y 23 | 24 | 25 | def softmax(x, d=-1): 26 | tmp = np.exp(np.array(x)) 27 | return tmp / tmp.sum(axis=d, keepdims=True) 28 | 29 | 30 | def sigmoid(x): 31 | return 1 / (1 + np.exp(-x)) 32 | 33 | 34 | def inv_sigmoid(x): 35 | return np.log(x / (1 - x)) 36 | 37 | 38 | def dist(a, b): 39 | return np.linalg.norm(a - b) 40 | 41 | 42 | def smooth_arr(arr, window=3): 43 | to_flatten = False 44 | if arr.ndim == 1: 45 | to_flatten = True 46 | arr = np.expand_dims(arr, 1) 47 | 48 | pad = window // 2 49 | tmp_arr = F.pad( 50 | torch.unsqueeze(torch.Tensor(arr.T), 0), [pad, pad], mode='reflect') 51 | tmp_arr = np.array(F.avg_pool1d(tmp_arr, window, stride=1).data) 52 | tmp_arr = tmp_arr[0].T 53 | 54 | if to_flatten: 55 | tmp_arr = tmp_arr[:, 0] 56 | 57 | return tmp_arr 58 | 59 | 60 | def decathlon_score(scores, task_idxs=None): 61 | if task_idxs is None: 62 | task_idxs = [i for i in range(10)] 63 | baseline_err = 1 - np.array([ 64 | 59.87, 60.34, 82.12, 92.82, 55.53, 97.53, 81.41, 87.69, 96.55, 51.20 65 | ]) / 100 66 | baseline_err = baseline_err[task_idxs] 67 | num_tasks = len(task_idxs) 68 | 69 | max_err = 2 * baseline_err 70 | gamma_vals = np.ones(num_tasks) * 2 71 | alpha_vals = 1000 * (max_err)**(-gamma_vals) 72 | 73 | err = 1 - scores 74 | if num_tasks == 1: 75 | err = [err] 76 | 77 | all_scores = [] 78 | for i in range(num_tasks): 79 | all_scores += [alpha_vals[i] * max(0, max_err[i] - err[i])**gamma_vals[i]] 80 | return sum(all_scores), all_scores 81 | 82 | 83 | def rescale(x, min_val, max_val, invert=False): 84 | if not invert: 85 | return x * (max_val - min_val) + min_val 86 | else: 87 | return (x - min_val) / (max_val - min_val) 88 | 89 | 90 | def pow10(x, min_val, max_val, invert=False): 91 | log_fn = np.log if type(x) is float else torch.log 92 | 93 | if not invert: 94 | return 10**rescale(x, 95 | np.log(min_val) / np.log(10), 96 | np.log(max_val) / np.log(10)) 97 | else: 98 | return rescale( 99 | log_fn(x) / np.log(10), 100 | np.log(min_val) / np.log(10), 101 | np.log(max_val) / np.log(10), invert) 102 | 103 | 104 | def map_val(x, min_val, max_val, scale='linear', invert=False): 105 | if scale == 'log': 106 | map_fn = pow10 107 | elif scale == 'linear': 108 | map_fn = rescale 109 | return map_fn(x, min_val, max_val, invert) 110 | 111 | 112 | def reverse_tensor(t, dim): 113 | return t.index_select(dim, torch.arange(t.shape[dim] - 1, -1, -1).long()) 114 | 115 | 116 | def convert_mat_aux(m, d, min_, max_, invert=False): 117 | if invert: 118 | m = (m - min_) / (max_ - min_ + 1e-5) 119 | else: 120 | m = m * (max_ - min_) + min_ 121 | m = np.triu(m, 1) 122 | return m + m.T + np.diag(d) 123 | 124 | 125 | def convert_mat(mat, invert=False): 126 | if mat.dim() == 3: 127 | # mat is 2 x n x n 128 | # where mat[0] is the forward matrix, and mat[1] is the backward one 129 | mat = np.array(mat) 130 | 131 | # Convert forward matrix 132 | d_f = mat[0].diagonal() 133 | min_ = np.maximum(0, np.add.outer(d_f, d_f) - 1) 134 | max_ = np.minimum.outer(d_f, d_f) 135 | m_f = convert_mat_aux(mat[0], d_f, min_, max_, invert=invert) 136 | 137 | # Convert backward matrix 138 | d_b = mat[1].diagonal() 139 | if not invert: 140 | d_b = d_b * d_f 141 | tmp_m = mat[0] if invert else m_f 142 | min_ = np.maximum(0, 143 | np.add.outer(d_b, d_b) - np.add.outer(d_f, d_f) + tmp_m) 144 | max_ = np.minimum(tmp_m, np.minimum.outer(d_b, d_b)) 145 | if invert: 146 | d_b = d_b / d_f 147 | m_b = convert_mat_aux(mat[1], d_b, min_, max_, invert=invert) 148 | 149 | tmp_mat = np.stack([m_f, m_b], 0) 150 | tmp_mat = np.round(tmp_mat * 1000) / 1000 151 | return torch.Tensor(tmp_mat) 152 | 153 | else: 154 | result = [convert_mat(m, invert=invert) for m in mat] 155 | return torch.stack(result) 156 | 157 | 158 | def mask_solver(p, n_iters=10, n_feats=100, filt=None): 159 | pw = (p * n_feats + 1e-3).astype(int) 160 | diag = pw.diagonal() 161 | n_tasks = p.shape[0] 162 | all_idxs = np.arange(n_feats) 163 | 164 | mask = np.zeros((n_tasks, n_feats)) 165 | mask[0][:pw[0, 0]] = 1 166 | 167 | p1_ = pw / np.maximum(1, np.outer(diag, diag)) 168 | p2_ = -pw / np.maximum(1, np.outer(n_feats - diag, diag)) 169 | p2_ += 1 / np.maximum(1, (n_feats - diag.reshape(-1, 1))) 170 | 171 | for curr_t in range(1, n_tasks): 172 | prob_dist = np.ones(n_feats) 173 | if filt is not None: 174 | prob_dist *= filt[curr_t] 175 | 176 | for cmp_t in range(curr_t): 177 | prob_dist *= p1_[cmp_t, curr_t] * mask[cmp_t] + p2_[cmp_t, curr_t] * ( 178 | 1 - mask[cmp_t]) 179 | 180 | if prob_dist.sum() == 0: 181 | prob_dist = np.ones(n_feats) 182 | 183 | prob_dist /= prob_dist.sum() 184 | 185 | scores = np.zeros((n_iters, curr_t)) 186 | best_score_dist = 999 187 | n_to_choose = min((prob_dist > 0).sum(), diag[curr_t]) 188 | if n_to_choose > 0: 189 | for i in range(n_iters): 190 | sample_idxs = np.random.choice( 191 | all_idxs, n_to_choose, replace=False, p=prob_dist) 192 | tmp_row = np.zeros(n_feats) 193 | tmp_row[sample_idxs] = 1 194 | 195 | scores[i] = np.dot(mask[:curr_t], tmp_row) 196 | score_dist = np.linalg.norm(scores[i] - pw[:curr_t, curr_t]) 197 | if score_dist < best_score_dist: 198 | best_score_dist = score_dist 199 | mask[curr_t] = tmp_row 200 | 201 | return mask 202 | 203 | 204 | def find_masks(p, n_feats=100, n_iters=8, return_scores=True): 205 | # Convert raw parameterization 206 | p = np.array(convert_mat(p)) 207 | kargs = {'n_feats': n_feats, 'n_iters': n_iters} 208 | mask_f = mask_solver(p[0], **kargs) 209 | mask_b = mask_solver(p[1], filt=mask_f, **kargs) 210 | 211 | # Sort mask channels 212 | count_row_f = mask_f.sum(0, keepdims=True) 213 | count_row_b = mask_b.sum(0, keepdims=True) 214 | tmp_mask = np.concatenate([mask_b, count_row_b, mask_f, count_row_f], 0) 215 | tmp_idx = np.lexsort(tmp_mask) 216 | tmp_mask = tmp_mask[:, tmp_idx] 217 | 218 | n = mask_b.shape[0] 219 | mask_b = tmp_mask[:n] 220 | mask_f = tmp_mask[n + 1:2 * n + 1] 221 | masks = np.stack([mask_f, mask_b], 0) 222 | 223 | return masks 224 | -------------------------------------------------------------------------------- /util/datasets/decathlon.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Visual Decathlon dataset reference file.""" 17 | import json 18 | import os 19 | import pickle 20 | 21 | import imageio 22 | from mtl.third_party.cutout import cutout 23 | import numpy as np 24 | import PIL 25 | import torch 26 | from torch.utils.data import DataLoader 27 | from torch.utils.data import Dataset 28 | from torchvision import transforms 29 | 30 | 31 | def setup_extra_args(parser): 32 | parser.add_argument('--input_res', type=int, default=72) 33 | parser.add_argument('--subsample_validation', type=int, default=0) 34 | 35 | 36 | task_names = [ 37 | 'imagenet12', 'aircraft', 'cifar100', 'daimlerpedcls', 'dtd', 'gtsrb', 38 | 'vgg-flowers', 'omniglot', 'svhn', 'ucf101' 39 | ] 40 | 41 | 42 | def collect_and_parse_annot(annot_dir, task_name, split): 43 | # Collect image paths, img_ids, cat_id, task_specific_cat, ht, wd 44 | ann = {} 45 | key_map = {'train': 'train', 'valid': 'val', 'test': 'test_stripped'} 46 | split = key_map[split] 47 | 48 | annot_fn = '%s/annotations/%s_%s.json' % (annot_dir, task_name, split) 49 | with open(annot_fn, 'r') as f: 50 | annot = json.load(f) 51 | 52 | num_imgs = len(annot['images']) 53 | keys = ['path', 'ht', 'wd', 'label', 'cat_id', 'img_id'] 54 | ann = {k: np.zeros(num_imgs, int) for k in keys[1:]} 55 | ann[keys[0]] = [] 56 | 57 | num_cats = len(annot['categories']) 58 | ann['num_cats'] = num_cats 59 | ann['cat_id_ref'] = np.zeros(num_cats, int) 60 | ann['cat_label_ref'] = [] 61 | for i in range(num_cats): 62 | ann['cat_id_ref'][i] = annot['categories'][i]['id'] 63 | ann['cat_label_ref'] += [annot['categories'][i]['name']] 64 | 65 | ann['num_imgs'] = num_imgs 66 | for i in range(num_imgs): 67 | im_annot = annot['images'][i] 68 | ann['path'] += [im_annot['file_name']] 69 | ann['ht'][i] = im_annot['height'] 70 | ann['wd'][i] = im_annot['width'] 71 | if 'test' in split: 72 | # For test images, put in all zeros as filler 73 | ann['img_id'][i] = im_annot['id'] 74 | ann['cat_id'][i] = 0 75 | ann['label'][i] = 0 76 | else: 77 | gt_annot = annot['annotations'][i] 78 | ann['img_id'][i] = gt_annot['image_id'] 79 | ann['cat_id'][i] = gt_annot['category_id'] 80 | ann['label'][i] = gt_annot['category_id'] % 1e5 - 1 81 | 82 | return ann 83 | 84 | 85 | def combine_annot(dt, dv): 86 | new_annot = {} 87 | for k in ['ht', 'wd', 'label', 'cat_id', 'img_id']: 88 | new_annot[k] = np.concatenate([dt[k], dv[k]], 0) 89 | new_annot['path'] = dt['path'] + dv['path'] 90 | new_annot['num_imgs'] = dt['num_imgs'] + dv['num_imgs'] 91 | for k in ['num_cats', 'cat_id_ref', 'cat_label_ref']: 92 | new_annot[k] = dt[k] 93 | return new_annot 94 | 95 | 96 | def get_annot_from_idxs(annot, idxs): 97 | new_annot = {} 98 | for k in ['ht', 'wd', 'label', 'cat_id', 'img_id']: 99 | new_annot[k] = annot[k][idxs] 100 | new_annot['path'] = [annot['path'][i] for i in idxs] 101 | new_annot['num_imgs'] = len(idxs) 102 | for k in ['num_cats', 'cat_id_ref', 'cat_label_ref']: 103 | new_annot[k] = annot[k] 104 | return new_annot 105 | 106 | 107 | class DecathlonDataset(Dataset): 108 | 109 | def __init__(self, opt, task_idx, is_train, annot=None, augment=None): 110 | if augment is None: 111 | augment = is_train 112 | task_name = task_names[task_idx] 113 | 114 | self.data_dir = opt.data_dir 115 | self.task_idx = task_idx 116 | self.task_name = task_name 117 | self.is_train = is_train 118 | self.targets = np.array( 119 | [59.87, 60.34, 82.12, 92.82, 55.53, 97.53, 81.41, 87.69, 96.55, 51.20]) 120 | 121 | if annot is None: 122 | if is_train or opt.validate_on_train: 123 | # Load training annotations 124 | annot = collect_and_parse_annot(opt.data_dir, task_name, 'train') 125 | if opt.train_on_valid: 126 | valid_annot = collect_and_parse_annot(opt.data_dir, task_name, 127 | 'valid') 128 | annot = combine_annot(annot, valid_annot) 129 | 130 | else: 131 | # Load validation/test annotations 132 | annot = collect_and_parse_annot(opt.data_dir, task_name, 133 | 'test' if opt.use_test else 'valid') 134 | 135 | if opt.subsample_validation and not opt.use_test: 136 | num_valid = annot['num_imgs'] 137 | if task_name != 'imagenet12': 138 | # Subsample to 3k or less (indices chosen randomly) 139 | tmp_filename = '%s/%s_valid_subset.npy' % (opt.data_dir, task_name) 140 | try: 141 | tmp_idxs = np.load(tmp_filename) 142 | except: 143 | tmp_idxs = np.random.permutation(np.arange(annot['num_imgs'])) 144 | tmp_idxs = tmp_idxs[:3000] 145 | np.save(tmp_filename, tmp_idxs) 146 | else: 147 | # Subsample ImageNet validation uniformly by factor of 10 148 | tmp_idxs = np.arange(0, num_valid, 10) 149 | 150 | annot = get_annot_from_idxs(annot, tmp_idxs) 151 | 152 | self.num_out = annot['num_cats'] 153 | self.num_imgs = annot['num_imgs'] 154 | self.annot = annot 155 | 156 | # Setup image transformation 157 | no_flip = [5, 7, 8] 158 | scale_factor = 1.0 159 | 160 | if augment: 161 | tmp_transform = [ 162 | transforms.Resize(int(opt.input_res * scale_factor)), 163 | transforms.RandomCrop(opt.input_res, padding=4), 164 | transforms.ColorJitter(0.3, 0.3, 0.3, 0.1) 165 | ] 166 | if task_idx not in no_flip: 167 | # Ignore datasets like SVHN (where left/right matters for digits) 168 | tmp_transform += [transforms.RandomHorizontalFlip()] 169 | 170 | tmp_transform = transforms.Compose(tmp_transform) 171 | else: 172 | tmp_transform = transforms.Compose([ 173 | transforms.Resize(int(opt.input_res * scale_factor)), 174 | transforms.CenterCrop(opt.input_res) 175 | ]) 176 | 177 | to_pil = transforms.ToPILImage() 178 | to_tensor = transforms.ToTensor() 179 | 180 | with open(opt.data_dir + '/decathlon_mean_ref.p', 'rb') as f: 181 | # Precomputed mean/variance of decathlon images 182 | ref_mean_std = pickle.load(f) 183 | ds_mean = ref_mean_std['mean'][self.task_name] 184 | ds_std = ref_mean_std['std'][self.task_name] 185 | normalize = transforms.Normalize(mean=ds_mean, std=ds_std) 186 | 187 | self.transform = transforms.Compose( 188 | [to_pil, tmp_transform, to_tensor, normalize]) 189 | if augment: 190 | self.transform.transforms.append( 191 | cutout.Cutout(n_holes=1, length=opt.input_res // 4)) 192 | 193 | def load_fixed_curriculum(self): 194 | return np.load(self.data_dir + '/dec_order.npy') 195 | 196 | def load_image(self, idx): 197 | imgpath = '%s/%s' % (self.data_dir, self.annot['path'][idx]) 198 | tmp_im = imageio.imread(imgpath).astype(float) / 255 199 | if tmp_im.ndim == 2: 200 | tmp_im = np.tile(np.expand_dims(tmp_im, 2), [1, 1, 3]) 201 | 202 | return tmp_im 203 | 204 | def __len__(self): 205 | return self.num_imgs 206 | 207 | def __getitem__(self, idx): 208 | img = self.load_image(idx) 209 | img = torch.Tensor(img).permute(2, 0, 1) # HWC --> CHW 210 | return { 211 | 'img': self.transform(img), 212 | 'label': self.annot['label'][idx], 213 | 'index': idx 214 | } 215 | 216 | 217 | def initialize(opt): 218 | if not opt.data_dir: 219 | # Default data directory 220 | curr_dir = os.path.dirname(__file__) 221 | opt.data_dir = os.path.join(curr_dir, '../../../data/decathlon') 222 | 223 | datasets = {'train': [], 'valid': []} 224 | dataloaders = {'train': [], 'valid': []} 225 | task_idxs = list(map(int, opt.task_choice.split('-'))) 226 | 227 | # Check whether to run a different batchsize during validation 228 | valid_bs = opt.valid_batchsize if opt.valid_batchsize else opt.batchsize 229 | 230 | print('Training on:') 231 | valid_iter_ref = [] 232 | for i, task_idx in enumerate(task_idxs): 233 | for split in datasets: 234 | is_train = split == 'train' 235 | datasets[split] += [DecathlonDataset(opt, task_idx, is_train)] 236 | dataloaders[split] += [ 237 | DataLoader( 238 | datasets[split][i], 239 | batch_size=opt.batchsize if is_train else valid_bs, 240 | shuffle=False if opt.use_test else True, 241 | num_workers=opt.num_data_threads) 242 | ] 243 | print(task_names[task_idx], 244 | [len(datasets[split][-1]) for split in datasets]) 245 | valid_iter_ref += [len(datasets['valid'][-1])] 246 | 247 | n_valid_samples = sum(valid_iter_ref) 248 | valid_iter_ref = [int(np.ceil(v / valid_bs)) for v in valid_iter_ref] 249 | opt.iters['valid'] = sum(valid_iter_ref) 250 | tmp_iter_ref = np.concatenate( 251 | [np.array([i] * v) for i, v in enumerate(valid_iter_ref)]) 252 | opt.valid_iter_ref = np.random.permutation(tmp_iter_ref) 253 | opt.valid_iters = opt.iters['valid'] 254 | print('%d validation samples available w/ batchsize %d (%d valid iters)' % 255 | (n_valid_samples, valid_bs, opt.iters['valid'])) 256 | 257 | return datasets, dataloaders 258 | -------------------------------------------------------------------------------- /util/prepare_submission.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Prepare json file for Decathlon submission server.""" 17 | import json 18 | 19 | from mtl.config.opts import parser 20 | from mtl.util.datasets import decathlon 21 | import torch 22 | 23 | 24 | opt = parser.parse_command_line() 25 | ds, _ = decathlon.initialize(opt) 26 | ds = ds['valid'] 27 | 28 | # Load predictions 29 | imgnet_p = torch.load(opt.exp_root_dir + 30 | '/imgnet_test/final_predictions')['preds'] 31 | p = torch.load(opt.exp_dir + '/final_predictions')['preds'] 32 | all_preds = imgnet_p + p 33 | 34 | tmp_result = [] 35 | for task_idx in range(10): 36 | n_ims = len(ds[task_idx]) 37 | cat_ref = ds[task_idx].annot['cat_id_ref'] 38 | im_ids = ds[task_idx].annot['img_id'] 39 | for im_idx in range(n_ims): 40 | tmp_result += [{ 41 | 'image_id': int(im_ids[im_idx]), 42 | 'category_id': int(cat_ref[all_preds[task_idx][im_idx]]) 43 | }] 44 | 45 | with open(opt.exp_dir + '/test_submission.json', 'w') as f: 46 | json.dump(tmp_result, f) 47 | -------------------------------------------------------------------------------- /util/session_manager.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Class for managing training and meta-optimization sessions.""" 17 | import subprocess 18 | import torch 19 | 20 | 21 | class SessionManager(): 22 | """Base class for managing experiments.""" 23 | 24 | def __init__(self, opt, ds=None, dataloaders=None): 25 | self.opt = opt 26 | self.ds = ds 27 | self.dataloaders = dataloaders 28 | self.score = 0 29 | 30 | # Absolute paths for experiment information 31 | self.exp_root_dir = opt.exp_root_dir 32 | self.exp_dir = opt.exp_dir 33 | 34 | self.checkpoint_ref_setup() 35 | self.tensorboard_setup() 36 | 37 | def checkpoint_ref_setup(self): 38 | self.checkpoint_ref = {'extra': ['score']} 39 | 40 | def tensorboard_setup(self): 41 | return 42 | 43 | def setup_optimizer(self, params, multi=False): 44 | """Initialize parameter optimizer. 45 | 46 | Args: 47 | params: Parameters to optimize. 48 | multi: Optimizing multiple sets of parameters. If true, keep track of 49 | gradient statistics separately. 50 | """ 51 | opt = self.opt 52 | optim_choice = opt.optimizer 53 | optim_fn = torch.optim.__dict__[optim_choice] 54 | 55 | # Setup optimizer arguments 56 | optim_kargs = {'lr': opt.learning_rate} 57 | if optim_choice == 'SGD': 58 | if 'momentum' in opt: 59 | optim_kargs['momentum'] = opt.momentum 60 | optim_kargs['nesterov'] = True 61 | if 'weight_decay' in opt: 62 | optim_kargs['weight_decay'] = opt.weight_decay 63 | 64 | elif optim_choice == 'RMSprop': 65 | optim_kargs['momentum'] = 0. 66 | optim_kargs['eps'] = 0.1 67 | 68 | # Initialize optimizers 69 | if multi: 70 | for i, p in enumerate(params): 71 | self.__dict__['optim_%d' % i] = optim_fn(p, **optim_kargs) 72 | self.checkpoint_ref['optim'] = [ 73 | 'optim_%d' % i for i in range(self.num_tasks) 74 | ] 75 | 76 | else: 77 | self.optim = optim_fn(params, **optim_kargs) 78 | self.checkpoint_ref['optim'] = ['optim'] 79 | 80 | def checkpoint(self, path, key_ref=None, groups=None, action='save'): 81 | """Manage loading and restoring experiment checkpoints. 82 | 83 | Args: 84 | path: File path to save/load checkpoint. 85 | key_ref: Dictionary listing all parts of session to save. 86 | groups: Which subset of key_ref to save. 87 | action: 'save' or 'load' 88 | """ 89 | if key_ref is None: 90 | key_ref = self.checkpoint_ref 91 | if groups is None: 92 | groups = key_ref.keys() 93 | 94 | for group in groups: 95 | if action == 'save': 96 | to_, from_ = {}, self.__dict__ 97 | elif action == 'load': 98 | to_, from_ = self.__dict__, torch.load('%s_%s' % (path, group)) 99 | 100 | for k in key_ref[group]: 101 | try: 102 | # Check whether or not to use a state_dict 103 | tmp_obj = from_[k] if action == 'save' else to_[k] 104 | use_state_dict = callable(getattr(tmp_obj, 'state_dict')) 105 | except AttributeError: 106 | use_state_dict = False 107 | 108 | if use_state_dict: 109 | if action == 'save': 110 | to_[k] = from_[k].state_dict() 111 | elif action == 'load': 112 | kargs = {'strict': False} if 'model' in k else {} 113 | to_[k].load_state_dict(from_[k], **kargs) 114 | 115 | else: 116 | to_[k] = from_[k] 117 | 118 | if action == 'save': 119 | torch.save(to_, '%s_%s' % (path, group)) 120 | 121 | def save(self, path, key_ref=None, groups=None): 122 | self.checkpoint(path, key_ref, groups, 'save') 123 | 124 | def load(self, path, key_ref=None, groups=None): 125 | self.checkpoint(path, key_ref, groups, 'load') 126 | 127 | def restore(self, path, groups=None): 128 | if path: 129 | print('Restoring previous session... (%s/snapshot)' % path) 130 | self.load(path + '/snapshot', groups=groups) 131 | 132 | def clear_checkpoints(self, path, groups=['model', 'optim']): 133 | for g in groups: 134 | try: 135 | subprocess.call(['rm', path + '/snapshot_%s' % g]) 136 | except Exception as e: 137 | print('Error clearing snapshot %s:' % g, repr(e)) 138 | -------------------------------------------------------------------------------- /util/tensorboard_manager.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Class definition for Tensorboard wrapper.""" 17 | 18 | import os 19 | import tensorflow as tf 20 | import torch 21 | 22 | 23 | class TBManager: 24 | """Simple wrapper for Tensorboard.""" 25 | 26 | def __init__(self, exp_dir, task_names, to_track, splits=['train', 'valid']): 27 | # Set up summaries 28 | summaries = {} 29 | placeholders = {} 30 | 31 | with tf.device('/cpu:0'): 32 | for task in task_names: 33 | summaries[task] = {} 34 | 35 | for s in splits: 36 | tmp_summaries = [] 37 | 38 | for k in to_track: 39 | if k not in placeholders: 40 | placeholders[k] = tf.placeholder(tf.float32, []) 41 | tmp_summaries += [ 42 | tf.summary.scalar('%s_%s_%s' % (task, s, k), placeholders[k]) 43 | ] 44 | 45 | summaries[task][s] = tf.summary.merge(tmp_summaries) 46 | 47 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '4' # Suppress TF warnings 48 | 49 | config = tf.ConfigProto() 50 | config.gpu_options.allow_growth = True 51 | self.sess = tf.Session(config=config) 52 | 53 | self.writer = tf.summary.FileWriter(exp_dir, self.sess.graph) 54 | self.summaries = summaries 55 | self.placeholders = placeholders 56 | 57 | def update(self, task, split, step, vals): 58 | """Write an update to events file. 59 | 60 | Args: 61 | task: Index of task to write summary for 62 | split: Specify 'train' or 'valid' 63 | step: Current training iteration 64 | vals: Dictionary with values to update 65 | """ 66 | 67 | # Get summary update for Tensorboard 68 | feed_dict = { 69 | self.placeholders[k]: v.cpu() if isinstance(v, torch.Tensor) else v 70 | for k, v in vals.items() 71 | } 72 | summary = self.sess.run(self.summaries[task][split], feed_dict=feed_dict) 73 | 74 | # Log data 75 | self.writer.add_summary(summary, step) 76 | self.writer.flush() 77 | --------------------------------------------------------------------------------