├── .gitignore ├── AUTHORS ├── CONTRIBUTING.md ├── Intro_to_Metadataset.ipynb ├── LICENSE ├── Leaderboard.ipynb ├── MANIFEST.in ├── README.md ├── VTAB-plus-MD.md ├── doc ├── dataset_conversion.md └── reproducing_best_results.md ├── meta_dataset ├── __init__.py ├── analysis │ ├── __init__.py │ ├── select_best_model.py │ └── transferability │ │ ├── __init__.py │ │ ├── criticality.py │ │ └── leep.py ├── analyze.py ├── data │ ├── __init__.py │ ├── config.py │ ├── dataset_spec.py │ ├── decoder.py │ ├── decoder_test.py │ ├── dump_and_read_episodes_test.py │ ├── dump_episodes.py │ ├── imagenet_specification.py │ ├── imagenet_specification_test.py │ ├── imagenet_stats.py │ ├── learning_spec.py │ ├── pipeline.py │ ├── pipeline_test.py │ ├── providers.py │ ├── read_episodes.py │ ├── reader.py │ ├── reader_test.py │ ├── sampling.py │ ├── sampling_test.py │ ├── sur_decoder.py │ ├── test_utils.py │ ├── tfds │ │ ├── README.md │ │ ├── __init__.py │ │ ├── api.py │ │ ├── constants.py │ │ ├── example_generators.py │ │ ├── md_tfds.py │ │ ├── meta_dataset_test.py │ │ └── test_utils.py │ └── utils.py ├── dataset_conversion │ ├── ImageNet_CUBirds_duplicates.txt │ ├── ImageNet_Caltech101_duplicates.txt │ ├── ImageNet_Caltech256_duplicates.txt │ ├── TrafficSign_labels.txt │ ├── VggFlower_labels.txt │ ├── __init__.py │ ├── check_dataset_consistency.py │ ├── convert_datasets_to_records.py │ ├── dataset_specs │ │ ├── aircraft_dataset_spec.json │ │ ├── cu_birds_dataset_spec.json │ │ ├── dtd_dataset_spec.json │ │ ├── fungi_dataset_spec.json │ │ ├── ilsvrc_2012_dataset_spec.json │ │ ├── ilsvrc_2012_num_leaf_images.json │ │ ├── ilsvrc_2012_v2_dataset_spec.json │ │ ├── ilsvrc_2012_v2_num_leaf_images.json │ │ ├── mscoco_dataset_spec.json │ │ ├── omniglot_dataset_spec.json │ │ ├── quickdraw_dataset_spec.json │ │ ├── traffic_sign_dataset_spec.json │ │ └── vgg_flower_dataset_spec.json │ ├── dataset_to_records.py │ └── splits │ │ ├── aircraft_splits.json │ │ ├── cu_birds_splits.json │ │ ├── dtd_splits.json │ │ ├── fungi_splits.json │ │ ├── mscoco_splits.json │ │ ├── quickdraw_splits.json │ │ ├── traffic_sign_splits.json │ │ └── vgg_flower_splits.json ├── distribute_utils.py ├── learn │ └── gin │ │ ├── best │ │ ├── baseline_all.gin │ │ ├── baseline_imagenet.gin │ │ ├── baselinefinetune_all.gin │ │ ├── baselinefinetune_cosine_mini_imagenet_fiveshot.gin │ │ ├── baselinefinetune_cosine_mini_imagenet_oneshot.gin │ │ ├── baselinefinetune_imagenet.gin │ │ ├── baselinefinetune_mini_imagenet_fiveshot.gin │ │ ├── baselinefinetune_mini_imagenet_oneshot.gin │ │ ├── flute.gin │ │ ├── flute_init_from_imagenet.gin │ │ ├── flute_init_from_scratch.gin │ │ ├── maml_all.gin │ │ ├── maml_all_from_scratch.gin │ │ ├── maml_imagenet.gin │ │ ├── maml_imagenet_from_scratch.gin │ │ ├── maml_init_with_proto_all.gin │ │ ├── maml_init_with_proto_all_from_scratch.gin │ │ ├── maml_init_with_proto_imagenet.gin │ │ ├── maml_init_with_proto_imagenet_from_scratch.gin │ │ ├── maml_init_with_proto_inference_all.gin │ │ ├── maml_init_with_proto_inference_imagenet.gin │ │ ├── matching_all.gin │ │ ├── matching_all_from_scratch.gin │ │ ├── matching_imagenet.gin │ │ ├── matching_imagenet_from_scratch.gin │ │ ├── matching_inference_all.gin │ │ ├── matching_inference_imagenet.gin │ │ ├── pretrain_imagenet_convnet.gin │ │ ├── pretrain_imagenet_resnet.gin │ │ ├── pretrain_imagenet_wide_resnet.gin │ │ ├── pretrained_convnet.gin │ │ ├── pretrained_resnet.gin │ │ ├── pretrained_wide_resnet.gin │ │ ├── prototypical_all.gin │ │ ├── prototypical_all_from_scratch.gin │ │ ├── prototypical_imagenet.gin │ │ ├── prototypical_imagenet_from_scratch.gin │ │ ├── prototypical_inference_all.gin │ │ ├── prototypical_inference_imagenet.gin │ │ ├── relationnet_all.gin │ │ ├── relationnet_all_from_scratch.gin │ │ ├── relationnet_imagenet.gin │ │ ├── relationnet_imagenet_from_scratch.gin │ │ └── relationnet_mini_imagenet_fiveshot.gin │ │ ├── default │ │ ├── baseline_all.gin │ │ ├── baseline_cosine_imagenet.gin │ │ ├── baseline_imagenet.gin │ │ ├── baseline_mini_imagenet_fiveshot.gin │ │ ├── baseline_mini_imagenet_oneshot.gin │ │ ├── baselinefinetune_all.gin │ │ ├── baselinefinetune_cosine_imagenet.gin │ │ ├── baselinefinetune_cosine_mini_imagenet_fiveshot.gin │ │ ├── baselinefinetune_cosine_mini_imagenet_oneshot.gin │ │ ├── baselinefinetune_imagenet.gin │ │ ├── baselinefinetune_mini_imagenet_fiveshot.gin │ │ ├── baselinefinetune_mini_imagenet_oneshot.gin │ │ ├── crosstransformer_imagenet.gin │ │ ├── crosstransformer_simclreps_imagenet.gin │ │ ├── debug_proto_fungi.gin │ │ ├── debug_proto_mini_imagenet.gin │ │ ├── flute.gin │ │ ├── flute_dataset_classifier.gin │ │ ├── maml_all.gin │ │ ├── maml_imagenet.gin │ │ ├── maml_init_with_proto_all.gin │ │ ├── maml_init_with_proto_imagenet.gin │ │ ├── maml_init_with_proto_inference_all.gin │ │ ├── maml_init_with_proto_inference_imagenet.gin │ │ ├── maml_init_with_proto_mini_imagenet_fiveshot.gin │ │ ├── maml_init_with_proto_mini_imagenet_oneshot.gin │ │ ├── maml_mini_imagenet_fiveshot.gin │ │ ├── maml_mini_imagenet_oneshot.gin │ │ ├── maml_protonet_all.gin │ │ ├── matching_all.gin │ │ ├── matching_imagenet.gin │ │ ├── matching_inference_all.gin │ │ ├── matching_inference_imagenet.gin │ │ ├── matching_mini_imagenet_fiveshot.gin │ │ ├── matching_mini_imagenet_oneshot.gin │ │ ├── pretrained_resnet34_224.gin │ │ ├── prototypical_all.gin │ │ ├── prototypical_imagenet.gin │ │ ├── prototypical_inference_all.gin │ │ ├── prototypical_inference_imagenet.gin │ │ ├── prototypical_mini_imagenet_fiveshot.gin │ │ ├── prototypical_mini_imagenet_oneshot.gin │ │ ├── relationnet_all.gin │ │ ├── relationnet_imagenet.gin │ │ ├── relationnet_mini_imagenet_fiveshot.gin │ │ ├── relationnet_mini_imagenet_oneshot.gin │ │ └── resnet34_stride16.gin │ │ ├── learners │ │ ├── baseline_config.gin │ │ ├── baseline_cosine_config.gin │ │ ├── baselinefinetune_config.gin │ │ ├── baselinefinetune_cosine_config.gin │ │ ├── crosstransformer_config.gin │ │ ├── learner_config.gin │ │ ├── maml_config.gin │ │ ├── maml_init_with_proto_config.gin │ │ ├── matching_config.gin │ │ ├── prototypical_config.gin │ │ └── relationnet_config.gin │ │ ├── metadataset_v2 │ │ ├── baseline_all.gin │ │ ├── baseline_imagenet.gin │ │ ├── baselinefinetune_all.gin │ │ ├── baselinefinetune_imagenet.gin │ │ ├── best │ │ │ ├── baseline_imagenet.gin │ │ │ ├── baselinefinetune_all.gin │ │ │ ├── baselinefinetune_imagenet.gin │ │ │ ├── maml_init_with_proto_all.gin │ │ │ ├── maml_init_with_proto_imagenet.gin │ │ │ ├── prototypical_all.gin │ │ │ └── prototypical_imagenet.gin │ │ ├── maml_init_with_proto_all.gin │ │ ├── maml_init_with_proto_imagenet.gin │ │ └── prototypical_all.gin │ │ └── setups │ │ ├── all.gin │ │ ├── all_datasets.gin │ │ ├── all_v2.gin │ │ ├── data_config.gin │ │ ├── data_config_common.gin │ │ ├── data_config_feature.gin │ │ ├── data_config_flute.gin │ │ ├── data_config_no_decoder.gin │ │ ├── data_config_string.gin │ │ ├── data_config_tfds.gin │ │ ├── fixed_way_and_shot.gin │ │ ├── imagenet.gin │ │ ├── imagenet_v2.gin │ │ ├── mini_imagenet.gin │ │ ├── sur_data_config.gin │ │ ├── trainer_config.gin │ │ ├── trainer_config_debug.gin │ │ ├── trainer_config_flute.gin │ │ └── variable_way_and_shot.gin ├── learners │ ├── __init__.py │ ├── base.py │ ├── base_test.py │ ├── baseline_learners.py │ ├── baseline_learners_test.py │ ├── experimental │ │ ├── __init__.py │ │ ├── base.py │ │ ├── metric_learners.py │ │ ├── metric_learners_test.py │ │ ├── optimization_learners.py │ │ └── optimization_learners_test.py │ ├── metric_learners.py │ ├── metric_learners_test.py │ ├── optimization_learners.py │ └── optimization_learners_test.py ├── models │ ├── __init__.py │ ├── experimental │ │ ├── __init__.py │ │ ├── parameter_adapter.py │ │ ├── reparameterizable_backbones.py │ │ ├── reparameterizable_backbones_test.py │ │ ├── reparameterizable_base.py │ │ ├── reparameterizable_base_test.py │ │ └── reparameterizable_distributions.py │ ├── functional_backbones.py │ └── functional_classifiers.py ├── test_utils.py ├── train.py ├── train_flute.py ├── trainer.py ├── trainer_flute.py └── trainer_test.py ├── requirements.txt └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | *.pyc 3 | __pycache__ 4 | *.swp 5 | -------------------------------------------------------------------------------- /AUTHORS: -------------------------------------------------------------------------------- 1 | # This is the list of the Meta-Dataset Project authors for copyright purposes. 2 | # 3 | # This does not necessarily list everyone who has contributed code, since in 4 | # some cases, their employer may be the copyright holder. To see the full list 5 | # of contributors, see the revision history in source control. 6 | Google LLC 7 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to Contribute 2 | 3 | We'd love to accept your patches and contributions to this project. There are 4 | just a few small guidelines you need to follow. 5 | 6 | ## Contributor License Agreement 7 | 8 | Contributions to this project must be accompanied by a Contributor License 9 | Agreement. You (or your employer) retain the copyright to your contribution; 10 | this simply gives us permission to use and redistribute your contributions as 11 | part of the project. Head over to to see 12 | your current agreements on file or to sign a new one. 13 | 14 | You generally only need to submit a CLA once, so if you've already submitted one 15 | (even if it was for a different project), you probably don't need to do it 16 | again. 17 | 18 | ## Code reviews 19 | 20 | All submissions, including submissions by project members, require review. We 21 | use GitHub pull requests for this purpose. Consult 22 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more 23 | information on using pull requests. 24 | 25 | ## Community Guidelines 26 | 27 | This project follows 28 | [Google's Open Source Community Guidelines](https://opensource.google.com/conduct/). 29 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | recursive-include meta_dataset/learn *.gin 2 | recursive-include meta_dataset/dataset_conversion *.json -------------------------------------------------------------------------------- /meta_dataset/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Meta-Dataset Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | 17 | -------------------------------------------------------------------------------- /meta_dataset/analysis/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Meta-Dataset Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | 17 | -------------------------------------------------------------------------------- /meta_dataset/analysis/transferability/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Meta-Dataset Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | 17 | -------------------------------------------------------------------------------- /meta_dataset/analysis/transferability/leep.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Meta-Dataset Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # pyformat: disable 17 | """Implementation of the log expected empirical prediction measure (LEEP). 18 | 19 | #### References 20 | 21 | [1]: Nguyen, Cuong V., Tal Hassner, Matthias Seeger, and Cedric Archambeau. 22 | LEEP: A new measure to evaluate transferability of learned representations. 23 | In _Proceedings of 37th International Conference on Machine Learning_, 24 | 2020. 25 | https://arxiv.org/abs/2002.12462 26 | """ 27 | # pyformat: enable 28 | 29 | from __future__ import absolute_import 30 | from __future__ import division 31 | from __future__ import print_function 32 | 33 | import numpy as np 34 | import tensorflow as tf 35 | 36 | 37 | def compute_leep(source_predictions, 38 | target_onehot_labels): 39 | """Compute LEEP using `source_predictions` and `target_onehot_labels`. 40 | 41 | Args: 42 | source_predictions: Predictions over the source label set for a batch of 43 | input data from the target dataset. 44 | target_onehot_labels: One-hot labels from the target label set. 45 | 46 | Returns: 47 | The log expected empirical prediction measure (LEEP) on the given batch. 48 | """ 49 | batch_size = tf.cast(source_predictions.shape[0], tf.float32) 50 | if target_onehot_labels.shape[0] != batch_size: 51 | raise ValueError('`source_predictions` and `target_onehot_labels` must ' 52 | 'represent the same number of examples.') 53 | 54 | # Normalize the predicted probabilities in log space. 55 | source_predictions -= tf.math.reduce_logsumexp( 56 | source_predictions, axis=1, keepdims=True) 57 | 58 | # p_model(y, z | x_i) for source labels `z`, target labels `y`, and each `i`. 59 | log_per_example_full_joint = tf.einsum('iz,iy->iyz', source_predictions, 60 | target_onehot_labels) 61 | 62 | # Workaround for one-hot indexing in log space. 63 | log_per_example_full_joint = ( 64 | tf.where( 65 | tf.expand_dims(target_onehot_labels != 0, axis=2), 66 | log_per_example_full_joint, 67 | tf.ones_like(log_per_example_full_joint) * -np.inf)) 68 | 69 | # Re-normalize the joint. 70 | log_per_example_full_joint -= tf.math.reduce_logsumexp( 71 | log_per_example_full_joint, axis=(1, 2), keepdims=True) 72 | 73 | # Average examples-wise probabilities (Eq. (1) in the paper). 74 | log_full_joint = tf.math.reduce_logsumexp( 75 | log_per_example_full_joint, axis=0) - tf.math.log(batch_size) 76 | 77 | # p_model(z) for source labels `z`, marginalizing out target labels `y`. 78 | log_full_target_marginal = tf.math.reduce_logsumexp( 79 | source_predictions, axis=0) - tf.math.log(batch_size) 80 | 81 | # p_model(y | z) for source labels `y`, conditioning on target labels `z`. 82 | log_full_conditional = ( 83 | log_full_joint - tf.expand_dims(log_full_target_marginal, axis=0)) 84 | 85 | # p_model(y = y_i | z) for datapoint `i`. 86 | log_predicted_conditional = ( 87 | tf.einsum('iy,yz->iz', target_onehot_labels, log_full_conditional)) 88 | 89 | # p_model(y = y_i | z) * p_model(z) for datapoint `i`. 90 | log_predicted_joint = (log_predicted_conditional + source_predictions) 91 | 92 | # p_model(y = y_i) for datapoint `i`. 93 | log_predicted_marginal = tf.math.reduce_logsumexp(log_predicted_joint, axis=1) 94 | 95 | return tf.reduce_sum(log_predicted_marginal) / batch_size 96 | -------------------------------------------------------------------------------- /meta_dataset/data/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Meta-Dataset Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Sub-module for reading data and assembling episodes.""" 17 | # Whether datasets with example-level splits, or pools, are supported. 18 | # Currently, this is not implemented. 19 | POOL_SUPPORTED = False 20 | -------------------------------------------------------------------------------- /meta_dataset/data/decoder_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Meta-Dataset Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for meta_dataset.data.decoder.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | from meta_dataset.data import decoder 23 | from meta_dataset.dataset_conversion import dataset_to_records 24 | import numpy as np 25 | import tensorflow.compat.v1 as tf 26 | 27 | 28 | class DecoderTest(tf.test.TestCase): 29 | 30 | def test_string_decoder(self): 31 | # Make random image. 32 | image_size = 32 33 | image = np.random.randint( 34 | low=0, high=255, size=[image_size, image_size, 3]).astype(np.ubyte) 35 | 36 | # Encode 37 | image_bytes = dataset_to_records.encode_image(image, image_format='PNG') 38 | label = np.zeros(1).astype(np.int64) 39 | image_example = dataset_to_records.make_example([ 40 | ('image', 'bytes', [image_bytes]), ('label', 'int64', [label]) 41 | ]) 42 | 43 | # Decode 44 | string_decoder = decoder.StringDecoder() 45 | image_string = string_decoder(image_example) 46 | decoded_image = tf.image.decode_image(image_string) 47 | # Assert perfect reconstruction. 48 | with self.session(use_gpu=False) as sess: 49 | image_decoded = sess.run(decoded_image) 50 | self.assertAllClose(image, image_decoded) 51 | 52 | def test_image_decoder(self): 53 | # Make random image. 54 | image_size = 84 55 | image = np.random.randint( 56 | low=0, high=255, size=[image_size, image_size, 3]).astype(np.ubyte) 57 | 58 | # Encode 59 | image_bytes = dataset_to_records.encode_image(image, image_format='PNG') 60 | label = np.zeros(1).astype(np.int64) 61 | image_example = dataset_to_records.make_example([ 62 | ('image', 'bytes', [image_bytes]), ('label', 'int64', [label]) 63 | ]) 64 | 65 | # Decode 66 | image_decoder = decoder.ImageDecoder(image_size=image_size) 67 | image_decoded = image_decoder(image_example) 68 | # Assert perfect reconstruction. 69 | with self.session(use_gpu=False) as sess: 70 | image_rec_numpy = sess.run(image_decoded) 71 | self.assertAllClose(2 * (image.astype(np.float32) / 255.0 - 0.5), 72 | image_rec_numpy) 73 | 74 | def test_feature_decoder(self): 75 | # Make random feature. 76 | feat_size = 64 77 | feat = np.random.randn(feat_size).astype(np.float32) 78 | label = np.zeros(1).astype(np.int64) 79 | 80 | # Encode 81 | feat_example = dataset_to_records.make_example([ 82 | ('image/embedding', 'float32', feat), 83 | ('image/class/label', 'int64', [label]), 84 | ]) 85 | 86 | # Decode 87 | feat_decoder = decoder.FeatureDecoder(feat_len=feat_size) 88 | feat_decoded = feat_decoder(feat_example) 89 | 90 | # Assert perfect reconstruction. 91 | with self.session(use_gpu=False) as sess: 92 | feat_rec_numpy = sess.run(feat_decoded) 93 | self.assertAllEqual(feat_rec_numpy, feat) 94 | 95 | 96 | if __name__ == '__main__': 97 | tf.test.main() 98 | -------------------------------------------------------------------------------- /meta_dataset/data/dump_episodes.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Meta-Dataset Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | r"""Dumps Meta-Dataset episodes to disk as tfrecords files. 17 | 18 | Episodes are stored as a pair of `{episode_number}-train.tfrecords` and 19 | `{episode_number}-test.tfrecords` files, each of which contains serialized 20 | TFExample strings for the support and query set, respectively. 21 | 22 | python -m meta_dataset.data.dump_episodes \ 23 | --gin_config=meta_dataset/learn/gin/setups/\ 24 | data_config_string.gin --gin_config=meta_dataset/learn/gin/\ 25 | setups/variable_way_and_shot.gin \ 26 | --gin_bindings="DataConfig.num_prefetch=" 27 | """ 28 | import json 29 | import os 30 | from absl import app 31 | from absl import flags 32 | from absl import logging 33 | 34 | import gin 35 | from meta_dataset.data import config 36 | from meta_dataset.data import dataset_spec as dataset_spec_lib 37 | from meta_dataset.data import learning_spec 38 | from meta_dataset.data import pipeline 39 | from meta_dataset.data import utils 40 | import tensorflow.compat.v1 as tf 41 | 42 | tf.enable_eager_execution() 43 | 44 | flags.DEFINE_multi_string('gin_config', None, 45 | 'List of paths to the config files.') 46 | flags.DEFINE_multi_string('gin_bindings', None, 47 | 'List of Gin parameter bindings.') 48 | flags.DEFINE_string('output_dir', '/tmp/cached_episodes/', 49 | 'Root directory for saving episodes.') 50 | flags.DEFINE_integer('num_episodes', 600, 'Number of episodes to sample.') 51 | flags.DEFINE_string('dataset_name', 'omniglot', 'Dataset name to create ' 52 | 'episodes from.') 53 | flags.DEFINE_enum_class('split', learning_spec.Split.TEST, learning_spec.Split, 54 | 'See learning_spec.Split for ' 55 | 'allowed values.') 56 | flags.DEFINE_boolean( 57 | 'ignore_dag_ontology', False, 'If True the dag ontology' 58 | ' for Imagenet dataset is not used.') 59 | flags.DEFINE_boolean( 60 | 'ignore_bilevel_ontology', False, 'If True the bilevel' 61 | ' sampling for Omniglot dataset is not used.') 62 | tf.flags.DEFINE_string('records_root_dir', '', 63 | 'Root directory containing a subdirectory per dataset.') 64 | FLAGS = flags.FLAGS 65 | 66 | 67 | def main(unused_argv): 68 | logging.info(FLAGS.output_dir) 69 | tf.io.gfile.makedirs(FLAGS.output_dir) 70 | gin.parse_config_files_and_bindings( 71 | FLAGS.gin_config, FLAGS.gin_bindings, finalize_config=True) 72 | dataset_spec = dataset_spec_lib.load_dataset_spec( 73 | os.path.join(FLAGS.records_root_dir, FLAGS.dataset_name)) 74 | data_config = config.DataConfig() 75 | episode_descr_config = config.EpisodeDescriptionConfig() 76 | use_dag_ontology = ( 77 | FLAGS.dataset_name in ('ilsvrc_2012', 'ilsvrc_2012_v2') and 78 | not FLAGS.ignore_dag_ontology) 79 | use_bilevel_ontology = ( 80 | FLAGS.dataset_name == 'omniglot' and not FLAGS.ignore_bilevel_ontology) 81 | data_pipeline = pipeline.make_one_source_episode_pipeline( 82 | dataset_spec, 83 | use_dag_ontology=use_dag_ontology, 84 | use_bilevel_ontology=use_bilevel_ontology, 85 | split=FLAGS.split, 86 | episode_descr_config=episode_descr_config, 87 | # TODO(evcu) Maybe set the following to 0 to prevent shuffling and check 88 | # reproducibility of dumping. 89 | shuffle_buffer_size=data_config.shuffle_buffer_size, 90 | read_buffer_size_bytes=data_config.read_buffer_size_bytes, 91 | num_prefetch=data_config.num_prefetch) 92 | dataset = data_pipeline.take(FLAGS.num_episodes) 93 | 94 | images_per_class_dict = {} 95 | # Ignoring dataset number since we are loading one dataset. 96 | for episode_number, (episode, _) in enumerate(dataset): 97 | logging.info('Dumping episode %d', episode_number) 98 | train_imgs, train_labels, _, test_imgs, test_labels, _ = episode 99 | path_train = utils.get_file_path(FLAGS.output_dir, episode_number, 'train') 100 | path_test = utils.get_file_path(FLAGS.output_dir, episode_number, 'test') 101 | utils.dump_as_tfrecord(path_train, train_imgs, train_labels) 102 | utils.dump_as_tfrecord(path_test, test_imgs, test_labels) 103 | images_per_class_dict[os.path.basename(path_train)] = ( 104 | utils.get_label_counts(train_labels)) 105 | images_per_class_dict[os.path.basename(path_test)] = ( 106 | utils.get_label_counts(test_labels)) 107 | info_path = utils.get_info_path(FLAGS.output_dir) 108 | with tf.io.gfile.GFile(info_path, 'w') as f: 109 | f.write(json.dumps(images_per_class_dict, indent=2)) 110 | 111 | 112 | if __name__ == '__main__': 113 | app.run(main) 114 | -------------------------------------------------------------------------------- /meta_dataset/data/learning_spec.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Meta-Dataset Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Interfaces for learning specifications.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import collections 23 | import enum 24 | 25 | 26 | class Split(enum.Enum): 27 | """The possible data splits.""" 28 | TRAIN = 0 29 | VALID = 1 30 | TEST = 2 31 | 32 | 33 | class BatchSpecification( 34 | collections.namedtuple('BatchSpecification', 'split, batch_size')): 35 | """The specification of an episode. 36 | 37 | Args: 38 | split: the Split from which to pick data. 39 | batch_size: an int, the number of (image, label) pairs in the batch. 40 | """ 41 | pass 42 | 43 | 44 | class EpisodeSpecification( 45 | collections.namedtuple( 46 | 'EpisodeSpecification', 47 | 'split, num_classes, num_train_examples, num_test_examples')): 48 | """The specification of an episode. 49 | 50 | Args: 51 | split: A Split from which to pick data. 52 | num_classes: The number of classes in the episode, or None for variable. 53 | num_train_examples: The number of examples to use per class in the train 54 | phase, or None for variable. 55 | num_test_examples: the number of examples to use per class in the test 56 | phase, or None for variable. 57 | """ 58 | -------------------------------------------------------------------------------- /meta_dataset/data/pipeline_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Meta-Dataset Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | r"""Tests for meta_dataset.data.pipeline. 17 | 18 | """ 19 | 20 | from __future__ import absolute_import 21 | from __future__ import division 22 | from __future__ import print_function 23 | 24 | import os 25 | 26 | from absl.testing import parameterized 27 | import gin 28 | from meta_dataset.data import config 29 | from meta_dataset.data import learning_spec 30 | from meta_dataset.data import pipeline 31 | from meta_dataset.data import test_utils 32 | from meta_dataset.data.dataset_spec import DatasetSpecification 33 | import numpy as np 34 | import tensorflow.compat.v1 as tf 35 | 36 | 37 | class PipelineTest(tf.test.TestCase, parameterized.TestCase): 38 | 39 | @parameterized.named_parameters( 40 | ('FeatureDecoder', 'feature', 41 | 'meta_dataset/learn/gin/setups/data_config_feature.gin'), 42 | ('NoDecoder', 'none', 43 | 'meta_dataset/learn/gin/setups/data_config_no_decoder.gin' 44 | ), 45 | ) 46 | def test_make_multisource_episode_pipeline_feature(self, decoder_type, 47 | config_file_path): 48 | 49 | # Create some feature records and write them to a temp directory. 50 | feat_size = 64 51 | num_examples = 100 52 | num_classes = 10 53 | output_path = self.get_temp_dir() 54 | gin.parse_config_file(config_file_path) 55 | 56 | # 1-Write feature records to temp directory. 57 | self.rng = np.random.RandomState(0) 58 | class_features = [] 59 | class_examples = [] 60 | for class_id in range(num_classes): 61 | features = self.rng.randn(num_examples, feat_size).astype(np.float32) 62 | label = np.array(class_id).astype(np.int64) 63 | output_file = os.path.join(output_path, str(class_id) + '.tfrecords') 64 | examples = test_utils.write_feature_records(features, label, output_file) 65 | class_examples.append(examples) 66 | class_features.append(features) 67 | class_examples = np.stack(class_examples) 68 | class_features = np.stack(class_features) 69 | 70 | # 2-Read records back using multi-source pipeline. 71 | # DatasetSpecification to use in tests 72 | dataset_spec = DatasetSpecification( 73 | name=None, 74 | classes_per_split={ 75 | learning_spec.Split.TRAIN: 5, 76 | learning_spec.Split.VALID: 2, 77 | learning_spec.Split.TEST: 3 78 | }, 79 | images_per_class={i: num_examples for i in range(num_classes)}, 80 | class_names=None, 81 | path=output_path, 82 | file_pattern='{}.tfrecords') 83 | 84 | # Duplicate the dataset to simulate reading from multiple datasets. 85 | use_bilevel_ontology_list = [False] * 2 86 | use_dag_ontology_list = [False] * 2 87 | all_dataset_specs = [dataset_spec] * 2 88 | 89 | fixed_ways_shots = config.EpisodeDescriptionConfig( 90 | num_query=5, num_support=5, num_ways=5) 91 | 92 | dataset_episodic = pipeline.make_multisource_episode_pipeline( 93 | dataset_spec_list=all_dataset_specs, 94 | use_dag_ontology_list=use_dag_ontology_list, 95 | use_bilevel_ontology_list=use_bilevel_ontology_list, 96 | episode_descr_config=fixed_ways_shots, 97 | split=learning_spec.Split.TRAIN, 98 | image_size=None) 99 | 100 | episode, _ = self.evaluate( 101 | dataset_episodic.make_one_shot_iterator().get_next()) 102 | 103 | if decoder_type == 'feature': 104 | # 3-Check that support and query features are in class_features and have 105 | # the correct corresponding label. 106 | support_features, support_class_ids = episode[0], episode[2] 107 | query_features, query_class_ids = episode[3], episode[5] 108 | 109 | for feat, class_id in zip( 110 | list(support_features), list(support_class_ids)): 111 | abs_err = np.abs(np.sum(class_features - feat[None][None], axis=-1)) 112 | # Make sure the feature is present in the original data. 113 | self.assertEqual(abs_err.min(), 0.0) 114 | found_class_id = np.where(abs_err == 0.0)[0][0] 115 | self.assertEqual(found_class_id, class_id) 116 | 117 | for feat, class_id in zip(list(query_features), list(query_class_ids)): 118 | abs_err = np.abs(np.sum(class_features - feat[None][None], axis=-1)) 119 | # Make sure the feature is present in the original data. 120 | self.assertEqual(abs_err.min(), 0.0) 121 | found_class_id = np.where(abs_err == 0.0)[0][0] 122 | self.assertEqual(found_class_id, class_id) 123 | 124 | elif decoder_type == 'none': 125 | # 3-Check that support and query examples are in class_examples and have 126 | # the correct corresponding label. 127 | 128 | support_examples, support_class_ids = episode[0], episode[2] 129 | query_examples, query_class_ids = episode[3], episode[5] 130 | 131 | for example, class_id in zip( 132 | list(support_examples), list(support_class_ids)): 133 | found_class_id = np.where(class_examples == example)[0][0] 134 | self.assertEqual(found_class_id, class_id) 135 | 136 | for example, class_id in zip(list(query_examples), list(query_class_ids)): 137 | found_class_id = np.where(class_examples == example)[0][0] 138 | self.assertEqual(found_class_id, class_id) 139 | 140 | 141 | if __name__ == '__main__': 142 | tf.test.main() 143 | -------------------------------------------------------------------------------- /meta_dataset/data/sur_decoder.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Meta-Dataset Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Module responsible for decoding images as in SUR (Dvornik et al). 17 | 18 | MIT License 19 | 20 | Copyright (c) 2021 Nikita Dvornik 21 | 22 | Permission is hereby granted, free of charge, to any person obtaining a copy 23 | of this software and associated documentation files (the "Software"), to deal 24 | in the Software without restriction, including without limitation the rights 25 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 26 | copies of the Software, and to permit persons to whom the Software is 27 | furnished to do so, subject to the following conditions: 28 | 29 | The above copyright notice and this permission notice shall be included in all 30 | copies or substantial portions of the Software. 31 | 32 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 33 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 34 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 35 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 36 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 37 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 38 | SOFTWARE. 39 | """ 40 | 41 | import gin.tf 42 | import tensorflow.compat.v1 as tf 43 | 44 | 45 | @gin.configurable 46 | class SURDataAugmentation(object): 47 | """Configurations for performing data augmentation.""" 48 | 49 | def __init__(self, enable_jitter, jitter_amount, enable_gaussian_noise, 50 | gaussian_noise_std, enable_random_flip, enable_random_brightness, 51 | random_brightness_delta, enable_random_contrast, 52 | random_contrast_delta, enable_random_hue, random_hue_delta, 53 | enable_random_saturation, random_saturation_delta): 54 | """Initializes a DataAugmentation.""" 55 | self.enable_jitter = enable_jitter 56 | self.jitter_amount = jitter_amount 57 | self.enable_gaussian_noise = enable_gaussian_noise 58 | self.gaussian_noise_std = gaussian_noise_std 59 | self.enable_random_flip = enable_random_flip 60 | self.enable_random_brightness = enable_random_brightness 61 | self.random_brightness_delta = random_brightness_delta 62 | self.enable_random_contrast = enable_random_contrast 63 | self.random_contrast_delta = random_contrast_delta 64 | self.enable_random_hue = enable_random_hue 65 | self.random_hue_delta = random_hue_delta 66 | self.enable_random_saturation = enable_random_saturation 67 | self.random_saturation_delta = random_saturation_delta 68 | 69 | 70 | @gin.configurable 71 | class SURImageDecoder(object): 72 | """Image decoder.""" 73 | out_type = tf.float32 74 | 75 | def __init__(self, image_size=None, data_augmentation=None): 76 | """Class constructor. 77 | 78 | Args: 79 | image_size: int, desired image size. The extracted image will be resized 80 | to `[image_size, image_size]`. 81 | data_augmentation: A DataAugmentation object with parameters for 82 | perturbing the images. 83 | """ 84 | 85 | self.image_size = image_size 86 | self.data_augmentation = data_augmentation 87 | 88 | def __call__(self, example_string): 89 | """Processes a single example string. 90 | 91 | Extracts and processes the image, and ignores the label. We assume that the 92 | image has three channels. 93 | Args: 94 | example_string: str, an Example protocol buffer. 95 | 96 | Returns: 97 | image_rescaled: the image, resized to `image_size x image_size` and 98 | rescaled to [-1, 1]. Note that Gaussian data augmentation may cause values 99 | to go beyond this range. 100 | """ 101 | image_string = tf.parse_single_example( 102 | example_string, 103 | features={ 104 | 'image': tf.FixedLenFeature([], dtype=tf.string), 105 | 'label': tf.FixedLenFeature([], tf.int64) 106 | })['image'] 107 | image_decoded = tf.image.decode_image(image_string, channels=3) 108 | image_decoded.set_shape([None, None, 3]) 109 | image_resized = tf.image.resize_images( 110 | image_decoded, [self.image_size, self.image_size], 111 | method=tf.image.ResizeMethod.BILINEAR, 112 | align_corners=True) 113 | image = tf.cast(image_resized, tf.float32) 114 | 115 | if self.data_augmentation is not None: 116 | if self.data_augmentation.enable_random_brightness: 117 | delta = self.data_augmentation.random_brightness_delta 118 | image = tf.image.random_brightness(image, delta) 119 | 120 | if self.data_augmentation.enable_random_saturation: 121 | delta = self.data_augmentation.random_saturation_delta 122 | image = tf.image.random_saturation(image, 1 - delta, 1 + delta) 123 | 124 | if self.data_augmentation.enable_random_contrast: 125 | delta = self.data_augmentation.random_contrast_delta 126 | image = tf.image.random_contrast(image, 1 - delta, 1 + delta) 127 | 128 | if self.data_augmentation.enable_random_hue: 129 | delta = self.data_augmentation.random_hue_delta 130 | image = tf.image.random_hue(image, delta) 131 | 132 | if self.data_augmentation.enable_random_flip: 133 | image = tf.image.random_flip_left_right(image) 134 | 135 | image = 2 * (image / 255.0 - 0.5) # Rescale to [-1, 1]. 136 | 137 | if self.data_augmentation is not None: 138 | if self.data_augmentation.enable_gaussian_noise: 139 | image = image + tf.random_normal( 140 | tf.shape(image)) * self.data_augmentation.gaussian_noise_std 141 | 142 | if self.data_augmentation.enable_jitter: 143 | j = self.data_augmentation.jitter_amount 144 | paddings = tf.constant([[j, j], [j, j], [0, 0]]) 145 | image = tf.pad(image, paddings, 'REFLECT') 146 | image = tf.image.random_crop(image, 147 | [self.image_size, self.image_size, 3]) 148 | 149 | return image 150 | -------------------------------------------------------------------------------- /meta_dataset/data/test_utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Meta-Dataset Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Utility functions for input pipeline tests.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import gin.tf 23 | from meta_dataset.data.dataset_spec import DatasetSpecification 24 | from meta_dataset.data.learning_spec import Split 25 | from meta_dataset.dataset_conversion import dataset_to_records 26 | import numpy as np 27 | import tensorflow.compat.v1 as tf 28 | 29 | # DatasetSpecification to use in tests 30 | DATASET_SPEC = DatasetSpecification( 31 | name=None, 32 | classes_per_split={ 33 | Split.TRAIN: 15, 34 | Split.VALID: 5, 35 | Split.TEST: 10 36 | }, 37 | images_per_class=dict(enumerate([10, 20, 30] * 10)), 38 | class_names=['%d' % i for i in range(30)], 39 | path=None, 40 | file_pattern='{}.tfrecords') 41 | 42 | # Define defaults for the input pipeline. 43 | MIN_WAYS = 5 44 | MAX_WAYS_UPPER_BOUND = 50 45 | MAX_NUM_QUERY = 10 46 | MAX_SUPPORT_SET_SIZE = 500 47 | MAX_SUPPORT_SIZE_CONTRIB_PER_CLASS = 100 48 | MIN_LOG_WEIGHT = np.log(0.5) 49 | MAX_LOG_WEIGHT = np.log(2) 50 | 51 | 52 | def set_episode_descr_config_defaults(): 53 | """Sets default values for EpisodeDescriptionConfig using gin.""" 54 | gin.parse_config('import meta_dataset.data.config') 55 | 56 | gin.bind_parameter('EpisodeDescriptionConfig.num_ways', None) 57 | gin.bind_parameter('EpisodeDescriptionConfig.num_support', None) 58 | gin.bind_parameter('EpisodeDescriptionConfig.num_query', None) 59 | gin.bind_parameter('EpisodeDescriptionConfig.min_ways', MIN_WAYS) 60 | gin.bind_parameter('EpisodeDescriptionConfig.max_ways_upper_bound', 61 | MAX_WAYS_UPPER_BOUND) 62 | gin.bind_parameter('EpisodeDescriptionConfig.max_num_query', MAX_NUM_QUERY) 63 | gin.bind_parameter('EpisodeDescriptionConfig.max_support_set_size', 64 | MAX_SUPPORT_SET_SIZE) 65 | gin.bind_parameter( 66 | 'EpisodeDescriptionConfig.max_support_size_contrib_per_class', 67 | MAX_SUPPORT_SIZE_CONTRIB_PER_CLASS) 68 | gin.bind_parameter('EpisodeDescriptionConfig.min_log_weight', MIN_LOG_WEIGHT) 69 | gin.bind_parameter('EpisodeDescriptionConfig.max_log_weight', MAX_LOG_WEIGHT) 70 | gin.bind_parameter('EpisodeDescriptionConfig.ignore_dag_ontology', False) 71 | gin.bind_parameter('EpisodeDescriptionConfig.ignore_bilevel_ontology', False) 72 | gin.bind_parameter('EpisodeDescriptionConfig.ignore_hierarchy_probability', 73 | 0.) 74 | gin.bind_parameter('EpisodeDescriptionConfig.simclr_episode_fraction', 0.) 75 | 76 | # Following is set in a different scope. 77 | gin.bind_parameter('none/EpisodeDescriptionConfig.min_ways', None) 78 | gin.bind_parameter('none/EpisodeDescriptionConfig.max_ways_upper_bound', None) 79 | gin.bind_parameter('none/EpisodeDescriptionConfig.max_num_query', None) 80 | gin.bind_parameter('none/EpisodeDescriptionConfig.max_support_set_size', None) 81 | gin.bind_parameter( 82 | 'none/EpisodeDescriptionConfig.max_support_size_contrib_per_class', None) 83 | gin.bind_parameter('none/EpisodeDescriptionConfig.min_log_weight', None) 84 | gin.bind_parameter('none/EpisodeDescriptionConfig.max_log_weight', None) 85 | 86 | 87 | def write_feature_records(features, label, output_path): 88 | """Creates a record file from features and labels. 89 | 90 | Args: 91 | features: An [n, m] numpy array of features. 92 | label: An integer, the label common to all records. 93 | output_path: A string specifying the location of the record. 94 | 95 | Returns: 96 | serialized_examples: list tf.Example protos written by the writer. 97 | """ 98 | writer = tf.python_io.TFRecordWriter(output_path) 99 | serialized_examples = [] 100 | for feat in list(features): 101 | # Write the example. 102 | serialized_example = dataset_to_records.make_example([ 103 | ('image/embedding', 'float32', feat.tolist()), 104 | ('image/class/label', 'int64', [label]) 105 | ]) 106 | writer.write(serialized_example) 107 | serialized_examples.append(serialized_example) 108 | 109 | writer.close() 110 | return serialized_examples 111 | -------------------------------------------------------------------------------- /meta_dataset/data/tfds/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Meta-Dataset Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Meta-Dataset data sources.""" 17 | from .md_tfds import MetaDataset 18 | -------------------------------------------------------------------------------- /meta_dataset/data/tfds/test_utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Meta-Dataset Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Test utility functions and constants.""" 17 | 18 | import numpy as np 19 | import tensorflow as tf 20 | 21 | _SOURCES = { 22 | 'aircraft': { 23 | 'same_across_md_versions': True, 24 | 'meta_splits': ('train', 'valid', 'test'), 25 | 'num_classes': (70, 15, 15) 26 | }, 27 | 'cu_birds': { 28 | 'same_across_md_versions': True, 29 | 'meta_splits': ('train', 'valid', 'test'), 30 | 'num_classes': (140, 30, 30) 31 | }, 32 | 'dtd': { 33 | 'same_across_md_versions': True, 34 | 'meta_splits': ('train', 'valid', 'test'), 35 | 'num_classes': (33, 7, 7) 36 | }, 37 | 'fungi': { 38 | 'same_across_md_versions': True, 39 | 'meta_splits': ('train', 'valid', 'test'), 40 | 'num_classes': (994, 200, 200) 41 | }, 42 | 'ilsvrc_2012': { 43 | 'same_across_md_versions': False, 44 | 'meta_splits': {'v1': ('train', 'valid', 'test'), 'v2': ('train',)}, 45 | 'num_classes': {'v1': (712, 158, 130), 'v2': (1000,)}, 46 | 'remap_labels': {'v1': False, 'v2': True}, 47 | 'md_source': {'v1': 'ilsvrc_2012', 'v2': 'ilsvrc_2012_v2'}, 48 | }, 49 | 'mscoco': { 50 | 'same_across_md_versions': True, 51 | 'meta_splits': ('valid', 'test'), 52 | 'num_classes': (40, 40) 53 | }, 54 | 'omniglot': { 55 | 'same_across_md_versions': True, 56 | 'meta_splits': ('train', 'valid', 'test'), 57 | 'num_classes': (883, 81, 659) 58 | }, 59 | 'quickdraw': { 60 | 'same_across_md_versions': True, 61 | 'meta_splits': ('train', 'valid', 'test'), 62 | 'num_classes': (241, 52, 52) 63 | }, 64 | 'traffic_sign': { 65 | 'same_across_md_versions': True, 66 | 'meta_splits': ('test',), 'num_classes': (42,) 67 | }, 68 | 'vgg_flower': { 69 | 'same_across_md_versions': False, 70 | 'meta_splits': {'v1': ('train', 'valid', 'test')}, 71 | 'num_classes': {'v1': (71, 15, 16)}, 72 | }, 73 | } 74 | 75 | 76 | def make_class_dataset_comparison_test_cases(): 77 | """Returns class dataset test cases for Meta-Dataset.""" 78 | testcases = [] 79 | for source, info in _SOURCES.items(): 80 | meta_splits = ({'v1': info['meta_splits'], 'v2': info['meta_splits']} 81 | if info['same_across_md_versions'] 82 | else info['meta_splits']) 83 | num_classes = ({'v1': info['num_classes'], 'v2': info['num_classes']} 84 | if info['same_across_md_versions'] 85 | else info['num_classes']) 86 | 87 | for md_version in meta_splits: 88 | offsets = np.cumsum((0,) + num_classes[md_version][:-1]) 89 | remap_labels = (info['remap_labels'][md_version] if 'remap_labels' in info 90 | else False) 91 | md_source = (info['md_source'][md_version] if 'md_source' in info 92 | else source) 93 | for meta_split, num_labels, offset in zip( 94 | meta_splits[md_version], num_classes[md_version], offsets): 95 | testcases.append((f'{source}_{md_version}_{meta_split}', source, 96 | md_source, md_version, meta_split, num_labels, offset, 97 | remap_labels)) 98 | 99 | return testcases 100 | 101 | 102 | def parse_example(example_string): 103 | return tf.io.parse_single_example( 104 | example_string, 105 | features={'image': tf.io.FixedLenFeature([], dtype=tf.string), 106 | 'label': tf.io.FixedLenFeature([], tf.int64)}) 107 | -------------------------------------------------------------------------------- /meta_dataset/data/utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Meta-Dataset Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | r"""Utils for dumping Meta-Dataset episodes to disk as tfrecords files.""" 17 | import collections 18 | import os 19 | 20 | from absl import logging 21 | 22 | from meta_dataset.dataset_conversion import dataset_to_records 23 | import tensorflow.compat.v1 as tf 24 | 25 | TRAIN_SUFFIX = 'train.tfrecords' 26 | TEST_SUFFIX = 'test.tfrecords' 27 | # For output file. 28 | FILE_NAME_TEMPLATE = 'episode-{:04d}-{}' 29 | 30 | 31 | def get_label_counts(labels): 32 | """Creates a JSON compatible dictionary of image per class counts.""" 33 | # JSON does not support integer keys. 34 | counts = {str(k): v for k, v in collections.Counter(labels.numpy()).items()} 35 | return counts 36 | 37 | 38 | def dump_as_tfrecord(path, images, labels): 39 | logging.info('Dumping records to: %s', path) 40 | with tf.io.TFRecordWriter(path) as writer: 41 | for image, label in zip(images, labels): 42 | dataset_to_records.write_example(image.numpy(), label.numpy(), writer) 43 | 44 | 45 | def get_file_path(folder, idx, split): 46 | if split == 'train': 47 | suffix = TRAIN_SUFFIX 48 | elif split == 'test': 49 | suffix = TEST_SUFFIX 50 | else: 51 | raise ValueError('Split: %s, should be train or test.' % split) 52 | return os.path.join(folder, FILE_NAME_TEMPLATE.format(idx, suffix)) 53 | 54 | 55 | def get_info_path(folder): 56 | return os.path.join(folder, 'images_per_class.json') 57 | -------------------------------------------------------------------------------- /meta_dataset/dataset_conversion/ImageNet_CUBirds_duplicates.txt: -------------------------------------------------------------------------------- 1 | # List from https://gist.github.com/arunmallya/a6889f151483dcb348fa70523cb4f578 2 | ## File name in ImageNet # File name in CUBirds 3 | n01531178/n01531178_12730.JPEG # American_Goldfinch_0062_31921.jpg 4 | n01537544/n01537544_9540.JPEG # Indigo_Bunting_0063_11820.jpg 5 | n01580077/n01580077_4622.JPEG # Blue_Jay_0053_62744.jpg 6 | n01531178/n01531178_17834.JPEG # American_Goldfinch_0131_32911.jpg 7 | n01534433/n01534433_12777.JPEG # Dark_Eyed_Junco_0057_68650.jpg 8 | n01537544/n01537544_2126.JPEG # Indigo_Bunting_0051_12837.jpg 9 | n01534433/n01534433_9482.JPEG # Dark_Eyed_Junco_0102_67402.jpg 10 | n01531178/n01531178_14394.JPEG # American_Goldfinch_0012_32338.jpg 11 | n02058221/n02058221_16284.JPEG # Laysan_Albatross_0033_658.jpg 12 | n02058221/n02058221_6390.JPEG # Black_Footed_Albatross_0024_796089.jpg 13 | n01537544/n01537544_26009.JPEG # Indigo_Bunting_0072_14197.jpg 14 | n01833805/n01833805_10347.JPEG # Green_Violetear_0002_795699.jpg 15 | n02058221/n02058221_7776.JPEG # Black_Footed_Albatross_0033_796086.jpg 16 | n02058221/n02058221_9823.JPEG # Black_Footed_Albatross_0086_796062.jpg 17 | n01833805/n01833805_10306.JPEG # Anna_Hummingbird_0034_56614.jpg 18 | n01531178/n01531178_8622.JPEG # American_Goldfinch_0064_32142.jpg 19 | n01855032/n01855032_19954.JPEG # Red_Breasted_Merganser_0068_79203.jpg 20 | n01580077/n01580077_5373.JPEG # Blue_Jay_0033_62024.jpg 21 | n01833805/n01833805_1561.JPEG # Ruby_Throated_Hummingbird_0090_57411.jpg 22 | n01537544/n01537544_25021.JPEG # Indigo_Bunting_0071_11639.jpg 23 | n01855032/n01855032_300.JPEG # Red_Breasted_Merganser_0001_79199.jpg 24 | n01537544/n01537544_8418.JPEG # Indigo_Bunting_0060_14495.jpg 25 | n02058221/n02058221_10440.JPEG # Laysan_Albatross_0053_543.jpg 26 | n01531178/n01531178_1502.JPEG # American_Goldfinch_0018_32324.jpg 27 | n01855032/n01855032_2846.JPEG # Red_Breasted_Merganser_0034_79292.jpg 28 | n01847000/n01847000_9239.JPEG # Mallard_0067_77623.jpg 29 | n01855032/n01855032_8051.JPEG # Red_Breasted_Merganser_0083_79562.jpg 30 | n02058221/n02058221_16869.JPEG # Laysan_Albatross_0049_918.jpg 31 | n02058221/n02058221_29424.JPEG # Black_Footed_Albatross_0002_55.jpg 32 | n01855032/n01855032_1136.JPEG # Red_Breasted_Merganser_0012_79425.jpg 33 | n01534433/n01534433_8318.JPEG # Dark_Eyed_Junco_0031_66785.jpg 34 | n01534433/n01534433_22902.JPEG # Dark_Eyed_Junco_0037_66321.jpg 35 | n01537544/n01537544_10504.JPEG # Indigo_Bunting_0031_13300.jpg 36 | n01580077/n01580077_5633.JPEG # Blue_Jay_0049_63082.jpg 37 | n01534433/n01534433_11948.JPEG # Dark_Eyed_Junco_0111_66488.jpg 38 | n01537544/n01537544_16.JPEG # Indigo_Bunting_0010_13000.jpg 39 | n01855032/n01855032_6416.JPEG # Red_Breasted_Merganser_0004_79232.jpg 40 | n01855032/n01855032_1358.JPEG # Red_Breasted_Merganser_0045_79358.jpg 41 | n01833805/n01833805_18429.JPEG # Ruby_Throated_Hummingbird_0040_57982.jpg 42 | n01531178/n01531178_8148.JPEG # American_Goldfinch_0116_31943.jpg 43 | n01580077/n01580077_8031.JPEG # Blue_Jay_0068_61543.jpg 44 | n01537544/n01537544_2813.JPEG # Indigo_Bunting_0073_13933.jpg 45 | n01534433/n01534433_420.JPEG # Dark_Eyed_Junco_0104_67820.jpg 46 | -------------------------------------------------------------------------------- /meta_dataset/dataset_conversion/ImageNet_Caltech101_duplicates.txt: -------------------------------------------------------------------------------- 1 | ## File name in ImageNet # File name in Caltech101. '_m' suffix indicates the mirrored image was a match. 2 | n04552348/n04552348_16200.JPEG # airplanes/image_0266.jpg 3 | n02363005/n02363005_3962.JPEG # beaver/image_0003_m.jpg 4 | n02363005/n02363005_3752.JPEG # beaver/image_0004.jpg 5 | n02363005/n02363005_143.JPEG # beaver/image_0005_m.jpg 6 | n02363005/n02363005_4158.JPEG # beaver/image_0006.jpg 7 | n02363005/n02363005_2918.JPEG # beaver/image_0014.jpg 8 | n02363005/n02363005_3752.JPEG # beaver/image_0022.jpg 9 | n02363005/n02363005_2733.JPEG # beaver/image_0027.jpg 10 | n02363005/n02363005_3366.JPEG # beaver/image_0030_m.jpg 11 | n02363005/n02363005_899.JPEG # beaver/image_0033_m.jpg 12 | n02363005/n02363005_6314.JPEG # beaver/image_0042.jpg 13 | n02363005/n02363005_1129.JPEG # beaver/image_0046.jpg 14 | n03991062/n03991062_4071.JPEG # bonsai/image_0037.jpg 15 | n02950826/n02950826_14812.JPEG # cannon/image_0011_m.jpg 16 | n04099969/n04099969_15758.JPEG # chair/image_0028_m.jpg 17 | n02125311/n02125311_8090.JPEG # cougar_body/image_0002_m.jpg 18 | n02125311/n02125311_12897.JPEG # cougar_body/image_0011.jpg 19 | n02125311/n02125311_4683.JPEG # cougar_body/image_0014_m.jpg 20 | n02125311/n02125311_8090.JPEG # cougar_body/image_0026_m.jpg 21 | n02125311/n02125311_6360.JPEG # cougar_body/image_0033_m.jpg 22 | n02125311/n02125311_8090.JPEG # cougar_body/image_0035_m.jpg 23 | n02125311/n02125311_43757.JPEG # cougar_body/image_0039.jpg 24 | n02125311/n02125311_21461.JPEG # cougar_body/image_0040_m.jpg 25 | n02125311/n02125311_518.JPEG # cougar_body/image_0046.jpg 26 | n02125311/n02125311_11820.JPEG # cougar_face/image_0056.jpg 27 | n01985128/n01985128_16.JPEG # crayfish/image_0002_m.jpg 28 | n01985128/n01985128_18494.JPEG # crayfish/image_0021_m.jpg 29 | n01985128/n01985128_9302.JPEG # crayfish/image_0049_m.jpg 30 | n01984695/n01984695_23488.JPEG # crayfish/image_0057.jpg 31 | n01697457/n01697457_2629.JPEG # crocodile_head/image_0004.jpg 32 | n02110341/n02110341_6788.JPEG # dalmatian/image_0055.jpg 33 | n02504458/n02504458_14537.JPEG # elephant/image_0025.jpg 34 | n02504013/n02504013_165.JPEG # elephant/image_0045.jpg 35 | n03452741/n03452741_9398.JPEG # grand_piano/image_0047.jpg 36 | n02443114/n02443114_421.JPEG # hedgehog/image_0002_m.jpg 37 | n02346627/n02346627_13777.JPEG # hedgehog/image_0014.jpg 38 | n02346627/n02346627_16530.JPEG # hedgehog/image_0054_m.jpg 39 | n02128385/n02128385_26353.JPEG # Leopards/image_0051.jpg 40 | n02437616/n02437616_2924.JPEG # llama/image_0006.jpg 41 | n02437616/n02437616_6670.JPEG # llama/image_0018_m.jpg 42 | n01983481/n01983481_604.JPEG # lobster/image_0003.jpg 43 | n01983481/n01983481_16549.JPEG # lobster/image_0007.jpg 44 | n01968897/n01968897_4233.JPEG # nautilus/image_0001.jpg 45 | n01968897/n01968897_2749.JPEG # nautilus/image_0007.jpg 46 | n01968897/n01968897_7568.JPEG # nautilus/image_0014.jpg 47 | n01968897/n01968897_2639.JPEG # nautilus/image_0019.jpg 48 | n01968897/n01968897_12974.JPEG # nautilus/image_0020.jpg 49 | n01968897/n01968897_8442.JPEG # nautilus/image_0028.jpg 50 | n01968897/n01968897_1579.JPEG # nautilus/image_0029.jpg 51 | n01968897/n01968897_12974.JPEG # nautilus/image_0033.jpg 52 | n01968897/n01968897_22259.JPEG # nautilus/image_0043.jpg 53 | n01968897/n01968897_10474.JPEG # nautilus/image_0047.jpg 54 | n01968897/n01968897_8901.JPEG # nautilus/image_0054.jpg 55 | n01968897/n01968897_11310.JPEG # nautilus/image_0055.jpg 56 | n01873310/n01873310_11884.JPEG # platypus/image_0004.jpg 57 | n01873310/n01873310_12404.JPEG # platypus/image_0004_m.jpg 58 | n01873310/n01873310_3541.JPEG # platypus/image_0005.jpg 59 | n01873310/n01873310_12404.JPEG # platypus/image_0008_m.jpg 60 | n01873310/n01873310_12691.JPEG # platypus/image_0017_m.jpg 61 | n01873310/n01873310_1543.JPEG # platypus/image_0019.jpg 62 | n01873310/n01873310_12801.JPEG # platypus/image_0021_m.jpg 63 | n01873310/n01873310_13795.JPEG # platypus/image_0025.jpg 64 | n01873310/n01873310_3149.JPEG # platypus/image_0026_m.jpg 65 | n04086273/n04086273_6486.JPEG # revolver/image_0032.jpg 66 | n04086273/n04086273_23080.JPEG # revolver/image_0046_m.jpg 67 | n04086273/n04086273_23336.JPEG # revolver/image_0056_m.jpg 68 | n01770393/n01770393_19263.JPEG # scorpion/image_0006_m.jpg 69 | n01770393/n01770393_14293.JPEG # scorpion/image_0024.jpg 70 | n01770393/n01770393_10779.JPEG # scorpion/image_0031_m.jpg 71 | n01770393/n01770393_1713.JPEG # scorpion/image_0049_m.jpg 72 | n01770393/n01770393_13927.JPEG # scorpion/image_0056.jpg 73 | n01770393/n01770393_19263.JPEG # scorpion/image_0072_m.jpg 74 | n01770393/n01770393_19263.JPEG # scorpion/image_0076_m.jpg 75 | n01770393/n01770393_3256.JPEG # scorpion/image_0081_m.jpg 76 | n04147183/n04147183_5706.JPEG # schooner/image_0007.jpg 77 | n04147183/n04147183_13667.JPEG # schooner/image_0013.jpg 78 | n04147183/n04147183_14898.JPEG # schooner/image_0029.jpg 79 | n04147183/n04147183_13627.JPEG # schooner/image_0032.jpg 80 | n04147183/n04147183_17475.JPEG # schooner/image_0039.jpg 81 | n04147183/n04147183_2764.JPEG # schooner/image_0049.jpg 82 | n04147183/n04147183_5881.JPEG # schooner/image_0051.jpg 83 | n04147183/n04147183_5898.JPEG # schooner/image_0058.jpg 84 | n04254680/n04254680_1090.JPEG # soccer_ball/image_0001.jpg 85 | n04254680/n04254680_577.JPEG # soccer_ball/image_0014.jpg 86 | n04254680/n04254680_34.JPEG # soccer_ball/image_0027.jpg 87 | n02317335/n02317335_35.JPEG # starfish/image_0046.jpg 88 | n01776313/n01776313_37040.JPEG # tick/image_0043.jpg 89 | n01776313/n01776313_37040.JPEG # tick/image_0049_m.jpg 90 | n04336792/n04336792_10301.JPEG # wheelchair/image_0025.jpg 91 | n02128925/n02128925_15131.JPEG # wild_cat/image_0030.jpg 92 | n02124075/n02124075_3497.JPEG # wild_cat/image_0031.jpg 93 | n02124075/n02124075_10583.JPEG # wild_cat/image_0033_m.jpg 94 | -------------------------------------------------------------------------------- /meta_dataset/dataset_conversion/TrafficSign_labels.txt: -------------------------------------------------------------------------------- 1 | # From: https://github.com/navoshta/traffic-signs/blob/master/signnames.csv 2 | Speed limit (20km/h) 3 | Speed limit (30km/h) 4 | Speed limit (50km/h) 5 | Speed limit (60km/h) 6 | Speed limit (70km/h) 7 | Speed limit (80km/h) 8 | End of speed limit (80km/h) 9 | Speed limit (100km/h) 10 | Speed limit (120km/h) 11 | No passing 12 | No passing for vehicles over 3.5 metric tons 13 | Right-of-way at the next intersection 14 | Priority road 15 | Yield 16 | Stop 17 | No vehicles 18 | Vehicles over 3.5 metric tons prohibited 19 | No entry 20 | General caution 21 | Dangerous curve to the left 22 | Dangerous curve to the right 23 | Double curve 24 | Bumpy road 25 | Slippery road 26 | Road narrows on the right 27 | Road work 28 | Traffic signals 29 | Pedestrians 30 | Children crossing 31 | Bicycles crossing 32 | Beware of ice/snow 33 | Wild animals crossing 34 | End of all speed and passing limits 35 | Turn right ahead 36 | Turn left ahead 37 | Ahead only 38 | Go straight or right 39 | Go straight or left 40 | Keep right 41 | Keep left 42 | Roundabout mandatory 43 | End of no passing 44 | End of no passing by vehicles over 3.5 metric tons 45 | -------------------------------------------------------------------------------- /meta_dataset/dataset_conversion/VggFlower_labels.txt: -------------------------------------------------------------------------------- 1 | # From: https://gist.github.com/JosephKJ/94c7728ed1a8e0cd87fe6a029769cde1 2 | pink primrose 3 | hard-leaved pocket orchid 4 | canterbury bells 5 | sweet pea 6 | english marigold 7 | tiger lily 8 | moon orchid 9 | bird of paradise 10 | monkshood 11 | globe thistle 12 | snapdragon 13 | colt's foot 14 | king protea 15 | spear thistle 16 | yellow iris 17 | globe-flower 18 | purple coneflower 19 | peruvian lily 20 | balloon flower 21 | giant white arum lily 22 | fire lily 23 | pincushion flower 24 | fritillary 25 | red ginger 26 | grape hyacinth 27 | corn poppy 28 | prince of wales feathers 29 | stemless gentian 30 | artichoke 31 | sweet william 32 | carnation 33 | garden phlox 34 | love in the mist 35 | mexican aster 36 | alpine sea holly 37 | ruby-lipped cattleya 38 | cape flower 39 | great masterwort 40 | siam tulip 41 | lenten rose 42 | barbeton daisy 43 | daffodil 44 | sword lily 45 | poinsettia 46 | bolero deep blue 47 | wallflower 48 | marigold 49 | buttercup 50 | oxeye daisy 51 | common dandelion 52 | petunia 53 | wild pansy 54 | primula 55 | sunflower 56 | pelargonium 57 | bishop of llandaff 58 | gaura 59 | geranium 60 | orange dahlia 61 | pink-yellow dahlia? 62 | cautleya spicata 63 | japanese anemone 64 | black-eyed susan 65 | silverbush 66 | californian poppy 67 | osteospermum 68 | spring crocus 69 | bearded iris 70 | windflower 71 | tree poppy 72 | gazania 73 | azalea 74 | water lily 75 | rose 76 | thorn apple 77 | morning glory 78 | passion flower 79 | lotus 80 | toad lily 81 | anthurium 82 | frangipani 83 | clematis 84 | hibiscus 85 | columbine 86 | desert-rose 87 | tree mallow 88 | magnolia 89 | cyclamen 90 | watercress 91 | canna lily 92 | hippeastrum 93 | bee balm 94 | ball moss 95 | foxglove 96 | bougainvillea 97 | camellia 98 | mallow 99 | mexican petunia 100 | bromelia 101 | blanket flower 102 | trumpet creeper 103 | blackberry lily 104 | -------------------------------------------------------------------------------- /meta_dataset/dataset_conversion/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Meta-Dataset Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | 17 | -------------------------------------------------------------------------------- /meta_dataset/dataset_conversion/dataset_specs/aircraft_dataset_spec.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "aircraft", 3 | "classes_per_split": { 4 | "TRAIN": 70, 5 | "VALID": 15, 6 | "TEST": 15 7 | }, 8 | "images_per_class": { 9 | "0": 100, 10 | "1": 100, 11 | "2": 100, 12 | "3": 100, 13 | "4": 100, 14 | "5": 100, 15 | "6": 100, 16 | "7": 100, 17 | "8": 100, 18 | "9": 100, 19 | "10": 100, 20 | "11": 100, 21 | "12": 100, 22 | "13": 100, 23 | "14": 100, 24 | "15": 100, 25 | "16": 100, 26 | "17": 100, 27 | "18": 100, 28 | "19": 100, 29 | "20": 100, 30 | "21": 100, 31 | "22": 100, 32 | "23": 100, 33 | "24": 100, 34 | "25": 100, 35 | "26": 100, 36 | "27": 100, 37 | "28": 100, 38 | "29": 100, 39 | "30": 100, 40 | "31": 100, 41 | "32": 100, 42 | "33": 100, 43 | "34": 100, 44 | "35": 100, 45 | "36": 100, 46 | "37": 100, 47 | "38": 100, 48 | "39": 100, 49 | "40": 100, 50 | "41": 100, 51 | "42": 100, 52 | "43": 100, 53 | "44": 100, 54 | "45": 100, 55 | "46": 100, 56 | "47": 100, 57 | "48": 100, 58 | "49": 100, 59 | "50": 100, 60 | "51": 100, 61 | "52": 100, 62 | "53": 100, 63 | "54": 100, 64 | "55": 100, 65 | "56": 100, 66 | "57": 100, 67 | "58": 100, 68 | "59": 100, 69 | "60": 100, 70 | "61": 100, 71 | "62": 100, 72 | "63": 100, 73 | "64": 100, 74 | "65": 100, 75 | "66": 100, 76 | "67": 100, 77 | "68": 100, 78 | "69": 100, 79 | "70": 100, 80 | "71": 100, 81 | "72": 100, 82 | "73": 100, 83 | "74": 100, 84 | "75": 100, 85 | "76": 100, 86 | "77": 100, 87 | "78": 100, 88 | "79": 100, 89 | "80": 100, 90 | "81": 100, 91 | "82": 100, 92 | "83": 100, 93 | "84": 100, 94 | "85": 100, 95 | "86": 100, 96 | "87": 100, 97 | "88": 100, 98 | "89": 100, 99 | "90": 100, 100 | "91": 100, 101 | "92": 100, 102 | "93": 100, 103 | "94": 100, 104 | "95": 100, 105 | "96": 100, 106 | "97": 100, 107 | "98": 100, 108 | "99": 100 109 | }, 110 | "class_names": { 111 | "0": "A340-300", 112 | "1": "A318", 113 | "2": "Falcon 2000", 114 | "3": "F-16A/B", 115 | "4": "F/A-18", 116 | "5": "C-130", 117 | "6": "MD-80", 118 | "7": "BAE 146-200", 119 | "8": "777-200", 120 | "9": "747-400", 121 | "10": "Cessna 172", 122 | "11": "An-12", 123 | "12": "A330-300", 124 | "13": "A321", 125 | "14": "Fokker 100", 126 | "15": "Fokker 50", 127 | "16": "DHC-1", 128 | "17": "Fokker 70", 129 | "18": "A340-200", 130 | "19": "DC-6", 131 | "20": "747-200", 132 | "21": "Il-76", 133 | "22": "747-300", 134 | "23": "Model B200", 135 | "24": "Saab 340", 136 | "25": "Cessna 560", 137 | "26": "Dornier 328", 138 | "27": "E-195", 139 | "28": "ERJ 135", 140 | "29": "747-100", 141 | "30": "737-600", 142 | "31": "C-47", 143 | "32": "DR-400", 144 | "33": "ATR-72", 145 | "34": "A330-200", 146 | "35": "727-200", 147 | "36": "737-700", 148 | "37": "PA-28", 149 | "38": "ERJ 145", 150 | "39": "737-300", 151 | "40": "767-300", 152 | "41": "737-500", 153 | "42": "737-200", 154 | "43": "DHC-6", 155 | "44": "Falcon 900", 156 | "45": "DC-3", 157 | "46": "Eurofighter Typhoon", 158 | "47": "Challenger 600", 159 | "48": "Hawk T1", 160 | "49": "A380", 161 | "50": "777-300", 162 | "51": "E-190", 163 | "52": "DHC-8-100", 164 | "53": "Cessna 525", 165 | "54": "Metroliner", 166 | "55": "EMB-120", 167 | "56": "Tu-134", 168 | "57": "Embraer Legacy 600", 169 | "58": "Gulfstream IV", 170 | "59": "Tu-154", 171 | "60": "MD-87", 172 | "61": "A300B4", 173 | "62": "A340-600", 174 | "63": "A340-500", 175 | "64": "MD-11", 176 | "65": "707-320", 177 | "66": "Cessna 208", 178 | "67": "Global Express", 179 | "68": "A319", 180 | "69": "DH-82", 181 | "70": "737-900", 182 | "71": "757-300", 183 | "72": "767-200", 184 | "73": "A310", 185 | "74": "A320", 186 | "75": "BAE 146-300", 187 | "76": "CRJ-900", 188 | "77": "DC-10", 189 | "78": "DC-8", 190 | "79": "DC-9-30", 191 | "80": "DHC-8-300", 192 | "81": "Gulfstream V", 193 | "82": "SR-20", 194 | "83": "Tornado", 195 | "84": "Yak-42", 196 | "85": "737-400", 197 | "86": "737-800", 198 | "87": "757-200", 199 | "88": "767-400", 200 | "89": "ATR-42", 201 | "90": "BAE-125", 202 | "91": "Beechcraft 1900", 203 | "92": "Boeing 717", 204 | "93": "CRJ-200", 205 | "94": "CRJ-700", 206 | "95": "E-170", 207 | "96": "L-1011", 208 | "97": "MD-90", 209 | "98": "Saab 2000", 210 | "99": "Spitfire" 211 | }, 212 | "path": "/path/to/records/aircraft", 213 | "file_pattern": "{}.tfrecords", 214 | "__class__": "DatasetSpecification" 215 | } 216 | -------------------------------------------------------------------------------- /meta_dataset/dataset_conversion/dataset_specs/dtd_dataset_spec.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "dtd", 3 | "classes_per_split": { 4 | "TRAIN": 33, 5 | "VALID": 7, 6 | "TEST": 7 7 | }, 8 | "images_per_class": { 9 | "0": 120, 10 | "1": 120, 11 | "2": 120, 12 | "3": 120, 13 | "4": 120, 14 | "5": 120, 15 | "6": 120, 16 | "7": 120, 17 | "8": 120, 18 | "9": 120, 19 | "10": 120, 20 | "11": 120, 21 | "12": 120, 22 | "13": 120, 23 | "14": 120, 24 | "15": 120, 25 | "16": 120, 26 | "17": 120, 27 | "18": 120, 28 | "19": 120, 29 | "20": 120, 30 | "21": 120, 31 | "22": 120, 32 | "23": 120, 33 | "24": 120, 34 | "25": 120, 35 | "26": 120, 36 | "27": 120, 37 | "28": 120, 38 | "29": 120, 39 | "30": 120, 40 | "31": 120, 41 | "32": 120, 42 | "33": 120, 43 | "34": 120, 44 | "35": 120, 45 | "36": 120, 46 | "37": 120, 47 | "38": 120, 48 | "39": 120, 49 | "40": 120, 50 | "41": 120, 51 | "42": 120, 52 | "43": 120, 53 | "44": 120, 54 | "45": 120, 55 | "46": 120 56 | }, 57 | "class_names": { 58 | "0": "chequered", 59 | "1": "braided", 60 | "2": "interlaced", 61 | "3": "matted", 62 | "4": "honeycombed", 63 | "5": "marbled", 64 | "6": "veined", 65 | "7": "frilly", 66 | "8": "zigzagged", 67 | "9": "cobwebbed", 68 | "10": "pitted", 69 | "11": "waffled", 70 | "12": "fibrous", 71 | "13": "flecked", 72 | "14": "grooved", 73 | "15": "potholed", 74 | "16": "blotchy", 75 | "17": "stained", 76 | "18": "crystalline", 77 | "19": "dotted", 78 | "20": "striped", 79 | "21": "swirly", 80 | "22": "meshed", 81 | "23": "bubbly", 82 | "24": "studded", 83 | "25": "pleated", 84 | "26": "lacelike", 85 | "27": "polka-dotted", 86 | "28": "perforated", 87 | "29": "freckled", 88 | "30": "smeared", 89 | "31": "cracked", 90 | "32": "wrinkled", 91 | "33": "gauzy", 92 | "34": "grid", 93 | "35": "lined", 94 | "36": "paisley", 95 | "37": "porous", 96 | "38": "scaly", 97 | "39": "spiralled", 98 | "40": "banded", 99 | "41": "bumpy", 100 | "42": "crosshatched", 101 | "43": "knitted", 102 | "44": "sprinkled", 103 | "45": "stratified", 104 | "46": "woven" 105 | }, 106 | "path": "/path/to/records/dtd", 107 | "file_pattern": "{}.tfrecords", 108 | "__class__": "DatasetSpecification" 109 | } 110 | -------------------------------------------------------------------------------- /meta_dataset/dataset_conversion/dataset_specs/mscoco_dataset_spec.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "mscoco", 3 | "classes_per_split": { 4 | "TRAIN": 0, 5 | "VALID": 40, 6 | "TEST": 40 7 | }, 8 | "images_per_class": { 9 | "67": 2918, 10 | "7": 5508, 11 | "71": 8652, 12 | "29": 5805, 13 | "48": 10806, 14 | "6": 4768, 15 | "49": 6587, 16 | "50": 9509, 17 | "51": 8147, 18 | "58": 24342, 19 | "70": 5779, 20 | "69": 38491, 21 | "27": 15714, 22 | "40": 7113, 23 | "41": 43867, 24 | "1": 8725, 25 | "42": 5135, 26 | "43": 6069, 27 | "2": 4571, 28 | "45": 10759, 29 | "0": 262465, 30 | "46": 1983, 31 | "12": 11431, 32 | "53": 6496, 33 | "16": 6347, 34 | "22": 4373, 35 | "72": 4192, 36 | "32": 6434, 37 | "77": 2637, 38 | "36": 6334, 39 | "39": 1954, 40 | "44": 9973, 41 | "3": 12884, 42 | "4": 1865, 43 | "5": 1285, 44 | "47": 9838, 45 | "8": 5513, 46 | "52": 5131, 47 | "54": 2682, 48 | "55": 6646, 49 | "15": 2685, 50 | "17": 9076, 51 | "56": 3276, 52 | "18": 3747, 53 | "19": 5543, 54 | "20": 6126, 55 | "57": 4812, 56 | "59": 7913, 57 | "60": 20650, 58 | "61": 5479, 59 | "21": 7770, 60 | "62": 6165, 61 | "63": 14358, 62 | "64": 9458, 63 | "65": 5851, 64 | "66": 6399, 65 | "23": 7308, 66 | "24": 7852, 67 | "68": 5821, 68 | "25": 7179, 69 | "26": 6353, 70 | "28": 4157, 71 | "30": 4970, 72 | "73": 2262, 73 | "31": 5703, 74 | "74": 2855, 75 | "33": 1673, 76 | "34": 3334, 77 | "75": 225, 78 | "76": 5610, 79 | "35": 24715, 80 | "78": 6613, 81 | "79": 1481, 82 | "37": 4793, 83 | "38": 198, 84 | "11": 8720, 85 | "13": 12354, 86 | "14": 6192, 87 | "10": 5303, 88 | "9": 1294 89 | }, 90 | "class_names": { 91 | "0": "person", 92 | "1": "motorcycle", 93 | "2": "train", 94 | "3": "traffic light", 95 | "4": "fire hydrant", 96 | "5": "parking meter", 97 | "6": "cat", 98 | "7": "dog", 99 | "8": "elephant", 100 | "9": "bear", 101 | "10": "zebra", 102 | "11": "backpack", 103 | "12": "umbrella", 104 | "13": "handbag", 105 | "14": "suitcase", 106 | "15": "snowboard", 107 | "16": "sports ball", 108 | "17": "kite", 109 | "18": "baseball glove", 110 | "19": "skateboard", 111 | "20": "surfboard", 112 | "21": "knife", 113 | "22": "sandwich", 114 | "23": "broccoli", 115 | "24": "carrot", 116 | "25": "donut", 117 | "26": "cake", 118 | "27": "dining table", 119 | "28": "toilet", 120 | "29": "tv", 121 | "30": "laptop", 122 | "31": "remote", 123 | "32": "cell phone", 124 | "33": "microwave", 125 | "34": "oven", 126 | "35": "book", 127 | "36": "clock", 128 | "37": "teddy bear", 129 | "38": "hair drier", 130 | "39": "toothbrush", 131 | "40": "bicycle", 132 | "41": "car", 133 | "42": "airplane", 134 | "43": "bus", 135 | "44": "truck", 136 | "45": "boat", 137 | "46": "stop sign", 138 | "47": "bench", 139 | "48": "bird", 140 | "49": "horse", 141 | "50": "sheep", 142 | "51": "cow", 143 | "52": "giraffe", 144 | "53": "tie", 145 | "54": "frisbee", 146 | "55": "skis", 147 | "56": "baseball bat", 148 | "57": "tennis racket", 149 | "58": "bottle", 150 | "59": "wine glass", 151 | "60": "cup", 152 | "61": "fork", 153 | "62": "spoon", 154 | "63": "bowl", 155 | "64": "banana", 156 | "65": "apple", 157 | "66": "orange", 158 | "67": "hot dog", 159 | "68": "pizza", 160 | "69": "chair", 161 | "70": "couch", 162 | "71": "potted plant", 163 | "72": "bed", 164 | "73": "mouse", 165 | "74": "keyboard", 166 | "75": "toaster", 167 | "76": "sink", 168 | "77": "refrigerator", 169 | "78": "vase", 170 | "79": "scissors" 171 | }, 172 | "path": "/path/to/records/mscoco", 173 | "file_pattern": "{}.tfrecords", 174 | "__class__": "DatasetSpecification" 175 | } 176 | -------------------------------------------------------------------------------- /meta_dataset/dataset_conversion/dataset_specs/traffic_sign_dataset_spec.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "traffic_sign", 3 | "classes_per_split": { 4 | "TRAIN": 0, 5 | "VALID": 0, 6 | "TEST": 43 7 | }, 8 | "images_per_class": { 9 | "0": 210, 10 | "1": 2220, 11 | "2": 2250, 12 | "3": 1410, 13 | "4": 1980, 14 | "5": 1860, 15 | "6": 420, 16 | "7": 1440, 17 | "8": 1410, 18 | "9": 1470, 19 | "10": 2010, 20 | "11": 1320, 21 | "12": 2100, 22 | "13": 2160, 23 | "14": 780, 24 | "15": 630, 25 | "16": 420, 26 | "17": 1110, 27 | "18": 1200, 28 | "19": 210, 29 | "20": 360, 30 | "21": 330, 31 | "22": 390, 32 | "23": 510, 33 | "24": 270, 34 | "25": 1500, 35 | "26": 600, 36 | "27": 240, 37 | "28": 540, 38 | "29": 270, 39 | "30": 450, 40 | "31": 780, 41 | "32": 240, 42 | "33": 689, 43 | "34": 420, 44 | "35": 1200, 45 | "36": 390, 46 | "37": 210, 47 | "38": 2070, 48 | "39": 300, 49 | "40": 360, 50 | "41": 240, 51 | "42": 240 52 | }, 53 | "class_names": { 54 | "0": "00.Speed limit (20km/h)", 55 | "1": "01.Speed limit (30km/h)", 56 | "2": "02.Speed limit (50km/h)", 57 | "3": "03.Speed limit (60km/h)", 58 | "4": "04.Speed limit (70km/h)", 59 | "5": "05.Speed limit (80km/h)", 60 | "6": "06.End of speed limit (80km/h)", 61 | "7": "07.Speed limit (100km/h)", 62 | "8": "08.Speed limit (120km/h)", 63 | "9": "09.No passing", 64 | "10": "10.No passing for vehicles over 3.5 metric tons", 65 | "11": "11.Right-of-way at the next intersection", 66 | "12": "12.Priority road", 67 | "13": "13.Yield", 68 | "14": "14.Stop", 69 | "15": "15.No vehicles", 70 | "16": "16.Vehicles over 3.5 metric tons prohibited", 71 | "17": "17.No entry", 72 | "18": "18.General caution", 73 | "19": "19.Dangerous curve to the left", 74 | "20": "20.Dangerous curve to the right", 75 | "21": "21.Double curve", 76 | "22": "22.Bumpy road", 77 | "23": "23.Slippery road", 78 | "24": "24.Road narrows on the right", 79 | "25": "25.Road work", 80 | "26": "26.Traffic signals", 81 | "27": "27.Pedestrians", 82 | "28": "28.Children crossing", 83 | "29": "29.Bicycles crossing", 84 | "30": "30.Beware of ice/snow", 85 | "31": "31.Wild animals crossing", 86 | "32": "32.End of all speed and passing limits", 87 | "33": "33.Turn right ahead", 88 | "34": "34.Turn left ahead", 89 | "35": "35.Ahead only", 90 | "36": "36.Go straight or right", 91 | "37": "37.Go straight or left", 92 | "38": "38.Keep right", 93 | "39": "39.Keep left", 94 | "40": "40.Roundabout mandatory", 95 | "41": "41.End of no passing", 96 | "42": "42.End of no passing by vehicles over 3.5 metric tons" 97 | }, 98 | "path": "/path/to/records/traffic_sign", 99 | "file_pattern": "{}.tfrecords", 100 | "__class__": "DatasetSpecification" 101 | } 102 | -------------------------------------------------------------------------------- /meta_dataset/dataset_conversion/dataset_specs/vgg_flower_dataset_spec.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "vgg_flower", 3 | "classes_per_split": { 4 | "TRAIN": 71, 5 | "VALID": 15, 6 | "TEST": 16 7 | }, 8 | "images_per_class": { 9 | "0": 82, 10 | "1": 56, 11 | "2": 105, 12 | "3": 85, 13 | "4": 78, 14 | "5": 87, 15 | "6": 130, 16 | "7": 40, 17 | "8": 56, 18 | "9": 52, 19 | "10": 52, 20 | "11": 63, 21 | "12": 85, 22 | "13": 42, 23 | "14": 128, 24 | "15": 251, 25 | "16": 137, 26 | "17": 50, 27 | "18": 154, 28 | "19": 171, 29 | "20": 71, 30 | "21": 45, 31 | "22": 40, 32 | "23": 49, 33 | "24": 41, 34 | "25": 62, 35 | "26": 258, 36 | "27": 54, 37 | "28": 48, 38 | "29": 109, 39 | "30": 87, 40 | "31": 41, 41 | "32": 54, 42 | "33": 108, 43 | "34": 75, 44 | "35": 66, 45 | "36": 71, 46 | "37": 40, 47 | "38": 46, 48 | "39": 60, 49 | "40": 82, 50 | "41": 42, 51 | "42": 45, 52 | "43": 40, 53 | "44": 93, 54 | "45": 107, 55 | "46": 120, 56 | "47": 96, 57 | "48": 85, 58 | "49": 86, 59 | "50": 194, 60 | "51": 40, 61 | "52": 61, 62 | "53": 61, 63 | "54": 67, 64 | "55": 92, 65 | "56": 76, 66 | "57": 54, 67 | "58": 49, 68 | "59": 78, 69 | "60": 166, 70 | "61": 58, 71 | "62": 66, 72 | "63": 59, 73 | "64": 46, 74 | "65": 63, 75 | "66": 40, 76 | "67": 49, 77 | "68": 56, 78 | "69": 41, 79 | "70": 114, 80 | "71": 45, 81 | "72": 41, 82 | "73": 85, 83 | "74": 91, 84 | "75": 41, 85 | "76": 67, 86 | "77": 93, 87 | "78": 109, 88 | "79": 67, 89 | "80": 55, 90 | "81": 112, 91 | "82": 131, 92 | "83": 58, 93 | "84": 66, 94 | "85": 48, 95 | "86": 65, 96 | "87": 46, 97 | "88": 49, 98 | "89": 49, 99 | "90": 43, 100 | "91": 67, 101 | "92": 127, 102 | "93": 59, 103 | "94": 40, 104 | "95": 196, 105 | "96": 102, 106 | "97": 63, 107 | "98": 184, 108 | "99": 162, 109 | "100": 91, 110 | "101": 82 111 | }, 112 | "class_names": { 113 | "0": "090.canna lily", 114 | "1": "038.great masterwort", 115 | "2": "080.anthurium", 116 | "3": "030.sweet william", 117 | "4": "029.artichoke", 118 | "5": "012.colt's foot", 119 | "6": "043.sword lily", 120 | "7": "027.prince of wales feathers", 121 | "8": "004.sweet pea", 122 | "9": "064.silverbush", 123 | "10": "031.carnation", 124 | "11": "099.bromelia", 125 | "12": "008.bird of paradise", 126 | "13": "067.spring crocus", 127 | "14": "095.bougainvillea", 128 | "15": "077.passion flower", 129 | "16": "078.lotus", 130 | "17": "061.cautleya spicata", 131 | "18": "088.cyclamen", 132 | "19": "074.rose", 133 | "20": "055.pelargonium", 134 | "21": "032.garden phlox", 135 | "22": "021.fire lily", 136 | "23": "013.king protea", 137 | "24": "079.toad lily", 138 | "25": "070.tree poppy", 139 | "26": "051.petunia", 140 | "27": "069.windflower", 141 | "28": "014.spear thistle", 142 | "29": "060.pink-yellow dahlia?", 143 | "30": "011.snapdragon", 144 | "31": "039.siam tulip", 145 | "32": "063.black-eyed susan", 146 | "33": "037.cape flower", 147 | "34": "036.ruby-lipped cattleya", 148 | "35": "028.stemless gentian", 149 | "36": "048.buttercup", 150 | "37": "007.moon orchid", 151 | "38": "093.ball moss", 152 | "39": "002.hard-leaved pocket orchid", 153 | "40": "018.peruvian lily", 154 | "41": "024.red ginger", 155 | "42": "006.tiger lily", 156 | "43": "003.canterbury bells", 157 | "44": "044.poinsettia", 158 | "45": "076.morning glory", 159 | "46": "075.thorn apple", 160 | "47": "072.azalea", 161 | "48": "052.wild pansy", 162 | "49": "084.columbine", 163 | "50": "073.water lily", 164 | "51": "034.mexican aster", 165 | "52": "054.sunflower", 166 | "53": "066.osteospermum", 167 | "54": "059.orange dahlia", 168 | "55": "050.common dandelion", 169 | "56": "091.hippeastrum", 170 | "57": "068.bearded iris", 171 | "58": "100.blanket flower", 172 | "59": "071.gazania", 173 | "60": "081.frangipani", 174 | "61": "101.trumpet creeper", 175 | "62": "092.bee balm", 176 | "63": "022.pincushion flower", 177 | "64": "033.love in the mist", 178 | "65": "087.magnolia", 179 | "66": "001.pink primrose", 180 | "67": "049.oxeye daisy", 181 | "68": "020.giant white arum lily", 182 | "69": "025.grape hyacinth", 183 | "70": "058.geranium", 184 | "71": "010.globe thistle", 185 | "72": "016.globe-flower", 186 | "73": "017.purple coneflower", 187 | "74": "023.fritillary", 188 | "75": "026.corn poppy", 189 | "76": "047.marigold", 190 | "77": "053.primula", 191 | "78": "056.bishop of llandaff", 192 | "79": "057.gaura", 193 | "80": "062.japanese anemone", 194 | "81": "082.clematis", 195 | "82": "083.hibiscus", 196 | "83": "086.tree mallow", 197 | "84": "097.mallow", 198 | "85": "102.blackberry lily", 199 | "86": "005.english marigold", 200 | "87": "009.monkshood", 201 | "88": "015.yellow iris", 202 | "89": "019.balloon flower", 203 | "90": "035.alpine sea holly", 204 | "91": "040.lenten rose", 205 | "92": "041.barbeton daisy", 206 | "93": "042.daffodil", 207 | "94": "045.bolero deep blue", 208 | "95": "046.wallflower", 209 | "96": "065.californian poppy", 210 | "97": "085.desert-rose", 211 | "98": "089.watercress", 212 | "99": "094.foxglove", 213 | "100": "096.camellia", 214 | "101": "098.mexican petunia" 215 | }, 216 | "path": "/path/to/records/vgg_flower", 217 | "file_pattern": "{}.tfrecords", 218 | "__class__": "DatasetSpecification" 219 | } 220 | -------------------------------------------------------------------------------- /meta_dataset/dataset_conversion/splits/aircraft_splits.json: -------------------------------------------------------------------------------- 1 | { 2 | "train": [ 3 | "A340-300", 4 | "A318", 5 | "Falcon 2000", 6 | "F-16A/B", 7 | "F/A-18", 8 | "C-130", 9 | "MD-80", 10 | "BAE 146-200", 11 | "777-200", 12 | "747-400", 13 | "Cessna 172", 14 | "An-12", 15 | "A330-300", 16 | "A321", 17 | "Fokker 100", 18 | "Fokker 50", 19 | "DHC-1", 20 | "Fokker 70", 21 | "A340-200", 22 | "DC-6", 23 | "747-200", 24 | "Il-76", 25 | "747-300", 26 | "Model B200", 27 | "Saab 340", 28 | "Cessna 560", 29 | "Dornier 328", 30 | "E-195", 31 | "ERJ 135", 32 | "747-100", 33 | "737-600", 34 | "C-47", 35 | "DR-400", 36 | "ATR-72", 37 | "A330-200", 38 | "727-200", 39 | "737-700", 40 | "PA-28", 41 | "ERJ 145", 42 | "737-300", 43 | "767-300", 44 | "737-500", 45 | "737-200", 46 | "DHC-6", 47 | "Falcon 900", 48 | "DC-3", 49 | "Eurofighter Typhoon", 50 | "Challenger 600", 51 | "Hawk T1", 52 | "A380", 53 | "777-300", 54 | "E-190", 55 | "DHC-8-100", 56 | "Cessna 525", 57 | "Metroliner", 58 | "EMB-120", 59 | "Tu-134", 60 | "Embraer Legacy 600", 61 | "Gulfstream IV", 62 | "Tu-154", 63 | "MD-87", 64 | "A300B4", 65 | "A340-600", 66 | "A340-500", 67 | "MD-11", 68 | "707-320", 69 | "Cessna 208", 70 | "Global Express", 71 | "A319", 72 | "DH-82" 73 | ], 74 | "test": [ 75 | "737-400", 76 | "737-800", 77 | "757-200", 78 | "767-400", 79 | "ATR-42", 80 | "BAE-125", 81 | "Beechcraft 1900", 82 | "Boeing 717", 83 | "CRJ-200", 84 | "CRJ-700", 85 | "E-170", 86 | "L-1011", 87 | "MD-90", 88 | "Saab 2000", 89 | "Spitfire" 90 | ], 91 | "valid": [ 92 | "737-900", 93 | "757-300", 94 | "767-200", 95 | "A310", 96 | "A320", 97 | "BAE 146-300", 98 | "CRJ-900", 99 | "DC-10", 100 | "DC-8", 101 | "DC-9-30", 102 | "DHC-8-300", 103 | "Gulfstream V", 104 | "SR-20", 105 | "Tornado", 106 | "Yak-42" 107 | ] 108 | } -------------------------------------------------------------------------------- /meta_dataset/dataset_conversion/splits/dtd_splits.json: -------------------------------------------------------------------------------- 1 | { 2 | "train": [ 3 | "chequered", 4 | "braided", 5 | "interlaced", 6 | "matted", 7 | "honeycombed", 8 | "marbled", 9 | "veined", 10 | "frilly", 11 | "zigzagged", 12 | "cobwebbed", 13 | "pitted", 14 | "waffled", 15 | "fibrous", 16 | "flecked", 17 | "grooved", 18 | "potholed", 19 | "blotchy", 20 | "stained", 21 | "crystalline", 22 | "dotted", 23 | "striped", 24 | "swirly", 25 | "meshed", 26 | "bubbly", 27 | "studded", 28 | "pleated", 29 | "lacelike", 30 | "polka-dotted", 31 | "perforated", 32 | "freckled", 33 | "smeared", 34 | "cracked", 35 | "wrinkled" 36 | ], 37 | "test": [ 38 | "banded", 39 | "bumpy", 40 | "crosshatched", 41 | "knitted", 42 | "sprinkled", 43 | "stratified", 44 | "woven" 45 | ], 46 | "valid": [ 47 | "gauzy", 48 | "grid", 49 | "lined", 50 | "paisley", 51 | "porous", 52 | "scaly", 53 | "spiralled" 54 | ] 55 | } -------------------------------------------------------------------------------- /meta_dataset/dataset_conversion/splits/mscoco_splits.json: -------------------------------------------------------------------------------- 1 | { 2 | "train": [], 3 | "test": [ 4 | "bicycle", 5 | "car", 6 | "airplane", 7 | "bus", 8 | "truck", 9 | "boat", 10 | "stop sign", 11 | "bench", 12 | "bird", 13 | "horse", 14 | "sheep", 15 | "cow", 16 | "giraffe", 17 | "tie", 18 | "frisbee", 19 | "skis", 20 | "baseball bat", 21 | "tennis racket", 22 | "bottle", 23 | "wine glass", 24 | "cup", 25 | "fork", 26 | "spoon", 27 | "bowl", 28 | "banana", 29 | "apple", 30 | "orange", 31 | "hot dog", 32 | "pizza", 33 | "chair", 34 | "couch", 35 | "potted plant", 36 | "bed", 37 | "mouse", 38 | "keyboard", 39 | "toaster", 40 | "sink", 41 | "refrigerator", 42 | "vase", 43 | "scissors" 44 | ], 45 | "valid": [ 46 | "person", 47 | "motorcycle", 48 | "train", 49 | "traffic light", 50 | "fire hydrant", 51 | "parking meter", 52 | "cat", 53 | "dog", 54 | "elephant", 55 | "bear", 56 | "zebra", 57 | "backpack", 58 | "umbrella", 59 | "handbag", 60 | "suitcase", 61 | "snowboard", 62 | "sports ball", 63 | "kite", 64 | "baseball glove", 65 | "skateboard", 66 | "surfboard", 67 | "knife", 68 | "sandwich", 69 | "broccoli", 70 | "carrot", 71 | "donut", 72 | "cake", 73 | "dining table", 74 | "toilet", 75 | "tv", 76 | "laptop", 77 | "remote", 78 | "cell phone", 79 | "microwave", 80 | "oven", 81 | "book", 82 | "clock", 83 | "teddy bear", 84 | "hair drier", 85 | "toothbrush" 86 | ] 87 | } -------------------------------------------------------------------------------- /meta_dataset/dataset_conversion/splits/traffic_sign_splits.json: -------------------------------------------------------------------------------- 1 | { 2 | "train": [], 3 | "test": [ 4 | "00.Speed limit (20km/h)", 5 | "01.Speed limit (30km/h)", 6 | "02.Speed limit (50km/h)", 7 | "03.Speed limit (60km/h)", 8 | "04.Speed limit (70km/h)", 9 | "05.Speed limit (80km/h)", 10 | "06.End of speed limit (80km/h)", 11 | "07.Speed limit (100km/h)", 12 | "08.Speed limit (120km/h)", 13 | "09.No passing", 14 | "10.No passing for vehicles over 3.5 metric tons", 15 | "11.Right-of-way at the next intersection", 16 | "12.Priority road", 17 | "13.Yield", 18 | "14.Stop", 19 | "15.No vehicles", 20 | "16.Vehicles over 3.5 metric tons prohibited", 21 | "17.No entry", 22 | "18.General caution", 23 | "19.Dangerous curve to the left", 24 | "20.Dangerous curve to the right", 25 | "21.Double curve", 26 | "22.Bumpy road", 27 | "23.Slippery road", 28 | "24.Road narrows on the right", 29 | "25.Road work", 30 | "26.Traffic signals", 31 | "27.Pedestrians", 32 | "28.Children crossing", 33 | "29.Bicycles crossing", 34 | "30.Beware of ice/snow", 35 | "31.Wild animals crossing", 36 | "32.End of all speed and passing limits", 37 | "33.Turn right ahead", 38 | "34.Turn left ahead", 39 | "35.Ahead only", 40 | "36.Go straight or right", 41 | "37.Go straight or left", 42 | "38.Keep right", 43 | "39.Keep left", 44 | "40.Roundabout mandatory", 45 | "41.End of no passing", 46 | "42.End of no passing by vehicles over 3.5 metric tons" 47 | ], 48 | "valid": [] 49 | } 50 | -------------------------------------------------------------------------------- /meta_dataset/dataset_conversion/splits/vgg_flower_splits.json: -------------------------------------------------------------------------------- 1 | { 2 | "train": [ 3 | "090.canna lily", 4 | "038.great masterwort", 5 | "080.anthurium", 6 | "030.sweet william", 7 | "029.artichoke", 8 | "012.colt's foot", 9 | "043.sword lily", 10 | "027.prince of wales feathers", 11 | "004.sweet pea", 12 | "064.silverbush", 13 | "031.carnation", 14 | "099.bromelia", 15 | "008.bird of paradise", 16 | "067.spring crocus", 17 | "095.bougainvillea", 18 | "077.passion flower", 19 | "078.lotus", 20 | "061.cautleya spicata", 21 | "088.cyclamen", 22 | "074.rose", 23 | "055.pelargonium", 24 | "032.garden phlox", 25 | "021.fire lily", 26 | "013.king protea", 27 | "079.toad lily", 28 | "070.tree poppy", 29 | "051.petunia", 30 | "069.windflower", 31 | "014.spear thistle", 32 | "060.pink-yellow dahlia?", 33 | "011.snapdragon", 34 | "039.siam tulip", 35 | "063.black-eyed susan", 36 | "037.cape flower", 37 | "036.ruby-lipped cattleya", 38 | "028.stemless gentian", 39 | "048.buttercup", 40 | "007.moon orchid", 41 | "093.ball moss", 42 | "002.hard-leaved pocket orchid", 43 | "018.peruvian lily", 44 | "024.red ginger", 45 | "006.tiger lily", 46 | "003.canterbury bells", 47 | "044.poinsettia", 48 | "076.morning glory", 49 | "075.thorn apple", 50 | "072.azalea", 51 | "052.wild pansy", 52 | "084.columbine", 53 | "073.water lily", 54 | "034.mexican aster", 55 | "054.sunflower", 56 | "066.osteospermum", 57 | "059.orange dahlia", 58 | "050.common dandelion", 59 | "091.hippeastrum", 60 | "068.bearded iris", 61 | "100.blanket flower", 62 | "071.gazania", 63 | "081.frangipani", 64 | "101.trumpet creeper", 65 | "092.bee balm", 66 | "022.pincushion flower", 67 | "033.love in the mist", 68 | "087.magnolia", 69 | "001.pink primrose", 70 | "049.oxeye daisy", 71 | "020.giant white arum lily", 72 | "025.grape hyacinth", 73 | "058.geranium" 74 | ], 75 | "valid": [ 76 | "010.globe thistle", 77 | "016.globe-flower", 78 | "017.purple coneflower", 79 | "023.fritillary", 80 | "026.corn poppy", 81 | "047.marigold", 82 | "053.primula", 83 | "056.bishop of llandaff", 84 | "057.gaura", 85 | "062.japanese anemone", 86 | "082.clematis", 87 | "083.hibiscus", 88 | "086.tree mallow", 89 | "097.mallow", 90 | "102.blackberry lily" 91 | ], 92 | "test": [ 93 | "005.english marigold", 94 | "009.monkshood", 95 | "015.yellow iris", 96 | "019.balloon flower", 97 | "035.alpine sea holly", 98 | "040.lenten rose", 99 | "041.barbeton daisy", 100 | "042.daffodil", 101 | "045.bolero deep blue", 102 | "046.wallflower", 103 | "065.californian poppy", 104 | "085.desert-rose", 105 | "089.watercress", 106 | "094.foxglove", 107 | "096.camellia", 108 | "098.mexican petunia" 109 | ] 110 | } 111 | -------------------------------------------------------------------------------- /meta_dataset/distribute_utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Meta-Dataset Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Utilities for working in a tf.distribute context.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import tensorflow.compat.v1 as tf 23 | 24 | 25 | def _pad(tensor, total_size, position): 26 | """Pad tensor with zeros along its first axis. 27 | 28 | The output will have first dimension total_size, and 29 | tensor will appear at position position. 30 | 31 | Args: 32 | tensor: tensor to pad. 33 | total_size: the output size. 34 | position: where in the output tensor should appear. 35 | 36 | Returns: 37 | output: The padded tensor. 38 | """ 39 | shape_rest = tf.shape(tensor)[1:] 40 | after_dim = total_size - position - tf.shape(tensor)[0] 41 | pad_before = tf.zeros( 42 | tf.concat([[position], shape_rest], axis=0), dtype=tensor.dtype) 43 | pad_after = tf.zeros( 44 | tf.concat([[after_dim], shape_rest], axis=0), dtype=tensor.dtype) 45 | return tf.concat([pad_before, tensor, pad_after], axis=0) 46 | 47 | 48 | def aggregate(tensor): 49 | """Aggregate a tensor across distributed replicas. 50 | 51 | If not running in a distributed context, this just returns the input tensor. 52 | 53 | Args: 54 | tensor: tensor aggregate. 55 | 56 | Returns: 57 | output: A single tensor with all values across different replicas 58 | concatenated along the first axis. The output is in order of gpu index. 59 | """ 60 | 61 | replica_ctx = tf.distribute.get_replica_context() 62 | if not replica_ctx: 63 | return tensor 64 | num = tf.shape(tensor)[0:1] 65 | padded_num = _pad(num, replica_ctx.num_replicas_in_sync, 66 | replica_ctx.replica_id_in_sync_group) 67 | all_num = replica_ctx.all_reduce('sum', padded_num) 68 | index_in_output = tf.gather( 69 | tf.cumsum(tf.concat([[0], all_num], axis=0)), 70 | replica_ctx.replica_id_in_sync_group) 71 | total_num = tf.reduce_sum(all_num) 72 | padded_tensor = _pad(tensor, total_num, index_in_output) 73 | return replica_ctx.all_reduce('sum', padded_tensor) 74 | -------------------------------------------------------------------------------- /meta_dataset/learn/gin/best/baseline_all.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/setups/all.gin' 2 | include 'meta_dataset/learn/gin/learners/baseline_config.gin' 3 | BatchSplitReaderGetReader.add_dataset_offset = True 4 | 5 | # Backbone hypers. 6 | include 'meta_dataset/learn/gin/best_v2/pretrained_wide_resnet.gin' 7 | 8 | # Model hypers. 9 | BaselineLearner.knn_distance = 'cosine' 10 | BaselineLearner.cosine_classifier = True 11 | BaselineLearner.cosine_logits_multiplier = 1 12 | BaselineLearner.use_weight_norm = True 13 | 14 | # Data hypers. 15 | DataConfig.image_height = 126 16 | 17 | # Training hypers (not needed for eval). 18 | Trainer.decay_every = 500 19 | Trainer.decay_rate = 0.8778059962506467 20 | Trainer.learning_rate = 0.000253906846867988 21 | weight_decay = 0.00002393929026012612 22 | -------------------------------------------------------------------------------- /meta_dataset/learn/gin/best/baseline_imagenet.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/setups/imagenet.gin' 2 | include 'meta_dataset/learn/gin/learners/baseline_config.gin' 3 | 4 | # Backbone hypers. 5 | Learner.embedding_fn = @wide_resnet 6 | weight_decay = 0.000007138118976497546 7 | 8 | Trainer.checkpoint_to_restore = '' 9 | Trainer.pretrained_source = 'scratch' 10 | 11 | # Model hypers. 12 | BaselineLearner.knn_distance = 'cosine' 13 | BaselineLearner.cosine_classifier = False 14 | BaselineLearner.cosine_logits_multiplier = 10 15 | BaselineLearner.use_weight_norm = False 16 | 17 | # Data hypers. 18 | DataConfig.image_height = 126 19 | 20 | # Training hypers (not needed for eval). 21 | Trainer.decay_every = 10000 22 | Trainer.decay_rate = 0.7294597641152971 23 | Trainer.learning_rate = 0.007634189137886614 24 | -------------------------------------------------------------------------------- /meta_dataset/learn/gin/best/baselinefinetune_all.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/setups/all.gin' 2 | include 'meta_dataset/learn/gin/learners/baselinefinetune_config.gin' 3 | BatchSplitReaderGetReader.add_dataset_offset = True 4 | 5 | # Backbone hypers. 6 | include 'meta_dataset/learn/gin/best_v2/pretrained_wide_resnet.gin' 7 | 8 | # Model hypers. 9 | BaselineLearner.cosine_classifier = False 10 | BaselineLearner.use_weight_norm = True 11 | BaselineLearner.cosine_logits_multiplier = 1 12 | BaselineFinetuneLearner.num_finetune_steps = 200 13 | BaselineFinetuneLearner.finetune_lr = 0.01 14 | BaselineFinetuneLearner.finetune_all_layers = True 15 | BaselineFinetuneLearner.finetune_with_adam = True 16 | 17 | # Data hypers. 18 | DataConfig.image_height = 84 19 | 20 | # Training hypers (not needed for eval). 21 | Trainer.decay_every = 5000 22 | Trainer.decay_rate = 0.5559080744371039 23 | Trainer.learning_rate = 0.0027015533546616804 24 | weight_decay = 0.00002266979856832968 25 | -------------------------------------------------------------------------------- /meta_dataset/learn/gin/best/baselinefinetune_cosine_mini_imagenet_fiveshot.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/setups/mini_imagenet_five_way_five_shot.gin' 2 | include 'meta_dataset/learn/gin/learners/baselinefinetune_cosine_config.gin' 3 | 4 | Learner.embedding_fn = @wide_resnet 5 | weight_decay = 0 6 | 7 | DataConfig.image_height = 84 8 | Trainer.pretrained_source = 'scratch' 9 | 10 | BaselineFinetuneLearner.num_finetune_steps = 200 11 | BaselineFinetuneLearner.finetune_lr = 0.01 12 | BaselineFinetuneLearner.finetune_with_adam = False 13 | BaselineLearner.use_weight_norm = False 14 | BaselineFinetuneLearner.finetune_all_layers = False 15 | BaselineLearner.cosine_logits_multiplier = 10 16 | -------------------------------------------------------------------------------- /meta_dataset/learn/gin/best/baselinefinetune_cosine_mini_imagenet_oneshot.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/setups/mini_imagenet_five_way_five_shot.gin' 2 | EpisodeDescriptionConfig.num_support = 1 # Change 5-shot to 1-shot. 3 | include 'meta_dataset/learn/gin/learners/baselinefinetune_cosine_config.gin' 4 | 5 | Learner.embedding_fn = @wide_resnet 6 | weight_decay = 0 7 | 8 | DataConfig.image_height = 84 9 | Trainer.pretrained_source = 'scratch' 10 | 11 | BaselineFinetuneLearner.num_finetune_steps = 200 12 | BaselineFinetuneLearner.finetune_lr = 0.01 13 | BaselineFinetuneLearner.finetune_with_adam = False 14 | BaselineLearner.use_weight_norm = False 15 | BaselineFinetuneLearner.finetune_all_layers = False 16 | BaselineLearner.cosine_logits_multiplier = 10 17 | -------------------------------------------------------------------------------- /meta_dataset/learn/gin/best/baselinefinetune_imagenet.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/setups/imagenet.gin' 2 | include 'meta_dataset/learn/gin/learners/baselinefinetune_config.gin' 3 | 4 | # Backbone hypers. 5 | include 'meta_dataset/learn/gin/best_v2/pretrained_wide_resnet.gin' 6 | 7 | # Model hypers. 8 | BaselineLearner.cosine_classifier = False 9 | BaselineLearner.use_weight_norm = True 10 | BaselineLearner.cosine_logits_multiplier = 1 11 | BaselineFinetuneLearner.num_finetune_steps = 200 12 | BaselineFinetuneLearner.finetune_lr = 0.01 13 | BaselineFinetuneLearner.finetune_all_layers = True 14 | BaselineFinetuneLearner.finetune_with_adam = True 15 | 16 | # Data hypers. 17 | DataConfig.image_height = 84 18 | 19 | # Training hypers (not needed for eval). 20 | Trainer.decay_every = 5000 21 | Trainer.decay_rate = 0.5559080744371039 22 | Trainer.learning_rate = 0.0027015533546616804 23 | weight_decay = 0.00002266979856832968 24 | -------------------------------------------------------------------------------- /meta_dataset/learn/gin/best/baselinefinetune_mini_imagenet_fiveshot.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/setups/mini_imagenet.gin' 2 | EpisodeDescriptionConfig.num_support = 5 3 | 4 | include 'meta_dataset/learn/gin/learners/baselinefinetune_config.gin' 5 | 6 | Learner.embedding_fn = @wide_resnet 7 | weight_decay = 0 8 | 9 | DataConfig.image_height = 84 10 | Trainer.pretrained_source = 'scratch' 11 | 12 | BaselineFinetuneLearner.num_finetune_steps = 200 13 | BaselineFinetuneLearner.finetune_lr = 0.01 14 | BaselineFinetuneLearner.finetune_with_adam = False 15 | BaselineLearner.use_weight_norm = False 16 | BaselineFinetuneLearner.finetune_all_layers = False 17 | -------------------------------------------------------------------------------- /meta_dataset/learn/gin/best/baselinefinetune_mini_imagenet_oneshot.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/setups/mini_imagenet_five_way_five_shot.gin' 2 | EpisodeDescriptionConfig.num_support = 1 # Change 5-shot to 1-shot. 3 | include 'meta_dataset/learn/gin/learners/baselinefinetune_config.gin' 4 | 5 | Learner.embedding_fn = @wide_resnet 6 | weight_decay = 0 7 | 8 | DataConfig.image_height = 84 9 | Trainer.pretrained_source = 'scratch' 10 | 11 | BaselineFinetuneLearner.num_finetune_steps = 200 12 | BaselineFinetuneLearner.finetune_lr = 0.01 13 | BaselineFinetuneLearner.finetune_with_adam = False 14 | BaselineLearner.use_weight_norm = False 15 | BaselineFinetuneLearner.finetune_all_layers = False 16 | -------------------------------------------------------------------------------- /meta_dataset/learn/gin/best/flute.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/setups/all.gin' 2 | include 'meta_dataset/learn/gin/learners/learner_config.gin' 3 | include 'meta_dataset/learn/gin/setups/trainer_config_flute.gin' 4 | include 'meta_dataset/learn/gin/setups/data_config_flute.gin' 5 | 6 | # Learner to use for evaluation (the `train_learner_class` isn't used now). 7 | Trainer_flute.train_learner_class = @DatasetConditionalBaselineLearner 8 | Trainer_flute.eval_learner_class = @FLUTEFiLMLearner 9 | 10 | # The path to the dataset classifier checkpoint to restore (this is required if 11 | # using the `blender` or `hard blender` as the `film_init` heuristic). 12 | Trainer_flute.dataset_classifier_to_restore = "path/to/trained/dataset/classifier" 13 | 14 | # FLUTE FiLM Learner settings. 15 | FLUTEFiLMLearner.film_init = 'blender' 16 | FLUTEFiLMLearner.num_steps = 6 17 | FLUTEFiLMLearner.lr = 0.005 18 | 19 | # Backbone settings. 20 | FLUTEFiLMLearner.embedding_fn = @flute_resnet 21 | bn_wrapper.batch_norm_fn = @bn_flute_eval 22 | bn_wrapper.num_film_sets = %num_film_sets 23 | dataset_classifier.weight_decay = %weight_decay 24 | dataset_classifier.num_datasets = %num_film_sets 25 | flute_resnet.weight_decay = %weight_decay 26 | weight_decay = 0. 27 | num_film_sets = 8 28 | -------------------------------------------------------------------------------- /meta_dataset/learn/gin/best/flute_init_from_imagenet.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/best_v2/flute.gin' 2 | 3 | Trainer_flute.dataset_classifier_to_restore = "" 4 | FLUTEFiLMLearner.film_init = 'imagenet' 5 | FLUTEFiLMLearner.num_steps = 2 6 | FLUTEFiLMLearner.lr = 0.001 7 | -------------------------------------------------------------------------------- /meta_dataset/learn/gin/best/flute_init_from_scratch.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/best_v2/flute.gin' 2 | 3 | Trainer_flute.dataset_classifier_to_restore = "" 4 | FLUTEFiLMLearner.film_init = 'scratch' 5 | FLUTEFiLMLearner.num_steps = 4 6 | FLUTEFiLMLearner.lr = 0.005 7 | -------------------------------------------------------------------------------- /meta_dataset/learn/gin/best/maml_all.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/setups/all.gin' 2 | include 'meta_dataset/learn/gin/learners/maml_config.gin' 3 | 4 | # Backbone hypers. 5 | include 'meta_dataset/learn/gin/best_v2/pretrained_resnet.gin' 6 | 7 | # Model hypers. (new results) 8 | MAMLLearner.first_order = True 9 | MAMLLearner.alpha = 0.09915203576061224 10 | MAMLLearner.additional_evaluation_update_steps = 5 11 | MAMLLearner.num_update_steps = 10 12 | 13 | # Data hypers. 14 | DataConfig.image_height = 126 15 | 16 | # Training hypers (not needed for eval). 17 | Trainer.decay_learning_rate = True 18 | Trainer.decay_every = 10000 19 | Trainer.decay_rate = 0.8354967819706964 20 | Trainer.learning_rate = 0.00023702989294963628 21 | weight_decay = 0.00030572401935585053 22 | -------------------------------------------------------------------------------- /meta_dataset/learn/gin/best/maml_all_from_scratch.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/setups/all.gin' 2 | include 'meta_dataset/learn/gin/learners/maml_config.gin' 3 | 4 | # Backbone hypers. 5 | Learner.embedding_fn = @resnet 6 | 7 | Trainer.pretrained_source = 'scratch' 8 | Trainer.checkpoint_to_restore = '' 9 | 10 | # Model hypers. 11 | MAMLLearner.first_order = True 12 | MAMLLearner.alpha = 0.01406233807768908 13 | MAMLLearner.additional_evaluation_update_steps = 5 14 | MAMLLearner.num_update_steps = 6 15 | 16 | # Data hypers. 17 | DataConfig.image_height = 126 18 | 19 | # Training hypers (not needed for eval). 20 | Trainer.decay_learning_rate = True 21 | Trainer.decay_every = 10000 22 | Trainer.decay_rate = 0.5575524279100452 23 | Trainer.learning_rate = 0.0004067024273450244 24 | weight_decay = 0.0000033242172150107816 25 | -------------------------------------------------------------------------------- /meta_dataset/learn/gin/best/maml_imagenet.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/setups/imagenet.gin' 2 | include 'meta_dataset/learn/gin/learners/maml_config.gin' 3 | 4 | # Backbone hypers. 5 | include 'meta_dataset/learn/gin/best_v2/pretrained_resnet.gin' 6 | 7 | # Model hypers. 8 | MAMLLearner.first_order = True 9 | MAMLLearner.alpha = 0.09915203576061224 10 | MAMLLearner.additional_evaluation_update_steps = 5 11 | MAMLLearner.num_update_steps = 10 12 | 13 | # Data hypers. 14 | DataConfig.image_height = 126 15 | 16 | # Training hypers (not needed for eval). 17 | Trainer.decay_learning_rate = True 18 | Trainer.decay_every = 10000 19 | Trainer.decay_rate = 0.8354967819706964 20 | Trainer.learning_rate = 0.0003601303690709516 21 | weight_decay = 0.00030572401935585053 22 | -------------------------------------------------------------------------------- /meta_dataset/learn/gin/best/maml_imagenet_from_scratch.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/setups/imagenet.gin' 2 | include 'meta_dataset/learn/gin/learners/maml_config.gin' 3 | 4 | # Backbone hypers. 5 | Learner.embedding_fn = @four_layer_convnet 6 | 7 | Trainer.pretrained_source = 'scratch' 8 | Trainer.checkpoint_to_restore = '' 9 | 10 | # Model hypers. 11 | MAMLLearner.first_order = True 12 | MAMLLearner.alpha = 0.05602373320602529 13 | MAMLLearner.additional_evaluation_update_steps = 5 14 | MAMLLearner.num_update_steps = 10 15 | 16 | # Data hypers. 17 | DataConfig.image_height = 84 18 | 19 | # Training hypers (not needed for eval). 20 | Trainer.decay_learning_rate = True 21 | Trainer.decay_every = 5000 22 | Trainer.decay_rate = 0.9153310958769965 23 | Trainer.learning_rate = 0.00012385665511953882 24 | weight_decay = 0.0006858703351826797 25 | -------------------------------------------------------------------------------- /meta_dataset/learn/gin/best/maml_init_with_proto_all.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/setups/all.gin' 2 | include 'meta_dataset/learn/gin/learners/maml_init_with_proto_config.gin' 3 | 4 | # Backbone hypers. 5 | include 'meta_dataset/learn/gin/best_v2/pretrained_resnet.gin' 6 | 7 | # Model hypers. 8 | MAMLLearner.first_order = True 9 | MAMLLearner.alpha = 0.01988024603751885 10 | MAMLLearner.additional_evaluation_update_steps = 5 11 | MAMLLearner.num_update_steps = 1 12 | 13 | # Data hypers. 14 | DataConfig.image_height = 126 15 | 16 | # Training hypers (not needed for eval). 17 | Trainer.decay_learning_rate = True 18 | Trainer.decay_every = 500 19 | Trainer.decay_rate = 0.7659662681067452 20 | Trainer.learning_rate = 0.00029809808680211807 21 | weight_decay = 0.0007159940362690606 22 | -------------------------------------------------------------------------------- /meta_dataset/learn/gin/best/maml_init_with_proto_all_from_scratch.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/setups/all.gin' 2 | include 'meta_dataset/learn/gin/learners/maml_init_with_proto_config.gin' 3 | 4 | # Backbone hypers. 5 | Learner.embedding_fn = @resnet 6 | 7 | Trainer.pretrained_source = 'scratch' 8 | Trainer.checkpoint_to_restore = '' 9 | 10 | # Model hypers. 11 | MAMLLearner.first_order = True 12 | MAMLLearner.alpha = 0.01139772733084185 13 | MAMLLearner.additional_evaluation_update_steps = 5 14 | MAMLLearner.num_update_steps = 1 15 | 16 | # Data hypers. 17 | DataConfig.image_height = 126 18 | 19 | # Training hypers (not needed for eval). 20 | Trainer.decay_learning_rate = True 21 | Trainer.decay_every = 10000 22 | Trainer.decay_rate = 0.6731268939317062 23 | Trainer.learning_rate = 0.0007252750256444753 24 | weight_decay = 0.00005323094246436839 25 | -------------------------------------------------------------------------------- /meta_dataset/learn/gin/best/maml_init_with_proto_imagenet.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/setups/imagenet.gin' 2 | include 'meta_dataset/learn/gin/learners/maml_init_with_proto_config.gin' 3 | 4 | # Backbone hypers. 5 | include 'meta_dataset/learn/gin/best_v2/pretrained_resnet.gin' 6 | 7 | # Model hypers. 8 | MAMLLearner.first_order = True 9 | MAMLLearner.alpha = 0.005435022808033229 10 | MAMLLearner.additional_evaluation_update_steps = 0 11 | MAMLLearner.num_update_steps = 6 12 | 13 | # Data hypers. 14 | DataConfig.image_height = 126 15 | 16 | # Training hypers (not needed for eval). 17 | Trainer.decay_learning_rate = True 18 | Trainer.decay_every = 2500 19 | Trainer.decay_rate = 0.6477898086638092 20 | Trainer.learning_rate = 0.00036339913514891586 21 | weight_decay = 0.000013656044331235537 22 | -------------------------------------------------------------------------------- /meta_dataset/learn/gin/best/maml_init_with_proto_imagenet_from_scratch.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/setups/imagenet.gin' 2 | include 'meta_dataset/learn/gin/learners/maml_init_with_proto_config.gin' 3 | 4 | # Backbone hypers. 5 | 6 | Learner.embedding_fn = @resnet 7 | 8 | Trainer.pretrained_source = 'scratch' 9 | Trainer.checkpoint_to_restore = '' 10 | 11 | # Model hypers. 12 | MAMLLearner.first_order = True 13 | MAMLLearner.alpha = 0.31106175977182243 14 | MAMLLearner.additional_evaluation_update_steps = 5 15 | MAMLLearner.num_update_steps = 6 16 | 17 | # Data hypers. 18 | DataConfig.image_height = 126 19 | 20 | # Training hypers (not needed for eval). 21 | Trainer.decay_learning_rate = True 22 | Trainer.decay_every = 5000 23 | Trainer.decay_rate = 0.6431136271727287 24 | Trainer.learning_rate = 0.0007181155997029211 25 | weight_decay = 0.00003630199690303937 26 | -------------------------------------------------------------------------------- /meta_dataset/learn/gin/best/maml_init_with_proto_inference_all.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/setups/all.gin' 2 | include 'meta_dataset/learn/gin/learners/maml_init_with_proto_config.gin' 3 | 4 | # Backbone hypers. 5 | include 'meta_dataset/learn/gin/best_v2/pretrained_resnet.gin' 6 | 7 | # Model hypers. 8 | MAMLLearner.first_order = True 9 | MAMLLearner.alpha = 0.061445396411177286 10 | MAMLLearner.additional_evaluation_update_steps = 0 11 | MAMLLearner.num_update_steps = 1 12 | 13 | # Data hypers. 14 | DataConfig.image_height = 126 15 | 16 | # Training hypers (not needed for eval). 17 | Trainer.decay_learning_rate = True 18 | Trainer.decay_every = 5000 19 | Trainer.decay_rate = 0.8449806376151254 20 | Trainer.learning_rate = 0.0008118174317224871 21 | weight_decay = 0.0000038819970639091496 22 | 23 | # Baseline hypers (just for the record). 24 | BaselineLearner.cosine_logits_multiplier = 1 25 | BaselineLearner.use_weight_norm = True 26 | BaselineLearner.knn_distance = 'l2' 27 | BaselineLearner.cosine_classifier = False 28 | -------------------------------------------------------------------------------- /meta_dataset/learn/gin/best/maml_init_with_proto_inference_imagenet.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/setups/imagenet.gin' 2 | include 'meta_dataset/learn/gin/learners/maml_init_with_proto_config.gin' 3 | 4 | # Backbone hypers. 5 | Learner.embedding_fn = @resnet 6 | 7 | Trainer.pretrained_source = 'scratch' 8 | Trainer.checkpoint_to_restore = '' 9 | 10 | # Model hypers. 11 | MAMLLearner.first_order = True 12 | MAMLLearner.alpha = 0.00577257710776957 13 | MAMLLearner.additional_evaluation_update_steps = 0 14 | MAMLLearner.num_update_steps = 1 15 | 16 | # Data hypers. 17 | DataConfig.image_height = 126 18 | 19 | # Training hypers (not needed for eval). 20 | Trainer.decay_learning_rate = True 21 | Trainer.decay_every = 500 22 | Trainer.decay_rate = 0.885662482266546 23 | Trainer.learning_rate = 0.00025036275525430426 24 | weight_decay = 0.003952085241872012 25 | 26 | # Baseline hypers (just for the record). 27 | BaselineLearner.cosine_logits_multiplier = 10 28 | BaselineLearner.use_weight_norm = True 29 | BaselineLearner.knn_distance = 'l2' 30 | BaselineLearner.cosine_classifier = True 31 | -------------------------------------------------------------------------------- /meta_dataset/learn/gin/best/matching_all.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/setups/all.gin' 2 | include 'meta_dataset/learn/gin/learners/matching_config.gin' 3 | 4 | # Backbone hypers. 5 | include 'meta_dataset/learn/gin/best_v2/pretrained_convnet.gin' 6 | 7 | # Data hypers. 8 | DataConfig.image_height = 84 9 | 10 | # Training hypers (not needed for eval). 11 | Trainer.decay_learning_rate = True 12 | Trainer.decay_every = 500 13 | Trainer.decay_rate = 0.5937020776489796 14 | Trainer.learning_rate = 0.005052178216688174 15 | weight_decay = 0.00000392173790384195 16 | -------------------------------------------------------------------------------- /meta_dataset/learn/gin/best/matching_all_from_scratch.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/setups/all.gin' 2 | include 'meta_dataset/learn/gin/learners/matching_config.gin' 3 | 4 | # Backbone hypers. 5 | Learner.embedding_fn = @four_layer_convnet 6 | 7 | Trainer.pretrained_source = 'scratch' 8 | Trainer.checkpoint_to_restore = '' 9 | 10 | # Data hypers. 11 | DataConfig.image_height = 126 12 | 13 | # Training hypers (not needed for eval). 14 | Trainer.decay_learning_rate = True 15 | Trainer.decay_every = 2500 16 | Trainer.decay_rate = 0.9197652570309498 17 | Trainer.learning_rate = 0.002748105034397538 18 | weight_decay = 0.0000013254789476822292 19 | -------------------------------------------------------------------------------- /meta_dataset/learn/gin/best/matching_imagenet.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/setups/imagenet.gin' 2 | include 'meta_dataset/learn/gin/learners/matching_config.gin' 3 | 4 | # Backbone hypers. 5 | include 'meta_dataset/learn/gin/best_v2/pretrained_resnet.gin' 6 | 7 | # Data hypers. 8 | DataConfig.image_height = 126 9 | 10 | # Training hypers (not needed for eval). 11 | Trainer.decay_learning_rate = True 12 | Trainer.decay_every = 500 13 | Trainer.decay_rate = 0.915193802145601 14 | Trainer.learning_rate = 0.0012064626897259694 15 | weight_decay = 0.0000885787420909229 16 | -------------------------------------------------------------------------------- /meta_dataset/learn/gin/best/matching_imagenet_from_scratch.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/setups/imagenet.gin' 2 | include 'meta_dataset/learn/gin/learners/matching_config.gin' 3 | 4 | # Backbone hypers. 5 | Learner.embedding_fn = @resnet 6 | 7 | Trainer.pretrained_source = 'scratch' 8 | Trainer.checkpoint_to_restore = '' 9 | 10 | # Data hypers. 11 | DataConfig.image_height = 126 12 | 13 | # Training hypers (not needed for eval). 14 | Trainer.decay_learning_rate = True 15 | Trainer.decay_every = 2500 16 | Trainer.decay_rate = 0.8333411536286996 17 | Trainer.learning_rate = 0.0006229660387662655 18 | weight_decay = 0.00018036259587809225 19 | -------------------------------------------------------------------------------- /meta_dataset/learn/gin/best/matching_inference_all.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/setups/all.gin' 2 | include 'meta_dataset/learn/gin/learners/matching_config.gin' 3 | 4 | # Backbone hypers. 5 | include 'meta_dataset/learn/gin/best_v2/pretrained_resnet.gin' 6 | 7 | # Data hypers. 8 | DataConfig.image_height = 84 9 | 10 | # Training hypers (not needed for eval). 11 | Trainer.decay_learning_rate = True 12 | Trainer.decay_every = 1000 13 | Trainer.decay_rate = 0.5568532448398684 14 | Trainer.learning_rate = 0.0009946451213635535 15 | weight_decay = 0.000006779214221539446 16 | 17 | # Baseline hypers (just for the record). 18 | BaselineLearner.cosine_logits_multiplier = 100 19 | BaselineLearner.use_weight_norm = False 20 | BaselineLearner.knn_distance = 'cosine' 21 | BaselineLearner.cosine_classifier = False 22 | -------------------------------------------------------------------------------- /meta_dataset/learn/gin/best/matching_inference_imagenet.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/setups/imagenet.gin' 2 | include 'meta_dataset/learn/gin/learners/matching_config.gin' 3 | 4 | # Backbone hypers. 5 | include 'meta_dataset/learn/gin/best_v2/pretrained_resnet.gin' 6 | 7 | # Data hypers. 8 | DataConfig.image_height = 84 9 | 10 | # Training hypers (not needed for eval). 11 | Trainer.decay_learning_rate = True 12 | Trainer.decay_every = 10000 13 | Trainer.decay_rate = 0.9714872833376629 14 | Trainer.learning_rate = 0.002670412664551181 15 | weight_decay = 0.000013646011503835137 16 | 17 | # Baseline hypers (just for the record). 18 | BaselineLearner.cosine_logits_multiplier = 1 19 | BaselineLearner.use_weight_norm = True 20 | BaselineLearner.knn_distance = 'cosine' 21 | BaselineLearner.cosine_classifier = False 22 | -------------------------------------------------------------------------------- /meta_dataset/learn/gin/best/pretrain_imagenet_convnet.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/setups/imagenet.gin' 2 | include 'meta_dataset/learn/gin/learners/baseline_config.gin' 3 | 4 | # Backbone hypers. 5 | Learner.embedding_fn = @four_layer_convnet 6 | 7 | Trainer.checkpoint_to_restore = '' 8 | Trainer.pretrained_source = 'scratch' 9 | 10 | # Model hypers. 11 | BaselineLearner.knn_distance = 'cosine' 12 | BaselineLearner.cosine_classifier = True 13 | BaselineLearner.cosine_logits_multiplier = 1 14 | BaselineLearner.use_weight_norm = True 15 | 16 | # Data hypers. 17 | DataConfig.image_height = 84 18 | 19 | # Training hypers (not needed for eval). 20 | Trainer.num_updates = 50000 21 | Trainer.decay_every = 1000 22 | Trainer.decay_rate = 0.9105573818947892 23 | Trainer.learning_rate = 0.008644114436633987 24 | weight_decay = 0.000005171477829794739 25 | -------------------------------------------------------------------------------- /meta_dataset/learn/gin/best/pretrain_imagenet_resnet.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/setups/imagenet.gin' 2 | include 'meta_dataset/learn/gin/learners/baseline_config.gin' 3 | 4 | # Backbone hypers. 5 | Learner.embedding_fn = @resnet 6 | 7 | Trainer.checkpoint_to_restore = '' 8 | Trainer.pretrained_source = 'scratch' 9 | 10 | # Model hypers. 11 | BaselineLearner.knn_distance = 'cosine' 12 | BaselineLearner.cosine_classifier = False 13 | BaselineLearner.cosine_logits_multiplier = 2 14 | BaselineLearner.use_weight_norm = False 15 | 16 | # Data hypers. 17 | DataConfig.image_height = 126 18 | 19 | # Training hypers (not needed for eval). 20 | Trainer.num_updates = 50000 21 | Trainer.decay_every = 1000 22 | Trainer.decay_rate = 0.9967524905880909 23 | Trainer.learning_rate = 0.00375640851370052 24 | weight_decay = 0.00002628042826116842 25 | -------------------------------------------------------------------------------- /meta_dataset/learn/gin/best/pretrain_imagenet_wide_resnet.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/setups/imagenet.gin' 2 | include 'meta_dataset/learn/gin/learners/baseline_config.gin' 3 | 4 | # Backbone hypers. 5 | Learner.embedding_fn = @wide_resnet 6 | 7 | Trainer.checkpoint_to_restore = '' 8 | Trainer.pretrained_source = 'scratch' 9 | 10 | # Model hypers. 11 | BaselineLearner.knn_distance = 'cosine' 12 | BaselineLearner.cosine_classifier = True 13 | BaselineLearner.cosine_logits_multiplier = 1 14 | BaselineLearner.use_weight_norm = True 15 | 16 | # Data hypers. 17 | DataConfig.image_height = 126 18 | 19 | # Training hypers (not needed for eval). 20 | Trainer.num_updates = 50000 21 | Trainer.decay_every = 100 22 | Trainer.decay_rate = 0.5082121576573064 23 | Trainer.learning_rate = 0.007084776688116927 24 | weight_decay = 0.000005078192976503067 25 | -------------------------------------------------------------------------------- /meta_dataset/learn/gin/best/pretrained_convnet.gin: -------------------------------------------------------------------------------- 1 | # Best pre-trained convnet. Update the path to the checkpoint. 2 | Learner.embedding_fn = @four_layer_convnet 3 | 4 | Trainer.checkpoint_to_restore = '/path/to/checkpoints/pretrain_imagenet_convnet/model_42500.ckpt' 5 | Trainer.pretrained_source = 'imagenet' 6 | -------------------------------------------------------------------------------- /meta_dataset/learn/gin/best/pretrained_resnet.gin: -------------------------------------------------------------------------------- 1 | # Best pre-trained resnet. Update the path to the checkpoint. 2 | Learner.embedding_fn = @resnet 3 | 4 | Trainer.checkpoint_to_restore = '/path/to/checkpoints/pretrain_imagenet_resnet/model_27500.ckpt' 5 | Trainer.pretrained_source = 'imagenet' 6 | -------------------------------------------------------------------------------- /meta_dataset/learn/gin/best/pretrained_wide_resnet.gin: -------------------------------------------------------------------------------- 1 | # Best pre-trained wide resnet. Update the path to the checkpoint. 2 | Learner.embedding_fn = @wide_resnet 3 | 4 | Trainer.checkpoint_to_restore = '/path/to/checkpoints/pretrain_imagenet_wide_resnet/model_46000.ckpt' 5 | Trainer.pretrained_source = 'imagenet' 6 | -------------------------------------------------------------------------------- /meta_dataset/learn/gin/best/prototypical_all.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/setups/all.gin' 2 | include 'meta_dataset/learn/gin/learners/prototypical_config.gin' 3 | 4 | # Backbone hypers. 5 | include 'meta_dataset/learn/gin/best_v2/pretrained_resnet.gin' 6 | 7 | # Data hypers. 8 | DataConfig.image_height = 126 9 | 10 | # Training hypers (not needed for eval). 11 | Trainer.decay_learning_rate = True 12 | Trainer.decay_every = 100 13 | Trainer.decay_rate = 0.5475646651850319 14 | Trainer.learning_rate = 0.0018231563858704268 15 | weight_decay = 0.0074472606534577565 16 | -------------------------------------------------------------------------------- /meta_dataset/learn/gin/best/prototypical_all_from_scratch.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/setups/all.gin' 2 | include 'meta_dataset/learn/gin/learners/prototypical_config.gin' 3 | 4 | # Backbone hypers. 5 | 6 | Learner.embedding_fn = @resnet 7 | 8 | Trainer.pretrained_source = 'scratch' 9 | Trainer.checkpoint_to_restore = '' 10 | 11 | # Data hypers. 12 | DataConfig.image_height = 126 13 | 14 | # Training hypers (not needed for eval). 15 | Trainer.decay_learning_rate = True 16 | Trainer.decay_every = 2500 17 | Trainer.decay_rate = 0.8333411536286996 18 | Trainer.learning_rate = 0.0006229660387662655 19 | weight_decay = 0.00018036259587809225 20 | -------------------------------------------------------------------------------- /meta_dataset/learn/gin/best/prototypical_imagenet.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/setups/imagenet.gin' 2 | include 'meta_dataset/learn/gin/learners/prototypical_config.gin' 3 | 4 | # Backbone hypers. 5 | include 'meta_dataset/learn/gin/best_v2/pretrained_resnet.gin' 6 | 7 | # Data hypers. 8 | DataConfig.image_height = 126 9 | 10 | # Training hypers (not needed for eval). 11 | Trainer.decay_learning_rate = True 12 | Trainer.decay_every = 500 13 | Trainer.decay_rate = 0.915193802145601 14 | Trainer.learning_rate = 0.0012064626897259694 15 | weight_decay = 0.0000885787420909229 16 | -------------------------------------------------------------------------------- /meta_dataset/learn/gin/best/prototypical_imagenet_from_scratch.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/setups/imagenet.gin' 2 | include 'meta_dataset/learn/gin/learners/prototypical_config.gin' 3 | 4 | # Backbone hypers. 5 | Learner.embedding_fn = @resnet 6 | 7 | Trainer.pretrained_source = 'scratch' 8 | Trainer.checkpoint_to_restore = '' 9 | 10 | # Data hypers. 11 | DataConfig.image_height = 126 12 | 13 | # Training hypers (not needed for eval). 14 | Trainer.decay_learning_rate = True 15 | Trainer.decay_every = 2500 16 | Trainer.decay_rate = 0.8333411536286996 17 | Trainer.learning_rate = 0.0006229660387662655 18 | weight_decay = 0.00018036259587809225 19 | -------------------------------------------------------------------------------- /meta_dataset/learn/gin/best/prototypical_inference_all.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/setups/all.gin' 2 | include 'meta_dataset/learn/gin/learners/prototypical_config.gin' 3 | 4 | # Backbone hypers. 5 | include 'meta_dataset/learn/gin/best_v2/pretrained_wide_resnet.gin' 6 | 7 | # Data hypers. 8 | DataConfig.image_height = 126 9 | 10 | # Training hypers (not needed for eval). 11 | Trainer.decay_learning_rate = True 12 | Trainer.decay_every = 100 13 | Trainer.decay_rate = 0.7426809516243701 14 | Trainer.learning_rate = 0.0009325756201058525 15 | weight_decay = 0.00003386806355382518 16 | 17 | # Baseline hypers (just for the record). 18 | BaselineLearner.cosine_logits_multiplier = 2 19 | BaselineLearner.use_weight_norm = True 20 | BaselineLearner.knn_distance = 'l2' 21 | BaselineLearner.cosine_classifier = False 22 | -------------------------------------------------------------------------------- /meta_dataset/learn/gin/best/prototypical_inference_imagenet.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/setups/imagenet.gin' 2 | include 'meta_dataset/learn/gin/learners/prototypical_config.gin' 3 | 4 | # Backbone hypers. 5 | include 'meta_dataset/learn/gin/best_v2/pretrained_wide_resnet.gin' 6 | 7 | # Data hypers. 8 | DataConfig.image_height = 126 9 | 10 | # Training hypers (not needed for eval). 11 | Trainer.decay_learning_rate = True 12 | Trainer.decay_every = 100 13 | Trainer.decay_rate = 0.7426809516243701 14 | Trainer.learning_rate = 0.0009325756201058525 15 | weight_decay = 0.00003386806355382518 16 | 17 | # Baseline hypers (just for the record). 18 | BaselineLearner.cosine_logits_multiplier = 2 19 | BaselineLearner.use_weight_norm = True 20 | BaselineLearner.knn_distance = 'l2' 21 | BaselineLearner.cosine_classifier = False 22 | -------------------------------------------------------------------------------- /meta_dataset/learn/gin/best/relationnet_all.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/setups/all.gin' 2 | include 'meta_dataset/learn/gin/learners/relationnet_config.gin' 3 | 4 | # Backbone hypers. 5 | include 'meta_dataset/learn/gin/best_v2/pretrained_convnet.gin' 6 | 7 | # Data hypers. 8 | DataConfig.image_height = 126 9 | 10 | # Training hypers (not needed for eval). 11 | Trainer.decay_learning_rate = True 12 | Trainer.decay_every = 5000 13 | Trainer.decay_rate = 0.8707355191010226 14 | Trainer.learning_rate = 0.000906783323100297 15 | weight_decay = 0.0000617576791048944 16 | -------------------------------------------------------------------------------- /meta_dataset/learn/gin/best/relationnet_all_from_scratch.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/setups/all.gin' 2 | include 'meta_dataset/learn/gin/learners/relationnet_config.gin' 3 | 4 | # Backbone hypers. 5 | Learner.embedding_fn = @four_layer_convnet 6 | 7 | Trainer.pretrained_source = 'scratch' 8 | Trainer.checkpoint_to_restore = '' 9 | 10 | # Data hypers. 11 | DataConfig.image_height = 126 12 | 13 | # Training hypers (not needed for eval). 14 | Trainer.decay_learning_rate = True 15 | Trainer.decay_every = 2500 16 | Trainer.decay_rate = 0.9197652570309498 17 | Trainer.learning_rate = 0.002748105034397538 18 | weight_decay = 0.0000013254789476822292 19 | -------------------------------------------------------------------------------- /meta_dataset/learn/gin/best/relationnet_imagenet.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/setups/imagenet.gin' 2 | include 'meta_dataset/learn/gin/learners/relationnet_config.gin' 3 | 4 | # Backbone hypers. 5 | include 'meta_dataset/learn/gin/best_v2/pretrained_convnet.gin' 6 | 7 | # Data hypers. 8 | DataConfig.image_height = 126 9 | 10 | # Training hypers (not needed for eval). 11 | Trainer.decay_learning_rate = True 12 | Trainer.decay_every = 5000 13 | Trainer.decay_rate = 0.8707355191010226 14 | Trainer.learning_rate = 0.000906783323100297 15 | weight_decay = 0.0000617576791048944 16 | -------------------------------------------------------------------------------- /meta_dataset/learn/gin/best/relationnet_imagenet_from_scratch.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/setups/imagenet.gin' 2 | include 'meta_dataset/learn/gin/learners/relationnet_config.gin' 3 | 4 | # Backbone hypers. 5 | Learner.embedding_fn = @four_layer_convnet 6 | 7 | Trainer.pretrained_source = 'scratch' 8 | Trainer.checkpoint_to_restore = '' 9 | 10 | # Data hypers. 11 | DataConfig.image_height = 126 12 | 13 | # Training hypers (not needed for eval). 14 | Trainer.decay_learning_rate = True 15 | Trainer.decay_every = 2500 16 | Trainer.decay_rate = 0.9197652570309498 17 | Trainer.learning_rate = 0.002748105034397538 18 | weight_decay = 0.0000013254789476822292 19 | -------------------------------------------------------------------------------- /meta_dataset/learn/gin/best/relationnet_mini_imagenet_fiveshot.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/setups/mini_imagenet.gin' 2 | include 'meta_dataset/learn/gin/learners/relationnet_config.gin' 3 | 4 | ## Hyperparameters used to train the best model 5 | Trainer.checkpoint_to_restore = '' 6 | Trainer.pretrained_source = 'scratch' 7 | weight_decay = 1e-6 8 | -------------------------------------------------------------------------------- /meta_dataset/learn/gin/default/baseline_all.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/setups/all.gin' 2 | 3 | include 'meta_dataset/learn/gin/learners/baseline_config.gin' 4 | Trainer.learning_rate = 1e-5 5 | BatchSplitReaderGetReader.add_dataset_offset = True 6 | -------------------------------------------------------------------------------- /meta_dataset/learn/gin/default/baseline_cosine_imagenet.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/setups/imagenet.gin' 2 | 3 | include 'meta_dataset/learn/gin/learners/baseline_cosine_config.gin' 4 | Trainer.learning_rate = 1e-5 5 | -------------------------------------------------------------------------------- /meta_dataset/learn/gin/default/baseline_imagenet.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/setups/imagenet.gin' 2 | 3 | include 'meta_dataset/learn/gin/learners/baseline_config.gin' 4 | Trainer.learning_rate = 1e-5 5 | -------------------------------------------------------------------------------- /meta_dataset/learn/gin/default/baseline_mini_imagenet_fiveshot.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/setups/mini_imagenet.gin' 2 | EpisodeDescriptionConfig.num_support = 5 3 | 4 | include 'meta_dataset/learn/gin/learners/baseline_config.gin' 5 | Trainer.learning_rate = 1e-4 6 | Trainer.decay_learning_rate = True 7 | -------------------------------------------------------------------------------- /meta_dataset/learn/gin/default/baseline_mini_imagenet_oneshot.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/setups/mini_imagenet.gin' 2 | EpisodeDescriptionConfig.num_support = 1 3 | 4 | include 'meta_dataset/learn/gin/learners/baseline_config.gin' 5 | Trainer.learning_rate = 1e-4 6 | Trainer.decay_learning_rate = True 7 | -------------------------------------------------------------------------------- /meta_dataset/learn/gin/default/baselinefinetune_all.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/setups/all.gin' 2 | 3 | include 'meta_dataset/learn/gin/learners/baselinefinetune_config.gin' 4 | Trainer.learning_rate = 1e-5 5 | BatchSplitReaderGetReader.add_dataset_offset = True 6 | -------------------------------------------------------------------------------- /meta_dataset/learn/gin/default/baselinefinetune_cosine_imagenet.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/setups/imagenet.gin' 2 | 3 | include 'meta_dataset/learn/gin/learners/baselinefinetune_cosine_config.gin' 4 | Trainer.learning_rate = 1e-5 5 | -------------------------------------------------------------------------------- /meta_dataset/learn/gin/default/baselinefinetune_cosine_mini_imagenet_fiveshot.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/setups/mini_imagenet.gin' 2 | EpisodeDescriptionConfig.num_support = 5 3 | 4 | include 'meta_dataset/learn/gin/learners/baselinefinetune_cosine_config.gin' 5 | Trainer.learning_rate = 1e-4 6 | Trainer.decay_learning_rate = True 7 | -------------------------------------------------------------------------------- /meta_dataset/learn/gin/default/baselinefinetune_cosine_mini_imagenet_oneshot.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/setups/mini_imagenet.gin' 2 | EpisodeDescriptionConfig.num_support = 1 3 | 4 | include 'meta_dataset/learn/gin/learners/baselinefinetune_cosine_config.gin' 5 | Trainer.learning_rate = 1e-4 6 | Trainer.decay_learning_rate = True 7 | -------------------------------------------------------------------------------- /meta_dataset/learn/gin/default/baselinefinetune_imagenet.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/setups/imagenet.gin' 2 | 3 | include 'meta_dataset/learn/gin/learners/baselinefinetune_config.gin' 4 | Trainer.learning_rate = 1e-5 5 | -------------------------------------------------------------------------------- /meta_dataset/learn/gin/default/baselinefinetune_mini_imagenet_fiveshot.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/setups/mini_imagenet.gin' 2 | EpisodeDescriptionConfig.num_support = 5 3 | 4 | include 'meta_dataset/learn/gin/learners/baselinefinetune_config.gin' 5 | Trainer.learning_rate = 1e-4 6 | Trainer.decay_learning_rate = True 7 | -------------------------------------------------------------------------------- /meta_dataset/learn/gin/default/baselinefinetune_mini_imagenet_oneshot.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/setups/mini_imagenet.gin' 2 | EpisodeDescriptionConfig.num_support = 1 3 | 4 | include 'meta_dataset/learn/gin/learners/baselinefinetune_config.gin' 5 | Trainer.learning_rate = 1e-4 6 | Trainer.decay_learning_rate = True 7 | -------------------------------------------------------------------------------- /meta_dataset/learn/gin/default/crosstransformer_imagenet.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/setups/imagenet.gin' 2 | include 'meta_dataset/learn/gin/learners/crosstransformer_config.gin' 3 | 4 | # Backbone hypers. 5 | include 'meta_dataset/learn/gin/default/resnet34_stride16.gin' 6 | 7 | # Data hypers. 8 | DataConfig.image_height = 224 9 | train/EpisodeDescriptionConfig.ignore_hierarchy_probability = 0.5 10 | resnet34.deeplab_alignment = False 11 | 12 | # Training hypers (not needed for eval). 13 | Trainer.num_updates = 100000 14 | Trainer.normalized_gradient_descent = True 15 | Trainer.decay_learning_rate = True 16 | Trainer.decay_every = 2000 17 | Trainer.decay_rate = 0.915193802145601 18 | Trainer.learning_rate = 0.0006 19 | Trainer.num_updates = 100000 20 | weight_decay = 0.0000885787420909229 21 | -------------------------------------------------------------------------------- /meta_dataset/learn/gin/default/crosstransformer_simclreps_imagenet.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/setups/imagenet.gin' 2 | include 'meta_dataset/learn/gin/learners/crosstransformer_config.gin' 3 | 4 | # Backbone hypers. 5 | include 'meta_dataset/learn/gin/default/resnet34_stride16.gin' 6 | 7 | # Data hypers. 8 | DataConfig.image_height = 224 9 | train/EpisodeDescriptionConfig.ignore_hierarchy_probability = 0.5 10 | resnet34.deeplab_alignment = False 11 | 12 | # Training hypers (not needed for eval). 13 | Trainer.normalized_gradient_descent = True 14 | Trainer.decay_learning_rate = True 15 | Trainer.decay_every = 4000 16 | Trainer.decay_rate = 0.915193802145601 17 | Trainer.learning_rate = 0.0006 18 | Trainer.num_updates = 400000 19 | Trainer.enable_tf_optimizations = False 20 | weight_decay = 0.0000885787420909229 21 | 22 | train/EpisodeDescriptionConfig.simclr_episode_fraction = 0.5 23 | -------------------------------------------------------------------------------- /meta_dataset/learn/gin/default/debug_proto_fungi.gin: -------------------------------------------------------------------------------- 1 | # 'fungi' has many classes (> 1000, more than ImageNet) and a wide spread of 2 | # class sizes (from 6 to 442), while being overall not too big (13GB on disk), 3 | # which makes it a good candidate to profile imbalanced datasets with many 4 | # classes locally. 5 | benchmark.datasets = 'fungi' 6 | include 'meta_dataset/learn/gin/setups/data_config.gin' 7 | include 'meta_dataset/learn/gin/setups/trainer_config_debug.gin' 8 | include 'meta_dataset/learn/gin/setups/variable_way_and_shot.gin' 9 | include 'meta_dataset/learn/gin/learners/prototypical_config.gin' 10 | Trainer.data_config = @DataConfig() 11 | 12 | # Total number of updates is 100, do not checkpoint or validate during profiling. 13 | Trainer.checkpoint_every = 1000 14 | Trainer.validate_every = 1000 15 | Trainer.log_every = 10 16 | 17 | Trainer.train_episode_config = @train/EpisodeDescriptionConfig() 18 | Trainer.eval_episode_config = @EpisodeDescriptionConfig() 19 | -------------------------------------------------------------------------------- /meta_dataset/learn/gin/default/debug_proto_mini_imagenet.gin: -------------------------------------------------------------------------------- 1 | benchmark.datasets = 'mini_imagenet' 2 | include 'meta_dataset/learn/gin/setups/data_config.gin' 3 | include 'meta_dataset/learn/gin/setups/trainer_config_debug.gin' 4 | include 'meta_dataset/learn/gin/setups/variable_way_and_shot.gin' 5 | include 'meta_dataset/learn/gin/learners/prototypical_config.gin' 6 | Trainer.data_config = @DataConfig() 7 | Trainer.learning_rate = 1e-4 8 | 9 | Trainer.train_episode_config = @train/EpisodeDescriptionConfig() 10 | Trainer.eval_episode_config = @EpisodeDescriptionConfig() 11 | -------------------------------------------------------------------------------- /meta_dataset/learn/gin/default/flute.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/setups/all.gin' 2 | include 'meta_dataset/learn/gin/learners/learner_config.gin' 3 | include 'meta_dataset/learn/gin/setups/trainer_config_flute.gin' 4 | include 'meta_dataset/learn/gin/setups/data_config_flute.gin' 5 | 6 | # Learners to use at train and validation. 7 | Trainer_flute.train_learner_class = @DatasetConditionalBaselineLearner 8 | Trainer_flute.eval_learner_class = @DatasetConditionalPrototypicalNetworkLearner 9 | BaselineLearner.cosine_classifier = True 10 | BaselineLearner.cosine_logits_multiplier = 10. 11 | BaselineLearner.use_weight_norm = False 12 | BatchSplitReaderGetReader.add_dataset_offset = True 13 | BaselineLearner.knn_in_fc = False 14 | BaselineLearner.knn_distance = 'l2' # Does not matter. 15 | 16 | # Optimization. 17 | Trainer_flute.optimizer_type = 'momentum' 18 | Trainer_flute.learn_rate_scheduler = 'cosine_decay_restarts' 19 | Trainer_flute.decay_learning_rate = True 20 | Trainer_flute.sample_half_from_imagenet = True 21 | Trainer_flute.meta_batch_size = 8 22 | Trainer_flute.batch_size = 16 23 | Trainer_flute.learning_rate = 0.01 24 | Trainer_flute.decay_every = 5000 25 | Trainer_flute.num_updates = 640000 26 | Trainer_flute.validate_every = 1000 27 | Trainer_flute.checkpoint_every = 2000 28 | 29 | # Weight decay and learner settings. 30 | separate_head_linear_classifier.learnable_scale = True 31 | separate_head_linear_classifier.weight_decay = %weight_decay 32 | flute_resnet.weight_decay = %weight_decay 33 | weight_decay = 7e-4 34 | Learner.transductive_batch_norm = False 35 | Learner.backprop_through_moments = True 36 | 37 | # Backbone settings. 38 | Learner.embedding_fn = @flute_resnet 39 | bn_wrapper.batch_norm_fn = @bn_flute_train 40 | bn_flute_train.film_weight_decay = 0.0001 41 | DatasetConditionalBaselineLearner.num_sets = %num_film_sets 42 | DatasetConditionalPrototypicalNetworkLearner.num_sets = %num_film_sets 43 | bn_wrapper.num_film_sets = %num_film_sets 44 | num_film_sets = 8 45 | 46 | # Data settings. 47 | # Validate on the same datasets as training happens, and in the same order, so 48 | # that the ground truth source ID can be used for validation-time forward passes 49 | # without remapping. 50 | benchmark.train_datasets = 'ilsvrc_2012,aircraft,cu_birds,omniglot,quickdraw,vgg_flower,dtd,fungi' 51 | benchmark.eval_datasets = 'ilsvrc_2012,aircraft,cu_birds,omniglot,quickdraw,vgg_flower,dtd,fungi' 52 | -------------------------------------------------------------------------------- /meta_dataset/learn/gin/default/flute_dataset_classifier.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/setups/all.gin' 2 | include 'meta_dataset/learn/gin/learners/learner_config.gin' 3 | include 'meta_dataset/learn/gin/setups/trainer_config_flute.gin' 4 | include 'meta_dataset/learn/gin/setups/data_config_flute.gin' 5 | 6 | # Learners to use at train and validation. 7 | Trainer_flute.train_learner_class = @DatasetLearner 8 | Trainer_flute.eval_learner_class = @DatasetLearner 9 | BatchSplitReaderGetReader.add_dataset_offset = True 10 | Learner.transductive_batch_norm = False 11 | Learner.backprop_through_moments = True 12 | 13 | # Optimization. 14 | Trainer_flute.optimizer_type = 'adam' 15 | Trainer_flute.learn_rate_scheduler = 'cosine_decay' 16 | Trainer_flute.decay_learning_rate = True 17 | Trainer_flute.sample_half_from_imagenet = False 18 | Trainer_flute.meta_batch_size = 8 19 | Trainer_flute.batch_size = 16 20 | Trainer_flute.learning_rate = 0.001 21 | Trainer_flute.decay_every = 3000 22 | Trainer_flute.num_updates = 14000 23 | 24 | # Backbone settings. 25 | Learner.embedding_fn = @dataset_classifier 26 | dataset_classifier.weight_decay = %weight_decay 27 | dataset_classifier.num_datasets = %num_film_sets 28 | num_film_sets = 8 29 | weight_decay = 7e-4 30 | 31 | # Data settings. 32 | # Validate on the same datasets as training happens, and in the same order, so 33 | # that the ground truth source ID can be used for validation-time forward passes 34 | # without remapping. 35 | benchmark.train_datasets = 'ilsvrc_2012,aircraft,cu_birds,omniglot,quickdraw,vgg_flower,dtd,fungi' 36 | benchmark.eval_datasets = 'ilsvrc_2012,aircraft,cu_birds,omniglot,quickdraw,vgg_flower,dtd,fungi' 37 | -------------------------------------------------------------------------------- /meta_dataset/learn/gin/default/maml_all.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/setups/all.gin' 2 | 3 | include 'meta_dataset/learn/gin/learners/maml_config.gin' 4 | Trainer.learning_rate = 1e-3 5 | -------------------------------------------------------------------------------- /meta_dataset/learn/gin/default/maml_imagenet.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/setups/imagenet.gin' 2 | 3 | include 'meta_dataset/learn/gin/learners/maml_config.gin' 4 | Trainer.learning_rate = 1e-3 5 | -------------------------------------------------------------------------------- /meta_dataset/learn/gin/default/maml_init_with_proto_all.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/setups/all.gin' 2 | 3 | include 'meta_dataset/learn/gin/learners/maml_config.gin' 4 | Trainer.learning_rate = 1e-3 5 | -------------------------------------------------------------------------------- /meta_dataset/learn/gin/default/maml_init_with_proto_imagenet.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/setups/imagenet.gin' 2 | 3 | include 'meta_dataset/learn/gin/learners/maml_config.gin' 4 | Trainer.learning_rate = 1e-3 5 | -------------------------------------------------------------------------------- /meta_dataset/learn/gin/default/maml_init_with_proto_inference_all.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/setups/all.gin' 2 | include 'meta_dataset/learn/gin/learners/learner_config.gin' 3 | BatchSplitReaderGetReader.add_dataset_offset = True 4 | Trainer.train_learner_class = @BaselineLearner 5 | Trainer.eval_learner_class = @MAMLLearner 6 | # The following line is what makes this proto-MAML. 7 | MAMLLearner.proto_maml_fc_layer_init = True 8 | weight_decay = 1e-4 9 | BaselineLearner.knn_in_fc = False 10 | MAMLLearner.debug = False 11 | 12 | BaselineLearner.transductive_batch_norm = False 13 | BaselineLearner.backprop_through_moments = True 14 | 15 | MAMLLearner.transductive_batch_norm = False 16 | MAMLLearner.backprop_through_moments = True 17 | -------------------------------------------------------------------------------- /meta_dataset/learn/gin/default/maml_init_with_proto_inference_imagenet.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/setups/imagenet.gin' 2 | include 'meta_dataset/learn/gin/learners/learner_config.gin' 3 | Trainer.train_learner_class = @BaselineLearner 4 | Trainer.eval_learner_class = @MAMLLearner 5 | # The following line is what makes this proto-MAML. 6 | MAMLLearner.proto_maml_fc_layer_init = True 7 | weight_decay = 1e-4 8 | BaselineLearner.knn_in_fc = False 9 | MAMLLearner.debug = False 10 | 11 | BaselineLearner.transductive_batch_norm = False 12 | BaselineLearner.backprop_through_moments = True 13 | 14 | MAMLLearner.transductive_batch_norm = False 15 | MAMLLearner.backprop_through_moments = True 16 | -------------------------------------------------------------------------------- /meta_dataset/learn/gin/default/maml_init_with_proto_mini_imagenet_fiveshot.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/setups/mini_imagenet.gin' 2 | EpisodeDescriptionConfig.num_support = 5 3 | 4 | include 'meta_dataset/learn/gin/learners/maml_config.gin' 5 | Trainer.learning_rate = 1e-4 6 | -------------------------------------------------------------------------------- /meta_dataset/learn/gin/default/maml_init_with_proto_mini_imagenet_oneshot.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/setups/mini_imagenet.gin' 2 | EpisodeDescriptionConfig.num_support = 1 3 | 4 | include 'meta_dataset/learn/gin/learners/maml_config.gin' 5 | Trainer.learning_rate = 1e-4 6 | -------------------------------------------------------------------------------- /meta_dataset/learn/gin/default/maml_mini_imagenet_fiveshot.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/setups/mini_imagenet.gin' 2 | EpisodeDescriptionConfig.num_support = 5 3 | 4 | include 'meta_dataset/learn/gin/learners/maml_config.gin' 5 | Trainer.learning_rate = 1e-4 6 | -------------------------------------------------------------------------------- /meta_dataset/learn/gin/default/maml_mini_imagenet_oneshot.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/setups/mini_imagenet.gin' 2 | EpisodeDescriptionConfig.num_support = 1 3 | 4 | include 'meta_dataset/learn/gin/learners/maml_config.gin' 5 | Trainer.learning_rate = 1e-4 6 | -------------------------------------------------------------------------------- /meta_dataset/learn/gin/default/maml_protonet_all.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/setups/all.gin' 2 | 3 | include 'meta_dataset/learn/gin/learners/maml_config.gin' 4 | Trainer.learning_rate = 1e-3 5 | -------------------------------------------------------------------------------- /meta_dataset/learn/gin/default/matching_all.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/setups/all.gin' 2 | 3 | include 'meta_dataset/learn/gin/learners/matching_config.gin' 4 | Trainer.learning_rate = 1e-3 5 | -------------------------------------------------------------------------------- /meta_dataset/learn/gin/default/matching_imagenet.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/setups/imagenet.gin' 2 | 3 | include 'meta_dataset/learn/gin/learners/matching_config.gin' 4 | Trainer.learning_rate = 1e-3 5 | -------------------------------------------------------------------------------- /meta_dataset/learn/gin/default/matching_inference_all.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/setups/all.gin' 2 | include 'meta_dataset/learn/gin/learners/learner_config.gin' 3 | BatchSplitReaderGetReader.add_dataset_offset = True 4 | Trainer.train_learner_class = @BaselineLearner 5 | Trainer.eval_learner_class = @MatchingNetworkLearner 6 | weight_decay = 1e-4 7 | BaselineLearner.knn_in_fc = False 8 | MatchingNetworkLearner.exact_cosine_distance = False 9 | 10 | BaselineLearner.transductive_batch_norm = False 11 | BaselineLearner.backprop_through_moments = True 12 | 13 | MatchingNetworkLearner.transductive_batch_norm = False 14 | MatchingNetworkLearner.backprop_through_moments = True 15 | -------------------------------------------------------------------------------- /meta_dataset/learn/gin/default/matching_inference_imagenet.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/setups/imagenet.gin' 2 | include 'meta_dataset/learn/gin/learners/learner_config.gin' 3 | Trainer.train_learner_class = @BaselineLearner 4 | Trainer.eval_learner_class = @MatchingNetworkLearner 5 | weight_decay = 1e-4 6 | BaselineLearner.knn_in_fc = False 7 | MatchingNetworkLearner.exact_cosine_distance = False 8 | 9 | BaselineLearner.transductive_batch_norm = False 10 | BaselineLearner.backprop_through_moments = True 11 | 12 | MatchingNetworkLearner.transductive_batch_norm = False 13 | MatchingNetworkLearner.backprop_through_moments = True 14 | -------------------------------------------------------------------------------- /meta_dataset/learn/gin/default/matching_mini_imagenet_fiveshot.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/setups/mini_imagenet.gin' 2 | EpisodeDescriptionConfig.num_support = 5 3 | 4 | include 'meta_dataset/learn/gin/learners/matching_config.gin' 5 | Trainer.learning_rate = 1e-4 6 | -------------------------------------------------------------------------------- /meta_dataset/learn/gin/default/matching_mini_imagenet_oneshot.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/setups/mini_imagenet.gin' 2 | EpisodeDescriptionConfig.num_support = 1 3 | 4 | include 'meta_dataset/learn/gin/learners/matching_config.gin' 5 | Trainer.learning_rate = 1e-4 6 | -------------------------------------------------------------------------------- /meta_dataset/learn/gin/default/pretrained_resnet34_224.gin: -------------------------------------------------------------------------------- 1 | # Best pre-trained resnet. Update the path to the checkpoint. 2 | Learner.embedding_fn = @resnet34 3 | Trainer.checkpoint_to_restore = '/path/to/checkpoints/pretrain_imagenet_resnet/model_27500.ckpt' 4 | Trainer.pretrained_source = 'imagenet' 5 | -------------------------------------------------------------------------------- /meta_dataset/learn/gin/default/prototypical_all.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/setups/all.gin' 2 | 3 | include 'meta_dataset/learn/gin/learners/prototypical_config.gin' 4 | Trainer.learning_rate = 1e-3 5 | -------------------------------------------------------------------------------- /meta_dataset/learn/gin/default/prototypical_imagenet.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/setups/imagenet.gin' 2 | 3 | include 'meta_dataset/learn/gin/learners/prototypical_config.gin' 4 | Trainer.learning_rate = 1e-3 5 | -------------------------------------------------------------------------------- /meta_dataset/learn/gin/default/prototypical_inference_all.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/setups/all.gin' 2 | include 'meta_dataset/learn/gin/learners/learner_config.gin' 3 | BatchSplitReaderGetReader.add_dataset_offset = True 4 | Trainer.train_learner_class = @BaselineLearner 5 | Trainer.eval_learner_class = @PrototypicalNetworkLearner 6 | weight_decay = 1e-4 7 | BaselineLearner.knn_in_fc = False 8 | 9 | BaselineLearner.transductive_batch_norm = False 10 | BaselineLearner.backprop_through_moments = True 11 | 12 | PrototypicalNetworkLearner.transductive_batch_norm = False 13 | PrototypicalNetworkLearner.backprop_through_moments = True 14 | -------------------------------------------------------------------------------- /meta_dataset/learn/gin/default/prototypical_inference_imagenet.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/setups/imagenet.gin' 2 | include 'meta_dataset/learn/gin/learners/learner_config.gin' 3 | Trainer.train_learner_class = @BaselineLearner 4 | Trainer.eval_learner_class = @PrototypicalNetworkLearner 5 | weight_decay = 1e-4 6 | BaselineLearner.knn_in_fc = False 7 | 8 | BaselineLearner.transductive_batch_norm = False 9 | BaselineLearner.backprop_through_moments = True 10 | 11 | PrototypicalNetworkLearner.transductive_batch_norm = False 12 | PrototypicalNetworkLearner.backprop_through_moments = True 13 | -------------------------------------------------------------------------------- /meta_dataset/learn/gin/default/prototypical_mini_imagenet_fiveshot.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/setups/mini_imagenet.gin' 2 | EpisodeDescriptionConfig.num_support = 5 3 | train/EpisodeDescriptionConfig.num_ways = 20 4 | 5 | include 'meta_dataset/learn/gin/learners/prototypical_config.gin' 6 | Trainer.learning_rate = 1e-4 7 | -------------------------------------------------------------------------------- /meta_dataset/learn/gin/default/prototypical_mini_imagenet_oneshot.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/setups/mini_imagenet.gin' 2 | EpisodeDescriptionConfig.num_support = 1 3 | train/EpisodeDescriptionConfig.num_ways = 30 4 | 5 | include 'meta_dataset/learn/gin/learners/prototypical_config.gin' 6 | Trainer.learning_rate = 1e-4 7 | -------------------------------------------------------------------------------- /meta_dataset/learn/gin/default/relationnet_all.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/setups/all.gin' 2 | 3 | include 'meta_dataset/learn/gin/learners/relationnet_config.gin' 4 | 5 | # There are more parameters than usual, so allow more updates. 6 | Trainer.num_updates = 100000 7 | 8 | Trainer.learning_rate = 1e-3 9 | -------------------------------------------------------------------------------- /meta_dataset/learn/gin/default/relationnet_imagenet.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/setups/imagenet.gin' 2 | include 'meta_dataset/learn/gin/learners/relationnet_config.gin' 3 | 4 | # There are more parameters than usual, so allow more updates. 5 | Trainer.num_updates = 100000 6 | 7 | Trainer.learning_rate = 1e-3 8 | -------------------------------------------------------------------------------- /meta_dataset/learn/gin/default/relationnet_mini_imagenet_fiveshot.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/setups/mini_imagenet.gin' 2 | include 'meta_dataset/learn/gin/learners/relationnet_config.gin' 3 | 4 | Trainer.learning_rate = 1e-3 5 | Trainer.decay_every = 100000 6 | Trainer.decay_rate = 0.5 7 | -------------------------------------------------------------------------------- /meta_dataset/learn/gin/default/relationnet_mini_imagenet_oneshot.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/setups/mini_imagenet.gin' 2 | EpisodeDescriptionConfig.num_support = 1 3 | include 'meta_dataset/learn/gin/learners/relationnet_config.gin' 4 | 5 | Trainer.learning_rate = 1e-3 6 | Trainer.decay_every = 100000 7 | Trainer.decay_rate = 0.5 8 | -------------------------------------------------------------------------------- /meta_dataset/learn/gin/default/resnet34_stride16.gin: -------------------------------------------------------------------------------- 1 | Learner.embedding_fn = @resnet34 2 | Trainer.pretrained_source = 'imagenet' 3 | resnet34.max_stride = 16 4 | -------------------------------------------------------------------------------- /meta_dataset/learn/gin/learners/baseline_config.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/learners/learner_config.gin' 2 | Trainer.train_learner_class = @BaselineLearner 3 | Trainer.eval_learner_class = @BaselineLearner 4 | 5 | Learner.embedding_fn = @four_layer_convnet 6 | weight_decay = 1e-4 7 | 8 | BaselineLearner.knn_in_fc = False 9 | BaselineLearner.knn_distance = 'l2' 10 | BaselineLearner.cosine_classifier = False 11 | BaselineLearner.cosine_logits_multiplier = None 12 | BaselineLearner.use_weight_norm = False 13 | -------------------------------------------------------------------------------- /meta_dataset/learn/gin/learners/baseline_cosine_config.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/learners/baseline_config.gin' 2 | 3 | BaselineLearner.cosine_classifier = True 4 | BaselineLearner.cosine_logits_multiplier = 1 5 | -------------------------------------------------------------------------------- /meta_dataset/learn/gin/learners/baselinefinetune_config.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/learners/baseline_config.gin' 2 | Trainer.train_learner_class = @BaselineFinetuneLearner 3 | Trainer.eval_learner_class = @BaselineFinetuneLearner 4 | weight_decay = 1e-4 5 | BaselineFinetuneLearner.num_finetune_steps = 5 6 | BaselineFinetuneLearner.finetune_lr = 1e-4 7 | BaselineFinetuneLearner.finetune_all_layers = False 8 | BaselineFinetuneLearner.finetune_with_adam = False 9 | -------------------------------------------------------------------------------- /meta_dataset/learn/gin/learners/baselinefinetune_cosine_config.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/learners/baselinefinetune_config.gin' 2 | 3 | BaselineLearner.cosine_classifier = True 4 | BaselineLearner.cosine_logits_multiplier = 1 5 | -------------------------------------------------------------------------------- /meta_dataset/learn/gin/learners/crosstransformer_config.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/learners/learner_config.gin' 2 | Trainer.train_learner_class = @CrossTransformerLearner 3 | Trainer.eval_learner_class = @CrossTransformerLearner 4 | Trainer.normalized_gradient_descent = True 5 | Learner.embedding_fn = 'four_layer_convnet' 6 | weight_decay = 1e-4 7 | CrossTransformerLearner.tformer_weight_decay = %weight_decay 8 | bn.use_ema=True 9 | Trainer.distribute = True 10 | -------------------------------------------------------------------------------- /meta_dataset/learn/gin/learners/learner_config.gin: -------------------------------------------------------------------------------- 1 | # Backbone regularization settings. 2 | linear_classifier.weight_decay = %weight_decay 3 | fully_connected_network.weight_decay = %weight_decay 4 | four_layer_convnet.weight_decay = %weight_decay 5 | relationnet_convnet.weight_decay = %weight_decay 6 | relation_module.weight_decay = %weight_decay 7 | resnet.weight_decay = %weight_decay 8 | resnet34.weight_decay = %weight_decay 9 | wide_resnet.weight_decay = %weight_decay 10 | MAMLLearner.classifier_weight_decay = %weight_decay 11 | weight_decay = 0.001 12 | 13 | # Training settings. 14 | Trainer.checkpoint_to_restore = '' 15 | Trainer.learning_rate = 1e-4 16 | Trainer.decay_learning_rate = False 17 | Trainer.decay_every = 5000 18 | Trainer.decay_rate = 0.5 19 | Trainer.normalized_gradient_descent = False 20 | Trainer.pretrained_source = '' 21 | -------------------------------------------------------------------------------- /meta_dataset/learn/gin/learners/maml_config.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/learners/learner_config.gin' 2 | Trainer.train_learner_class = @MAMLLearner 3 | Trainer.eval_learner_class = @MAMLLearner 4 | Trainer.decay_learning_rate = True 5 | 6 | Learner.embedding_fn = @four_layer_convnet 7 | weight_decay = 1e-4 8 | 9 | MAMLLearner.num_update_steps = 5 10 | MAMLLearner.additional_evaluation_update_steps = 0 11 | MAMLLearner.first_order = True 12 | MAMLLearner.alpha = 0.01 13 | MAMLLearner.adapt_batch_norm = False 14 | MAMLLearner.debug = False 15 | MAMLLearner.zero_fc_layer = True 16 | MAMLLearner.proto_maml_fc_layer_init = False 17 | -------------------------------------------------------------------------------- /meta_dataset/learn/gin/learners/maml_init_with_proto_config.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/learners/learner_config.gin' 2 | Trainer.train_learner_class = @MAMLLearner 3 | Trainer.decay_learning_rate = True 4 | Trainer.eval_learner_class = @MAMLLearner 5 | 6 | Learner.embedding_fn = @four_layer_convnet 7 | weight_decay = 1e-4 8 | 9 | MAMLLearner.num_update_steps = 5 10 | MAMLLearner.additional_evaluation_update_steps = 0 11 | MAMLLearner.first_order = True 12 | MAMLLearner.alpha = 0.01 13 | MAMLLearner.adapt_batch_norm = False 14 | MAMLLearner.debug = False 15 | MAMLLearner.zero_fc_layer = True 16 | MAMLLearner.proto_maml_fc_layer_init = True 17 | -------------------------------------------------------------------------------- /meta_dataset/learn/gin/learners/matching_config.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/learners/learner_config.gin' 2 | Trainer.train_learner_class = @MatchingNetworkLearner 3 | Trainer.eval_learner_class = @MatchingNetworkLearner 4 | 5 | Trainer.decay_learning_rate = True 6 | Learner.embedding_fn = @four_layer_convnet 7 | weight_decay = 1e-4 8 | 9 | MatchingNetworkLearner.exact_cosine_distance = False 10 | -------------------------------------------------------------------------------- /meta_dataset/learn/gin/learners/prototypical_config.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/learners/learner_config.gin' 2 | Trainer.train_learner_class = @PrototypicalNetworkLearner 3 | Trainer.eval_learner_class = @PrototypicalNetworkLearner 4 | Trainer.decay_learning_rate = True 5 | 6 | Learner.embedding_fn = @four_layer_convnet 7 | weight_decay = 1e-4 8 | -------------------------------------------------------------------------------- /meta_dataset/learn/gin/learners/relationnet_config.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/learners/learner_config.gin' 2 | Trainer.train_learner_class = @RelationNetworkLearner 3 | Trainer.eval_learner_class = @RelationNetworkLearner 4 | Trainer.decay_learning_rate = True 5 | 6 | Learner.embedding_fn = @relationnet_convnet 7 | Learner.transductive_batch_norm = True 8 | weight_decay = 1e-6 9 | 10 | # Allow transductive batch norm to be faithful with the original Relation Net 11 | # implementation, even though this gives it an advantage over the rest of the 12 | # models we used on Meta-Dataset. 13 | Learner.transductive_batch_norm = True 14 | -------------------------------------------------------------------------------- /meta_dataset/learn/gin/metadataset_v2/baseline_all.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/setups/all_v2.gin' 2 | include 'meta_dataset/learn/gin/learners/baseline_config.gin' 3 | Trainer.learning_rate = 1e-5 4 | BatchSplitReaderGetReader.add_dataset_offset = True 5 | -------------------------------------------------------------------------------- /meta_dataset/learn/gin/metadataset_v2/baseline_imagenet.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/setups/imagenet_v2.gin' 2 | include 'meta_dataset/learn/gin/learners/baseline_config.gin' 3 | Trainer.learning_rate = 1e-5 4 | -------------------------------------------------------------------------------- /meta_dataset/learn/gin/metadataset_v2/baselinefinetune_all.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/setups/all_v2.gin' 2 | include 'meta_dataset/learn/gin/learners/baselinefinetune_config.gin' 3 | Trainer.learning_rate = 1e-5 4 | BatchSplitReaderGetReader.add_dataset_offset = True 5 | -------------------------------------------------------------------------------- /meta_dataset/learn/gin/metadataset_v2/baselinefinetune_imagenet.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/setups/imagenet_v2.gin' 2 | include 'meta_dataset/learn/gin/learners/baselinefinetune_config.gin' 3 | Trainer.learning_rate = 1e-5 4 | -------------------------------------------------------------------------------- /meta_dataset/learn/gin/metadataset_v2/best/baseline_imagenet.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/metadataset_v2/baseline_imagenet.gin' 2 | include 'meta_dataset/learn/gin/learners/baseline_config.gin' 3 | 4 | BatchSplitReaderGetReader.add_dataset_offset = True 5 | 6 | # Backbone hypers. 7 | Learner.embedding_fn = @resnet 8 | Trainer.checkpoint_to_restore = '' 9 | Trainer.pretrained_source = '' 10 | 11 | # Model hypers. 12 | BaselineLearner.knn_distance = 'cosine' 13 | BaselineLearner.cosine_classifier = False 14 | BaselineLearner.cosine_logits_multiplier = 2 15 | BaselineLearner.use_weight_norm = False 16 | 17 | Trainer.decay_every = 1000 18 | Trainer.decay_rate = 0.5979159492081371 19 | Trainer.learning_rate = 0.00047244647904730503 20 | weight_decay = 0.026388517138594258 21 | DataConfig.image_height = 126 22 | DataAugmentation.jitter_amount = 6 23 | DataAugmentation.gaussian_noise_std = 0.17564536824131866 24 | -------------------------------------------------------------------------------- /meta_dataset/learn/gin/metadataset_v2/best/baselinefinetune_all.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/setups/all_v2.gin' 2 | include 'meta_dataset/learn/gin/learners/baselinefinetune_config.gin' 3 | # Backbone hypers. 4 | Learner.embedding_fn = @resnet 5 | weight_decay = 0.0001 6 | 7 | # Model hypers. 8 | BaselineLearner.cosine_classifier = True 9 | BaselineLearner.cosine_logits_multiplier = 10 10 | BaselineLearner.knn_distance = 'l2' 11 | BaselineLearner.knn_in_fc = False 12 | BaselineLearner.use_weight_norm = True 13 | BaselineFinetuneLearner.finetune_all_layers = True 14 | BaselineFinetuneLearner.finetune_lr = 0.01 15 | BaselineFinetuneLearner.finetune_with_adam = True 16 | BaselineFinetuneLearner.num_finetune_steps = 100 17 | 18 | # Training hypers (not needed for eval). 19 | Trainer.decay_every = 500 20 | Trainer.decay_learning_rate = False 21 | Trainer.decay_rate = 0.5278384940678894 22 | Trainer.learning_rate = 3.4293725734843445e-06 23 | Trainer.pretrained_source = 'imagenet' 24 | Trainer.checkpoint_to_restore = 'path/to/pretrained_checkpoint' 25 | 26 | 27 | DataConfig.image_height = 126 28 | batch/DataAugmentation.gaussian_noise_std = 0.026413512951864337 29 | batch/DataAugmentation.jitter_amount = 5 30 | -------------------------------------------------------------------------------- /meta_dataset/learn/gin/metadataset_v2/best/baselinefinetune_imagenet.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/setups/imagenet_v2.gin' 2 | include 'meta_dataset/learn/gin/learners/baselinefinetune_config.gin' 3 | # Backbone hypers. 4 | Learner.embedding_fn = @resnet 5 | weight_decay = 0.0001 6 | 7 | # Model hypers. 8 | BaselineLearner.cosine_classifier = True 9 | BaselineLearner.cosine_logits_multiplier = 10 10 | BaselineLearner.knn_distance = 'l2' 11 | BaselineLearner.knn_in_fc = False 12 | BaselineLearner.use_weight_norm = True 13 | BaselineFinetuneLearner.finetune_all_layers = True 14 | BaselineFinetuneLearner.finetune_lr = 0.01 15 | BaselineFinetuneLearner.finetune_with_adam = True 16 | BaselineFinetuneLearner.num_finetune_steps = 100 17 | 18 | # Training hypers (not needed for eval). 19 | Trainer.decay_every = 500 20 | Trainer.decay_learning_rate = False 21 | Trainer.decay_rate = 0.5278384940678894 22 | Trainer.learning_rate = 3.4293725734843445e-06 23 | Trainer.pretrained_source = 'imagenet' 24 | Trainer.checkpoint_to_restore = 'path/to/pretrained_checkpoint' 25 | 26 | DataConfig.image_height = 126 27 | batch/DataAugmentation.gaussian_noise_std = 0.026413512951864337 28 | batch/DataAugmentation.jitter_amount = 5 29 | -------------------------------------------------------------------------------- /meta_dataset/learn/gin/metadataset_v2/best/maml_init_with_proto_all.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/setups/all_v2.gin' 2 | include 'meta_dataset/learn/gin/learners/maml_init_with_proto_config.gin' 3 | # Backbone hypers. 4 | Learner.embedding_fn = @resnet 5 | weight_decay = 0.0001 6 | Trainer.pretrained_source = 'imagenet' 7 | Trainer.checkpoint_to_restore = 'path/to/pretrained_checkpoint' 8 | 9 | # Model hypers. 10 | MAMLLearner.adapt_batch_norm = False 11 | MAMLLearner.additional_evaluation_update_steps = 0 12 | MAMLLearner.alpha = 0.005435022808033229 13 | MAMLLearner.first_order = True 14 | MAMLLearner.num_update_steps = 10 15 | MAMLLearner.proto_maml_fc_layer_init = True 16 | MAMLLearner.zero_fc_layer = True 17 | # Training hypers (not needed for eval). 18 | Trainer.decay_every = 1000 19 | Trainer.decay_learning_rate = True 20 | Trainer.decay_rate = 0.6477898086638092 21 | Trainer.learning_rate = 0.00036339913514891586 22 | 23 | DataConfig.image_height = 126 24 | support/DataAugmentation.gaussian_noise_std = 0.4658549336962272 25 | support/DataAugmentation.jitter_amount = 0 26 | -------------------------------------------------------------------------------- /meta_dataset/learn/gin/metadataset_v2/best/maml_init_with_proto_imagenet.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/setups/imagenet_v2.gin' 2 | include 'meta_dataset/learn/gin/learners/maml_init_with_proto_config.gin' 3 | # Backbone hypers. 4 | Learner.embedding_fn = @resnet 5 | weight_decay = 0.0001 6 | Trainer.pretrained_source = 'imagenet' 7 | Trainer.checkpoint_to_restore = 'path/to/pretrained_checkpoint' 8 | 9 | # Model hypers. 10 | MAMLLearner.adapt_batch_norm = False 11 | MAMLLearner.additional_evaluation_update_steps = 90 12 | MAMLLearner.alpha = 0.01 13 | MAMLLearner.first_order = True 14 | MAMLLearner.num_update_steps = 10 15 | MAMLLearner.proto_maml_fc_layer_init = True 16 | MAMLLearner.zero_fc_layer = True 17 | # Training hypers (not needed for eval). 18 | Trainer.decay_every = 1000 19 | Trainer.decay_learning_rate = True 20 | Trainer.decay_rate = 0.6477898086638092 21 | Trainer.learning_rate = 0.00036339913514891586 22 | 23 | DataConfig.image_height = 126 24 | support/DataAugmentation.gaussian_noise_std = 0.4658549336962272 25 | support/DataAugmentation.jitter_amount = 0 26 | -------------------------------------------------------------------------------- /meta_dataset/learn/gin/metadataset_v2/best/prototypical_all.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/setups/all_v2.gin' 2 | include 'meta_dataset/learn/gin/learners/prototypical_config.gin' 3 | # Backbone hypers. 4 | Learner.embedding_fn = @resnet 5 | Trainer.pretrained_source = 'imagenet' 6 | Trainer.checkpoint_to_restore = 'path/to/pretrained_checkpoint' 7 | 8 | # Training hypers (not needed for eval). 9 | Trainer.decay_every = 500 10 | Trainer.decay_learning_rate = True 11 | Trainer.decay_rate = 0.885662482266546 12 | Trainer.learning_rate = 0.00025036275525430426 13 | Learner.backprop_through_moments = True 14 | Learner.transductive_batch_norm = False 15 | weight_decay = 0.0001 16 | 17 | DataConfig.image_height = 126 18 | support/DataAugmentation.gaussian_noise_std = 0.15335348868374565 19 | support/DataAugmentation.jitter_amount = 5 20 | -------------------------------------------------------------------------------- /meta_dataset/learn/gin/metadataset_v2/best/prototypical_imagenet.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/setups/imagenet_v2.gin' 2 | include 'meta_dataset/learn/gin/learners/prototypical_config.gin' 3 | # Backbone hypers. 4 | Learner.embedding_fn = @resnet 5 | Trainer.pretrained_source = 'imagenet' 6 | Trainer.checkpoint_to_restore = 'path/to/pretrained_checkpoint' 7 | 8 | # Training hypers (not needed for eval). 9 | Trainer.decay_every = 500 10 | Trainer.decay_learning_rate = True 11 | Trainer.decay_rate = 0.885662482266546 12 | Trainer.learning_rate = 0.00025036275525430426 13 | Learner.backprop_through_moments = True 14 | Learner.transductive_batch_norm = False 15 | weight_decay = 0.0001 16 | 17 | DataConfig.image_height = 126 18 | support/DataAugmentation.gaussian_noise_std = 0.15335348868374565 19 | support/DataAugmentation.jitter_amount = 5 20 | -------------------------------------------------------------------------------- /meta_dataset/learn/gin/metadataset_v2/maml_init_with_proto_all.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/setups/all_v2.gin' 2 | include 'meta_dataset/learn/gin/learners/maml_config.gin' 3 | Trainer.learning_rate = 1e-3 4 | -------------------------------------------------------------------------------- /meta_dataset/learn/gin/metadataset_v2/maml_init_with_proto_imagenet.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/setups/imagenet_v2.gin' 2 | include 'meta_dataset/learn/gin/learners/maml_config.gin' 3 | Trainer.learning_rate = 1e-3 4 | -------------------------------------------------------------------------------- /meta_dataset/learn/gin/metadataset_v2/prototypical_all.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/setups/all_v2.gin' 2 | include 'meta_dataset/learn/gin/learners/prototypical_config.gin' 3 | Trainer.learning_rate = 1e-3 4 | -------------------------------------------------------------------------------- /meta_dataset/learn/gin/setups/all.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/setups/all_datasets.gin' 2 | include 'meta_dataset/learn/gin/setups/data_config.gin' 3 | include 'meta_dataset/learn/gin/setups/trainer_config.gin' 4 | include 'meta_dataset/learn/gin/setups/variable_way_and_shot.gin' 5 | Trainer.data_config = @DataConfig() 6 | Trainer.train_episode_config = @train/EpisodeDescriptionConfig() 7 | Trainer.eval_episode_config = @EpisodeDescriptionConfig() 8 | -------------------------------------------------------------------------------- /meta_dataset/learn/gin/setups/all_datasets.gin: -------------------------------------------------------------------------------- 1 | benchmark.train_datasets = 'ilsvrc_2012,aircraft,cu_birds,omniglot,quickdraw,vgg_flower,dtd,fungi' 2 | benchmark.eval_datasets = 'ilsvrc_2012' 3 | -------------------------------------------------------------------------------- /meta_dataset/learn/gin/setups/all_v2.gin: -------------------------------------------------------------------------------- 1 | benchmark.train_datasets = 'ilsvrc_2012_v2,aircraft,cu_birds,omniglot,quickdraw,dtd,fungi' 2 | benchmark.eval_datasets = 'fungi,aircraft,quickdraw,omniglot,mscoco,cu_birds,dtd' 3 | include 'meta_dataset/learn/gin/setups/data_config.gin' 4 | include 'meta_dataset/learn/gin/setups/trainer_config.gin' 5 | include 'meta_dataset/learn/gin/setups/variable_way_and_shot.gin' 6 | Trainer.data_config = @DataConfig() 7 | Trainer.train_episode_config = @train/EpisodeDescriptionConfig() 8 | Trainer.eval_episode_config = @EpisodeDescriptionConfig() 9 | -------------------------------------------------------------------------------- /meta_dataset/learn/gin/setups/data_config.gin: -------------------------------------------------------------------------------- 1 | import meta_dataset.data.config 2 | import meta_dataset.data.decoder 3 | include 'meta_dataset/learn/gin/setups/data_config_common.gin' 4 | 5 | # If we decode features then change the lines below to use FeatureDecoder. 6 | process_dumped_episode.support_decoder = @support/ImageDecoder() 7 | process_episode.support_decoder = @support/ImageDecoder() 8 | support/ImageDecoder.data_augmentation = @support/DataAugmentation() 9 | support/DataAugmentation.enable_jitter = True 10 | support/DataAugmentation.jitter_amount = 0 11 | support/DataAugmentation.enable_gaussian_noise = True 12 | support/DataAugmentation.gaussian_noise_std = 0.0 13 | 14 | process_dumped_episode.query_decoder = @query/ImageDecoder() 15 | process_episode.query_decoder = @query/ImageDecoder() 16 | query/ImageDecoder.data_augmentation = @query/DataAugmentation() 17 | query/DataAugmentation.enable_jitter = False 18 | query/DataAugmentation.jitter_amount = 0 19 | query/DataAugmentation.enable_gaussian_noise = False 20 | query/DataAugmentation.gaussian_noise_std = 0.0 21 | 22 | # It is possible to override the support and query decoders for the meta-train 23 | # and / or meta-evaluation phases using the scopes 'train' or 'evaluation' 24 | # (thanks to trainer.py that will appropriately pick the scope for each phase). 25 | # The following commented-out lines show an example: 26 | # train/process_episode.support_decoder = @train/support/ImageDecoder() 27 | # train/support/ImageDecoder.image_size = ... 28 | 29 | process_batch.batch_decoder = @batch/ImageDecoder() 30 | batch/ImageDecoder.data_augmentation = @batch/DataAugmentation() 31 | batch/DataAugmentation.enable_jitter = True 32 | batch/DataAugmentation.jitter_amount = 0 33 | batch/DataAugmentation.enable_gaussian_noise = True 34 | batch/DataAugmentation.gaussian_noise_std = 0.0 35 | 36 | -------------------------------------------------------------------------------- /meta_dataset/learn/gin/setups/data_config_common.gin: -------------------------------------------------------------------------------- 1 | # This gin file doesn't set process_episode.support_decoder and .query_decoder 2 | # therefore the image strings are not decoded. This gin file can be used directly 3 | # if the episodes are needed without decoding. 4 | import meta_dataset.data.config 5 | # Default values for sampling variable shots / ways. 6 | EpisodeDescriptionConfig.min_ways = 5 7 | EpisodeDescriptionConfig.max_ways_upper_bound = 50 8 | EpisodeDescriptionConfig.max_num_query = 10 9 | # For weak shot experiments where we have missing class data. 10 | # This should not affect any other experiements. 11 | EpisodeDescriptionConfig.min_examples_in_class = 0 12 | EpisodeDescriptionConfig.max_support_set_size = 500 13 | EpisodeDescriptionConfig.max_support_size_contrib_per_class = 100 14 | EpisodeDescriptionConfig.min_log_weight = -0.69314718055994529 # np.log(0.5) 15 | EpisodeDescriptionConfig.max_log_weight = 0.69314718055994529 # np.log(2) 16 | EpisodeDescriptionConfig.ignore_dag_ontology = False 17 | EpisodeDescriptionConfig.ignore_bilevel_ontology = False 18 | EpisodeDescriptionConfig.ignore_hierarchy_probability = 0.0 19 | 20 | # By default don't use SimCLR Episodes. 21 | EpisodeDescriptionConfig.simclr_episode_fraction = 0.0 22 | # It is possible to override some of the above defaults only for meta-training. 23 | # An example is shown in the following two commented-out lines. 24 | # train/EpisodeDescriptionConfig.min_ways = 5 25 | # train/EpisodeDescriptionConfig.max_ways_upper_bound = 50 26 | 27 | # Other default values for the data pipeline. 28 | DataConfig.image_height = 84 29 | DataConfig.shuffle_buffer_size = 1000 30 | DataConfig.read_buffer_size_bytes = 1048576 # 1 MB (1024**2) 31 | DataConfig.num_prefetch = 64 32 | -------------------------------------------------------------------------------- /meta_dataset/learn/gin/setups/data_config_feature.gin: -------------------------------------------------------------------------------- 1 | import meta_dataset.data.config 2 | import meta_dataset.data.decoder 3 | import meta_dataset.data.pipeline 4 | include 'meta_dataset/learn/gin/setups/data_config_common.gin' 5 | 6 | process_episode.support_decoder = @FeatureDecoder() 7 | process_episode.query_decoder = @FeatureDecoder() 8 | process_batch.batch_decoder = @FeatureDecoder() 9 | FeatureDecoder.feat_len = 64 10 | -------------------------------------------------------------------------------- /meta_dataset/learn/gin/setups/data_config_flute.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/setups/data_config_common.gin' 2 | # Do SUR-like image processing and data augmentation. This differs from the 3 | # default Meta-Dataset image processing in that more augmentations are added 4 | # (e.g. random flips, random brightness, contrast, saturation etc) 5 | include 'meta_dataset/learn/gin/setups/sur_data_config.gin' 6 | -------------------------------------------------------------------------------- /meta_dataset/learn/gin/setups/data_config_no_decoder.gin: -------------------------------------------------------------------------------- 1 | import meta_dataset.data.config 2 | import meta_dataset.data.decoder 3 | import meta_dataset.data.pipeline 4 | include 'meta_dataset/learn/gin/setups/data_config_common.gin' 5 | 6 | process_episode.support_decoder = None 7 | process_episode.query_decoder = None 8 | process_batch.batch_decoder = None 9 | -------------------------------------------------------------------------------- /meta_dataset/learn/gin/setups/data_config_string.gin: -------------------------------------------------------------------------------- 1 | import meta_dataset.data.config 2 | import meta_dataset.data.decoder 3 | import meta_dataset.data.pipeline 4 | include 'meta_dataset/learn/gin/setups/data_config_common.gin' 5 | 6 | process_episode.support_decoder = @StringDecoder() 7 | process_episode.query_decoder = @StringDecoder() 8 | process_batch.batch_decoder = @StringDecoder() 9 | -------------------------------------------------------------------------------- /meta_dataset/learn/gin/setups/data_config_tfds.gin: -------------------------------------------------------------------------------- 1 | import meta_dataset.data.config 2 | import meta_dataset.data.decoder 3 | import meta_dataset.data.pipeline 4 | 5 | # Variable shots / ways. 6 | EpisodeDescriptionConfig.num_ways = None 7 | EpisodeDescriptionConfig.num_support = None 8 | EpisodeDescriptionConfig.num_query = None 9 | 10 | # Default values for sampling variable shots / ways. 11 | EpisodeDescriptionConfig.min_ways = 5 12 | EpisodeDescriptionConfig.max_ways_upper_bound = 50 13 | EpisodeDescriptionConfig.max_num_query = 10 14 | 15 | EpisodeDescriptionConfig.min_examples_in_class = 0 16 | EpisodeDescriptionConfig.max_support_set_size = 500 17 | EpisodeDescriptionConfig.max_support_size_contrib_per_class = 100 18 | EpisodeDescriptionConfig.min_log_weight = -0.69314718055994529 # np.log(0.5) 19 | EpisodeDescriptionConfig.max_log_weight = 0.69314718055994529 # np.log(2) 20 | EpisodeDescriptionConfig.ignore_dag_ontology = False 21 | EpisodeDescriptionConfig.ignore_bilevel_ontology = False 22 | EpisodeDescriptionConfig.ignore_hierarchy_probability = 0.0 23 | 24 | # By default don't use SimCLR Episodes. 25 | EpisodeDescriptionConfig.simclr_episode_fraction = 0.0 26 | 27 | # Other default values for the data pipeline. 28 | DataConfig.image_height = 84 29 | DataConfig.shuffle_buffer_size = 1000 30 | DataConfig.read_buffer_size_bytes = 1048576 # 1 MB (1024**2) 31 | DataConfig.num_prefetch = 64 32 | 33 | process_dumped_episode.support_decoder = @support/ImageDecoder() 34 | process_episode.support_decoder = @support/ImageDecoder() 35 | support/ImageDecoder.data_augmentation = @support/DataAugmentation() 36 | support/DataAugmentation.enable_jitter = True 37 | support/DataAugmentation.jitter_amount = 0 38 | support/DataAugmentation.enable_gaussian_noise = True 39 | support/DataAugmentation.gaussian_noise_std = 0.0 40 | 41 | process_dumped_episode.query_decoder = @query/ImageDecoder() 42 | process_episode.query_decoder = @query/ImageDecoder() 43 | query/ImageDecoder.data_augmentation = @query/DataAugmentation() 44 | query/DataAugmentation.enable_jitter = False 45 | query/DataAugmentation.jitter_amount = 0 46 | query/DataAugmentation.enable_gaussian_noise = False 47 | query/DataAugmentation.gaussian_noise_std = 0.0 48 | 49 | 50 | ImageDecoder.skip_tfexample_decoding = True 51 | -------------------------------------------------------------------------------- /meta_dataset/learn/gin/setups/fixed_way_and_shot.gin: -------------------------------------------------------------------------------- 1 | EpisodeDescriptionConfig.num_ways = 5 2 | EpisodeDescriptionConfig.num_support = 5 3 | EpisodeDescriptionConfig.num_query = 15 4 | -------------------------------------------------------------------------------- /meta_dataset/learn/gin/setups/imagenet.gin: -------------------------------------------------------------------------------- 1 | benchmark.train_datasets = 'ilsvrc_2012' 2 | benchmark.eval_datasets = 'ilsvrc_2012' 3 | include 'meta_dataset/learn/gin/setups/data_config.gin' 4 | include 'meta_dataset/learn/gin/setups/trainer_config.gin' 5 | include 'meta_dataset/learn/gin/setups/variable_way_and_shot.gin' 6 | Trainer.data_config = @DataConfig() 7 | Trainer.train_episode_config = @train/EpisodeDescriptionConfig() 8 | Trainer.eval_episode_config = @EpisodeDescriptionConfig() 9 | -------------------------------------------------------------------------------- /meta_dataset/learn/gin/setups/imagenet_v2.gin: -------------------------------------------------------------------------------- 1 | benchmark.train_datasets = 'ilsvrc_2012_v2' 2 | benchmark.eval_datasets = 'fungi,aircraft,quickdraw,omniglot,mscoco,cu_birds,dtd' 3 | include 'meta_dataset/learn/gin/setups/data_config.gin' 4 | include 'meta_dataset/learn/gin/setups/trainer_config.gin' 5 | include 'meta_dataset/learn/gin/setups/variable_way_and_shot.gin' 6 | Trainer.data_config = @DataConfig() 7 | Trainer.train_episode_config = @train/EpisodeDescriptionConfig() 8 | Trainer.eval_episode_config = @EpisodeDescriptionConfig() 9 | -------------------------------------------------------------------------------- /meta_dataset/learn/gin/setups/mini_imagenet.gin: -------------------------------------------------------------------------------- 1 | benchmark.train_datasets = 'mini_imagenet' 2 | benchmark.eval_datasets = 'mini_imagenet' 3 | include 'meta_dataset/learn/gin/setups/data_config.gin' 4 | include 'meta_dataset/learn/gin/setups/trainer_config.gin' 5 | include 'meta_dataset/learn/gin/setups/fixed_way_and_shot.gin' 6 | Trainer.data_config = @DataConfig() 7 | Trainer.train_episode_config = @train/EpisodeDescriptionConfig() 8 | Trainer.eval_episode_config = @EpisodeDescriptionConfig() 9 | DataConfig.image_height = 84 10 | -------------------------------------------------------------------------------- /meta_dataset/learn/gin/setups/sur_data_config.gin: -------------------------------------------------------------------------------- 1 | import meta_dataset.data.sur_decoder 2 | import meta_dataset.data.config 3 | import meta_dataset.data.decoder 4 | 5 | # Other default values for the data pipeline. 6 | DataConfig.image_height = 84 7 | DataConfig.shuffle_buffer_size = 1000 8 | DataConfig.read_buffer_size_bytes = 1048576 # 1 MB (1024**2) 9 | DataConfig.num_prefetch = 400 10 | SURImageDecoder.image_size = 84 11 | 12 | # If we decode features then change the lines below to use FeatureDecoder. 13 | process_episode.support_decoder = @support/SURImageDecoder() 14 | support/SURImageDecoder.data_augmentation = @support/SURDataAugmentation() 15 | support/SURDataAugmentation.enable_jitter = True 16 | support/SURDataAugmentation.jitter_amount = 0 17 | support/SURDataAugmentation.enable_gaussian_noise = True 18 | support/SURDataAugmentation.gaussian_noise_std = 0.0 19 | support/SURDataAugmentation.enable_random_flip = False 20 | support/SURDataAugmentation.enable_random_brightness = False 21 | support/SURDataAugmentation.random_brightness_delta = 0 22 | support/SURDataAugmentation.enable_random_contrast = False 23 | support/SURDataAugmentation.random_contrast_delta = 0 24 | support/SURDataAugmentation.enable_random_hue = False 25 | support/SURDataAugmentation.random_hue_delta = 0 26 | support/SURDataAugmentation.enable_random_saturation = False 27 | support/SURDataAugmentation.random_saturation_delta = 0 28 | 29 | process_episode.query_decoder = @query/SURImageDecoder() 30 | query/SURImageDecoder.data_augmentation = @query/SURDataAugmentation() 31 | query/SURDataAugmentation.enable_jitter = False 32 | query/SURDataAugmentation.jitter_amount = 0 33 | query/SURDataAugmentation.enable_gaussian_noise = False 34 | query/SURDataAugmentation.gaussian_noise_std = 0.0 35 | query/SURDataAugmentation.enable_random_flip = False 36 | query/SURDataAugmentation.enable_random_brightness = False 37 | query/SURDataAugmentation.random_brightness_delta = 0 38 | query/SURDataAugmentation.enable_random_contrast = False 39 | query/SURDataAugmentation.random_contrast_delta = 0 40 | query/SURDataAugmentation.enable_random_hue = False 41 | query/SURDataAugmentation.random_hue_delta = 0 42 | query/SURDataAugmentation.enable_random_saturation = False 43 | query/SURDataAugmentation.random_saturation_delta = 0 44 | 45 | process_batch.batch_decoder = @batch/SURImageDecoder() 46 | batch/SURImageDecoder.data_augmentation = @batch/SURDataAugmentation() 47 | batch/SURDataAugmentation.enable_jitter = True 48 | batch/SURDataAugmentation.jitter_amount = 8 49 | batch/SURDataAugmentation.enable_gaussian_noise = True 50 | batch/SURDataAugmentation.gaussian_noise_std = 0.0 51 | batch/SURDataAugmentation.enable_random_flip = False 52 | batch/SURDataAugmentation.enable_random_brightness = True 53 | batch/SURDataAugmentation.random_brightness_delta = 0.125 54 | batch/SURDataAugmentation.enable_random_contrast = True 55 | batch/SURDataAugmentation.random_contrast_delta = 0.2 56 | batch/SURDataAugmentation.enable_random_hue = True 57 | batch/SURDataAugmentation.random_hue_delta = 0.03 58 | batch/SURDataAugmentation.enable_random_saturation = True 59 | batch/SURDataAugmentation.random_saturation_delta = 0.2 60 | 61 | # For dumped episodes: 62 | process_dumped_episode.support_decoder = @support/ImageDecoder() 63 | support/ImageDecoder.data_augmentation = @support/DataAugmentation() 64 | support/DataAugmentation.enable_jitter = True 65 | support/DataAugmentation.jitter_amount = 0 66 | support/DataAugmentation.enable_gaussian_noise = True 67 | support/DataAugmentation.gaussian_noise_std = 0.0 68 | process_dumped_episode.query_decoder = @query/ImageDecoder() 69 | query/ImageDecoder.data_augmentation = @query/DataAugmentation() 70 | query/DataAugmentation.enable_jitter = False 71 | query/DataAugmentation.jitter_amount = 0 72 | query/DataAugmentation.enable_gaussian_noise = False 73 | query/DataAugmentation.gaussian_noise_std = 0.0 74 | -------------------------------------------------------------------------------- /meta_dataset/learn/gin/setups/trainer_config.gin: -------------------------------------------------------------------------------- 1 | Trainer.num_updates = 75000 2 | Trainer.batch_size = 256 # Only applicable to non-episodic models. 3 | Trainer.num_eval_episodes = 600 4 | Trainer.checkpoint_every = 500 5 | Trainer.validate_every = 500 6 | Trainer.log_every = 100 7 | Trainer.distribute = False 8 | # Enable TensorFlow optimizations. It can add a few minutes to the first 9 | # calls to session.run(), but decrease memory usage. 10 | Trainer.enable_tf_optimizations = True 11 | 12 | Learner.transductive_batch_norm = False 13 | Learner.backprop_through_moments = True 14 | 15 | 16 | -------------------------------------------------------------------------------- /meta_dataset/learn/gin/setups/trainer_config_debug.gin: -------------------------------------------------------------------------------- 1 | Trainer.num_updates = 100 2 | Trainer.batch_size = 8 # Only applicable to non-episodic models. 3 | Trainer.num_eval_episodes = 10 4 | Trainer.checkpoint_every = 10 5 | Trainer.validate_every = 5 6 | Trainer.log_every = 1 7 | Learner.transductive_batch_norm = False 8 | Learner.backprop_through_moments = True 9 | -------------------------------------------------------------------------------- /meta_dataset/learn/gin/setups/trainer_config_flute.gin: -------------------------------------------------------------------------------- 1 | Trainer_flute.num_updates = 640000 2 | Trainer_flute.batch_size = 16 3 | Trainer_flute.num_eval_episodes = 600 4 | Trainer_flute.checkpoint_every = 2000 5 | Trainer_flute.validate_every = 1000 6 | Trainer_flute.log_every = 100 7 | Trainer_flute.decay_learning_rate = True 8 | Trainer_flute.checkpoint_to_restore = '' 9 | Trainer_flute.dataset_classifier_to_restore = '' 10 | Trainer_flute.learning_rate = 0.01 11 | Trainer_flute.decay_every = 5000 12 | Trainer_flute.decay_rate = 0.5 13 | Trainer_flute.learn_rate_scheduler = 'cosine_decay_restarts' 14 | Trainer_flute.optimizer_type = 'momentum' 15 | Trainer_flute.meta_batch_size = 8 16 | Trainer_flute.sample_half_from_imagenet = True 17 | 18 | Trainer_flute.distribute = False 19 | # Enable TensorFlow optimizations. It can add a few minutes to the first 20 | # calls to session.run(), but decrease memory usage. 21 | Trainer_flute.enable_tf_optimizations = True 22 | Trainer_flute.normalized_gradient_descent = False 23 | Trainer_flute.pretrained_source = '' 24 | 25 | Trainer_flute.data_config = @DataConfig() 26 | Trainer_flute.train_episode_config = @train/EpisodeDescriptionConfig() 27 | Trainer_flute.eval_episode_config = @EpisodeDescriptionConfig() 28 | -------------------------------------------------------------------------------- /meta_dataset/learn/gin/setups/variable_way_and_shot.gin: -------------------------------------------------------------------------------- 1 | EpisodeDescriptionConfig.num_ways = None 2 | EpisodeDescriptionConfig.num_support = None 3 | EpisodeDescriptionConfig.num_query = None 4 | -------------------------------------------------------------------------------- /meta_dataset/learners/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Meta-Dataset Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Module containing (meta-)Learners.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | from meta_dataset.learners.base import BatchLearner 23 | from meta_dataset.learners.base import EpisodicLearner 24 | from meta_dataset.learners.base import Learner 25 | from meta_dataset.learners.baseline_learners import BaselineLearner 26 | from meta_dataset.learners.metric_learners import MatchingNetworkLearner 27 | from meta_dataset.learners.metric_learners import MetricLearner 28 | from meta_dataset.learners.metric_learners import PrototypicalNetworkLearner 29 | from meta_dataset.learners.metric_learners import RelationNetworkLearner 30 | from meta_dataset.learners.optimization_learners import BaselineFinetuneLearner 31 | from meta_dataset.learners.optimization_learners import FLUTEFiLMLearner 32 | from meta_dataset.learners.optimization_learners import MAMLLearner 33 | from meta_dataset.learners.optimization_learners import OptimizationLearner 34 | 35 | __all__ = [ 36 | 'BaselineFinetuneLearner', 37 | 'BatchLearner', 38 | 'BaselineLearner', 39 | 'EpisodicLearner', 40 | 'Learner', 41 | 'MAMLLearner', 42 | 'MatchingNetworkLearner', 43 | 'MetricLearner', 44 | 'OptimizationLearner', 45 | 'PrototypicalNetworkLearner', 46 | 'RelationNetworkLearner', 47 | 'FLUTEFiLMLearner', 48 | ] 49 | -------------------------------------------------------------------------------- /meta_dataset/learners/base.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Meta-Dataset Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Abstract learners.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import gin.tf 23 | import tensorflow.compat.v1 as tf 24 | 25 | 26 | @gin.configurable 27 | class Learner(object): 28 | """A Learner.""" 29 | 30 | def __init__( 31 | self, 32 | is_training, 33 | logit_dim, 34 | transductive_batch_norm, 35 | backprop_through_moments, 36 | embedding_fn, 37 | input_shape, 38 | ): 39 | """Initializes a Learner. 40 | 41 | Note that Gin configuration of subclasses of `Learner` will override any 42 | corresponding Gin configurations of `Learner`, since parameters are passed 43 | to the `Learner` base class's constructor (See 44 | https://github.com/google/gin-config/blob/master/README.md) for more 45 | details). 46 | 47 | Args: 48 | is_training: Whether the `Learner` is in training (as opposed to 49 | evaluation) mode. 50 | logit_dim: An integer; the maximum dimensionality of output predictions or 51 | a list of ints, as required for each subclass of Learner. 52 | transductive_batch_norm: Whether to batch-normalize in the transductive 53 | setting, where means and variances for normalization are computed from 54 | each of the support and query sets (rather than using the support set 55 | statistics for normalization of both the support and query set). 56 | backprop_through_moments: Whether to allow gradients to flow through the 57 | support set moments; only applies to non-transductive batch norm. 58 | embedding_fn: A function that embeds examples. 59 | input_shape: A Tensor corresponding to `[batch_size] + example_shape`. 60 | """ 61 | self.is_training = is_training 62 | self.logit_dim = logit_dim 63 | self.transductive_batch_norm = transductive_batch_norm 64 | self.backprop_through_moments = backprop_through_moments 65 | self.embedding_fn = embedding_fn 66 | self.input_shape = input_shape 67 | 68 | def build(self): 69 | """Additional build functionality for subclasses of `Learner`.""" 70 | 71 | def compute_regularizer(self): 72 | """Computes a regularizer, independent of the data.""" 73 | return tf.reduce_sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)) 74 | 75 | def compute_loss(self, onehot_labels, predictions): 76 | """Computes the CE loss of `predictions` with respect to `onehot_labels`. 77 | 78 | Args: 79 | onehot_labels: A `tf.Tensor` containing the the class labels; each vector 80 | along the (last) class dimension should represent a valid probability 81 | distribution. 82 | predictions: A `tf.Tensor` containing the the class predictions, 83 | interpreted as unnormalized log probabilities. 84 | 85 | Returns: 86 | A `tf.Tensor` representing the loss per example. 87 | """ 88 | cross_entropy_loss = tf.losses.softmax_cross_entropy( 89 | onehot_labels=onehot_labels, 90 | logits=predictions, 91 | reduction=tf.losses.Reduction.NONE) 92 | return cross_entropy_loss 93 | 94 | def compute_accuracy(self, onehot_labels, predictions): 95 | """Computes the accuracy of `predictions` with respect to `onehot_labels`. 96 | 97 | Args: 98 | onehot_labels: A `tf.Tensor` containing the the class labels; each vector 99 | along the (last) class dimension is expected to contain only a single 100 | `1`. 101 | predictions: A `tf.Tensor` containing the the class predictions 102 | represented as unnormalized log probabilities. 103 | 104 | Returns: 105 | A `tf.Tensor` of ones and zeros representing the correctness of 106 | individual predictions; use `tf.reduce_mean(...)` to obtain the average 107 | accuracy. 108 | """ 109 | correct = tf.equal(tf.argmax(onehot_labels, -1), tf.argmax(predictions, -1)) 110 | return tf.cast(correct, tf.float32) 111 | 112 | def forward_pass(self, data): 113 | """Returns the (query if episodic) logits.""" 114 | raise NotImplementedError('Abstract method.') 115 | 116 | 117 | class EpisodicLearner(Learner): 118 | """An episodic learner.""" 119 | 120 | pass 121 | 122 | 123 | class BatchLearner(Learner): 124 | """A batch learner.""" 125 | 126 | pass 127 | -------------------------------------------------------------------------------- /meta_dataset/learners/baseline_learners_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Meta-Dataset Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for `meta_dataset.learners.baseline_learners`.""" 17 | 18 | import gin.tf 19 | 20 | from meta_dataset.learners import base_test 21 | from meta_dataset.learners import baseline_learners 22 | 23 | import tensorflow.compat.v1 as tf 24 | tf.disable_eager_execution() 25 | 26 | BASELINE_ARGS = { 27 | **base_test.VALID_LEARNER_INIT_ARGS, 28 | 'cosine_logits_multiplier': 1.0, 29 | 'knn_in_fc': False, 30 | 'knn_distance': 'l2', 31 | } 32 | 33 | gin.bind_parameter('linear_classifier.weight_decay', 0.01) 34 | 35 | 36 | class BaselineTest(base_test.TestBatchLearner): 37 | learner_cls = baseline_learners.BaselineLearner 38 | learner_kwargs = { 39 | **BASELINE_ARGS, 40 | 'cosine_classifier': False, 41 | 'use_weight_norm': False, 42 | } 43 | 44 | 45 | class WeightNormalizedBaselineTest(base_test.TestBatchLearner): 46 | learner_cls = baseline_learners.BaselineLearner 47 | learner_kwargs = { 48 | **BASELINE_ARGS, 49 | 'cosine_classifier': False, 50 | 'use_weight_norm': True, 51 | } 52 | 53 | 54 | class CosineClassifierBaselineTest(base_test.TestBatchLearner): 55 | learner_cls = baseline_learners.BaselineLearner 56 | learner_kwargs = { 57 | **BASELINE_ARGS, 58 | 'cosine_classifier': True, 59 | 'use_weight_norm': False, 60 | } 61 | 62 | 63 | class WeightNormalizedCosineClassifierBaselineTest(base_test.TestBatchLearner): 64 | learner_cls = baseline_learners.BaselineLearner 65 | learner_kwargs = { 66 | **BASELINE_ARGS, 67 | 'cosine_classifier': True, 68 | 'use_weight_norm': True, 69 | } 70 | 71 | 72 | if __name__ == '__main__': 73 | tf.test.main() 74 | -------------------------------------------------------------------------------- /meta_dataset/learners/experimental/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Meta-Dataset Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Module containing (meta-)Learners.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | from meta_dataset.learners.experimental.base import ExperimentalBatchLearner 23 | from meta_dataset.learners.experimental.base import ExperimentalEpisodicLearner 24 | from meta_dataset.learners.experimental.base import ExperimentalLearner 25 | from meta_dataset.learners.experimental.optimization_learners import ANIL 26 | from meta_dataset.learners.experimental.optimization_learners import ExperimentalOptimizationLearner 27 | from meta_dataset.learners.experimental.optimization_learners import HeadAndBackboneLearner 28 | from meta_dataset.learners.experimental.optimization_learners import MAML 29 | 30 | __all__ = [ 31 | 'ExperimentalBatchLearner', 32 | 'ExperimentalEpisodicLearner', 33 | 'ExperimentalLearner', 34 | 'ExperimentalOptimizationLearner', 35 | 'HeadAndBackboneLearner', 36 | 'MAML', 37 | 'ANIL', 38 | ] 39 | -------------------------------------------------------------------------------- /meta_dataset/learners/experimental/base.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Meta-Dataset Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Abstract experimental learners that use `ReparameterizableModule`s.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import gin.tf 23 | 24 | from meta_dataset.learners import base as learner_base 25 | import tensorflow as tf 26 | 27 | 28 | class NotBuiltError(RuntimeError): 29 | 30 | def __init__(self): 31 | super(NotBuiltError, self).__init__( 32 | 'The `build` method of `ExperimentalLearner` must be called before ' 33 | 'accessing its variables.') 34 | 35 | 36 | def class_specific_data(onehot_labels, data, num_classes, axis=0): 37 | # TODO(eringrant): Deal with case of no data for a class in [1...num_classes]. 38 | data_shape = [s for i, s in enumerate(data.shape) if i != axis] 39 | labels = tf.argmax(onehot_labels, axis=-1) 40 | class_idx = [tf.where(tf.equal(labels, i)) for i in range(num_classes)] 41 | return [ 42 | tf.reshape(tf.gather(data, idx, axis=axis), [-1] + data_shape) 43 | for idx in class_idx 44 | ] 45 | 46 | 47 | @gin.configurable 48 | class ExperimentalLearner(learner_base.Learner): 49 | """An experimental learner.""" 50 | 51 | def __init__(self, **kwargs): 52 | """Constructs an `ExperimentalLearner`. 53 | 54 | Args: 55 | **kwargs: Keyword arguments common to all `Learner`s. 56 | 57 | Raises: 58 | ValueError: If the `embedding_fn` provided is not an instance of 59 | `tf.Module`. 60 | """ 61 | super(ExperimentalLearner, self).__init__(**kwargs) 62 | 63 | if not isinstance(self.embedding_fn, tf.Module): 64 | raise ValueError('The `embedding_fn` provided to `ExperimentalLearner`s ' 65 | 'must be an instance of `tf.Module`.') 66 | 67 | self._built = False 68 | 69 | def compute_regularizer(self, onehot_labels, predictions): 70 | """Computes a regularizer, maybe using `predictions` and `onehot_labels`.""" 71 | del onehot_labels 72 | del predictions 73 | return tf.reduce_sum(input_tensor=self.embedding_fn.losses) 74 | 75 | def build(self): 76 | """Instantiate the parameters belonging to this `ExperimentalLearner`.""" 77 | if not self.embedding_fn.built: 78 | self.embedding_fn.build([None] + self.input_shape) 79 | self.embedding_shape = self.embedding_fn.compute_output_shape( 80 | [None] + self.input_shape) 81 | self._built = True 82 | 83 | @property 84 | def variables(self): 85 | """Returns a list of this `ExperimentalLearner`'s variables.""" 86 | if not self._built: 87 | raise NotBuiltError 88 | return self.embedding_fn.variables 89 | 90 | @property 91 | def trainable_variables(self): 92 | """Returns a list of this `ExperimentalLearner`'s trainable variables.""" 93 | if not self._built: 94 | raise NotBuiltError 95 | return self.embedding_fn.trainable_variables 96 | 97 | 98 | class ExperimentalEpisodicLearner(ExperimentalLearner, 99 | learner_base.EpisodicLearner): 100 | """An experimental episodic learner.""" 101 | 102 | pass 103 | 104 | 105 | class ExperimentalBatchLearner(ExperimentalLearner, learner_base.BatchLearner): 106 | """An experimental batch learner.""" 107 | 108 | pass 109 | -------------------------------------------------------------------------------- /meta_dataset/learners/experimental/metric_learners_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Meta-Dataset Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for meta_dataset.learners.experimental.metric_learners.""" 17 | 18 | from meta_dataset.learners import base_test 19 | from meta_dataset.learners.experimental import metric_learners 20 | from meta_dataset.models.experimental import reparameterizable_backbones 21 | import tensorflow.compat.v1 as tf 22 | 23 | tf.compat.v1.disable_eager_execution() 24 | tf.compat.v1.experimental.output_all_intermediates(True) 25 | 26 | metric_learner_kwargs = { 27 | 'backprop_through_moments': True, 28 | 'transductive_batch_norm': True, 29 | 'input_shape': [84, 84, 3], 30 | 'logit_dim': 5, 31 | 'is_training': True, 32 | 'distance_metric': metric_learners.euclidean_distance, 33 | } 34 | 35 | 36 | class PrototypicalNetworkTest(base_test.TestEpisodicLearner): 37 | learner_cls = metric_learners.PrototypicalNetwork 38 | learner_kwargs = metric_learner_kwargs 39 | 40 | 41 | class MatchingNetworkTest(base_test.TestEpisodicLearner): 42 | learner_cls = metric_learners.MatchingNetwork 43 | learner_kwargs = metric_learner_kwargs 44 | 45 | 46 | class RelationNetworkTest(base_test.TestEpisodicLearner): 47 | learner_cls = metric_learners.RelationNetwork 48 | learner_kwargs = metric_learner_kwargs 49 | 50 | def set_up_learner(self): 51 | """Set up a `reparameterizable_backbones.RelationNetConvNet` backbone.""" 52 | learner_kwargs = self.learner_kwargs 53 | learner_kwargs['embedding_fn'] = ( 54 | reparameterizable_backbones.RelationNetConvNet(keep_spatial_dims=True)) 55 | data = self.random_data() 56 | learner = self.learner_cls(**learner_kwargs) 57 | return data, learner 58 | 59 | 60 | if __name__ == '__main__': 61 | tf.test.main() 62 | -------------------------------------------------------------------------------- /meta_dataset/learners/experimental/optimization_learners_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Meta-Dataset Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for meta_dataset.learners.experimental.metric_learners.""" 17 | 18 | import gin.tf 19 | from meta_dataset.learners import base_test 20 | from meta_dataset.learners.experimental import optimization_learners 21 | import tensorflow.compat.v1 as tf 22 | 23 | tf.compat.v1.disable_eager_execution() 24 | tf.compat.v1.experimental.output_all_intermediates(True) 25 | 26 | 27 | def mock_sgd(): 28 | 29 | def init(x0): 30 | return x0 31 | 32 | def update(i, grad, state): 33 | del i 34 | x = state 35 | return x - 0.01 * grad 36 | 37 | def get_params(state): 38 | x = state 39 | return x 40 | 41 | return init, update, get_params 42 | 43 | 44 | optimization_learner_kwargs = { 45 | 'backprop_through_moments': True, 46 | 'input_shape': [84, 84, 3], 47 | 'logit_dim': 5, 48 | 'is_training': True, 49 | 'update_fn': mock_sgd, 50 | 'additional_evaluation_update_steps': 5, 51 | 'clip_grad_norm': 10.0, 52 | 'num_update_steps': 5, 53 | } 54 | 55 | 56 | class FirstOrderMAMLTest(base_test.TestEpisodicLearner): 57 | learner_cls = optimization_learners.MAML 58 | learner_kwargs = dict( 59 | **optimization_learner_kwargs, **{ 60 | 'transductive_batch_norm': False, 61 | 'proto_maml_fc_layer_init': False, 62 | 'zero_fc_layer_init': False, 63 | 'first_order': False, 64 | 'adapt_batch_norm': True, 65 | }) 66 | 67 | 68 | class VanillaMAMLTest(base_test.TestEpisodicLearner): 69 | learner_cls = optimization_learners.MAML 70 | learner_kwargs = dict( 71 | **optimization_learner_kwargs, **{ 72 | 'transductive_batch_norm': True, 73 | 'proto_maml_fc_layer_init': False, 74 | 'zero_fc_layer_init': False, 75 | 'first_order': False, 76 | 'adapt_batch_norm': False, 77 | }) 78 | 79 | 80 | class ProtoMAMLTest(base_test.TestEpisodicLearner): 81 | 82 | def setUp(self): 83 | super().setUp() 84 | gin.bind_parameter('proto_maml_fc_layer_init_fn.prototype_multiplier', 1.0) 85 | 86 | def tearDown(self): 87 | gin.clear_config() 88 | super().tearDown() 89 | 90 | learner_cls = optimization_learners.MAML 91 | learner_kwargs = dict( 92 | **optimization_learner_kwargs, **{ 93 | 'transductive_batch_norm': False, 94 | 'proto_maml_fc_layer_init': True, 95 | 'zero_fc_layer_init': False, 96 | 'first_order': False, 97 | 'adapt_batch_norm': True, 98 | }) 99 | 100 | 101 | class FirstOrderANILTest(base_test.TestEpisodicLearner): 102 | learner_cls = optimization_learners.ANIL 103 | learner_kwargs = dict( 104 | **optimization_learner_kwargs, **{ 105 | 'transductive_batch_norm': False, 106 | 'proto_maml_fc_layer_init': False, 107 | 'zero_fc_layer_init': False, 108 | 'first_order': False, 109 | 'adapt_batch_norm': True, 110 | }) 111 | 112 | 113 | class VanillaANILTest(base_test.TestEpisodicLearner): 114 | learner_cls = optimization_learners.ANIL 115 | learner_kwargs = dict( 116 | **optimization_learner_kwargs, **{ 117 | 'transductive_batch_norm': True, 118 | 'proto_maml_fc_layer_init': False, 119 | 'zero_fc_layer_init': False, 120 | 'first_order': False, 121 | 'adapt_batch_norm': False, 122 | }) 123 | 124 | 125 | class ProtoANILTest(base_test.TestEpisodicLearner): 126 | 127 | def setUp(self): 128 | super().setUp() 129 | gin.bind_parameter('proto_maml_fc_layer_init_fn.prototype_multiplier', 1.0) 130 | 131 | def tearDown(self): 132 | gin.clear_config() 133 | super().tearDown() 134 | 135 | learner_cls = optimization_learners.ANIL 136 | learner_kwargs = dict( 137 | **optimization_learner_kwargs, **{ 138 | 'transductive_batch_norm': False, 139 | 'proto_maml_fc_layer_init': True, 140 | 'zero_fc_layer_init': False, 141 | 'first_order': False, 142 | 'adapt_batch_norm': True, 143 | }) 144 | 145 | 146 | if __name__ == '__main__': 147 | tf.test.main() 148 | -------------------------------------------------------------------------------- /meta_dataset/learners/metric_learners_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Meta-Dataset Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for `meta_dataset.learners.metric_learners`.""" 17 | 18 | import gin 19 | 20 | from meta_dataset.learners import base_test 21 | from meta_dataset.learners import metric_learners 22 | 23 | import tensorflow.compat.v1 as tf 24 | tf.disable_eager_execution() 25 | 26 | metric_learner_kwargs = { 27 | **base_test.VALID_LEARNER_INIT_ARGS, 28 | } 29 | 30 | gin.bind_parameter('relationnet_convnet.weight_decay', 0.01) 31 | gin.bind_parameter('relation_module.weight_decay', 0.01) 32 | 33 | 34 | class CosineMatchingNetworkLearnerTest(base_test.TestEpisodicLearner): 35 | learner_cls = metric_learners.MatchingNetworkLearner 36 | learner_kwargs = { 37 | **metric_learner_kwargs, 38 | 'exact_cosine_distance': False, 39 | } 40 | 41 | 42 | class ExactCosineMatchingNetworkLearnerTest(base_test.TestEpisodicLearner): 43 | learner_cls = metric_learners.MatchingNetworkLearner 44 | learner_kwargs = { 45 | **metric_learner_kwargs, 46 | 'exact_cosine_distance': True, 47 | } 48 | 49 | 50 | class PrototypicalNetworkLearnerTest(base_test.TestEpisodicLearner): 51 | learner_cls = metric_learners.PrototypicalNetworkLearner 52 | learner_kwargs = { 53 | **metric_learner_kwargs, 54 | } 55 | 56 | 57 | class RelationNetworkLearnerTest(base_test.TestEpisodicLearner): 58 | learner_cls = metric_learners.RelationNetworkLearner 59 | learner_kwargs = { 60 | **metric_learner_kwargs, 61 | } 62 | 63 | 64 | if __name__ == '__main__': 65 | tf.test.main() 66 | -------------------------------------------------------------------------------- /meta_dataset/learners/optimization_learners_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Meta-Dataset Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for `meta_dataset.learners.optimization_learners`.""" 17 | 18 | import gin.tf 19 | 20 | from meta_dataset.learners import base_test 21 | from meta_dataset.learners import optimization_learners 22 | 23 | import tensorflow.compat.v1 as tf 24 | tf.disable_eager_execution() 25 | 26 | BASELINE_FINETUNE_ARGS = { 27 | **base_test.VALID_LEARNER_INIT_ARGS, 28 | 'knn_in_fc': False, 29 | 'knn_distance': False, 30 | 'cosine_classifier': False, 31 | 'cosine_logits_multiplier': 1.0, 32 | 'use_weight_norm': False, 33 | 'is_training': False, 34 | } 35 | 36 | gin.bind_parameter('MAMLLearner.classifier_weight_decay', 0.01) 37 | 38 | 39 | class BaselineFinetuneTest(): 40 | 41 | def testLearnerConvergence(self): 42 | # `BaselineFinetuneLearner` differs from `BaselineLearner` only at 43 | # evaluation time. 44 | pass 45 | 46 | def testLearnerImprovement(self): 47 | # `BaselineFinetuneLearner` differs from `BaselineLearner` only at 48 | # evaluation time. 49 | pass 50 | 51 | 52 | class BaselineFinetuneAllLayersAdamTest(BaselineFinetuneTest, 53 | base_test.TestEpisodicLearner): 54 | learner_cls = optimization_learners.BaselineFinetuneLearner 55 | learner_kwargs = { 56 | **BASELINE_FINETUNE_ARGS, 57 | 'num_finetune_steps': 5, 58 | 'finetune_lr': 0.01, 59 | 'finetune_all_layers': True, 60 | 'finetune_with_adam': True, 61 | } 62 | 63 | 64 | class BaselineFinetuneAllLayersGDTest(BaselineFinetuneTest, 65 | base_test.TestEpisodicLearner): 66 | learner_cls = optimization_learners.BaselineFinetuneLearner 67 | learner_kwargs = { 68 | **BASELINE_FINETUNE_ARGS, 69 | 'num_finetune_steps': 5, 70 | 'finetune_lr': 0.1, 71 | 'finetune_all_layers': True, 72 | 'finetune_with_adam': False, 73 | } 74 | 75 | 76 | class BaselineFinetuneLastLayerGDTest(BaselineFinetuneTest, 77 | base_test.TestEpisodicLearner): 78 | learner_cls = optimization_learners.BaselineFinetuneLearner 79 | learner_kwargs = { 80 | **BASELINE_FINETUNE_ARGS, 81 | 'num_finetune_steps': 10, 82 | 'finetune_lr': 0.1, 83 | 'finetune_all_layers': False, 84 | 'finetune_with_adam': False, 85 | } 86 | 87 | 88 | MAML_KWARGS = { 89 | **base_test.VALID_LEARNER_INIT_ARGS, 90 | 'additional_evaluation_update_steps': 91 | 0, 92 | 'first_order': 93 | True, 94 | 'adapt_batch_norm': 95 | True, 96 | } 97 | 98 | 99 | class MAMLLearnerTest(base_test.TestEpisodicLearner): 100 | learner_cls = optimization_learners.MAMLLearner 101 | learner_kwargs = { 102 | **MAML_KWARGS, 103 | 'num_update_steps': 5, 104 | 'alpha': 0.01, 105 | 'zero_fc_layer': False, 106 | 'proto_maml_fc_layer_init': False, 107 | } 108 | 109 | 110 | class ZeroInitMAMLLearnerTest(base_test.TestEpisodicLearner): 111 | learner_cls = optimization_learners.MAMLLearner 112 | learner_kwargs = { 113 | **MAML_KWARGS, 114 | 'num_update_steps': 10, 115 | 'alpha': 0.01, 116 | 'zero_fc_layer': True, 117 | 'proto_maml_fc_layer_init': False, 118 | } 119 | 120 | 121 | class ProtoMAMLLearnerTest(base_test.TestEpisodicLearner): 122 | learner_cls = optimization_learners.MAMLLearner 123 | learner_kwargs = { 124 | **MAML_KWARGS, 125 | 'num_update_steps': 5, 126 | 'alpha': 0.01, 127 | 'zero_fc_layer': False, 128 | 'proto_maml_fc_layer_init': True, 129 | } 130 | 131 | 132 | if __name__ == '__main__': 133 | tf.test.main() 134 | -------------------------------------------------------------------------------- /meta_dataset/models/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Meta-Dataset Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | 17 | -------------------------------------------------------------------------------- /meta_dataset/models/experimental/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Meta-Dataset Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | 17 | -------------------------------------------------------------------------------- /meta_dataset/test_utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Meta-Dataset Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Testing utilities for meta-dataset. 17 | 18 | This module includes functions that inspect class APIs in order to recursively 19 | enumerate arguments to a method of the class as well as arguments passed 20 | implicitly (via *args and/or **kwargs) to the overriden method of any ancestor 21 | classes. These functions are useful for concisely defining tests for classes 22 | that take potentially-overlapping but non-identical arguments. 23 | 24 | For examples, see 25 | `meta_dataset.models.experimental.reparameterizable_base_test`. 26 | """ 27 | 28 | from __future__ import absolute_import 29 | from __future__ import division 30 | from __future__ import print_function 31 | 32 | import functools 33 | import inspect 34 | import itertools 35 | import six 36 | 37 | 38 | def get_argspec(fn): 39 | """Return the `ArgSpec` namespace for this (potentially wrapped) `fn`.""" 40 | while hasattr(fn, '__wrapped__'): 41 | fn = fn.__wrapped__ 42 | if six.PY3: 43 | return inspect.getfullargspec(fn) 44 | else: 45 | try: 46 | return inspect.getargspec(fn) # pylint: disable=deprecated-method 47 | except TypeError: 48 | # Cannot inspect C variables (https://stackoverflow.com/a/7628130), so 49 | # return an empty `ArgSpec`. 50 | return inspect.ArgSpec([], None, None, ()) 51 | 52 | 53 | def get_inherited_args(cls, cls_method): 54 | """Return arguments to `cls_method` of `cls` and all parents of `cls`.""" 55 | # TODO(eringrant): This slicing operation will not work for static methods 56 | # (for which the first argument is not the object instance). 57 | args = get_argspec(getattr(cls, cls_method)).args[1:] 58 | 59 | # Inspect all parent classes. 60 | for parent_cls in cls.__mro__: 61 | if cls_method in parent_cls.__dict__: 62 | args += get_argspec(getattr(parent_cls, cls_method)).args[1:] 63 | 64 | return set(args) 65 | 66 | 67 | get_inherited_init_args = functools.partial( 68 | get_inherited_args, cls_method='__init__') 69 | get_inherited_call_args = functools.partial( 70 | get_inherited_args, cls_method='__call__') 71 | 72 | 73 | def get_valid_kwargs(module_cls, valid_module_init_args, 74 | valid_module_call_args): 75 | """Return all valid kwarg configurations for `module_cls`.""" 76 | init_args = get_inherited_init_args(module_cls) 77 | call_args = get_inherited_call_args(module_cls) 78 | 79 | valid_init_kwargs = ( 80 | zip(itertools.repeat(init_arg), valid_module_init_args[init_arg]) 81 | for init_arg in init_args 82 | if init_arg in valid_module_init_args) 83 | valid_call_kwargs = ( 84 | zip(itertools.repeat(call_arg), valid_module_call_args[call_arg]) 85 | for call_arg in call_args 86 | if call_arg in valid_module_call_args) 87 | 88 | init_kwargs_combos = ( 89 | dict(combo) for combo in itertools.product(*valid_init_kwargs)) 90 | call_kwargs_combos = ( 91 | dict(combo) for combo in itertools.product(*valid_call_kwargs)) 92 | 93 | return itertools.product(init_kwargs_combos, call_kwargs_combos) 94 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py >= 0.7 2 | etils >= 0.4.0 3 | gin-config>=0.1.2 4 | numpy>=1.13.3 5 | scipy>=1.0.0 6 | six >= 1.10 7 | tensorflow-gpu 8 | sklearn 9 | tensorflow_probability <= 0.7 10 | tf-models-official 11 | tensorflow-datasets 12 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | """Install script for setuptools.""" 2 | 3 | from distutils import cmd 4 | import os 5 | import urllib.request 6 | 7 | from setuptools import find_packages 8 | from setuptools import setup 9 | from setuptools.command import install 10 | 11 | SIMCLR_DIR = 'simclr' 12 | DATA_UTILS_URL = 'https://raw.githubusercontent.com/google-research/simclr/master/data_util.py' 13 | 14 | 15 | class DownloadSimCLRAugmentationCommand(cmd.Command): 16 | """Downloads SimCLR data_utils.py as it's not built into an egg.""" 17 | description = __doc__ 18 | user_options = [] 19 | 20 | def initialize_options(self): 21 | pass 22 | 23 | def finalize_options(self): 24 | pass 25 | 26 | def run(self): 27 | build_cmd = self.get_finalized_command('build') 28 | dist_root = os.path.realpath(build_cmd.build_lib) 29 | output_dir = os.path.join(dist_root, SIMCLR_DIR) 30 | if not os.path.exists(output_dir): 31 | os.makedirs(output_dir) 32 | output_path = os.path.join(output_dir, 'data_util.py') 33 | downloader = urllib.request.URLopener() 34 | downloader.retrieve(DATA_UTILS_URL, output_path) 35 | 36 | 37 | class InstallCommand(install.install): 38 | 39 | def run(self): 40 | self.run_command('simclr_download') 41 | install.install.run(self) 42 | 43 | setup( 44 | name='meta_dataset', 45 | version='0.2.0', 46 | description='A benchmark for few-shot classification.', 47 | author='Google LLC', 48 | license='Apache License, Version 2.0', 49 | python_requires='>=2.7, <3.10', 50 | packages=find_packages(), 51 | include_package_data=True, 52 | install_requires=[ 53 | 'absl-py>=0.7.0', 54 | 'etils>=0.4.0', 55 | 'gin-config>=0.1.2', 56 | 'numpy>=1.13.3', 57 | 'scipy>=1.0.0', 58 | 'setuptools', 59 | 'six>=1.12.0', 60 | # Note that this will install tf 2.0, even though this is a tf 1.0 61 | # project. This is necessary because we rely on augmentation from 62 | # tf-models-official that wasn't added until after tf 2.0 was released. 63 | 'tensorflow-gpu', 64 | 'sklearn', 65 | 'tensorflow_probability<=0.7', 66 | 'tf-models-official', 67 | 'tensorflow-datasets', 68 | ], 69 | cmdclass={ 70 | 'simclr_download': DownloadSimCLRAugmentationCommand, 71 | 'install': InstallCommand, 72 | }) 73 | --------------------------------------------------------------------------------