├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── config.py ├── data.py ├── model.py ├── model_test.py ├── runner.py └── utils.py /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to Contribute 2 | 3 | We'd love to accept your patches and contributions to this project. There are 4 | just a few small guidelines you need to follow. 5 | 6 | ## Contributor License Agreements 7 | 8 | Contributions to this project must be accompanied by a Contributor License 9 | Agreement. You (or your employer) retain the copyright to your contribution, 10 | this simply gives us permission to use and redistribute your contributions as 11 | part of the project. 12 | 13 | Please fill out either the individual or corporate Contributor License Agreement 14 | (CLA). 15 | 16 | * If you are an individual writing original source code and you're sure you 17 | own the intellectual property, then you'll need to sign an [individual 18 | CLA](https://cla.developers.google.com/about/google-individual). 19 | * If you work for a company that wants to allow you to contribute your work, 20 | then you'll need to sign a [corporate 21 | CLA](https://cla.developers.google.com/about/google-corporate). 22 | 23 | Follow either of the two links above to access the appropriate CLA and 24 | instructions for how to sign and return it. Once we receive it, we'll be able to 25 | accept your pull requests. 26 | 27 | ## Code reviews 28 | 29 | All submissions, including submissions by project members, require review. We 30 | use GitHub pull requests for this purpose. Consult 31 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more 32 | information on using pull requests. 33 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright 2018 DeepMind Technologies Limited 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Meta-Learning with Latent Embedding Optimization 2 | 3 | ## Overview 4 | This repository contains the implementation of the meta-learning model 5 | described in the paper "[Meta-Learning with Latent Embedding 6 | Optimization](https://arxiv.org/abs/1807.05960)" by Rusu et. al. It was posted 7 | on arXiv in July 2018 and will be presented at ICLR 2019. 8 | 9 | The paper learns a data-dependent latent representation of model parameters and 10 | performs gradient-based meta-learning in this low-dimensional space. 11 | 12 | The code here doesn't include the (standard) method for pre-training the 13 | data embeddings. Instead, the trained embeddings are provided. 14 | 15 | Disclaimer: This is not an official Google product. 16 | 17 | ## Running the code 18 | 19 | ### Setup 20 | To run the code, you first need to need to install: 21 | 22 | - [TensorFlow](https://www.tensorflow.org/install/) and [TensorFlow Probability](https://www.tensorflow.org/probability) (we used version 1.12), 23 | - [Sonnet](https://github.com/deepmind/sonnet) (we used version v1.29), and 24 | - [Abseil](https://github.com/abseil/abseil-py) (we use only the FLAGS module). 25 | 26 | ### Getting the data 27 | You need to download [the embeddings](http://storage.googleapis.com/leo-embeddings/embeddings.zip) and extract them on disk: 28 | 29 | ``` 30 | $ wget http://storage.googleapis.com/leo-embeddings/embeddings.zip 31 | $ unzip embeddings.zip 32 | $ EMBEDDINGS=`pwd`/embeddings 33 | ``` 34 | 35 | ### Running the code 36 | Then, clone this repository using: 37 | 38 | `$ git clone https://github.com/deepmind/leo` 39 | 40 | and run the code as: 41 | 42 | `$ python runner.py --data_path=$EMBEDDINGS` 43 | 44 | This will train the model for solving 5-way 1-shot miniImageNet classification. 45 | 46 | ### Hyperparameters 47 | To train the model on the tieredImageNet dataset or with a different number of 48 | training examples per class (K-shot), you can pass these parameters with 49 | command-line or in `config.py`, e.g.: 50 | 51 | `$ python runner.py --data_path=$EMBEDDINGS --dataset_name=tieredImageNet 52 | --num_tr_examples_per_class=5 --outer_lr=1e-4` 53 | 54 | See `config.py` for the list of options to set. 55 | 56 | Comparison of paper and open-source implementations in terms of test set accuracy: 57 | 58 | | Implementation | miniImageNet 1-shot | miniImageNet 5-shot | tieredImageNet 1-shot | tieredImageNet 5-shot | 59 | | -----------------------| ------------------- | ------------------- | --------------------- | --------------------- | 60 | | `LEO Paper` | `61.76 ± 0.08%` | `77.59 ± 0.12%` | `66.33 ± 0.05%` | `81.44 ± 0.09%` | 61 | | `This code` | `61.89 ± 0.16%` | `77.65 ± 0.09%` | `66.25 ± 0.14%` | `81.77 ± 0.09%` | 62 | 63 | 64 | The hyperparameters we found working best for different setups are as follows: 65 | 66 | | Hyperparameter | miniImageNet 1-shot | miniImageNet 5-shot | tieredImageNet 1-shot | tieredImageNet 5-shot | 67 | | ------------------------------ | ------------------- | ------------------- | --------------------- | --------------------- | 68 | | `outer_lr` | `2.739071e-4` | `4.102361e-4` | `8.659053e-4` | `6.110314e-4` | 69 | | `l2_penalty_weight` | `3.623413e-10` | `8.540338e-9` | `4.148858e-10` | `1.690399e-10` | 70 | | `orthogonality_penalty_weight` | `0.188103` | `1.523998e-3` | `5.451078e-3` | `2.481216e-2` | 71 | | `dropout_rate` | `0.307651` | `0.300299` | `0.475126` | `0.415158` | 72 | | `kl_weight` | `0.756143` | `0.466387` | `2.034189e-3` | `1.622811` | 73 | | `encoder_penalty_weight` | `5.756821e-6` | `2.661608e-7` | `8.302962e-5` | `2.672450e-5` | 74 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | # coding=utf8 2 | # Copyright 2018 DeepMind Technologies Limited 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================ 16 | """A module containing just the configs for the different LEO parts.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | from absl import flags 23 | 24 | 25 | FLAGS = flags.FLAGS 26 | flags.DEFINE_string("data_path", None, "Path to the dataset.") 27 | flags.DEFINE_string( 28 | "dataset_name", "miniImageNet", "Name of the dataset to " 29 | "train on, which will be mapped to data.MetaDataset.") 30 | flags.DEFINE_string( 31 | "embedding_crop", "center", "Type of the cropping, which " 32 | "will be mapped to data.EmbeddingCrop.") 33 | flags.DEFINE_boolean("train_on_val", False, "Whether to train on the " 34 | "validation data.") 35 | 36 | flags.DEFINE_integer( 37 | "inner_unroll_length", 5, "Number of unroll steps in the " 38 | "inner loop of leo (number of adaptation steps in the " 39 | "latent space).") 40 | flags.DEFINE_integer( 41 | "finetuning_unroll_length", 5, "Number of unroll steps " 42 | "in the loop performing finetuning (number of adaptation " 43 | "steps in the parameter space).") 44 | flags.DEFINE_integer("num_latents", 64, "The dimensionality of the latent " 45 | "space.") 46 | flags.DEFINE_float( 47 | "inner_lr_init", 1.0, "The initialization value for the " 48 | "learning rate of the inner loop of leo.") 49 | flags.DEFINE_float( 50 | "finetuning_lr_init", 0.001, "The initialization value for " 51 | "learning rate of the finetuning loop.") 52 | flags.DEFINE_float("dropout_rate", 0.5, "Rate of dropout: probability of " 53 | "dropping a given unit.") 54 | flags.DEFINE_float( 55 | "kl_weight", 1e-3, "The weight measuring importance of the " 56 | "KL in the final loss. β in the paper.") 57 | flags.DEFINE_float( 58 | "encoder_penalty_weight", 1e-9, "The weight measuring " 59 | "importance of the encoder penalty in the final loss. γ in " 60 | "the paper.") 61 | flags.DEFINE_float("l2_penalty_weight", 1e-8, "The weight measuring the " 62 | "importance of the l2 regularization in the final loss. λ₁ " 63 | "in the paper.") 64 | flags.DEFINE_float("orthogonality_penalty_weight", 1e-3, "The weight measuring " 65 | "the importance of the decoder orthogonality regularization " 66 | "in the final loss. λ₂ in the paper.") 67 | 68 | flags.DEFINE_integer( 69 | "num_classes", 5, "Number of classes, N in N-way classification.") 70 | flags.DEFINE_integer( 71 | "num_tr_examples_per_class", 1, "Number of training samples per class, " 72 | "K in K-shot classification.") 73 | flags.DEFINE_integer( 74 | "num_val_examples_per_class", 15, "Number of validation samples per class " 75 | "in a task instance.") 76 | flags.DEFINE_integer("metatrain_batch_size", 12, "Number of problem instances " 77 | "in a batch.") 78 | flags.DEFINE_integer("metavalid_batch_size", 200, "Number of meta-validation " 79 | "problem instances.") 80 | flags.DEFINE_integer("metatest_batch_size", 200, "Number of meta-testing " 81 | "problem instances.") 82 | flags.DEFINE_integer("num_steps_limit", int(1e5), "Number of steps to train " 83 | "for.") 84 | flags.DEFINE_float("outer_lr", 1e-4, "Outer (metatraining) loop learning " 85 | "rate.") 86 | flags.DEFINE_float( 87 | "gradient_threshold", 0.1, "The cutoff for the gradient " 88 | "clipping. Gradients will be clipped to " 89 | "[-gradient_threshold, gradient_threshold]") 90 | flags.DEFINE_float( 91 | "gradient_norm_threshold", 0.1, "The cutoff for clipping of " 92 | "the gradient norm. Gradient norm clipping will be applied " 93 | "after pointwise clipping (described above).") 94 | 95 | 96 | def get_data_config(): 97 | config = {} 98 | config["data_path"] = FLAGS.data_path 99 | config["dataset_name"] = FLAGS.dataset_name 100 | config["embedding_crop"] = FLAGS.embedding_crop 101 | config["train_on_val"] = FLAGS.train_on_val 102 | config["total_examples_per_class"] = 600 103 | return config 104 | 105 | 106 | def get_inner_model_config(): 107 | """Returns the config used to initialize LEO model.""" 108 | config = {} 109 | config["inner_unroll_length"] = FLAGS.inner_unroll_length 110 | config["finetuning_unroll_length"] = FLAGS.finetuning_unroll_length 111 | config["num_latents"] = FLAGS.num_latents 112 | config["inner_lr_init"] = FLAGS.inner_lr_init 113 | config["finetuning_lr_init"] = FLAGS.finetuning_lr_init 114 | config["dropout_rate"] = FLAGS.dropout_rate 115 | config["kl_weight"] = FLAGS.kl_weight 116 | config["encoder_penalty_weight"] = FLAGS.encoder_penalty_weight 117 | config["l2_penalty_weight"] = FLAGS.l2_penalty_weight 118 | config["orthogonality_penalty_weight"] = FLAGS.orthogonality_penalty_weight 119 | 120 | return config 121 | 122 | 123 | def get_outer_model_config(): 124 | """Returns the outer config file for N-way K-shot classification tasks.""" 125 | config = {} 126 | config["num_classes"] = FLAGS.num_classes 127 | config["num_tr_examples_per_class"] = FLAGS.num_tr_examples_per_class 128 | config["num_val_examples_per_class"] = FLAGS.num_val_examples_per_class 129 | config["metatrain_batch_size"] = FLAGS.metatrain_batch_size 130 | config["metavalid_batch_size"] = FLAGS.metavalid_batch_size 131 | config["metatest_batch_size"] = FLAGS.metatest_batch_size 132 | config["num_steps_limit"] = FLAGS.num_steps_limit 133 | config["outer_lr"] = FLAGS.outer_lr 134 | config["gradient_threshold"] = FLAGS.gradient_threshold 135 | config["gradient_norm_threshold"] = FLAGS.gradient_norm_threshold 136 | return config 137 | -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | """Creates problem instances for LEO.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import collections 22 | import os 23 | import pickle 24 | import random 25 | 26 | import enum 27 | import numpy as np 28 | import six 29 | import tensorflow as tf 30 | 31 | 32 | NDIM = 640 33 | 34 | ProblemInstance = collections.namedtuple( 35 | "ProblemInstance", 36 | ["tr_input", "tr_output", "tr_info", "val_input", "val_output", "val_info"]) 37 | 38 | 39 | class StrEnum(enum.Enum): 40 | """An Enum represented by a string.""" 41 | 42 | def __str__(self): 43 | return self.value 44 | 45 | def __repr__(self): 46 | return self.__str__() 47 | 48 | 49 | class MetaDataset(StrEnum): 50 | """Datasets supported by the DataProvider class.""" 51 | MINI = "miniImageNet" 52 | TIERED = "tieredImageNet" 53 | 54 | 55 | class EmbeddingCrop(StrEnum): 56 | """Embedding types supported by the DataProvider class.""" 57 | CENTER = "center" 58 | MULTIVIEW = "multiview" 59 | 60 | 61 | class MetaSplit(StrEnum): 62 | """Meta-datasets split supported by the DataProvider class.""" 63 | TRAIN = "train" 64 | VALID = "val" 65 | TEST = "test" 66 | 67 | 68 | class DataProvider(object): 69 | """Creates problem instances from a specific split and dataset.""" 70 | 71 | def __init__(self, dataset_split, config, verbose=False): 72 | self._dataset_split = MetaSplit(dataset_split) 73 | self._config = config 74 | self._verbose = verbose 75 | self._check_config() 76 | 77 | self._index_data(self._load_data()) 78 | 79 | def _check_config(self): 80 | """Checks configuration arguments of constructor.""" 81 | self._config["dataset_name"] = MetaDataset(self._config["dataset_name"]) 82 | self._config["embedding_crop"] = EmbeddingCrop( 83 | self._config["embedding_crop"]) 84 | if self._config["dataset_name"] == MetaDataset.TIERED: 85 | error_message = "embedding_crop: {} not supported for {}".format( 86 | self._config["embedding_crop"], self._config["dataset_name"]) 87 | assert self._config[ 88 | "embedding_crop"] == EmbeddingCrop.CENTER, error_message 89 | 90 | def _load_data(self): 91 | """Loads data into memory and caches .""" 92 | raw_data = self._load( 93 | tf.gfile.Open(self._get_full_pickle_path(self._dataset_split), "rb")) 94 | if self._dataset_split == MetaSplit.TRAIN and self._config["train_on_val"]: 95 | valid_data = self._load( 96 | tf.gfile.Open(self._get_full_pickle_path(MetaSplit.VALID), "rb")) 97 | for key in valid_data: 98 | if self._verbose: 99 | tf.logging.info(str([key, raw_data[key].shape])) 100 | raw_data[key] = np.concatenate([raw_data[key], 101 | valid_data[key]], axis=0) 102 | if self._verbose: 103 | tf.logging.info(str([key, raw_data[key].shape])) 104 | 105 | if self._verbose: 106 | tf.logging.info( 107 | str([(k, np.shape(v)) for k, v in six.iteritems(raw_data)])) 108 | 109 | return raw_data 110 | 111 | def _load(self, opened_file): 112 | if six.PY2: 113 | result = pickle.load(opened_file) 114 | else: 115 | result = pickle.load(opened_file, encoding="latin1") # pylint: disable=unexpected-keyword-arg 116 | return result 117 | 118 | def _index_data(self, raw_data): 119 | """Builds an index of images embeddings by class.""" 120 | self._all_class_images = collections.OrderedDict() 121 | self._image_embedding = collections.OrderedDict() 122 | for i, k in enumerate(raw_data["keys"]): 123 | _, class_label, image_file = k.split("-") 124 | image_file_class_label = image_file.split("_")[0] 125 | assert class_label == image_file_class_label 126 | self._image_embedding[image_file] = raw_data["embeddings"][i] 127 | if class_label not in self._all_class_images: 128 | self._all_class_images[class_label] = [] 129 | self._all_class_images[class_label].append(image_file) 130 | 131 | self._check_data_index(raw_data) 132 | 133 | self._all_class_images = collections.OrderedDict([ 134 | (k, np.array(v)) for k, v in six.iteritems(self._all_class_images) 135 | ]) 136 | if self._verbose: 137 | tf.logging.info(str([len(raw_data), len(self._all_class_images), 138 | len(self._image_embedding)])) 139 | 140 | def _check_data_index(self, raw_data): 141 | """Performs checks of the data index and image counts per class.""" 142 | n = raw_data["keys"].shape[0] 143 | error_message = "{} != {}".format(len(self._image_embedding), n) 144 | assert len(self._image_embedding) == n, error_message 145 | error_message = "{} != {}".format(raw_data["embeddings"].shape[0], n) 146 | assert raw_data["embeddings"].shape[0] == n, error_message 147 | 148 | all_class_folders = list(self._all_class_images.keys()) 149 | error_message = "no duplicate class names" 150 | assert len(set(all_class_folders)) == len(all_class_folders), error_message 151 | image_counts = set([len(class_images) 152 | for class_images in self._all_class_images.values()]) 153 | error_message = ("len(image_counts) should have at least one element but " 154 | "is: {}").format(image_counts) 155 | assert len(image_counts) >= 1, error_message 156 | assert min(image_counts) > 0 157 | 158 | def _get_full_pickle_path(self, split_name): 159 | full_pickle_path = os.path.join( 160 | self._config["data_path"], 161 | str(self._config["dataset_name"]), 162 | str(self._config["embedding_crop"]), 163 | "{}_embeddings.pkl".format(split_name)) 164 | if self._verbose: 165 | tf.logging.info("get_one_emb_instance: folder_path: {}".format( 166 | full_pickle_path)) 167 | return full_pickle_path 168 | 169 | def get_instance(self, num_classes, tr_size, val_size): 170 | """Samples a random N-way K-shot classification problem instance. 171 | 172 | Args: 173 | num_classes: N in N-way classification. 174 | tr_size: K in K-shot; number of training examples per class. 175 | val_size: number of validation examples per class. 176 | 177 | Returns: 178 | A tuple with 6 Tensors with the following shapes: 179 | - tr_input: (num_classes, tr_size, NDIM): training image embeddings. 180 | - tr_output: (num_classes, tr_size, 1): training image labels. 181 | - tr_info: (num_classes, tr_size): training image file names. 182 | - val_input: (num_classes, val_size, NDIM): validation image embeddings. 183 | - val_output: (num_classes, val_size, 1): validation image labels. 184 | - val_input: (num_classes, val_size): validation image file names. 185 | """ 186 | 187 | def _build_one_instance_py(): 188 | """Builds a random problem instance using data from specified classes.""" 189 | class_list = list(self._all_class_images.keys()) 190 | sample_count = (tr_size + val_size) 191 | shuffled_folders = class_list[:] 192 | random.shuffle(shuffled_folders) 193 | shuffled_folders = shuffled_folders[:num_classes] 194 | error_message = "len(shuffled_folders) {} is not num_classes: {}".format( 195 | len(shuffled_folders), num_classes) 196 | assert len(shuffled_folders) == num_classes, error_message 197 | image_paths = [] 198 | class_ids = [] 199 | embeddings = self._image_embedding 200 | for class_id, class_name in enumerate(shuffled_folders): 201 | all_images = self._all_class_images[class_name] 202 | all_images = np.random.choice(all_images, sample_count, replace=False) 203 | error_message = "{} == {} failed".format(len(all_images), sample_count) 204 | assert len(all_images) == sample_count, error_message 205 | image_paths.append(all_images) 206 | class_ids.append([[class_id]]*sample_count) 207 | 208 | label_array = np.array(class_ids, dtype=np.int32) 209 | if self._verbose: 210 | tf.logging.info(label_array.shape) 211 | if self._verbose: 212 | tf.logging.info(label_array.shape) 213 | path_array = np.array(image_paths) 214 | if self._verbose: 215 | tf.logging.info(path_array.shape) 216 | if self._verbose: 217 | tf.logging.info(path_array.shape) 218 | embedding_array = np.array([[embeddings[image_path] 219 | for image_path in class_paths] 220 | for class_paths in path_array]) 221 | if self._verbose: 222 | tf.logging.info(embedding_array.shape) 223 | return embedding_array, label_array, path_array 224 | 225 | output_list = tf.py_func(_build_one_instance_py, [], 226 | [tf.float32, tf.int32, tf.string]) 227 | instance_input, instance_output, instance_info = output_list 228 | instance_input = tf.nn.l2_normalize(instance_input, axis=-1) 229 | instance_info = tf.regex_replace(instance_info, "\x00*", "") 230 | 231 | if self._verbose: 232 | tf.logging.info("input_batch: {} ".format(instance_input.shape)) 233 | tf.logging.info("output_batch: {} ".format(instance_output.shape)) 234 | tf.logging.info("info_batch: {} ".format(instance_info.shape)) 235 | 236 | split_sizes = [tr_size, val_size] 237 | tr_input, val_input = tf.split(instance_input, split_sizes, axis=1) 238 | tr_output, val_output = tf.split(instance_output, split_sizes, axis=1) 239 | tr_info, val_info = tf.split(instance_info, split_sizes, axis=1) 240 | if self._verbose: 241 | tf.logging.info("tr_output: {} ".format(tr_output)) 242 | tf.logging.info("val_output: {}".format(val_output)) 243 | 244 | with tf.control_dependencies( 245 | self._check_labels(num_classes, tr_size, val_size, 246 | tr_output, val_output)): 247 | tr_output = tf.identity(tr_output) 248 | val_output = tf.identity(val_output) 249 | 250 | return tr_input, tr_output, tr_info, val_input, val_output, val_info 251 | 252 | def get_batch(self, batch_size, num_classes, tr_size, val_size, 253 | num_threads=10): 254 | """Returns a batch of random N-way K-shot classification problem instances. 255 | 256 | Args: 257 | batch_size: number of problem instances in the batch. 258 | num_classes: N in N-way classification. 259 | tr_size: K in K-shot; number of training examples per class. 260 | val_size: number of validation examples per class. 261 | num_threads: number of threads used to sample problem instances in 262 | parallel. 263 | 264 | Returns: 265 | A ProblemInstance of Tensors with the following shapes: 266 | - tr_input: (batch_size, num_classes, tr_size, NDIM): training image 267 | embeddings. 268 | - tr_output: (batch_size, num_classes, tr_size, 1): training image 269 | labels. 270 | - tr_info: (batch_size, num_classes, tr_size): training image file 271 | names. 272 | - val_input: (batch_size, num_classes, val_size, NDIM): validation 273 | image embeddings. 274 | - val_output: (batch_size, num_classes, val_size, 1): validation 275 | image labels. 276 | - val_info: (batch_size, num_classes, val_size): validation image 277 | file names. 278 | """ 279 | if self._verbose: 280 | num_threads = 1 281 | one_instance = self.get_instance(num_classes, tr_size, val_size) 282 | 283 | tr_data_size = (num_classes, tr_size) 284 | val_data_size = (num_classes, val_size) 285 | task_batch = tf.train.shuffle_batch(one_instance, batch_size=batch_size, 286 | capacity=1000, min_after_dequeue=0, 287 | enqueue_many=False, 288 | shapes=[tr_data_size + (NDIM,), 289 | tr_data_size + (1,), 290 | tr_data_size, 291 | val_data_size + (NDIM,), 292 | val_data_size + (1,), 293 | val_data_size], 294 | num_threads=num_threads) 295 | 296 | if self._verbose: 297 | tf.logging.info(task_batch) 298 | 299 | return ProblemInstance(*task_batch) 300 | 301 | def _check_labels(self, num_classes, tr_size, val_size, 302 | tr_output, val_output): 303 | correct_label_sum = (num_classes*(num_classes-1))//2 304 | tr_label_sum = tf.reduce_sum(tr_output)/tr_size 305 | val_label_sum = tf.reduce_sum(val_output)/val_size 306 | all_label_asserts = [ 307 | tf.assert_equal(tf.to_int32(tr_label_sum), correct_label_sum), 308 | tf.assert_equal(tf.to_int32(val_label_sum), correct_label_sum), 309 | ] 310 | return all_label_asserts 311 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | """Code defining LEO inner loop. 16 | 17 | See "Meta-Learning with Latent Embedding Optimization" by Rusu et al. 18 | (https://arxiv.org/pdf/1807.05960.pdf). 19 | """ 20 | 21 | from __future__ import absolute_import 22 | from __future__ import division 23 | from __future__ import print_function 24 | 25 | import numpy as np 26 | from six.moves import range 27 | from six.moves import zip 28 | import sonnet as snt 29 | import tensorflow as tf 30 | import tensorflow_probability as tfp 31 | 32 | import data as data_module 33 | 34 | 35 | def get_orthogonality_regularizer(orthogonality_penalty_weight): 36 | """Returns the orthogonality regularizer.""" 37 | def orthogonality(weight): 38 | """Calculates the layer-wise penalty encouraging orthogonality.""" 39 | with tf.name_scope(None, "orthogonality", [weight]) as name: 40 | w2 = tf.matmul(weight, weight, transpose_b=True) 41 | wn = tf.norm(weight, ord=2, axis=1, keepdims=True) + 1e-32 42 | correlation_matrix = w2 / tf.matmul(wn, wn, transpose_b=True) 43 | matrix_size = correlation_matrix.get_shape().as_list()[0] 44 | base_dtype = weight.dtype.base_dtype 45 | identity = tf.eye(matrix_size, dtype=base_dtype) 46 | weight_corr = tf.reduce_mean( 47 | tf.squared_difference(correlation_matrix, identity)) 48 | return tf.multiply( 49 | tf.cast(orthogonality_penalty_weight, base_dtype), 50 | weight_corr, 51 | name=name) 52 | 53 | return orthogonality 54 | 55 | 56 | class LEO(snt.AbstractModule): 57 | """Sonnet module implementing the inner loop of LEO.""" 58 | 59 | def __init__(self, config=None, use_64bits_dtype=True, name="leo"): 60 | super(LEO, self).__init__(name=name) 61 | 62 | self._float_dtype = tf.float64 if use_64bits_dtype else tf.float32 63 | self._int_dtype = tf.int64 if use_64bits_dtype else tf.int32 64 | 65 | self._inner_unroll_length = config["inner_unroll_length"] 66 | self._finetuning_unroll_length = config["finetuning_unroll_length"] 67 | self._inner_lr_init = config["inner_lr_init"] 68 | self._finetuning_lr_init = config["finetuning_lr_init"] 69 | self._num_latents = config["num_latents"] 70 | self._dropout_rate = config["dropout_rate"] 71 | 72 | self._kl_weight = config["kl_weight"] # beta 73 | self._encoder_penalty_weight = config["encoder_penalty_weight"] # gamma 74 | self._l2_penalty_weight = config["l2_penalty_weight"] # lambda_1 75 | # lambda_2 76 | self._orthogonality_penalty_weight = config["orthogonality_penalty_weight"] 77 | 78 | assert self._inner_unroll_length > 0, ("Positive unroll length is necessary" 79 | " to create the graph") 80 | 81 | def _build(self, data, is_meta_training=True): 82 | """Connects the LEO module to the graph, creating the variables. 83 | 84 | Args: 85 | data: A data_module.ProblemInstance constaining Tensors with the 86 | following shapes: 87 | - tr_input: (N, K, dim) 88 | - tr_output: (N, K, 1) 89 | - tr_info: (N, K) 90 | - val_input: (N, K_valid, dim) 91 | - val_output: (N, K_valid, 1) 92 | - val_info: (N, K_valid) 93 | where N is the number of classes (as in N-way) and K and the and 94 | K_valid are numbers of training and validation examples within a 95 | problem instance correspondingly (as in K-shot), and dim is the 96 | dimensionality of the embedding. 97 | is_meta_training: A boolean describing whether we run in the training 98 | mode. 99 | 100 | Returns: 101 | Tensor with the inner validation loss of LEO (include both adaptation in 102 | the latent space and finetuning). 103 | """ 104 | if isinstance(data, list): 105 | data = data_module.ProblemInstance(*data) 106 | self.is_meta_training = is_meta_training 107 | self.save_problem_instance_stats(data.tr_input) 108 | 109 | latents, kl = self.forward_encoder(data) 110 | tr_loss, adapted_classifier_weights, encoder_penalty = self.leo_inner_loop( 111 | data, latents) 112 | 113 | val_loss, val_accuracy = self.finetuning_inner_loop( 114 | data, tr_loss, adapted_classifier_weights) 115 | 116 | val_loss += self._kl_weight * kl 117 | val_loss += self._encoder_penalty_weight * encoder_penalty 118 | # The l2 regularization is is already added to the graph when constructing 119 | # the snt.Linear modules. We pass the orthogonality regularizer separately, 120 | # because it is not used in self.grads_and_vars. 121 | regularization_penalty = ( 122 | self._l2_regularization + self._decoder_orthogonality_reg) 123 | 124 | batch_val_loss = tf.reduce_mean(val_loss) 125 | batch_val_accuracy = tf.reduce_mean(val_accuracy) 126 | 127 | return batch_val_loss + regularization_penalty, batch_val_accuracy 128 | 129 | @snt.reuse_variables 130 | def leo_inner_loop(self, data, latents): 131 | with tf.variable_scope("leo_inner"): 132 | inner_lr = tf.get_variable( 133 | "lr", [1, 1, self._num_latents], 134 | dtype=self._float_dtype, 135 | initializer=tf.constant_initializer(self._inner_lr_init)) 136 | starting_latents = latents 137 | loss, _ = self.forward_decoder(data, latents) 138 | for _ in range(self._inner_unroll_length): 139 | loss_grad = tf.gradients(loss, latents) # dLtrain/dz 140 | latents -= inner_lr * loss_grad[0] 141 | loss, classifier_weights = self.forward_decoder(data, latents) 142 | 143 | if self.is_meta_training: 144 | encoder_penalty = tf.losses.mean_squared_error( 145 | labels=tf.stop_gradient(latents), predictions=starting_latents) 146 | encoder_penalty = tf.cast(encoder_penalty, self._float_dtype) 147 | else: 148 | encoder_penalty = tf.constant(0., self._float_dtype) 149 | 150 | return loss, classifier_weights, encoder_penalty 151 | 152 | @snt.reuse_variables 153 | def finetuning_inner_loop(self, data, leo_loss, classifier_weights): 154 | tr_loss = leo_loss 155 | with tf.variable_scope("finetuning"): 156 | finetuning_lr = tf.get_variable( 157 | "lr", [1, 1, self.embedding_dim], 158 | dtype=self._float_dtype, 159 | initializer=tf.constant_initializer(self._finetuning_lr_init)) 160 | for _ in range(self._finetuning_unroll_length): 161 | loss_grad = tf.gradients(tr_loss, classifier_weights) 162 | classifier_weights -= finetuning_lr * loss_grad[0] 163 | tr_loss, _ = self.calculate_inner_loss(data.tr_input, data.tr_output, 164 | classifier_weights) 165 | 166 | val_loss, val_accuracy = self.calculate_inner_loss( 167 | data.val_input, data.val_output, classifier_weights) 168 | return val_loss, val_accuracy 169 | 170 | @snt.reuse_variables 171 | def forward_encoder(self, data): 172 | encoder_outputs = self.encoder(data.tr_input) 173 | relation_network_outputs = self.relation_network(encoder_outputs) 174 | latent_dist_params = self.average_codes_per_class(relation_network_outputs) 175 | latents, kl = self.possibly_sample(latent_dist_params) 176 | return latents, kl 177 | 178 | @snt.reuse_variables 179 | def forward_decoder(self, data, latents): 180 | weights_dist_params = self.decoder(latents) 181 | # Default to glorot_initialization and not stddev=1. 182 | fan_in = self.embedding_dim.value 183 | fan_out = self.num_classes.value 184 | stddev_offset = np.sqrt(2. / (fan_out + fan_in)) 185 | classifier_weights, _ = self.possibly_sample(weights_dist_params, 186 | stddev_offset=stddev_offset) 187 | tr_loss, _ = self.calculate_inner_loss(data.tr_input, data.tr_output, 188 | classifier_weights) 189 | return tr_loss, classifier_weights 190 | 191 | @snt.reuse_variables 192 | def encoder(self, inputs): 193 | with tf.variable_scope("encoder"): 194 | after_dropout = tf.nn.dropout(inputs, rate=self.dropout_rate) 195 | regularizer = tf.contrib.layers.l2_regularizer(self._l2_penalty_weight) 196 | initializer = tf.initializers.glorot_uniform(dtype=self._float_dtype) 197 | encoder_module = snt.Linear( 198 | self._num_latents, 199 | use_bias=False, 200 | regularizers={"w": regularizer}, 201 | initializers={"w": initializer}, 202 | ) 203 | outputs = snt.BatchApply(encoder_module)(after_dropout) 204 | return outputs 205 | 206 | @snt.reuse_variables 207 | def relation_network(self, inputs): 208 | with tf.variable_scope("relation_network"): 209 | regularizer = tf.contrib.layers.l2_regularizer(self._l2_penalty_weight) 210 | initializer = tf.initializers.glorot_uniform(dtype=self._float_dtype) 211 | relation_network_module = snt.nets.MLP( 212 | [2 * self._num_latents] * 3, 213 | use_bias=False, 214 | regularizers={"w": regularizer}, 215 | initializers={"w": initializer}, 216 | ) 217 | total_num_examples = self.num_examples_per_class*self.num_classes 218 | inputs = tf.reshape(inputs, [total_num_examples, self._num_latents]) 219 | 220 | left = tf.tile(tf.expand_dims(inputs, 1), [1, total_num_examples, 1]) 221 | right = tf.tile(tf.expand_dims(inputs, 0), [total_num_examples, 1, 1]) 222 | concat_codes = tf.concat([left, right], axis=-1) 223 | outputs = snt.BatchApply(relation_network_module)(concat_codes) 224 | outputs = tf.reduce_mean(outputs, axis=1) 225 | # 2 * latents, because we are returning means and variances of a Gaussian 226 | outputs = tf.reshape(outputs, [self.num_classes, 227 | self.num_examples_per_class, 228 | 2 * self._num_latents]) 229 | 230 | return outputs 231 | 232 | @snt.reuse_variables 233 | def decoder(self, inputs): 234 | with tf.variable_scope("decoder"): 235 | l2_regularizer = tf.contrib.layers.l2_regularizer(self._l2_penalty_weight) 236 | orthogonality_reg = get_orthogonality_regularizer( 237 | self._orthogonality_penalty_weight) 238 | initializer = tf.initializers.glorot_uniform(dtype=self._float_dtype) 239 | # 2 * embedding_dim, because we are returning means and variances 240 | decoder_module = snt.Linear( 241 | 2 * self.embedding_dim, 242 | use_bias=False, 243 | regularizers={"w": l2_regularizer}, 244 | initializers={"w": initializer}, 245 | ) 246 | outputs = snt.BatchApply(decoder_module)(inputs) 247 | self._orthogonality_reg = orthogonality_reg(decoder_module.w) 248 | return outputs 249 | 250 | def average_codes_per_class(self, codes): 251 | codes = tf.reduce_mean(codes, axis=1, keep_dims=True) # K dimension 252 | # Keep the shape (N, K, *) 253 | codes = tf.tile(codes, [1, self.num_examples_per_class, 1]) 254 | return codes 255 | 256 | def possibly_sample(self, distribution_params, stddev_offset=0.): 257 | means, unnormalized_stddev = tf.split(distribution_params, 2, axis=-1) 258 | stddev = tf.exp(unnormalized_stddev) 259 | stddev -= (1. - stddev_offset) 260 | stddev = tf.maximum(stddev, 1e-10) 261 | distribution = tfp.distributions.Normal(loc=means, scale=stddev) 262 | if not self.is_meta_training: 263 | return means, tf.constant(0., dtype=self._float_dtype) 264 | 265 | samples = distribution.sample() 266 | kl_divergence = self.kl_divergence(samples, distribution) 267 | return samples, kl_divergence 268 | 269 | def kl_divergence(self, samples, normal_distribution): 270 | random_prior = tfp.distributions.Normal( 271 | loc=tf.zeros_like(samples), scale=tf.ones_like(samples)) 272 | kl = tf.reduce_mean( 273 | normal_distribution.log_prob(samples) - random_prior.log_prob(samples)) 274 | return kl 275 | 276 | def predict(self, inputs, weights): 277 | after_dropout = tf.nn.dropout(inputs, rate=self.dropout_rate) 278 | # This is 3-dimensional equivalent of a matrix product, where we sum over 279 | # the last (embedding_dim) dimension. We get [N, K, N, K] tensor as output. 280 | per_image_predictions = tf.einsum("ijk,lmk->ijlm", after_dropout, weights) 281 | 282 | # Predictions have shape [N, K, N]: for each image ([N, K] of them), what 283 | # is the probability of a given class (N)? 284 | predictions = tf.reduce_mean(per_image_predictions, axis=-1) 285 | return predictions 286 | 287 | def calculate_inner_loss(self, inputs, true_outputs, classifier_weights): 288 | model_outputs = self.predict(inputs, classifier_weights) 289 | model_predictions = tf.argmax( 290 | model_outputs, -1, output_type=self._int_dtype) 291 | accuracy = tf.contrib.metrics.accuracy(model_predictions, 292 | tf.squeeze(true_outputs, axis=-1)) 293 | 294 | return self.loss_fn(model_outputs, true_outputs), accuracy 295 | 296 | def save_problem_instance_stats(self, instance): 297 | num_classes, num_examples_per_class, embedding_dim = instance.get_shape() 298 | if hasattr(self, "num_classes"): 299 | assert self.num_classes == num_classes, ( 300 | "Given different number of classes (N in N-way) in consecutive runs.") 301 | if hasattr(self, "num_examples_per_class"): 302 | assert self.num_examples_per_class == num_examples_per_class, ( 303 | "Given different number of examples (K in K-shot) in consecutive" 304 | "runs.") 305 | if hasattr(self, "embedding_dim"): 306 | assert self.embedding_dim == embedding_dim, ( 307 | "Given different embedding dimension in consecutive runs.") 308 | 309 | self.num_classes = num_classes 310 | self.num_examples_per_class = num_examples_per_class 311 | self.embedding_dim = embedding_dim 312 | 313 | @property 314 | def dropout_rate(self): 315 | return self._dropout_rate if self.is_meta_training else 0.0 316 | 317 | def loss_fn(self, model_outputs, original_classes): 318 | original_classes = tf.squeeze(original_classes, axis=-1) 319 | # Tensorflow doesn't handle second order gradients of a sparse_softmax yet. 320 | one_hot_outputs = tf.one_hot(original_classes, depth=self.num_classes) 321 | return tf.nn.softmax_cross_entropy_with_logits_v2( 322 | labels=one_hot_outputs, logits=model_outputs) 323 | 324 | def grads_and_vars(self, metatrain_loss): 325 | """Computes gradients of metatrain_loss, avoiding NaN. 326 | 327 | Uses a fixed penalty of 1e-4 to enforce only the l2 regularization (and not 328 | minimize the loss) when metatrain_loss or any of its gradients with respect 329 | to trainable_vars are NaN. In practice, this approach pulls the variables 330 | back into a feasible region of the space when the loss or its gradients are 331 | not defined. 332 | 333 | Args: 334 | metatrain_loss: A tensor with the LEO meta-training loss. 335 | 336 | Returns: 337 | A tuple with: 338 | metatrain_gradients: A list of gradient tensors. 339 | metatrain_variables: A list of variables for this LEO model. 340 | """ 341 | metatrain_variables = self.trainable_variables 342 | metatrain_gradients = tf.gradients(metatrain_loss, metatrain_variables) 343 | 344 | nan_loss_or_grad = tf.logical_or( 345 | tf.is_nan(metatrain_loss), 346 | tf.reduce_any([tf.reduce_any(tf.is_nan(g)) 347 | for g in metatrain_gradients])) 348 | 349 | regularization_penalty = ( 350 | 1e-4 / self._l2_penalty_weight * self._l2_regularization) 351 | zero_or_regularization_gradients = [ 352 | g if g is not None else tf.zeros_like(v) 353 | for v, g in zip(tf.gradients(regularization_penalty, 354 | metatrain_variables), metatrain_variables)] 355 | 356 | metatrain_gradients = tf.cond(nan_loss_or_grad, 357 | lambda: zero_or_regularization_gradients, 358 | lambda: metatrain_gradients, strict=True) 359 | 360 | return metatrain_gradients, metatrain_variables 361 | 362 | @property 363 | def _l2_regularization(self): 364 | return tf.cast( 365 | tf.reduce_sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)), 366 | dtype=self._float_dtype) 367 | 368 | @property 369 | def _decoder_orthogonality_reg(self): 370 | return self._orthogonality_reg 371 | -------------------------------------------------------------------------------- /model_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | """Tests for ml_leo.model.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import functools 22 | from absl.testing import parameterized 23 | import mock 24 | import numpy as np 25 | from six.moves import zip 26 | import sonnet as snt 27 | import tensorflow as tf 28 | 29 | import data 30 | import model 31 | 32 | # Adding float64 and 32 gives an error in TensorFlow. 33 | constant_float64 = lambda x: tf.constant(x, dtype=tf.float64) 34 | 35 | 36 | def get_test_config(): 37 | """Returns the config used to initialize LEO model.""" 38 | config = {} 39 | config["inner_unroll_length"] = 3 40 | config["finetuning_unroll_length"] = 4 41 | config["inner_lr_init"] = 0.1 42 | config["finetuning_lr_init"] = 0.2 43 | config["num_latents"] = 1 44 | config["dropout_rate"] = 0.3 45 | config["kl_weight"] = 0.01 46 | config["encoder_penalty_weight"] = 0.01 47 | config["l2_penalty_weight"] = 0.01 48 | config["orthogonality_penalty_weight"] = 0.01 49 | 50 | return config 51 | 52 | 53 | def mockify_everything(test_function=None, 54 | mock_finetuning=True, 55 | mock_encdec=True): 56 | """Mockifies most of the LEO"s model functions to behave as identity.""" 57 | 58 | def inner_decorator(f): 59 | @functools.wraps(f) 60 | def mockified(*args, **kwargs): 61 | identity_mapping = lambda unused_self, inp, *args: tf.identity(inp) 62 | mock_encoder = mock.patch.object( 63 | model.LEO, "encoder", new=identity_mapping) 64 | mock_relation_network = mock.patch.object( 65 | model.LEO, "relation_network", new=identity_mapping) 66 | mock_decoder = mock.patch.object( 67 | model.LEO, "decoder", new=identity_mapping) 68 | mock_average = mock.patch.object( 69 | model.LEO, "average_codes_per_class", new=identity_mapping) 70 | mock_loss = mock.patch.object(model.LEO, "loss_fn", new=identity_mapping) 71 | 72 | float64_zero = constant_float64(0.) 73 | def identity_sample_fn(unused_self, inp, *unused_args, **unused_kwargs): 74 | return inp, float64_zero 75 | 76 | def mock_sample_with_split(unused_self, inp, *unused_args, 77 | **unused_kwargs): 78 | out = tf.split(inp, 2, axis=-1)[0] 79 | return out, float64_zero 80 | 81 | # When not mocking relation net, it will double the latents. 82 | mock_sample = mock.patch.object( 83 | model.LEO, 84 | "possibly_sample", 85 | new=identity_sample_fn if mock_encdec else mock_sample_with_split) 86 | 87 | def dummy_predict(unused_self, inputs, classifier_weights): 88 | return inputs * classifier_weights**2 89 | 90 | mock_predict = mock.patch.object(model.LEO, "predict", new=dummy_predict) 91 | 92 | mock_decoder_regularizer = mock.patch.object( 93 | model.LEO, "_decoder_orthogonality_reg", new=float64_zero) 94 | 95 | all_mocks = [mock_average, mock_loss, mock_predict, mock_sample] 96 | if mock_encdec: 97 | all_mocks.extend([ 98 | mock_encoder, 99 | mock_relation_network, 100 | mock_decoder, 101 | mock_decoder_regularizer, 102 | ]) 103 | if mock_finetuning: 104 | mock_finetuning_inner = mock.patch.object( 105 | model.LEO, 106 | "finetuning_inner_loop", 107 | new=lambda unused_self, d, l, adapted: (adapted, float64_zero)) 108 | all_mocks.append(mock_finetuning_inner) 109 | 110 | for m in all_mocks: 111 | m.start() 112 | 113 | f(*args, **kwargs) 114 | 115 | for m in all_mocks: 116 | m.stop() 117 | 118 | return mockified 119 | 120 | if test_function: 121 | # Decorator called with no arguments, so the function is passed 122 | return inner_decorator(test_function) 123 | return inner_decorator 124 | 125 | 126 | def _random_problem_instance(num_classes=7, 127 | num_examples_per_class=5, 128 | embedding_dim=17, use_64bits_dtype=True): 129 | inputs_dtype = tf.float64 if use_64bits_dtype else tf.float32 130 | inputs = tf.constant( 131 | np.random.random((num_classes, num_examples_per_class, embedding_dim)), 132 | dtype=inputs_dtype) 133 | outputs_dtype = tf.int64 if use_64bits_dtype else tf.int32 134 | outputs = tf.constant( 135 | np.random.randint( 136 | low=0, 137 | high=num_classes, 138 | size=(num_classes, num_examples_per_class, 1)), dtype=outputs_dtype) 139 | problem = data.ProblemInstance( 140 | tr_input=inputs, 141 | val_input=inputs, 142 | tr_info=inputs, 143 | tr_output=outputs, 144 | val_output=outputs, 145 | val_info=inputs) 146 | return problem 147 | 148 | 149 | class LEOTest(tf.test.TestCase, parameterized.TestCase): 150 | 151 | def setUp(self): 152 | super(LEOTest, self).setUp() 153 | self._problem = _random_problem_instance(5, 7, 4) 154 | # This doesn"t call any function, so doesn't need the mocks to be started. 155 | self._config = get_test_config() 156 | self._leo = model.LEO(config=self._config) 157 | self.addCleanup(mock.patch.stopall) 158 | 159 | @mockify_everything 160 | def test_instantiate_leo(self): 161 | encoder_output = self._leo.encoder(5, 7) 162 | with self.session() as sess: 163 | encoder_output_ev = sess.run(encoder_output) 164 | 165 | self.assertEqual(encoder_output_ev, 5) 166 | 167 | @mockify_everything 168 | def test_inner_loop_adaptation(self): 169 | problem_instance = data.ProblemInstance( 170 | tr_input=constant_float64([[[4.]]]), 171 | tr_output=tf.constant([[[0]]], dtype=tf.int64), 172 | tr_info=[], 173 | val_input=[], 174 | val_output=[], 175 | val_info=[], 176 | ) 177 | # encoder = decoder = id 178 | # predict returns classifier_weights**2 * inputs = latents**2 * inputs 179 | # loss = id = inputs*latents 180 | # dl/dlatent = 2 * latent * inputs 181 | # 4 -> 4 - 0.1 * 2 * 4 * 4 = 0.8 182 | # 0.8 -> 0.8 - 0.1 * 2 * 0.8 * 4 = 0.16 183 | # 0.16 -> 0.16 - 0.1 * 2 * 0.16 * 4 = 0.032 184 | 185 | # is_meta_training=False disables kl and encoder penalties 186 | adapted_parameters, _ = self._leo(problem_instance, is_meta_training=False) 187 | 188 | with self.session() as sess: 189 | sess.run(tf.global_variables_initializer()) 190 | self.assertAllClose(sess.run(adapted_parameters), 0.032) 191 | 192 | @mockify_everything 193 | def test_map_input(self): 194 | problem = [ 195 | constant_float64([[[5.]]]), # tr_input 196 | tf.constant([[[0]]], dtype=tf.int64), # tr_output 197 | constant_float64([[[0]]]), # tr_info 198 | constant_float64([[[0.]]]), # val_input 199 | tf.constant([[[0]]], dtype=tf.int64), # val_output 200 | constant_float64([[[0]]]), # val_info 201 | ] 202 | another_problem = [ 203 | constant_float64([[[4.]]]), 204 | tf.constant([[[0]]], dtype=tf.int64), 205 | constant_float64([[[0]]]), 206 | constant_float64([[[0.]]]), 207 | tf.constant([[[0]]], dtype=tf.int64), 208 | constant_float64([[[0]]]), 209 | ] 210 | # first dimension (list): diffent input kind (tr_input, val_output, etc.) 211 | # second dim: different problems; this has to be a tensor dim for map_fn 212 | # to split over it. 213 | # next three: (1, 1, 1) 214 | 215 | # map_fn cannot receive structured inputs (namedtuples). 216 | ins = [ 217 | tf.stack([in1, in2]) 218 | for in1, in2 in zip(problem, another_problem) 219 | ] 220 | 221 | two_adapted_params, _ = tf.map_fn( 222 | self._leo.__call__, ins, dtype=(tf.float64, tf.float64)) 223 | 224 | with self.session() as sess: 225 | sess.run(tf.global_variables_initializer()) 226 | output1, output2 = sess.run(two_adapted_params) 227 | self.assertGreater(abs(output1 - output2), 1e-3) 228 | 229 | @mockify_everything 230 | def test_setting_is_meta_training(self): 231 | self._leo(self._problem, is_meta_training=True) 232 | self.assertTrue(self._leo.is_meta_training) 233 | self._leo(self._problem, is_meta_training=False) 234 | self.assertFalse(self._leo.is_meta_training) 235 | 236 | @mockify_everything(mock_finetuning=False) 237 | def test_finetuning_improves_loss(self): 238 | # Create graph 239 | self._leo(self._problem) 240 | 241 | latents, _ = self._leo.forward_encoder(self._problem) 242 | leo_loss, adapted_classifier_weights, _ = self._leo.leo_inner_loop( 243 | self._problem, latents) 244 | leo_loss = tf.reduce_mean(leo_loss) 245 | finetuning_loss, _ = self._leo.finetuning_inner_loop( 246 | self._problem, leo_loss, adapted_classifier_weights) 247 | finetuning_loss = tf.reduce_mean(finetuning_loss) 248 | with self.session() as sess: 249 | sess.run(tf.global_variables_initializer()) 250 | leo_loss_ev, finetuning_loss_ev = sess.run([leo_loss, finetuning_loss]) 251 | self.assertGreater(leo_loss_ev - 1e-3, finetuning_loss_ev) 252 | 253 | @mockify_everything 254 | def test_gradients_dont_flow_through_input(self): 255 | # Create graph 256 | self._leo(self._problem) 257 | latents, _ = self._leo.forward_encoder(self._problem) 258 | grads = tf.gradients(self._problem.tr_input, latents) 259 | self.assertIsNone(grads[0]) 260 | 261 | @mockify_everything 262 | def test_inferring_embedding_dim(self): 263 | self._leo(self._problem) 264 | self.assertEqual(self._leo.embedding_dim, 4) 265 | 266 | @mockify_everything(mock_encdec=False, mock_finetuning=False) 267 | def test_variable_creation(self): 268 | self._leo(self._problem) 269 | encoder_variables = snt.get_variables_in_scope("leo/encoder") 270 | self.assertNotEmpty(encoder_variables) 271 | relation_network_variables = snt.get_variables_in_scope( 272 | "leo/relation_network") 273 | self.assertNotEmpty(relation_network_variables) 274 | decoder_variables = snt.get_variables_in_scope("leo/decoder") 275 | self.assertNotEmpty(decoder_variables) 276 | inner_lr = snt.get_variables_in_scope("leo/leo_inner") 277 | self.assertNotEmpty(inner_lr) 278 | finetuning_lr = snt.get_variables_in_scope("leo/finetuning") 279 | self.assertNotEmpty(finetuning_lr) 280 | self.assertSameElements( 281 | encoder_variables + relation_network_variables + decoder_variables + 282 | inner_lr + finetuning_lr, self._leo.trainable_variables) 283 | 284 | def test_graph_construction(self): 285 | self._leo(self._problem) 286 | 287 | def test_possibly_sample(self): 288 | # Embedding dimension has to be divisible by 2 here. 289 | self._leo(self._problem, is_meta_training=True) 290 | train_samples, train_kl = self._leo.possibly_sample(self._problem.tr_input) 291 | 292 | self._leo(self._problem, is_meta_training=False) 293 | test_samples, test_kl = self._leo.possibly_sample(self._problem.tr_input) 294 | 295 | with self.session() as sess: 296 | train_samples_ev1, test_samples_ev1 = sess.run( 297 | [train_samples, test_samples]) 298 | train_samples_ev2, test_samples_ev2 = sess.run( 299 | [train_samples, test_samples]) 300 | 301 | self.assertAllClose(test_samples_ev1, test_samples_ev2) 302 | self.assertGreater(abs(np.sum(train_samples_ev1 - train_samples_ev2)), 1.) 303 | 304 | train_kl_ev, test_kl_ev = sess.run([train_kl, test_kl]) 305 | self.assertNotEqual(train_kl_ev, 0.) 306 | self.assertEqual(test_kl_ev, 0.) 307 | 308 | def test_different_shapes(self): 309 | problem_instance2 = _random_problem_instance(5, 6, 13) 310 | 311 | self._leo(self._problem) 312 | with self.assertRaises(AssertionError): 313 | self._leo(problem_instance2) 314 | 315 | def test_encoder_penalty(self): 316 | self._leo(self._problem) # Sets is_meta_training 317 | latents, _ = self._leo.forward_encoder(self._problem) 318 | _, _, train_encoder_penalty = self._leo.leo_inner_loop( 319 | self._problem, latents) 320 | 321 | self._leo(self._problem, is_meta_training=False) 322 | _, _, test_encoder_penalty = self._leo.leo_inner_loop( 323 | self._problem, latents) 324 | 325 | with self.session() as sess: 326 | sess.run(tf.initializers.global_variables()) 327 | train_encoder_penalty_ev, test_encoder_penalty_ev = sess.run( 328 | [train_encoder_penalty, test_encoder_penalty]) 329 | self.assertGreater(train_encoder_penalty_ev, 1e-3) 330 | self.assertLess(test_encoder_penalty_ev, 1e-7) 331 | 332 | def test_construct_float32_leo_graph(self): 333 | leo = model.LEO(use_64bits_dtype=False, config=self._config) 334 | problem_instance_32_bits = _random_problem_instance(use_64bits_dtype=False) 335 | leo(problem_instance_32_bits) 336 | 337 | 338 | if __name__ == "__main__": 339 | tf.test.main() 340 | -------------------------------------------------------------------------------- /runner.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | """A binary building the graph and performing the optimization of LEO.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import functools 22 | import os 23 | import pickle 24 | 25 | from absl import flags 26 | from six.moves import zip 27 | import tensorflow as tf 28 | 29 | import config 30 | import data 31 | import model 32 | import utils 33 | 34 | FLAGS = flags.FLAGS 35 | flags.DEFINE_string("checkpoint_path", "/tmp/leo", "Path to restore from and " 36 | "save to checkpoints.") 37 | flags.DEFINE_integer( 38 | "checkpoint_steps", 1000, "The frequency, in number of " 39 | "steps, of saving the checkpoints.") 40 | flags.DEFINE_boolean("evaluation_mode", False, "Whether to run in an " 41 | "evaluation-only mode.") 42 | 43 | 44 | def _clip_gradients(gradients, gradient_threshold, gradient_norm_threshold): 45 | """Clips gradients by value and then by norm.""" 46 | if gradient_threshold > 0: 47 | gradients = [ 48 | tf.clip_by_value(g, -gradient_threshold, gradient_threshold) 49 | for g in gradients 50 | ] 51 | if gradient_norm_threshold > 0: 52 | gradients = [ 53 | tf.clip_by_norm(g, gradient_norm_threshold) for g in gradients 54 | ] 55 | return gradients 56 | 57 | 58 | def _construct_validation_summaries(metavalid_loss, metavalid_accuracy): 59 | tf.summary.scalar("metavalid_loss", metavalid_loss) 60 | tf.summary.scalar("metavalid_valid_accuracy", metavalid_accuracy) 61 | # The summaries are passed implicitly by TensorFlow. 62 | 63 | 64 | def _construct_training_summaries(metatrain_loss, metatrain_accuracy, 65 | model_grads, model_vars): 66 | tf.summary.scalar("metatrain_loss", metatrain_loss) 67 | tf.summary.scalar("metatrain_valid_accuracy", metatrain_accuracy) 68 | for g, v in zip(model_grads, model_vars): 69 | histogram_name = v.name.split(":")[0] 70 | tf.summary.histogram(histogram_name, v) 71 | histogram_name = "gradient/{}".format(histogram_name) 72 | tf.summary.histogram(histogram_name, g) 73 | 74 | 75 | def _construct_examples_batch(batch_size, split, num_classes, 76 | num_tr_examples_per_class, 77 | num_val_examples_per_class): 78 | data_provider = data.DataProvider(split, config.get_data_config()) 79 | examples_batch = data_provider.get_batch(batch_size, num_classes, 80 | num_tr_examples_per_class, 81 | num_val_examples_per_class) 82 | return utils.unpack_data(examples_batch) 83 | 84 | 85 | def _construct_loss_and_accuracy(inner_model, inputs, is_meta_training): 86 | """Returns batched loss and accuracy of the model ran on the inputs.""" 87 | call_fn = functools.partial( 88 | inner_model.__call__, is_meta_training=is_meta_training) 89 | per_instance_loss, per_instance_accuracy = tf.map_fn( 90 | call_fn, 91 | inputs, 92 | dtype=(tf.float32, tf.float32), 93 | back_prop=is_meta_training) 94 | loss = tf.reduce_mean(per_instance_loss) 95 | accuracy = tf.reduce_mean(per_instance_accuracy) 96 | return loss, accuracy 97 | 98 | 99 | def construct_graph(outer_model_config): 100 | """Constructs the optimization graph.""" 101 | inner_model_config = config.get_inner_model_config() 102 | tf.logging.info("inner_model_config: {}".format(inner_model_config)) 103 | leo = model.LEO(inner_model_config, use_64bits_dtype=False) 104 | 105 | num_classes = outer_model_config["num_classes"] 106 | num_tr_examples_per_class = outer_model_config["num_tr_examples_per_class"] 107 | metatrain_batch = _construct_examples_batch( 108 | outer_model_config["metatrain_batch_size"], "train", num_classes, 109 | num_tr_examples_per_class, 110 | outer_model_config["num_val_examples_per_class"]) 111 | metatrain_loss, metatrain_accuracy = _construct_loss_and_accuracy( 112 | leo, metatrain_batch, True) 113 | 114 | metatrain_gradients, metatrain_variables = leo.grads_and_vars(metatrain_loss) 115 | 116 | # Avoids NaNs in summaries. 117 | metatrain_loss = tf.cond(tf.is_nan(metatrain_loss), 118 | lambda: tf.zeros_like(metatrain_loss), 119 | lambda: metatrain_loss) 120 | 121 | metatrain_gradients = _clip_gradients( 122 | metatrain_gradients, outer_model_config["gradient_threshold"], 123 | outer_model_config["gradient_norm_threshold"]) 124 | 125 | _construct_training_summaries(metatrain_loss, metatrain_accuracy, 126 | metatrain_gradients, metatrain_variables) 127 | optimizer = tf.train.AdamOptimizer( 128 | learning_rate=outer_model_config["outer_lr"]) 129 | global_step = tf.train.get_or_create_global_step() 130 | train_op = optimizer.apply_gradients( 131 | list(zip(metatrain_gradients, metatrain_variables)), global_step) 132 | 133 | data_config = config.get_data_config() 134 | tf.logging.info("data_config: {}".format(data_config)) 135 | total_examples_per_class = data_config["total_examples_per_class"] 136 | metavalid_batch = _construct_examples_batch( 137 | outer_model_config["metavalid_batch_size"], "val", num_classes, 138 | num_tr_examples_per_class, 139 | total_examples_per_class - num_tr_examples_per_class) 140 | metavalid_loss, metavalid_accuracy = _construct_loss_and_accuracy( 141 | leo, metavalid_batch, False) 142 | 143 | metatest_batch = _construct_examples_batch( 144 | outer_model_config["metatest_batch_size"], "test", num_classes, 145 | num_tr_examples_per_class, 146 | total_examples_per_class - num_tr_examples_per_class) 147 | _, metatest_accuracy = _construct_loss_and_accuracy( 148 | leo, metatest_batch, False) 149 | _construct_validation_summaries(metavalid_loss, metavalid_accuracy) 150 | 151 | return (train_op, global_step, metatrain_accuracy, metavalid_accuracy, 152 | metatest_accuracy) 153 | 154 | 155 | def run_training_loop(checkpoint_path): 156 | """Runs the training loop, either saving a checkpoint or evaluating it.""" 157 | outer_model_config = config.get_outer_model_config() 158 | tf.logging.info("outer_model_config: {}".format(outer_model_config)) 159 | (train_op, global_step, metatrain_accuracy, metavalid_accuracy, 160 | metatest_accuracy) = construct_graph(outer_model_config) 161 | 162 | num_steps_limit = outer_model_config["num_steps_limit"] 163 | best_metavalid_accuracy = 0. 164 | 165 | with tf.train.MonitoredTrainingSession( 166 | checkpoint_dir=checkpoint_path, 167 | save_summaries_steps=FLAGS.checkpoint_steps, 168 | log_step_count_steps=FLAGS.checkpoint_steps, 169 | save_checkpoint_steps=FLAGS.checkpoint_steps, 170 | summary_dir=checkpoint_path) as sess: 171 | if not FLAGS.evaluation_mode: 172 | global_step_ev = sess.run(global_step) 173 | while global_step_ev < num_steps_limit: 174 | if global_step_ev % FLAGS.checkpoint_steps == 0: 175 | # Just after saving checkpoint, calculate accuracy 10 times and save 176 | # the best checkpoint for early stopping. 177 | metavalid_accuracy_ev = utils.evaluate_and_average( 178 | sess, metavalid_accuracy, 10) 179 | tf.logging.info("Step: {} meta-valid accuracy: {}".format( 180 | global_step_ev, metavalid_accuracy_ev)) 181 | 182 | if metavalid_accuracy_ev > best_metavalid_accuracy: 183 | utils.copy_checkpoint(checkpoint_path, global_step_ev, 184 | metavalid_accuracy_ev) 185 | best_metavalid_accuracy = metavalid_accuracy_ev 186 | 187 | _, global_step_ev, metatrain_accuracy_ev = sess.run( 188 | [train_op, global_step, metatrain_accuracy]) 189 | if global_step_ev % (FLAGS.checkpoint_steps // 2) == 0: 190 | tf.logging.info("Step: {} meta-train accuracy: {}".format( 191 | global_step_ev, metatrain_accuracy_ev)) 192 | else: 193 | assert not FLAGS.checkpoint_steps 194 | num_metatest_estimates = ( 195 | 10000 // outer_model_config["metatest_batch_size"]) 196 | 197 | test_accuracy = utils.evaluate_and_average(sess, metatest_accuracy, 198 | num_metatest_estimates) 199 | 200 | tf.logging.info("Metatest accuracy: %f", test_accuracy) 201 | with tf.gfile.Open( 202 | os.path.join(checkpoint_path, "test_accuracy"), "wb") as f: 203 | pickle.dump(test_accuracy, f) 204 | 205 | 206 | def main(argv): 207 | del argv # Unused. 208 | run_training_loop(FLAGS.checkpoint_path) 209 | 210 | 211 | if __name__ == "__main__": 212 | tf.app.run() 213 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | """Short utility functions for LEO.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import os 22 | import pickle 23 | 24 | from six.moves import range 25 | import tensorflow as tf 26 | 27 | import config 28 | import data 29 | 30 | 31 | def unpack_data(problem_instance): 32 | """Map data.ProblemInstance to a list of Tensors, to process with map_fn.""" 33 | if isinstance(problem_instance, data.ProblemInstance): 34 | return list(problem_instance) 35 | return problem_instance 36 | 37 | 38 | def copy_checkpoint(checkpoint_path, global_step, accuracy): 39 | """Copies the checkpoint to a separate directory.""" 40 | tmp_checkpoint_path = os.path.join(checkpoint_path, "tmp_best_checkpoint") 41 | best_checkpoint_path = os.path.join(checkpoint_path, "best_checkpoint") 42 | if _is_previous_accuracy_better(best_checkpoint_path, accuracy): 43 | tf.logging.info("Not copying the checkpoint: there is a better one from " 44 | "before a preemption.") 45 | return 46 | 47 | checkpoint_regex = os.path.join(checkpoint_path, 48 | "model.ckpt-{}.*".format(global_step)) 49 | checkpoint_files = tf.gfile.Glob(checkpoint_regex) 50 | graph_file = os.path.join(checkpoint_path, "graph.pbtxt") 51 | checkpoint_files.append(graph_file) 52 | 53 | _save_files_in_tmp_directory(tmp_checkpoint_path, checkpoint_files, accuracy) 54 | 55 | new_checkpoint_index_file = os.path.join(tmp_checkpoint_path, "checkpoint") 56 | with tf.gfile.Open(new_checkpoint_index_file, "w") as f: 57 | f.write("model_checkpoint_path: \"{}/model.ckpt-{}\"\n".format( 58 | best_checkpoint_path, global_step)) 59 | 60 | # We first copy the better checkpoint to a temporary directory, and only 61 | # when it's created move it to avoid inconsistent state when job is preempted 62 | # when copying the checkpoint. 63 | if tf.gfile.Exists(best_checkpoint_path): 64 | tf.gfile.DeleteRecursively(best_checkpoint_path) 65 | tf.gfile.Rename(tmp_checkpoint_path, best_checkpoint_path) 66 | tf.logging.info("Copied new best checkpoint with accuracy %.5f", accuracy) 67 | 68 | 69 | def _save_files_in_tmp_directory(tmp_checkpoint_path, checkpoint_files, 70 | accuracy): 71 | """Saves the checkpoint files and accuracy in a temporary directory.""" 72 | 73 | if tf.gfile.Exists(tmp_checkpoint_path): 74 | tf.logging.info("The temporary directory exists, because job was preempted " 75 | "before it managed to move it. We're removing it.") 76 | tf.gfile.DeleteRecursively(tmp_checkpoint_path) 77 | tf.gfile.MkDir(tmp_checkpoint_path) 78 | 79 | def dump_in_best_checkpoint_path(obj, filename): 80 | full_path = os.path.join(tmp_checkpoint_path, filename) 81 | with tf.gfile.Open(full_path, "wb") as f: 82 | pickle.dump(obj, f) 83 | 84 | for file_ in checkpoint_files: 85 | just_filename = file_.split("/")[-1] 86 | tf.gfile.Copy( 87 | file_, 88 | os.path.join(tmp_checkpoint_path, just_filename), 89 | overwrite=False) 90 | dump_in_best_checkpoint_path(config.get_inner_model_config(), "inner_config") 91 | dump_in_best_checkpoint_path(config.get_outer_model_config(), "outer_config") 92 | dump_in_best_checkpoint_path(accuracy, "accuracy") 93 | 94 | 95 | def _is_previous_accuracy_better(best_checkpoint_path, accuracy): 96 | if not tf.gfile.Exists(best_checkpoint_path): 97 | return False 98 | 99 | previous_accuracy_file = os.path.join(best_checkpoint_path, "accuracy") 100 | with tf.gfile.Open(previous_accuracy_file, "rb") as f: 101 | previous_accuracy = pickle.load(f) 102 | 103 | return previous_accuracy > accuracy 104 | 105 | 106 | def evaluate_and_average(session, tensor, num_estimates): 107 | tensor_value_estimates = [session.run(tensor) for _ in range(num_estimates)] 108 | average_tensor_value = sum(tensor_value_estimates) / num_estimates 109 | return average_tensor_value 110 | --------------------------------------------------------------------------------