├── .gitignore
├── AUTHORS
├── CONTRIBUTING.md
├── Intro_to_Metadataset.ipynb
├── LICENSE
├── Leaderboard.ipynb
├── MANIFEST.in
├── README.md
├── VTAB-plus-MD.md
├── doc
├── dataset_conversion.md
└── reproducing_best_results.md
├── meta_dataset
├── __init__.py
├── analysis
│ ├── __init__.py
│ ├── select_best_model.py
│ └── transferability
│ │ ├── __init__.py
│ │ ├── criticality.py
│ │ └── leep.py
├── analyze.py
├── data
│ ├── __init__.py
│ ├── config.py
│ ├── dataset_spec.py
│ ├── decoder.py
│ ├── decoder_test.py
│ ├── dump_and_read_episodes_test.py
│ ├── dump_episodes.py
│ ├── imagenet_specification.py
│ ├── imagenet_specification_test.py
│ ├── imagenet_stats.py
│ ├── learning_spec.py
│ ├── pipeline.py
│ ├── pipeline_test.py
│ ├── providers.py
│ ├── read_episodes.py
│ ├── reader.py
│ ├── reader_test.py
│ ├── sampling.py
│ ├── sampling_test.py
│ ├── sur_decoder.py
│ ├── test_utils.py
│ ├── tfds
│ │ ├── README.md
│ │ ├── __init__.py
│ │ ├── api.py
│ │ ├── constants.py
│ │ ├── example_generators.py
│ │ ├── md_tfds.py
│ │ ├── meta_dataset_test.py
│ │ └── test_utils.py
│ └── utils.py
├── dataset_conversion
│ ├── ImageNet_CUBirds_duplicates.txt
│ ├── ImageNet_Caltech101_duplicates.txt
│ ├── ImageNet_Caltech256_duplicates.txt
│ ├── TrafficSign_labels.txt
│ ├── VggFlower_labels.txt
│ ├── __init__.py
│ ├── check_dataset_consistency.py
│ ├── convert_datasets_to_records.py
│ ├── dataset_specs
│ │ ├── aircraft_dataset_spec.json
│ │ ├── cu_birds_dataset_spec.json
│ │ ├── dtd_dataset_spec.json
│ │ ├── fungi_dataset_spec.json
│ │ ├── ilsvrc_2012_dataset_spec.json
│ │ ├── ilsvrc_2012_num_leaf_images.json
│ │ ├── ilsvrc_2012_v2_dataset_spec.json
│ │ ├── ilsvrc_2012_v2_num_leaf_images.json
│ │ ├── mscoco_dataset_spec.json
│ │ ├── omniglot_dataset_spec.json
│ │ ├── quickdraw_dataset_spec.json
│ │ ├── traffic_sign_dataset_spec.json
│ │ └── vgg_flower_dataset_spec.json
│ ├── dataset_to_records.py
│ └── splits
│ │ ├── aircraft_splits.json
│ │ ├── cu_birds_splits.json
│ │ ├── dtd_splits.json
│ │ ├── fungi_splits.json
│ │ ├── mscoco_splits.json
│ │ ├── quickdraw_splits.json
│ │ ├── traffic_sign_splits.json
│ │ └── vgg_flower_splits.json
├── distribute_utils.py
├── learn
│ └── gin
│ │ ├── best
│ │ ├── baseline_all.gin
│ │ ├── baseline_imagenet.gin
│ │ ├── baselinefinetune_all.gin
│ │ ├── baselinefinetune_cosine_mini_imagenet_fiveshot.gin
│ │ ├── baselinefinetune_cosine_mini_imagenet_oneshot.gin
│ │ ├── baselinefinetune_imagenet.gin
│ │ ├── baselinefinetune_mini_imagenet_fiveshot.gin
│ │ ├── baselinefinetune_mini_imagenet_oneshot.gin
│ │ ├── flute.gin
│ │ ├── flute_init_from_imagenet.gin
│ │ ├── flute_init_from_scratch.gin
│ │ ├── maml_all.gin
│ │ ├── maml_all_from_scratch.gin
│ │ ├── maml_imagenet.gin
│ │ ├── maml_imagenet_from_scratch.gin
│ │ ├── maml_init_with_proto_all.gin
│ │ ├── maml_init_with_proto_all_from_scratch.gin
│ │ ├── maml_init_with_proto_imagenet.gin
│ │ ├── maml_init_with_proto_imagenet_from_scratch.gin
│ │ ├── maml_init_with_proto_inference_all.gin
│ │ ├── maml_init_with_proto_inference_imagenet.gin
│ │ ├── matching_all.gin
│ │ ├── matching_all_from_scratch.gin
│ │ ├── matching_imagenet.gin
│ │ ├── matching_imagenet_from_scratch.gin
│ │ ├── matching_inference_all.gin
│ │ ├── matching_inference_imagenet.gin
│ │ ├── pretrain_imagenet_convnet.gin
│ │ ├── pretrain_imagenet_resnet.gin
│ │ ├── pretrain_imagenet_wide_resnet.gin
│ │ ├── pretrained_convnet.gin
│ │ ├── pretrained_resnet.gin
│ │ ├── pretrained_wide_resnet.gin
│ │ ├── prototypical_all.gin
│ │ ├── prototypical_all_from_scratch.gin
│ │ ├── prototypical_imagenet.gin
│ │ ├── prototypical_imagenet_from_scratch.gin
│ │ ├── prototypical_inference_all.gin
│ │ ├── prototypical_inference_imagenet.gin
│ │ ├── relationnet_all.gin
│ │ ├── relationnet_all_from_scratch.gin
│ │ ├── relationnet_imagenet.gin
│ │ ├── relationnet_imagenet_from_scratch.gin
│ │ └── relationnet_mini_imagenet_fiveshot.gin
│ │ ├── default
│ │ ├── baseline_all.gin
│ │ ├── baseline_cosine_imagenet.gin
│ │ ├── baseline_imagenet.gin
│ │ ├── baseline_mini_imagenet_fiveshot.gin
│ │ ├── baseline_mini_imagenet_oneshot.gin
│ │ ├── baselinefinetune_all.gin
│ │ ├── baselinefinetune_cosine_imagenet.gin
│ │ ├── baselinefinetune_cosine_mini_imagenet_fiveshot.gin
│ │ ├── baselinefinetune_cosine_mini_imagenet_oneshot.gin
│ │ ├── baselinefinetune_imagenet.gin
│ │ ├── baselinefinetune_mini_imagenet_fiveshot.gin
│ │ ├── baselinefinetune_mini_imagenet_oneshot.gin
│ │ ├── crosstransformer_imagenet.gin
│ │ ├── crosstransformer_simclreps_imagenet.gin
│ │ ├── debug_proto_fungi.gin
│ │ ├── debug_proto_mini_imagenet.gin
│ │ ├── flute.gin
│ │ ├── flute_dataset_classifier.gin
│ │ ├── maml_all.gin
│ │ ├── maml_imagenet.gin
│ │ ├── maml_init_with_proto_all.gin
│ │ ├── maml_init_with_proto_imagenet.gin
│ │ ├── maml_init_with_proto_inference_all.gin
│ │ ├── maml_init_with_proto_inference_imagenet.gin
│ │ ├── maml_init_with_proto_mini_imagenet_fiveshot.gin
│ │ ├── maml_init_with_proto_mini_imagenet_oneshot.gin
│ │ ├── maml_mini_imagenet_fiveshot.gin
│ │ ├── maml_mini_imagenet_oneshot.gin
│ │ ├── maml_protonet_all.gin
│ │ ├── matching_all.gin
│ │ ├── matching_imagenet.gin
│ │ ├── matching_inference_all.gin
│ │ ├── matching_inference_imagenet.gin
│ │ ├── matching_mini_imagenet_fiveshot.gin
│ │ ├── matching_mini_imagenet_oneshot.gin
│ │ ├── pretrained_resnet34_224.gin
│ │ ├── prototypical_all.gin
│ │ ├── prototypical_imagenet.gin
│ │ ├── prototypical_inference_all.gin
│ │ ├── prototypical_inference_imagenet.gin
│ │ ├── prototypical_mini_imagenet_fiveshot.gin
│ │ ├── prototypical_mini_imagenet_oneshot.gin
│ │ ├── relationnet_all.gin
│ │ ├── relationnet_imagenet.gin
│ │ ├── relationnet_mini_imagenet_fiveshot.gin
│ │ ├── relationnet_mini_imagenet_oneshot.gin
│ │ └── resnet34_stride16.gin
│ │ ├── learners
│ │ ├── baseline_config.gin
│ │ ├── baseline_cosine_config.gin
│ │ ├── baselinefinetune_config.gin
│ │ ├── baselinefinetune_cosine_config.gin
│ │ ├── crosstransformer_config.gin
│ │ ├── learner_config.gin
│ │ ├── maml_config.gin
│ │ ├── maml_init_with_proto_config.gin
│ │ ├── matching_config.gin
│ │ ├── prototypical_config.gin
│ │ └── relationnet_config.gin
│ │ ├── metadataset_v2
│ │ ├── baseline_all.gin
│ │ ├── baseline_imagenet.gin
│ │ ├── baselinefinetune_all.gin
│ │ ├── baselinefinetune_imagenet.gin
│ │ ├── best
│ │ │ ├── baseline_imagenet.gin
│ │ │ ├── baselinefinetune_all.gin
│ │ │ ├── baselinefinetune_imagenet.gin
│ │ │ ├── maml_init_with_proto_all.gin
│ │ │ ├── maml_init_with_proto_imagenet.gin
│ │ │ ├── prototypical_all.gin
│ │ │ └── prototypical_imagenet.gin
│ │ ├── maml_init_with_proto_all.gin
│ │ ├── maml_init_with_proto_imagenet.gin
│ │ └── prototypical_all.gin
│ │ └── setups
│ │ ├── all.gin
│ │ ├── all_datasets.gin
│ │ ├── all_v2.gin
│ │ ├── data_config.gin
│ │ ├── data_config_common.gin
│ │ ├── data_config_feature.gin
│ │ ├── data_config_flute.gin
│ │ ├── data_config_no_decoder.gin
│ │ ├── data_config_string.gin
│ │ ├── data_config_tfds.gin
│ │ ├── fixed_way_and_shot.gin
│ │ ├── imagenet.gin
│ │ ├── imagenet_v2.gin
│ │ ├── mini_imagenet.gin
│ │ ├── sur_data_config.gin
│ │ ├── trainer_config.gin
│ │ ├── trainer_config_debug.gin
│ │ ├── trainer_config_flute.gin
│ │ └── variable_way_and_shot.gin
├── learners
│ ├── __init__.py
│ ├── base.py
│ ├── base_test.py
│ ├── baseline_learners.py
│ ├── baseline_learners_test.py
│ ├── experimental
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── metric_learners.py
│ │ ├── metric_learners_test.py
│ │ ├── optimization_learners.py
│ │ └── optimization_learners_test.py
│ ├── metric_learners.py
│ ├── metric_learners_test.py
│ ├── optimization_learners.py
│ └── optimization_learners_test.py
├── models
│ ├── __init__.py
│ ├── experimental
│ │ ├── __init__.py
│ │ ├── parameter_adapter.py
│ │ ├── reparameterizable_backbones.py
│ │ ├── reparameterizable_backbones_test.py
│ │ ├── reparameterizable_base.py
│ │ ├── reparameterizable_base_test.py
│ │ └── reparameterizable_distributions.py
│ ├── functional_backbones.py
│ └── functional_classifiers.py
├── test_utils.py
├── train.py
├── train_flute.py
├── trainer.py
├── trainer_flute.py
└── trainer_test.py
├── requirements.txt
└── setup.py
/.gitignore:
--------------------------------------------------------------------------------
1 | .DS_Store
2 | *.pyc
3 | __pycache__
4 | *.swp
5 |
--------------------------------------------------------------------------------
/AUTHORS:
--------------------------------------------------------------------------------
1 | # This is the list of the Meta-Dataset Project authors for copyright purposes.
2 | #
3 | # This does not necessarily list everyone who has contributed code, since in
4 | # some cases, their employer may be the copyright holder. To see the full list
5 | # of contributors, see the revision history in source control.
6 | Google LLC
7 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # How to Contribute
2 |
3 | We'd love to accept your patches and contributions to this project. There are
4 | just a few small guidelines you need to follow.
5 |
6 | ## Contributor License Agreement
7 |
8 | Contributions to this project must be accompanied by a Contributor License
9 | Agreement. You (or your employer) retain the copyright to your contribution;
10 | this simply gives us permission to use and redistribute your contributions as
11 | part of the project. Head over to to see
12 | your current agreements on file or to sign a new one.
13 |
14 | You generally only need to submit a CLA once, so if you've already submitted one
15 | (even if it was for a different project), you probably don't need to do it
16 | again.
17 |
18 | ## Code reviews
19 |
20 | All submissions, including submissions by project members, require review. We
21 | use GitHub pull requests for this purpose. Consult
22 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more
23 | information on using pull requests.
24 |
25 | ## Community Guidelines
26 |
27 | This project follows
28 | [Google's Open Source Community Guidelines](https://opensource.google.com/conduct/).
29 |
--------------------------------------------------------------------------------
/MANIFEST.in:
--------------------------------------------------------------------------------
1 | recursive-include meta_dataset/learn *.gin
2 | recursive-include meta_dataset/dataset_conversion *.json
--------------------------------------------------------------------------------
/meta_dataset/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Meta-Dataset Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 |
17 |
--------------------------------------------------------------------------------
/meta_dataset/analysis/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Meta-Dataset Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 |
17 |
--------------------------------------------------------------------------------
/meta_dataset/analysis/transferability/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Meta-Dataset Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 |
17 |
--------------------------------------------------------------------------------
/meta_dataset/analysis/transferability/leep.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Meta-Dataset Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | # pyformat: disable
17 | """Implementation of the log expected empirical prediction measure (LEEP).
18 |
19 | #### References
20 |
21 | [1]: Nguyen, Cuong V., Tal Hassner, Matthias Seeger, and Cedric Archambeau.
22 | LEEP: A new measure to evaluate transferability of learned representations.
23 | In _Proceedings of 37th International Conference on Machine Learning_,
24 | 2020.
25 | https://arxiv.org/abs/2002.12462
26 | """
27 | # pyformat: enable
28 |
29 | from __future__ import absolute_import
30 | from __future__ import division
31 | from __future__ import print_function
32 |
33 | import numpy as np
34 | import tensorflow as tf
35 |
36 |
37 | def compute_leep(source_predictions,
38 | target_onehot_labels):
39 | """Compute LEEP using `source_predictions` and `target_onehot_labels`.
40 |
41 | Args:
42 | source_predictions: Predictions over the source label set for a batch of
43 | input data from the target dataset.
44 | target_onehot_labels: One-hot labels from the target label set.
45 |
46 | Returns:
47 | The log expected empirical prediction measure (LEEP) on the given batch.
48 | """
49 | batch_size = tf.cast(source_predictions.shape[0], tf.float32)
50 | if target_onehot_labels.shape[0] != batch_size:
51 | raise ValueError('`source_predictions` and `target_onehot_labels` must '
52 | 'represent the same number of examples.')
53 |
54 | # Normalize the predicted probabilities in log space.
55 | source_predictions -= tf.math.reduce_logsumexp(
56 | source_predictions, axis=1, keepdims=True)
57 |
58 | # p_model(y, z | x_i) for source labels `z`, target labels `y`, and each `i`.
59 | log_per_example_full_joint = tf.einsum('iz,iy->iyz', source_predictions,
60 | target_onehot_labels)
61 |
62 | # Workaround for one-hot indexing in log space.
63 | log_per_example_full_joint = (
64 | tf.where(
65 | tf.expand_dims(target_onehot_labels != 0, axis=2),
66 | log_per_example_full_joint,
67 | tf.ones_like(log_per_example_full_joint) * -np.inf))
68 |
69 | # Re-normalize the joint.
70 | log_per_example_full_joint -= tf.math.reduce_logsumexp(
71 | log_per_example_full_joint, axis=(1, 2), keepdims=True)
72 |
73 | # Average examples-wise probabilities (Eq. (1) in the paper).
74 | log_full_joint = tf.math.reduce_logsumexp(
75 | log_per_example_full_joint, axis=0) - tf.math.log(batch_size)
76 |
77 | # p_model(z) for source labels `z`, marginalizing out target labels `y`.
78 | log_full_target_marginal = tf.math.reduce_logsumexp(
79 | source_predictions, axis=0) - tf.math.log(batch_size)
80 |
81 | # p_model(y | z) for source labels `y`, conditioning on target labels `z`.
82 | log_full_conditional = (
83 | log_full_joint - tf.expand_dims(log_full_target_marginal, axis=0))
84 |
85 | # p_model(y = y_i | z) for datapoint `i`.
86 | log_predicted_conditional = (
87 | tf.einsum('iy,yz->iz', target_onehot_labels, log_full_conditional))
88 |
89 | # p_model(y = y_i | z) * p_model(z) for datapoint `i`.
90 | log_predicted_joint = (log_predicted_conditional + source_predictions)
91 |
92 | # p_model(y = y_i) for datapoint `i`.
93 | log_predicted_marginal = tf.math.reduce_logsumexp(log_predicted_joint, axis=1)
94 |
95 | return tf.reduce_sum(log_predicted_marginal) / batch_size
96 |
--------------------------------------------------------------------------------
/meta_dataset/data/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Meta-Dataset Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Sub-module for reading data and assembling episodes."""
17 | # Whether datasets with example-level splits, or pools, are supported.
18 | # Currently, this is not implemented.
19 | POOL_SUPPORTED = False
20 |
--------------------------------------------------------------------------------
/meta_dataset/data/decoder_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Meta-Dataset Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for meta_dataset.data.decoder."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | from meta_dataset.data import decoder
23 | from meta_dataset.dataset_conversion import dataset_to_records
24 | import numpy as np
25 | import tensorflow.compat.v1 as tf
26 |
27 |
28 | class DecoderTest(tf.test.TestCase):
29 |
30 | def test_string_decoder(self):
31 | # Make random image.
32 | image_size = 32
33 | image = np.random.randint(
34 | low=0, high=255, size=[image_size, image_size, 3]).astype(np.ubyte)
35 |
36 | # Encode
37 | image_bytes = dataset_to_records.encode_image(image, image_format='PNG')
38 | label = np.zeros(1).astype(np.int64)
39 | image_example = dataset_to_records.make_example([
40 | ('image', 'bytes', [image_bytes]), ('label', 'int64', [label])
41 | ])
42 |
43 | # Decode
44 | string_decoder = decoder.StringDecoder()
45 | image_string = string_decoder(image_example)
46 | decoded_image = tf.image.decode_image(image_string)
47 | # Assert perfect reconstruction.
48 | with self.session(use_gpu=False) as sess:
49 | image_decoded = sess.run(decoded_image)
50 | self.assertAllClose(image, image_decoded)
51 |
52 | def test_image_decoder(self):
53 | # Make random image.
54 | image_size = 84
55 | image = np.random.randint(
56 | low=0, high=255, size=[image_size, image_size, 3]).astype(np.ubyte)
57 |
58 | # Encode
59 | image_bytes = dataset_to_records.encode_image(image, image_format='PNG')
60 | label = np.zeros(1).astype(np.int64)
61 | image_example = dataset_to_records.make_example([
62 | ('image', 'bytes', [image_bytes]), ('label', 'int64', [label])
63 | ])
64 |
65 | # Decode
66 | image_decoder = decoder.ImageDecoder(image_size=image_size)
67 | image_decoded = image_decoder(image_example)
68 | # Assert perfect reconstruction.
69 | with self.session(use_gpu=False) as sess:
70 | image_rec_numpy = sess.run(image_decoded)
71 | self.assertAllClose(2 * (image.astype(np.float32) / 255.0 - 0.5),
72 | image_rec_numpy)
73 |
74 | def test_feature_decoder(self):
75 | # Make random feature.
76 | feat_size = 64
77 | feat = np.random.randn(feat_size).astype(np.float32)
78 | label = np.zeros(1).astype(np.int64)
79 |
80 | # Encode
81 | feat_example = dataset_to_records.make_example([
82 | ('image/embedding', 'float32', feat),
83 | ('image/class/label', 'int64', [label]),
84 | ])
85 |
86 | # Decode
87 | feat_decoder = decoder.FeatureDecoder(feat_len=feat_size)
88 | feat_decoded = feat_decoder(feat_example)
89 |
90 | # Assert perfect reconstruction.
91 | with self.session(use_gpu=False) as sess:
92 | feat_rec_numpy = sess.run(feat_decoded)
93 | self.assertAllEqual(feat_rec_numpy, feat)
94 |
95 |
96 | if __name__ == '__main__':
97 | tf.test.main()
98 |
--------------------------------------------------------------------------------
/meta_dataset/data/dump_episodes.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Meta-Dataset Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | r"""Dumps Meta-Dataset episodes to disk as tfrecords files.
17 |
18 | Episodes are stored as a pair of `{episode_number}-train.tfrecords` and
19 | `{episode_number}-test.tfrecords` files, each of which contains serialized
20 | TFExample strings for the support and query set, respectively.
21 |
22 | python -m meta_dataset.data.dump_episodes \
23 | --gin_config=meta_dataset/learn/gin/setups/\
24 | data_config_string.gin --gin_config=meta_dataset/learn/gin/\
25 | setups/variable_way_and_shot.gin \
26 | --gin_bindings="DataConfig.num_prefetch="
27 | """
28 | import json
29 | import os
30 | from absl import app
31 | from absl import flags
32 | from absl import logging
33 |
34 | import gin
35 | from meta_dataset.data import config
36 | from meta_dataset.data import dataset_spec as dataset_spec_lib
37 | from meta_dataset.data import learning_spec
38 | from meta_dataset.data import pipeline
39 | from meta_dataset.data import utils
40 | import tensorflow.compat.v1 as tf
41 |
42 | tf.enable_eager_execution()
43 |
44 | flags.DEFINE_multi_string('gin_config', None,
45 | 'List of paths to the config files.')
46 | flags.DEFINE_multi_string('gin_bindings', None,
47 | 'List of Gin parameter bindings.')
48 | flags.DEFINE_string('output_dir', '/tmp/cached_episodes/',
49 | 'Root directory for saving episodes.')
50 | flags.DEFINE_integer('num_episodes', 600, 'Number of episodes to sample.')
51 | flags.DEFINE_string('dataset_name', 'omniglot', 'Dataset name to create '
52 | 'episodes from.')
53 | flags.DEFINE_enum_class('split', learning_spec.Split.TEST, learning_spec.Split,
54 | 'See learning_spec.Split for '
55 | 'allowed values.')
56 | flags.DEFINE_boolean(
57 | 'ignore_dag_ontology', False, 'If True the dag ontology'
58 | ' for Imagenet dataset is not used.')
59 | flags.DEFINE_boolean(
60 | 'ignore_bilevel_ontology', False, 'If True the bilevel'
61 | ' sampling for Omniglot dataset is not used.')
62 | tf.flags.DEFINE_string('records_root_dir', '',
63 | 'Root directory containing a subdirectory per dataset.')
64 | FLAGS = flags.FLAGS
65 |
66 |
67 | def main(unused_argv):
68 | logging.info(FLAGS.output_dir)
69 | tf.io.gfile.makedirs(FLAGS.output_dir)
70 | gin.parse_config_files_and_bindings(
71 | FLAGS.gin_config, FLAGS.gin_bindings, finalize_config=True)
72 | dataset_spec = dataset_spec_lib.load_dataset_spec(
73 | os.path.join(FLAGS.records_root_dir, FLAGS.dataset_name))
74 | data_config = config.DataConfig()
75 | episode_descr_config = config.EpisodeDescriptionConfig()
76 | use_dag_ontology = (
77 | FLAGS.dataset_name in ('ilsvrc_2012', 'ilsvrc_2012_v2') and
78 | not FLAGS.ignore_dag_ontology)
79 | use_bilevel_ontology = (
80 | FLAGS.dataset_name == 'omniglot' and not FLAGS.ignore_bilevel_ontology)
81 | data_pipeline = pipeline.make_one_source_episode_pipeline(
82 | dataset_spec,
83 | use_dag_ontology=use_dag_ontology,
84 | use_bilevel_ontology=use_bilevel_ontology,
85 | split=FLAGS.split,
86 | episode_descr_config=episode_descr_config,
87 | # TODO(evcu) Maybe set the following to 0 to prevent shuffling and check
88 | # reproducibility of dumping.
89 | shuffle_buffer_size=data_config.shuffle_buffer_size,
90 | read_buffer_size_bytes=data_config.read_buffer_size_bytes,
91 | num_prefetch=data_config.num_prefetch)
92 | dataset = data_pipeline.take(FLAGS.num_episodes)
93 |
94 | images_per_class_dict = {}
95 | # Ignoring dataset number since we are loading one dataset.
96 | for episode_number, (episode, _) in enumerate(dataset):
97 | logging.info('Dumping episode %d', episode_number)
98 | train_imgs, train_labels, _, test_imgs, test_labels, _ = episode
99 | path_train = utils.get_file_path(FLAGS.output_dir, episode_number, 'train')
100 | path_test = utils.get_file_path(FLAGS.output_dir, episode_number, 'test')
101 | utils.dump_as_tfrecord(path_train, train_imgs, train_labels)
102 | utils.dump_as_tfrecord(path_test, test_imgs, test_labels)
103 | images_per_class_dict[os.path.basename(path_train)] = (
104 | utils.get_label_counts(train_labels))
105 | images_per_class_dict[os.path.basename(path_test)] = (
106 | utils.get_label_counts(test_labels))
107 | info_path = utils.get_info_path(FLAGS.output_dir)
108 | with tf.io.gfile.GFile(info_path, 'w') as f:
109 | f.write(json.dumps(images_per_class_dict, indent=2))
110 |
111 |
112 | if __name__ == '__main__':
113 | app.run(main)
114 |
--------------------------------------------------------------------------------
/meta_dataset/data/learning_spec.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Meta-Dataset Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Interfaces for learning specifications."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | import collections
23 | import enum
24 |
25 |
26 | class Split(enum.Enum):
27 | """The possible data splits."""
28 | TRAIN = 0
29 | VALID = 1
30 | TEST = 2
31 |
32 |
33 | class BatchSpecification(
34 | collections.namedtuple('BatchSpecification', 'split, batch_size')):
35 | """The specification of an episode.
36 |
37 | Args:
38 | split: the Split from which to pick data.
39 | batch_size: an int, the number of (image, label) pairs in the batch.
40 | """
41 | pass
42 |
43 |
44 | class EpisodeSpecification(
45 | collections.namedtuple(
46 | 'EpisodeSpecification',
47 | 'split, num_classes, num_train_examples, num_test_examples')):
48 | """The specification of an episode.
49 |
50 | Args:
51 | split: A Split from which to pick data.
52 | num_classes: The number of classes in the episode, or None for variable.
53 | num_train_examples: The number of examples to use per class in the train
54 | phase, or None for variable.
55 | num_test_examples: the number of examples to use per class in the test
56 | phase, or None for variable.
57 | """
58 |
--------------------------------------------------------------------------------
/meta_dataset/data/pipeline_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Meta-Dataset Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | r"""Tests for meta_dataset.data.pipeline.
17 |
18 | """
19 |
20 | from __future__ import absolute_import
21 | from __future__ import division
22 | from __future__ import print_function
23 |
24 | import os
25 |
26 | from absl.testing import parameterized
27 | import gin
28 | from meta_dataset.data import config
29 | from meta_dataset.data import learning_spec
30 | from meta_dataset.data import pipeline
31 | from meta_dataset.data import test_utils
32 | from meta_dataset.data.dataset_spec import DatasetSpecification
33 | import numpy as np
34 | import tensorflow.compat.v1 as tf
35 |
36 |
37 | class PipelineTest(tf.test.TestCase, parameterized.TestCase):
38 |
39 | @parameterized.named_parameters(
40 | ('FeatureDecoder', 'feature',
41 | 'meta_dataset/learn/gin/setups/data_config_feature.gin'),
42 | ('NoDecoder', 'none',
43 | 'meta_dataset/learn/gin/setups/data_config_no_decoder.gin'
44 | ),
45 | )
46 | def test_make_multisource_episode_pipeline_feature(self, decoder_type,
47 | config_file_path):
48 |
49 | # Create some feature records and write them to a temp directory.
50 | feat_size = 64
51 | num_examples = 100
52 | num_classes = 10
53 | output_path = self.get_temp_dir()
54 | gin.parse_config_file(config_file_path)
55 |
56 | # 1-Write feature records to temp directory.
57 | self.rng = np.random.RandomState(0)
58 | class_features = []
59 | class_examples = []
60 | for class_id in range(num_classes):
61 | features = self.rng.randn(num_examples, feat_size).astype(np.float32)
62 | label = np.array(class_id).astype(np.int64)
63 | output_file = os.path.join(output_path, str(class_id) + '.tfrecords')
64 | examples = test_utils.write_feature_records(features, label, output_file)
65 | class_examples.append(examples)
66 | class_features.append(features)
67 | class_examples = np.stack(class_examples)
68 | class_features = np.stack(class_features)
69 |
70 | # 2-Read records back using multi-source pipeline.
71 | # DatasetSpecification to use in tests
72 | dataset_spec = DatasetSpecification(
73 | name=None,
74 | classes_per_split={
75 | learning_spec.Split.TRAIN: 5,
76 | learning_spec.Split.VALID: 2,
77 | learning_spec.Split.TEST: 3
78 | },
79 | images_per_class={i: num_examples for i in range(num_classes)},
80 | class_names=None,
81 | path=output_path,
82 | file_pattern='{}.tfrecords')
83 |
84 | # Duplicate the dataset to simulate reading from multiple datasets.
85 | use_bilevel_ontology_list = [False] * 2
86 | use_dag_ontology_list = [False] * 2
87 | all_dataset_specs = [dataset_spec] * 2
88 |
89 | fixed_ways_shots = config.EpisodeDescriptionConfig(
90 | num_query=5, num_support=5, num_ways=5)
91 |
92 | dataset_episodic = pipeline.make_multisource_episode_pipeline(
93 | dataset_spec_list=all_dataset_specs,
94 | use_dag_ontology_list=use_dag_ontology_list,
95 | use_bilevel_ontology_list=use_bilevel_ontology_list,
96 | episode_descr_config=fixed_ways_shots,
97 | split=learning_spec.Split.TRAIN,
98 | image_size=None)
99 |
100 | episode, _ = self.evaluate(
101 | dataset_episodic.make_one_shot_iterator().get_next())
102 |
103 | if decoder_type == 'feature':
104 | # 3-Check that support and query features are in class_features and have
105 | # the correct corresponding label.
106 | support_features, support_class_ids = episode[0], episode[2]
107 | query_features, query_class_ids = episode[3], episode[5]
108 |
109 | for feat, class_id in zip(
110 | list(support_features), list(support_class_ids)):
111 | abs_err = np.abs(np.sum(class_features - feat[None][None], axis=-1))
112 | # Make sure the feature is present in the original data.
113 | self.assertEqual(abs_err.min(), 0.0)
114 | found_class_id = np.where(abs_err == 0.0)[0][0]
115 | self.assertEqual(found_class_id, class_id)
116 |
117 | for feat, class_id in zip(list(query_features), list(query_class_ids)):
118 | abs_err = np.abs(np.sum(class_features - feat[None][None], axis=-1))
119 | # Make sure the feature is present in the original data.
120 | self.assertEqual(abs_err.min(), 0.0)
121 | found_class_id = np.where(abs_err == 0.0)[0][0]
122 | self.assertEqual(found_class_id, class_id)
123 |
124 | elif decoder_type == 'none':
125 | # 3-Check that support and query examples are in class_examples and have
126 | # the correct corresponding label.
127 |
128 | support_examples, support_class_ids = episode[0], episode[2]
129 | query_examples, query_class_ids = episode[3], episode[5]
130 |
131 | for example, class_id in zip(
132 | list(support_examples), list(support_class_ids)):
133 | found_class_id = np.where(class_examples == example)[0][0]
134 | self.assertEqual(found_class_id, class_id)
135 |
136 | for example, class_id in zip(list(query_examples), list(query_class_ids)):
137 | found_class_id = np.where(class_examples == example)[0][0]
138 | self.assertEqual(found_class_id, class_id)
139 |
140 |
141 | if __name__ == '__main__':
142 | tf.test.main()
143 |
--------------------------------------------------------------------------------
/meta_dataset/data/sur_decoder.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Meta-Dataset Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Module responsible for decoding images as in SUR (Dvornik et al).
17 |
18 | MIT License
19 |
20 | Copyright (c) 2021 Nikita Dvornik
21 |
22 | Permission is hereby granted, free of charge, to any person obtaining a copy
23 | of this software and associated documentation files (the "Software"), to deal
24 | in the Software without restriction, including without limitation the rights
25 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
26 | copies of the Software, and to permit persons to whom the Software is
27 | furnished to do so, subject to the following conditions:
28 |
29 | The above copyright notice and this permission notice shall be included in all
30 | copies or substantial portions of the Software.
31 |
32 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
33 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
34 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
35 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
36 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
37 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
38 | SOFTWARE.
39 | """
40 |
41 | import gin.tf
42 | import tensorflow.compat.v1 as tf
43 |
44 |
45 | @gin.configurable
46 | class SURDataAugmentation(object):
47 | """Configurations for performing data augmentation."""
48 |
49 | def __init__(self, enable_jitter, jitter_amount, enable_gaussian_noise,
50 | gaussian_noise_std, enable_random_flip, enable_random_brightness,
51 | random_brightness_delta, enable_random_contrast,
52 | random_contrast_delta, enable_random_hue, random_hue_delta,
53 | enable_random_saturation, random_saturation_delta):
54 | """Initializes a DataAugmentation."""
55 | self.enable_jitter = enable_jitter
56 | self.jitter_amount = jitter_amount
57 | self.enable_gaussian_noise = enable_gaussian_noise
58 | self.gaussian_noise_std = gaussian_noise_std
59 | self.enable_random_flip = enable_random_flip
60 | self.enable_random_brightness = enable_random_brightness
61 | self.random_brightness_delta = random_brightness_delta
62 | self.enable_random_contrast = enable_random_contrast
63 | self.random_contrast_delta = random_contrast_delta
64 | self.enable_random_hue = enable_random_hue
65 | self.random_hue_delta = random_hue_delta
66 | self.enable_random_saturation = enable_random_saturation
67 | self.random_saturation_delta = random_saturation_delta
68 |
69 |
70 | @gin.configurable
71 | class SURImageDecoder(object):
72 | """Image decoder."""
73 | out_type = tf.float32
74 |
75 | def __init__(self, image_size=None, data_augmentation=None):
76 | """Class constructor.
77 |
78 | Args:
79 | image_size: int, desired image size. The extracted image will be resized
80 | to `[image_size, image_size]`.
81 | data_augmentation: A DataAugmentation object with parameters for
82 | perturbing the images.
83 | """
84 |
85 | self.image_size = image_size
86 | self.data_augmentation = data_augmentation
87 |
88 | def __call__(self, example_string):
89 | """Processes a single example string.
90 |
91 | Extracts and processes the image, and ignores the label. We assume that the
92 | image has three channels.
93 | Args:
94 | example_string: str, an Example protocol buffer.
95 |
96 | Returns:
97 | image_rescaled: the image, resized to `image_size x image_size` and
98 | rescaled to [-1, 1]. Note that Gaussian data augmentation may cause values
99 | to go beyond this range.
100 | """
101 | image_string = tf.parse_single_example(
102 | example_string,
103 | features={
104 | 'image': tf.FixedLenFeature([], dtype=tf.string),
105 | 'label': tf.FixedLenFeature([], tf.int64)
106 | })['image']
107 | image_decoded = tf.image.decode_image(image_string, channels=3)
108 | image_decoded.set_shape([None, None, 3])
109 | image_resized = tf.image.resize_images(
110 | image_decoded, [self.image_size, self.image_size],
111 | method=tf.image.ResizeMethod.BILINEAR,
112 | align_corners=True)
113 | image = tf.cast(image_resized, tf.float32)
114 |
115 | if self.data_augmentation is not None:
116 | if self.data_augmentation.enable_random_brightness:
117 | delta = self.data_augmentation.random_brightness_delta
118 | image = tf.image.random_brightness(image, delta)
119 |
120 | if self.data_augmentation.enable_random_saturation:
121 | delta = self.data_augmentation.random_saturation_delta
122 | image = tf.image.random_saturation(image, 1 - delta, 1 + delta)
123 |
124 | if self.data_augmentation.enable_random_contrast:
125 | delta = self.data_augmentation.random_contrast_delta
126 | image = tf.image.random_contrast(image, 1 - delta, 1 + delta)
127 |
128 | if self.data_augmentation.enable_random_hue:
129 | delta = self.data_augmentation.random_hue_delta
130 | image = tf.image.random_hue(image, delta)
131 |
132 | if self.data_augmentation.enable_random_flip:
133 | image = tf.image.random_flip_left_right(image)
134 |
135 | image = 2 * (image / 255.0 - 0.5) # Rescale to [-1, 1].
136 |
137 | if self.data_augmentation is not None:
138 | if self.data_augmentation.enable_gaussian_noise:
139 | image = image + tf.random_normal(
140 | tf.shape(image)) * self.data_augmentation.gaussian_noise_std
141 |
142 | if self.data_augmentation.enable_jitter:
143 | j = self.data_augmentation.jitter_amount
144 | paddings = tf.constant([[j, j], [j, j], [0, 0]])
145 | image = tf.pad(image, paddings, 'REFLECT')
146 | image = tf.image.random_crop(image,
147 | [self.image_size, self.image_size, 3])
148 |
149 | return image
150 |
--------------------------------------------------------------------------------
/meta_dataset/data/test_utils.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Meta-Dataset Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Utility functions for input pipeline tests."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | import gin.tf
23 | from meta_dataset.data.dataset_spec import DatasetSpecification
24 | from meta_dataset.data.learning_spec import Split
25 | from meta_dataset.dataset_conversion import dataset_to_records
26 | import numpy as np
27 | import tensorflow.compat.v1 as tf
28 |
29 | # DatasetSpecification to use in tests
30 | DATASET_SPEC = DatasetSpecification(
31 | name=None,
32 | classes_per_split={
33 | Split.TRAIN: 15,
34 | Split.VALID: 5,
35 | Split.TEST: 10
36 | },
37 | images_per_class=dict(enumerate([10, 20, 30] * 10)),
38 | class_names=['%d' % i for i in range(30)],
39 | path=None,
40 | file_pattern='{}.tfrecords')
41 |
42 | # Define defaults for the input pipeline.
43 | MIN_WAYS = 5
44 | MAX_WAYS_UPPER_BOUND = 50
45 | MAX_NUM_QUERY = 10
46 | MAX_SUPPORT_SET_SIZE = 500
47 | MAX_SUPPORT_SIZE_CONTRIB_PER_CLASS = 100
48 | MIN_LOG_WEIGHT = np.log(0.5)
49 | MAX_LOG_WEIGHT = np.log(2)
50 |
51 |
52 | def set_episode_descr_config_defaults():
53 | """Sets default values for EpisodeDescriptionConfig using gin."""
54 | gin.parse_config('import meta_dataset.data.config')
55 |
56 | gin.bind_parameter('EpisodeDescriptionConfig.num_ways', None)
57 | gin.bind_parameter('EpisodeDescriptionConfig.num_support', None)
58 | gin.bind_parameter('EpisodeDescriptionConfig.num_query', None)
59 | gin.bind_parameter('EpisodeDescriptionConfig.min_ways', MIN_WAYS)
60 | gin.bind_parameter('EpisodeDescriptionConfig.max_ways_upper_bound',
61 | MAX_WAYS_UPPER_BOUND)
62 | gin.bind_parameter('EpisodeDescriptionConfig.max_num_query', MAX_NUM_QUERY)
63 | gin.bind_parameter('EpisodeDescriptionConfig.max_support_set_size',
64 | MAX_SUPPORT_SET_SIZE)
65 | gin.bind_parameter(
66 | 'EpisodeDescriptionConfig.max_support_size_contrib_per_class',
67 | MAX_SUPPORT_SIZE_CONTRIB_PER_CLASS)
68 | gin.bind_parameter('EpisodeDescriptionConfig.min_log_weight', MIN_LOG_WEIGHT)
69 | gin.bind_parameter('EpisodeDescriptionConfig.max_log_weight', MAX_LOG_WEIGHT)
70 | gin.bind_parameter('EpisodeDescriptionConfig.ignore_dag_ontology', False)
71 | gin.bind_parameter('EpisodeDescriptionConfig.ignore_bilevel_ontology', False)
72 | gin.bind_parameter('EpisodeDescriptionConfig.ignore_hierarchy_probability',
73 | 0.)
74 | gin.bind_parameter('EpisodeDescriptionConfig.simclr_episode_fraction', 0.)
75 |
76 | # Following is set in a different scope.
77 | gin.bind_parameter('none/EpisodeDescriptionConfig.min_ways', None)
78 | gin.bind_parameter('none/EpisodeDescriptionConfig.max_ways_upper_bound', None)
79 | gin.bind_parameter('none/EpisodeDescriptionConfig.max_num_query', None)
80 | gin.bind_parameter('none/EpisodeDescriptionConfig.max_support_set_size', None)
81 | gin.bind_parameter(
82 | 'none/EpisodeDescriptionConfig.max_support_size_contrib_per_class', None)
83 | gin.bind_parameter('none/EpisodeDescriptionConfig.min_log_weight', None)
84 | gin.bind_parameter('none/EpisodeDescriptionConfig.max_log_weight', None)
85 |
86 |
87 | def write_feature_records(features, label, output_path):
88 | """Creates a record file from features and labels.
89 |
90 | Args:
91 | features: An [n, m] numpy array of features.
92 | label: An integer, the label common to all records.
93 | output_path: A string specifying the location of the record.
94 |
95 | Returns:
96 | serialized_examples: list tf.Example protos written by the writer.
97 | """
98 | writer = tf.python_io.TFRecordWriter(output_path)
99 | serialized_examples = []
100 | for feat in list(features):
101 | # Write the example.
102 | serialized_example = dataset_to_records.make_example([
103 | ('image/embedding', 'float32', feat.tolist()),
104 | ('image/class/label', 'int64', [label])
105 | ])
106 | writer.write(serialized_example)
107 | serialized_examples.append(serialized_example)
108 |
109 | writer.close()
110 | return serialized_examples
111 |
--------------------------------------------------------------------------------
/meta_dataset/data/tfds/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Meta-Dataset Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Meta-Dataset data sources."""
17 | from .md_tfds import MetaDataset
18 |
--------------------------------------------------------------------------------
/meta_dataset/data/tfds/test_utils.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Meta-Dataset Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Test utility functions and constants."""
17 |
18 | import numpy as np
19 | import tensorflow as tf
20 |
21 | _SOURCES = {
22 | 'aircraft': {
23 | 'same_across_md_versions': True,
24 | 'meta_splits': ('train', 'valid', 'test'),
25 | 'num_classes': (70, 15, 15)
26 | },
27 | 'cu_birds': {
28 | 'same_across_md_versions': True,
29 | 'meta_splits': ('train', 'valid', 'test'),
30 | 'num_classes': (140, 30, 30)
31 | },
32 | 'dtd': {
33 | 'same_across_md_versions': True,
34 | 'meta_splits': ('train', 'valid', 'test'),
35 | 'num_classes': (33, 7, 7)
36 | },
37 | 'fungi': {
38 | 'same_across_md_versions': True,
39 | 'meta_splits': ('train', 'valid', 'test'),
40 | 'num_classes': (994, 200, 200)
41 | },
42 | 'ilsvrc_2012': {
43 | 'same_across_md_versions': False,
44 | 'meta_splits': {'v1': ('train', 'valid', 'test'), 'v2': ('train',)},
45 | 'num_classes': {'v1': (712, 158, 130), 'v2': (1000,)},
46 | 'remap_labels': {'v1': False, 'v2': True},
47 | 'md_source': {'v1': 'ilsvrc_2012', 'v2': 'ilsvrc_2012_v2'},
48 | },
49 | 'mscoco': {
50 | 'same_across_md_versions': True,
51 | 'meta_splits': ('valid', 'test'),
52 | 'num_classes': (40, 40)
53 | },
54 | 'omniglot': {
55 | 'same_across_md_versions': True,
56 | 'meta_splits': ('train', 'valid', 'test'),
57 | 'num_classes': (883, 81, 659)
58 | },
59 | 'quickdraw': {
60 | 'same_across_md_versions': True,
61 | 'meta_splits': ('train', 'valid', 'test'),
62 | 'num_classes': (241, 52, 52)
63 | },
64 | 'traffic_sign': {
65 | 'same_across_md_versions': True,
66 | 'meta_splits': ('test',), 'num_classes': (42,)
67 | },
68 | 'vgg_flower': {
69 | 'same_across_md_versions': False,
70 | 'meta_splits': {'v1': ('train', 'valid', 'test')},
71 | 'num_classes': {'v1': (71, 15, 16)},
72 | },
73 | }
74 |
75 |
76 | def make_class_dataset_comparison_test_cases():
77 | """Returns class dataset test cases for Meta-Dataset."""
78 | testcases = []
79 | for source, info in _SOURCES.items():
80 | meta_splits = ({'v1': info['meta_splits'], 'v2': info['meta_splits']}
81 | if info['same_across_md_versions']
82 | else info['meta_splits'])
83 | num_classes = ({'v1': info['num_classes'], 'v2': info['num_classes']}
84 | if info['same_across_md_versions']
85 | else info['num_classes'])
86 |
87 | for md_version in meta_splits:
88 | offsets = np.cumsum((0,) + num_classes[md_version][:-1])
89 | remap_labels = (info['remap_labels'][md_version] if 'remap_labels' in info
90 | else False)
91 | md_source = (info['md_source'][md_version] if 'md_source' in info
92 | else source)
93 | for meta_split, num_labels, offset in zip(
94 | meta_splits[md_version], num_classes[md_version], offsets):
95 | testcases.append((f'{source}_{md_version}_{meta_split}', source,
96 | md_source, md_version, meta_split, num_labels, offset,
97 | remap_labels))
98 |
99 | return testcases
100 |
101 |
102 | def parse_example(example_string):
103 | return tf.io.parse_single_example(
104 | example_string,
105 | features={'image': tf.io.FixedLenFeature([], dtype=tf.string),
106 | 'label': tf.io.FixedLenFeature([], tf.int64)})
107 |
--------------------------------------------------------------------------------
/meta_dataset/data/utils.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Meta-Dataset Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | r"""Utils for dumping Meta-Dataset episodes to disk as tfrecords files."""
17 | import collections
18 | import os
19 |
20 | from absl import logging
21 |
22 | from meta_dataset.dataset_conversion import dataset_to_records
23 | import tensorflow.compat.v1 as tf
24 |
25 | TRAIN_SUFFIX = 'train.tfrecords'
26 | TEST_SUFFIX = 'test.tfrecords'
27 | # For output file.
28 | FILE_NAME_TEMPLATE = 'episode-{:04d}-{}'
29 |
30 |
31 | def get_label_counts(labels):
32 | """Creates a JSON compatible dictionary of image per class counts."""
33 | # JSON does not support integer keys.
34 | counts = {str(k): v for k, v in collections.Counter(labels.numpy()).items()}
35 | return counts
36 |
37 |
38 | def dump_as_tfrecord(path, images, labels):
39 | logging.info('Dumping records to: %s', path)
40 | with tf.io.TFRecordWriter(path) as writer:
41 | for image, label in zip(images, labels):
42 | dataset_to_records.write_example(image.numpy(), label.numpy(), writer)
43 |
44 |
45 | def get_file_path(folder, idx, split):
46 | if split == 'train':
47 | suffix = TRAIN_SUFFIX
48 | elif split == 'test':
49 | suffix = TEST_SUFFIX
50 | else:
51 | raise ValueError('Split: %s, should be train or test.' % split)
52 | return os.path.join(folder, FILE_NAME_TEMPLATE.format(idx, suffix))
53 |
54 |
55 | def get_info_path(folder):
56 | return os.path.join(folder, 'images_per_class.json')
57 |
--------------------------------------------------------------------------------
/meta_dataset/dataset_conversion/ImageNet_CUBirds_duplicates.txt:
--------------------------------------------------------------------------------
1 | # List from https://gist.github.com/arunmallya/a6889f151483dcb348fa70523cb4f578
2 | ## File name in ImageNet # File name in CUBirds
3 | n01531178/n01531178_12730.JPEG # American_Goldfinch_0062_31921.jpg
4 | n01537544/n01537544_9540.JPEG # Indigo_Bunting_0063_11820.jpg
5 | n01580077/n01580077_4622.JPEG # Blue_Jay_0053_62744.jpg
6 | n01531178/n01531178_17834.JPEG # American_Goldfinch_0131_32911.jpg
7 | n01534433/n01534433_12777.JPEG # Dark_Eyed_Junco_0057_68650.jpg
8 | n01537544/n01537544_2126.JPEG # Indigo_Bunting_0051_12837.jpg
9 | n01534433/n01534433_9482.JPEG # Dark_Eyed_Junco_0102_67402.jpg
10 | n01531178/n01531178_14394.JPEG # American_Goldfinch_0012_32338.jpg
11 | n02058221/n02058221_16284.JPEG # Laysan_Albatross_0033_658.jpg
12 | n02058221/n02058221_6390.JPEG # Black_Footed_Albatross_0024_796089.jpg
13 | n01537544/n01537544_26009.JPEG # Indigo_Bunting_0072_14197.jpg
14 | n01833805/n01833805_10347.JPEG # Green_Violetear_0002_795699.jpg
15 | n02058221/n02058221_7776.JPEG # Black_Footed_Albatross_0033_796086.jpg
16 | n02058221/n02058221_9823.JPEG # Black_Footed_Albatross_0086_796062.jpg
17 | n01833805/n01833805_10306.JPEG # Anna_Hummingbird_0034_56614.jpg
18 | n01531178/n01531178_8622.JPEG # American_Goldfinch_0064_32142.jpg
19 | n01855032/n01855032_19954.JPEG # Red_Breasted_Merganser_0068_79203.jpg
20 | n01580077/n01580077_5373.JPEG # Blue_Jay_0033_62024.jpg
21 | n01833805/n01833805_1561.JPEG # Ruby_Throated_Hummingbird_0090_57411.jpg
22 | n01537544/n01537544_25021.JPEG # Indigo_Bunting_0071_11639.jpg
23 | n01855032/n01855032_300.JPEG # Red_Breasted_Merganser_0001_79199.jpg
24 | n01537544/n01537544_8418.JPEG # Indigo_Bunting_0060_14495.jpg
25 | n02058221/n02058221_10440.JPEG # Laysan_Albatross_0053_543.jpg
26 | n01531178/n01531178_1502.JPEG # American_Goldfinch_0018_32324.jpg
27 | n01855032/n01855032_2846.JPEG # Red_Breasted_Merganser_0034_79292.jpg
28 | n01847000/n01847000_9239.JPEG # Mallard_0067_77623.jpg
29 | n01855032/n01855032_8051.JPEG # Red_Breasted_Merganser_0083_79562.jpg
30 | n02058221/n02058221_16869.JPEG # Laysan_Albatross_0049_918.jpg
31 | n02058221/n02058221_29424.JPEG # Black_Footed_Albatross_0002_55.jpg
32 | n01855032/n01855032_1136.JPEG # Red_Breasted_Merganser_0012_79425.jpg
33 | n01534433/n01534433_8318.JPEG # Dark_Eyed_Junco_0031_66785.jpg
34 | n01534433/n01534433_22902.JPEG # Dark_Eyed_Junco_0037_66321.jpg
35 | n01537544/n01537544_10504.JPEG # Indigo_Bunting_0031_13300.jpg
36 | n01580077/n01580077_5633.JPEG # Blue_Jay_0049_63082.jpg
37 | n01534433/n01534433_11948.JPEG # Dark_Eyed_Junco_0111_66488.jpg
38 | n01537544/n01537544_16.JPEG # Indigo_Bunting_0010_13000.jpg
39 | n01855032/n01855032_6416.JPEG # Red_Breasted_Merganser_0004_79232.jpg
40 | n01855032/n01855032_1358.JPEG # Red_Breasted_Merganser_0045_79358.jpg
41 | n01833805/n01833805_18429.JPEG # Ruby_Throated_Hummingbird_0040_57982.jpg
42 | n01531178/n01531178_8148.JPEG # American_Goldfinch_0116_31943.jpg
43 | n01580077/n01580077_8031.JPEG # Blue_Jay_0068_61543.jpg
44 | n01537544/n01537544_2813.JPEG # Indigo_Bunting_0073_13933.jpg
45 | n01534433/n01534433_420.JPEG # Dark_Eyed_Junco_0104_67820.jpg
46 |
--------------------------------------------------------------------------------
/meta_dataset/dataset_conversion/ImageNet_Caltech101_duplicates.txt:
--------------------------------------------------------------------------------
1 | ## File name in ImageNet # File name in Caltech101. '_m' suffix indicates the mirrored image was a match.
2 | n04552348/n04552348_16200.JPEG # airplanes/image_0266.jpg
3 | n02363005/n02363005_3962.JPEG # beaver/image_0003_m.jpg
4 | n02363005/n02363005_3752.JPEG # beaver/image_0004.jpg
5 | n02363005/n02363005_143.JPEG # beaver/image_0005_m.jpg
6 | n02363005/n02363005_4158.JPEG # beaver/image_0006.jpg
7 | n02363005/n02363005_2918.JPEG # beaver/image_0014.jpg
8 | n02363005/n02363005_3752.JPEG # beaver/image_0022.jpg
9 | n02363005/n02363005_2733.JPEG # beaver/image_0027.jpg
10 | n02363005/n02363005_3366.JPEG # beaver/image_0030_m.jpg
11 | n02363005/n02363005_899.JPEG # beaver/image_0033_m.jpg
12 | n02363005/n02363005_6314.JPEG # beaver/image_0042.jpg
13 | n02363005/n02363005_1129.JPEG # beaver/image_0046.jpg
14 | n03991062/n03991062_4071.JPEG # bonsai/image_0037.jpg
15 | n02950826/n02950826_14812.JPEG # cannon/image_0011_m.jpg
16 | n04099969/n04099969_15758.JPEG # chair/image_0028_m.jpg
17 | n02125311/n02125311_8090.JPEG # cougar_body/image_0002_m.jpg
18 | n02125311/n02125311_12897.JPEG # cougar_body/image_0011.jpg
19 | n02125311/n02125311_4683.JPEG # cougar_body/image_0014_m.jpg
20 | n02125311/n02125311_8090.JPEG # cougar_body/image_0026_m.jpg
21 | n02125311/n02125311_6360.JPEG # cougar_body/image_0033_m.jpg
22 | n02125311/n02125311_8090.JPEG # cougar_body/image_0035_m.jpg
23 | n02125311/n02125311_43757.JPEG # cougar_body/image_0039.jpg
24 | n02125311/n02125311_21461.JPEG # cougar_body/image_0040_m.jpg
25 | n02125311/n02125311_518.JPEG # cougar_body/image_0046.jpg
26 | n02125311/n02125311_11820.JPEG # cougar_face/image_0056.jpg
27 | n01985128/n01985128_16.JPEG # crayfish/image_0002_m.jpg
28 | n01985128/n01985128_18494.JPEG # crayfish/image_0021_m.jpg
29 | n01985128/n01985128_9302.JPEG # crayfish/image_0049_m.jpg
30 | n01984695/n01984695_23488.JPEG # crayfish/image_0057.jpg
31 | n01697457/n01697457_2629.JPEG # crocodile_head/image_0004.jpg
32 | n02110341/n02110341_6788.JPEG # dalmatian/image_0055.jpg
33 | n02504458/n02504458_14537.JPEG # elephant/image_0025.jpg
34 | n02504013/n02504013_165.JPEG # elephant/image_0045.jpg
35 | n03452741/n03452741_9398.JPEG # grand_piano/image_0047.jpg
36 | n02443114/n02443114_421.JPEG # hedgehog/image_0002_m.jpg
37 | n02346627/n02346627_13777.JPEG # hedgehog/image_0014.jpg
38 | n02346627/n02346627_16530.JPEG # hedgehog/image_0054_m.jpg
39 | n02128385/n02128385_26353.JPEG # Leopards/image_0051.jpg
40 | n02437616/n02437616_2924.JPEG # llama/image_0006.jpg
41 | n02437616/n02437616_6670.JPEG # llama/image_0018_m.jpg
42 | n01983481/n01983481_604.JPEG # lobster/image_0003.jpg
43 | n01983481/n01983481_16549.JPEG # lobster/image_0007.jpg
44 | n01968897/n01968897_4233.JPEG # nautilus/image_0001.jpg
45 | n01968897/n01968897_2749.JPEG # nautilus/image_0007.jpg
46 | n01968897/n01968897_7568.JPEG # nautilus/image_0014.jpg
47 | n01968897/n01968897_2639.JPEG # nautilus/image_0019.jpg
48 | n01968897/n01968897_12974.JPEG # nautilus/image_0020.jpg
49 | n01968897/n01968897_8442.JPEG # nautilus/image_0028.jpg
50 | n01968897/n01968897_1579.JPEG # nautilus/image_0029.jpg
51 | n01968897/n01968897_12974.JPEG # nautilus/image_0033.jpg
52 | n01968897/n01968897_22259.JPEG # nautilus/image_0043.jpg
53 | n01968897/n01968897_10474.JPEG # nautilus/image_0047.jpg
54 | n01968897/n01968897_8901.JPEG # nautilus/image_0054.jpg
55 | n01968897/n01968897_11310.JPEG # nautilus/image_0055.jpg
56 | n01873310/n01873310_11884.JPEG # platypus/image_0004.jpg
57 | n01873310/n01873310_12404.JPEG # platypus/image_0004_m.jpg
58 | n01873310/n01873310_3541.JPEG # platypus/image_0005.jpg
59 | n01873310/n01873310_12404.JPEG # platypus/image_0008_m.jpg
60 | n01873310/n01873310_12691.JPEG # platypus/image_0017_m.jpg
61 | n01873310/n01873310_1543.JPEG # platypus/image_0019.jpg
62 | n01873310/n01873310_12801.JPEG # platypus/image_0021_m.jpg
63 | n01873310/n01873310_13795.JPEG # platypus/image_0025.jpg
64 | n01873310/n01873310_3149.JPEG # platypus/image_0026_m.jpg
65 | n04086273/n04086273_6486.JPEG # revolver/image_0032.jpg
66 | n04086273/n04086273_23080.JPEG # revolver/image_0046_m.jpg
67 | n04086273/n04086273_23336.JPEG # revolver/image_0056_m.jpg
68 | n01770393/n01770393_19263.JPEG # scorpion/image_0006_m.jpg
69 | n01770393/n01770393_14293.JPEG # scorpion/image_0024.jpg
70 | n01770393/n01770393_10779.JPEG # scorpion/image_0031_m.jpg
71 | n01770393/n01770393_1713.JPEG # scorpion/image_0049_m.jpg
72 | n01770393/n01770393_13927.JPEG # scorpion/image_0056.jpg
73 | n01770393/n01770393_19263.JPEG # scorpion/image_0072_m.jpg
74 | n01770393/n01770393_19263.JPEG # scorpion/image_0076_m.jpg
75 | n01770393/n01770393_3256.JPEG # scorpion/image_0081_m.jpg
76 | n04147183/n04147183_5706.JPEG # schooner/image_0007.jpg
77 | n04147183/n04147183_13667.JPEG # schooner/image_0013.jpg
78 | n04147183/n04147183_14898.JPEG # schooner/image_0029.jpg
79 | n04147183/n04147183_13627.JPEG # schooner/image_0032.jpg
80 | n04147183/n04147183_17475.JPEG # schooner/image_0039.jpg
81 | n04147183/n04147183_2764.JPEG # schooner/image_0049.jpg
82 | n04147183/n04147183_5881.JPEG # schooner/image_0051.jpg
83 | n04147183/n04147183_5898.JPEG # schooner/image_0058.jpg
84 | n04254680/n04254680_1090.JPEG # soccer_ball/image_0001.jpg
85 | n04254680/n04254680_577.JPEG # soccer_ball/image_0014.jpg
86 | n04254680/n04254680_34.JPEG # soccer_ball/image_0027.jpg
87 | n02317335/n02317335_35.JPEG # starfish/image_0046.jpg
88 | n01776313/n01776313_37040.JPEG # tick/image_0043.jpg
89 | n01776313/n01776313_37040.JPEG # tick/image_0049_m.jpg
90 | n04336792/n04336792_10301.JPEG # wheelchair/image_0025.jpg
91 | n02128925/n02128925_15131.JPEG # wild_cat/image_0030.jpg
92 | n02124075/n02124075_3497.JPEG # wild_cat/image_0031.jpg
93 | n02124075/n02124075_10583.JPEG # wild_cat/image_0033_m.jpg
94 |
--------------------------------------------------------------------------------
/meta_dataset/dataset_conversion/TrafficSign_labels.txt:
--------------------------------------------------------------------------------
1 | # From: https://github.com/navoshta/traffic-signs/blob/master/signnames.csv
2 | Speed limit (20km/h)
3 | Speed limit (30km/h)
4 | Speed limit (50km/h)
5 | Speed limit (60km/h)
6 | Speed limit (70km/h)
7 | Speed limit (80km/h)
8 | End of speed limit (80km/h)
9 | Speed limit (100km/h)
10 | Speed limit (120km/h)
11 | No passing
12 | No passing for vehicles over 3.5 metric tons
13 | Right-of-way at the next intersection
14 | Priority road
15 | Yield
16 | Stop
17 | No vehicles
18 | Vehicles over 3.5 metric tons prohibited
19 | No entry
20 | General caution
21 | Dangerous curve to the left
22 | Dangerous curve to the right
23 | Double curve
24 | Bumpy road
25 | Slippery road
26 | Road narrows on the right
27 | Road work
28 | Traffic signals
29 | Pedestrians
30 | Children crossing
31 | Bicycles crossing
32 | Beware of ice/snow
33 | Wild animals crossing
34 | End of all speed and passing limits
35 | Turn right ahead
36 | Turn left ahead
37 | Ahead only
38 | Go straight or right
39 | Go straight or left
40 | Keep right
41 | Keep left
42 | Roundabout mandatory
43 | End of no passing
44 | End of no passing by vehicles over 3.5 metric tons
45 |
--------------------------------------------------------------------------------
/meta_dataset/dataset_conversion/VggFlower_labels.txt:
--------------------------------------------------------------------------------
1 | # From: https://gist.github.com/JosephKJ/94c7728ed1a8e0cd87fe6a029769cde1
2 | pink primrose
3 | hard-leaved pocket orchid
4 | canterbury bells
5 | sweet pea
6 | english marigold
7 | tiger lily
8 | moon orchid
9 | bird of paradise
10 | monkshood
11 | globe thistle
12 | snapdragon
13 | colt's foot
14 | king protea
15 | spear thistle
16 | yellow iris
17 | globe-flower
18 | purple coneflower
19 | peruvian lily
20 | balloon flower
21 | giant white arum lily
22 | fire lily
23 | pincushion flower
24 | fritillary
25 | red ginger
26 | grape hyacinth
27 | corn poppy
28 | prince of wales feathers
29 | stemless gentian
30 | artichoke
31 | sweet william
32 | carnation
33 | garden phlox
34 | love in the mist
35 | mexican aster
36 | alpine sea holly
37 | ruby-lipped cattleya
38 | cape flower
39 | great masterwort
40 | siam tulip
41 | lenten rose
42 | barbeton daisy
43 | daffodil
44 | sword lily
45 | poinsettia
46 | bolero deep blue
47 | wallflower
48 | marigold
49 | buttercup
50 | oxeye daisy
51 | common dandelion
52 | petunia
53 | wild pansy
54 | primula
55 | sunflower
56 | pelargonium
57 | bishop of llandaff
58 | gaura
59 | geranium
60 | orange dahlia
61 | pink-yellow dahlia?
62 | cautleya spicata
63 | japanese anemone
64 | black-eyed susan
65 | silverbush
66 | californian poppy
67 | osteospermum
68 | spring crocus
69 | bearded iris
70 | windflower
71 | tree poppy
72 | gazania
73 | azalea
74 | water lily
75 | rose
76 | thorn apple
77 | morning glory
78 | passion flower
79 | lotus
80 | toad lily
81 | anthurium
82 | frangipani
83 | clematis
84 | hibiscus
85 | columbine
86 | desert-rose
87 | tree mallow
88 | magnolia
89 | cyclamen
90 | watercress
91 | canna lily
92 | hippeastrum
93 | bee balm
94 | ball moss
95 | foxglove
96 | bougainvillea
97 | camellia
98 | mallow
99 | mexican petunia
100 | bromelia
101 | blanket flower
102 | trumpet creeper
103 | blackberry lily
104 |
--------------------------------------------------------------------------------
/meta_dataset/dataset_conversion/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Meta-Dataset Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 |
17 |
--------------------------------------------------------------------------------
/meta_dataset/dataset_conversion/dataset_specs/aircraft_dataset_spec.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "aircraft",
3 | "classes_per_split": {
4 | "TRAIN": 70,
5 | "VALID": 15,
6 | "TEST": 15
7 | },
8 | "images_per_class": {
9 | "0": 100,
10 | "1": 100,
11 | "2": 100,
12 | "3": 100,
13 | "4": 100,
14 | "5": 100,
15 | "6": 100,
16 | "7": 100,
17 | "8": 100,
18 | "9": 100,
19 | "10": 100,
20 | "11": 100,
21 | "12": 100,
22 | "13": 100,
23 | "14": 100,
24 | "15": 100,
25 | "16": 100,
26 | "17": 100,
27 | "18": 100,
28 | "19": 100,
29 | "20": 100,
30 | "21": 100,
31 | "22": 100,
32 | "23": 100,
33 | "24": 100,
34 | "25": 100,
35 | "26": 100,
36 | "27": 100,
37 | "28": 100,
38 | "29": 100,
39 | "30": 100,
40 | "31": 100,
41 | "32": 100,
42 | "33": 100,
43 | "34": 100,
44 | "35": 100,
45 | "36": 100,
46 | "37": 100,
47 | "38": 100,
48 | "39": 100,
49 | "40": 100,
50 | "41": 100,
51 | "42": 100,
52 | "43": 100,
53 | "44": 100,
54 | "45": 100,
55 | "46": 100,
56 | "47": 100,
57 | "48": 100,
58 | "49": 100,
59 | "50": 100,
60 | "51": 100,
61 | "52": 100,
62 | "53": 100,
63 | "54": 100,
64 | "55": 100,
65 | "56": 100,
66 | "57": 100,
67 | "58": 100,
68 | "59": 100,
69 | "60": 100,
70 | "61": 100,
71 | "62": 100,
72 | "63": 100,
73 | "64": 100,
74 | "65": 100,
75 | "66": 100,
76 | "67": 100,
77 | "68": 100,
78 | "69": 100,
79 | "70": 100,
80 | "71": 100,
81 | "72": 100,
82 | "73": 100,
83 | "74": 100,
84 | "75": 100,
85 | "76": 100,
86 | "77": 100,
87 | "78": 100,
88 | "79": 100,
89 | "80": 100,
90 | "81": 100,
91 | "82": 100,
92 | "83": 100,
93 | "84": 100,
94 | "85": 100,
95 | "86": 100,
96 | "87": 100,
97 | "88": 100,
98 | "89": 100,
99 | "90": 100,
100 | "91": 100,
101 | "92": 100,
102 | "93": 100,
103 | "94": 100,
104 | "95": 100,
105 | "96": 100,
106 | "97": 100,
107 | "98": 100,
108 | "99": 100
109 | },
110 | "class_names": {
111 | "0": "A340-300",
112 | "1": "A318",
113 | "2": "Falcon 2000",
114 | "3": "F-16A/B",
115 | "4": "F/A-18",
116 | "5": "C-130",
117 | "6": "MD-80",
118 | "7": "BAE 146-200",
119 | "8": "777-200",
120 | "9": "747-400",
121 | "10": "Cessna 172",
122 | "11": "An-12",
123 | "12": "A330-300",
124 | "13": "A321",
125 | "14": "Fokker 100",
126 | "15": "Fokker 50",
127 | "16": "DHC-1",
128 | "17": "Fokker 70",
129 | "18": "A340-200",
130 | "19": "DC-6",
131 | "20": "747-200",
132 | "21": "Il-76",
133 | "22": "747-300",
134 | "23": "Model B200",
135 | "24": "Saab 340",
136 | "25": "Cessna 560",
137 | "26": "Dornier 328",
138 | "27": "E-195",
139 | "28": "ERJ 135",
140 | "29": "747-100",
141 | "30": "737-600",
142 | "31": "C-47",
143 | "32": "DR-400",
144 | "33": "ATR-72",
145 | "34": "A330-200",
146 | "35": "727-200",
147 | "36": "737-700",
148 | "37": "PA-28",
149 | "38": "ERJ 145",
150 | "39": "737-300",
151 | "40": "767-300",
152 | "41": "737-500",
153 | "42": "737-200",
154 | "43": "DHC-6",
155 | "44": "Falcon 900",
156 | "45": "DC-3",
157 | "46": "Eurofighter Typhoon",
158 | "47": "Challenger 600",
159 | "48": "Hawk T1",
160 | "49": "A380",
161 | "50": "777-300",
162 | "51": "E-190",
163 | "52": "DHC-8-100",
164 | "53": "Cessna 525",
165 | "54": "Metroliner",
166 | "55": "EMB-120",
167 | "56": "Tu-134",
168 | "57": "Embraer Legacy 600",
169 | "58": "Gulfstream IV",
170 | "59": "Tu-154",
171 | "60": "MD-87",
172 | "61": "A300B4",
173 | "62": "A340-600",
174 | "63": "A340-500",
175 | "64": "MD-11",
176 | "65": "707-320",
177 | "66": "Cessna 208",
178 | "67": "Global Express",
179 | "68": "A319",
180 | "69": "DH-82",
181 | "70": "737-900",
182 | "71": "757-300",
183 | "72": "767-200",
184 | "73": "A310",
185 | "74": "A320",
186 | "75": "BAE 146-300",
187 | "76": "CRJ-900",
188 | "77": "DC-10",
189 | "78": "DC-8",
190 | "79": "DC-9-30",
191 | "80": "DHC-8-300",
192 | "81": "Gulfstream V",
193 | "82": "SR-20",
194 | "83": "Tornado",
195 | "84": "Yak-42",
196 | "85": "737-400",
197 | "86": "737-800",
198 | "87": "757-200",
199 | "88": "767-400",
200 | "89": "ATR-42",
201 | "90": "BAE-125",
202 | "91": "Beechcraft 1900",
203 | "92": "Boeing 717",
204 | "93": "CRJ-200",
205 | "94": "CRJ-700",
206 | "95": "E-170",
207 | "96": "L-1011",
208 | "97": "MD-90",
209 | "98": "Saab 2000",
210 | "99": "Spitfire"
211 | },
212 | "path": "/path/to/records/aircraft",
213 | "file_pattern": "{}.tfrecords",
214 | "__class__": "DatasetSpecification"
215 | }
216 |
--------------------------------------------------------------------------------
/meta_dataset/dataset_conversion/dataset_specs/dtd_dataset_spec.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "dtd",
3 | "classes_per_split": {
4 | "TRAIN": 33,
5 | "VALID": 7,
6 | "TEST": 7
7 | },
8 | "images_per_class": {
9 | "0": 120,
10 | "1": 120,
11 | "2": 120,
12 | "3": 120,
13 | "4": 120,
14 | "5": 120,
15 | "6": 120,
16 | "7": 120,
17 | "8": 120,
18 | "9": 120,
19 | "10": 120,
20 | "11": 120,
21 | "12": 120,
22 | "13": 120,
23 | "14": 120,
24 | "15": 120,
25 | "16": 120,
26 | "17": 120,
27 | "18": 120,
28 | "19": 120,
29 | "20": 120,
30 | "21": 120,
31 | "22": 120,
32 | "23": 120,
33 | "24": 120,
34 | "25": 120,
35 | "26": 120,
36 | "27": 120,
37 | "28": 120,
38 | "29": 120,
39 | "30": 120,
40 | "31": 120,
41 | "32": 120,
42 | "33": 120,
43 | "34": 120,
44 | "35": 120,
45 | "36": 120,
46 | "37": 120,
47 | "38": 120,
48 | "39": 120,
49 | "40": 120,
50 | "41": 120,
51 | "42": 120,
52 | "43": 120,
53 | "44": 120,
54 | "45": 120,
55 | "46": 120
56 | },
57 | "class_names": {
58 | "0": "chequered",
59 | "1": "braided",
60 | "2": "interlaced",
61 | "3": "matted",
62 | "4": "honeycombed",
63 | "5": "marbled",
64 | "6": "veined",
65 | "7": "frilly",
66 | "8": "zigzagged",
67 | "9": "cobwebbed",
68 | "10": "pitted",
69 | "11": "waffled",
70 | "12": "fibrous",
71 | "13": "flecked",
72 | "14": "grooved",
73 | "15": "potholed",
74 | "16": "blotchy",
75 | "17": "stained",
76 | "18": "crystalline",
77 | "19": "dotted",
78 | "20": "striped",
79 | "21": "swirly",
80 | "22": "meshed",
81 | "23": "bubbly",
82 | "24": "studded",
83 | "25": "pleated",
84 | "26": "lacelike",
85 | "27": "polka-dotted",
86 | "28": "perforated",
87 | "29": "freckled",
88 | "30": "smeared",
89 | "31": "cracked",
90 | "32": "wrinkled",
91 | "33": "gauzy",
92 | "34": "grid",
93 | "35": "lined",
94 | "36": "paisley",
95 | "37": "porous",
96 | "38": "scaly",
97 | "39": "spiralled",
98 | "40": "banded",
99 | "41": "bumpy",
100 | "42": "crosshatched",
101 | "43": "knitted",
102 | "44": "sprinkled",
103 | "45": "stratified",
104 | "46": "woven"
105 | },
106 | "path": "/path/to/records/dtd",
107 | "file_pattern": "{}.tfrecords",
108 | "__class__": "DatasetSpecification"
109 | }
110 |
--------------------------------------------------------------------------------
/meta_dataset/dataset_conversion/dataset_specs/mscoco_dataset_spec.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "mscoco",
3 | "classes_per_split": {
4 | "TRAIN": 0,
5 | "VALID": 40,
6 | "TEST": 40
7 | },
8 | "images_per_class": {
9 | "67": 2918,
10 | "7": 5508,
11 | "71": 8652,
12 | "29": 5805,
13 | "48": 10806,
14 | "6": 4768,
15 | "49": 6587,
16 | "50": 9509,
17 | "51": 8147,
18 | "58": 24342,
19 | "70": 5779,
20 | "69": 38491,
21 | "27": 15714,
22 | "40": 7113,
23 | "41": 43867,
24 | "1": 8725,
25 | "42": 5135,
26 | "43": 6069,
27 | "2": 4571,
28 | "45": 10759,
29 | "0": 262465,
30 | "46": 1983,
31 | "12": 11431,
32 | "53": 6496,
33 | "16": 6347,
34 | "22": 4373,
35 | "72": 4192,
36 | "32": 6434,
37 | "77": 2637,
38 | "36": 6334,
39 | "39": 1954,
40 | "44": 9973,
41 | "3": 12884,
42 | "4": 1865,
43 | "5": 1285,
44 | "47": 9838,
45 | "8": 5513,
46 | "52": 5131,
47 | "54": 2682,
48 | "55": 6646,
49 | "15": 2685,
50 | "17": 9076,
51 | "56": 3276,
52 | "18": 3747,
53 | "19": 5543,
54 | "20": 6126,
55 | "57": 4812,
56 | "59": 7913,
57 | "60": 20650,
58 | "61": 5479,
59 | "21": 7770,
60 | "62": 6165,
61 | "63": 14358,
62 | "64": 9458,
63 | "65": 5851,
64 | "66": 6399,
65 | "23": 7308,
66 | "24": 7852,
67 | "68": 5821,
68 | "25": 7179,
69 | "26": 6353,
70 | "28": 4157,
71 | "30": 4970,
72 | "73": 2262,
73 | "31": 5703,
74 | "74": 2855,
75 | "33": 1673,
76 | "34": 3334,
77 | "75": 225,
78 | "76": 5610,
79 | "35": 24715,
80 | "78": 6613,
81 | "79": 1481,
82 | "37": 4793,
83 | "38": 198,
84 | "11": 8720,
85 | "13": 12354,
86 | "14": 6192,
87 | "10": 5303,
88 | "9": 1294
89 | },
90 | "class_names": {
91 | "0": "person",
92 | "1": "motorcycle",
93 | "2": "train",
94 | "3": "traffic light",
95 | "4": "fire hydrant",
96 | "5": "parking meter",
97 | "6": "cat",
98 | "7": "dog",
99 | "8": "elephant",
100 | "9": "bear",
101 | "10": "zebra",
102 | "11": "backpack",
103 | "12": "umbrella",
104 | "13": "handbag",
105 | "14": "suitcase",
106 | "15": "snowboard",
107 | "16": "sports ball",
108 | "17": "kite",
109 | "18": "baseball glove",
110 | "19": "skateboard",
111 | "20": "surfboard",
112 | "21": "knife",
113 | "22": "sandwich",
114 | "23": "broccoli",
115 | "24": "carrot",
116 | "25": "donut",
117 | "26": "cake",
118 | "27": "dining table",
119 | "28": "toilet",
120 | "29": "tv",
121 | "30": "laptop",
122 | "31": "remote",
123 | "32": "cell phone",
124 | "33": "microwave",
125 | "34": "oven",
126 | "35": "book",
127 | "36": "clock",
128 | "37": "teddy bear",
129 | "38": "hair drier",
130 | "39": "toothbrush",
131 | "40": "bicycle",
132 | "41": "car",
133 | "42": "airplane",
134 | "43": "bus",
135 | "44": "truck",
136 | "45": "boat",
137 | "46": "stop sign",
138 | "47": "bench",
139 | "48": "bird",
140 | "49": "horse",
141 | "50": "sheep",
142 | "51": "cow",
143 | "52": "giraffe",
144 | "53": "tie",
145 | "54": "frisbee",
146 | "55": "skis",
147 | "56": "baseball bat",
148 | "57": "tennis racket",
149 | "58": "bottle",
150 | "59": "wine glass",
151 | "60": "cup",
152 | "61": "fork",
153 | "62": "spoon",
154 | "63": "bowl",
155 | "64": "banana",
156 | "65": "apple",
157 | "66": "orange",
158 | "67": "hot dog",
159 | "68": "pizza",
160 | "69": "chair",
161 | "70": "couch",
162 | "71": "potted plant",
163 | "72": "bed",
164 | "73": "mouse",
165 | "74": "keyboard",
166 | "75": "toaster",
167 | "76": "sink",
168 | "77": "refrigerator",
169 | "78": "vase",
170 | "79": "scissors"
171 | },
172 | "path": "/path/to/records/mscoco",
173 | "file_pattern": "{}.tfrecords",
174 | "__class__": "DatasetSpecification"
175 | }
176 |
--------------------------------------------------------------------------------
/meta_dataset/dataset_conversion/dataset_specs/traffic_sign_dataset_spec.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "traffic_sign",
3 | "classes_per_split": {
4 | "TRAIN": 0,
5 | "VALID": 0,
6 | "TEST": 43
7 | },
8 | "images_per_class": {
9 | "0": 210,
10 | "1": 2220,
11 | "2": 2250,
12 | "3": 1410,
13 | "4": 1980,
14 | "5": 1860,
15 | "6": 420,
16 | "7": 1440,
17 | "8": 1410,
18 | "9": 1470,
19 | "10": 2010,
20 | "11": 1320,
21 | "12": 2100,
22 | "13": 2160,
23 | "14": 780,
24 | "15": 630,
25 | "16": 420,
26 | "17": 1110,
27 | "18": 1200,
28 | "19": 210,
29 | "20": 360,
30 | "21": 330,
31 | "22": 390,
32 | "23": 510,
33 | "24": 270,
34 | "25": 1500,
35 | "26": 600,
36 | "27": 240,
37 | "28": 540,
38 | "29": 270,
39 | "30": 450,
40 | "31": 780,
41 | "32": 240,
42 | "33": 689,
43 | "34": 420,
44 | "35": 1200,
45 | "36": 390,
46 | "37": 210,
47 | "38": 2070,
48 | "39": 300,
49 | "40": 360,
50 | "41": 240,
51 | "42": 240
52 | },
53 | "class_names": {
54 | "0": "00.Speed limit (20km/h)",
55 | "1": "01.Speed limit (30km/h)",
56 | "2": "02.Speed limit (50km/h)",
57 | "3": "03.Speed limit (60km/h)",
58 | "4": "04.Speed limit (70km/h)",
59 | "5": "05.Speed limit (80km/h)",
60 | "6": "06.End of speed limit (80km/h)",
61 | "7": "07.Speed limit (100km/h)",
62 | "8": "08.Speed limit (120km/h)",
63 | "9": "09.No passing",
64 | "10": "10.No passing for vehicles over 3.5 metric tons",
65 | "11": "11.Right-of-way at the next intersection",
66 | "12": "12.Priority road",
67 | "13": "13.Yield",
68 | "14": "14.Stop",
69 | "15": "15.No vehicles",
70 | "16": "16.Vehicles over 3.5 metric tons prohibited",
71 | "17": "17.No entry",
72 | "18": "18.General caution",
73 | "19": "19.Dangerous curve to the left",
74 | "20": "20.Dangerous curve to the right",
75 | "21": "21.Double curve",
76 | "22": "22.Bumpy road",
77 | "23": "23.Slippery road",
78 | "24": "24.Road narrows on the right",
79 | "25": "25.Road work",
80 | "26": "26.Traffic signals",
81 | "27": "27.Pedestrians",
82 | "28": "28.Children crossing",
83 | "29": "29.Bicycles crossing",
84 | "30": "30.Beware of ice/snow",
85 | "31": "31.Wild animals crossing",
86 | "32": "32.End of all speed and passing limits",
87 | "33": "33.Turn right ahead",
88 | "34": "34.Turn left ahead",
89 | "35": "35.Ahead only",
90 | "36": "36.Go straight or right",
91 | "37": "37.Go straight or left",
92 | "38": "38.Keep right",
93 | "39": "39.Keep left",
94 | "40": "40.Roundabout mandatory",
95 | "41": "41.End of no passing",
96 | "42": "42.End of no passing by vehicles over 3.5 metric tons"
97 | },
98 | "path": "/path/to/records/traffic_sign",
99 | "file_pattern": "{}.tfrecords",
100 | "__class__": "DatasetSpecification"
101 | }
102 |
--------------------------------------------------------------------------------
/meta_dataset/dataset_conversion/dataset_specs/vgg_flower_dataset_spec.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "vgg_flower",
3 | "classes_per_split": {
4 | "TRAIN": 71,
5 | "VALID": 15,
6 | "TEST": 16
7 | },
8 | "images_per_class": {
9 | "0": 82,
10 | "1": 56,
11 | "2": 105,
12 | "3": 85,
13 | "4": 78,
14 | "5": 87,
15 | "6": 130,
16 | "7": 40,
17 | "8": 56,
18 | "9": 52,
19 | "10": 52,
20 | "11": 63,
21 | "12": 85,
22 | "13": 42,
23 | "14": 128,
24 | "15": 251,
25 | "16": 137,
26 | "17": 50,
27 | "18": 154,
28 | "19": 171,
29 | "20": 71,
30 | "21": 45,
31 | "22": 40,
32 | "23": 49,
33 | "24": 41,
34 | "25": 62,
35 | "26": 258,
36 | "27": 54,
37 | "28": 48,
38 | "29": 109,
39 | "30": 87,
40 | "31": 41,
41 | "32": 54,
42 | "33": 108,
43 | "34": 75,
44 | "35": 66,
45 | "36": 71,
46 | "37": 40,
47 | "38": 46,
48 | "39": 60,
49 | "40": 82,
50 | "41": 42,
51 | "42": 45,
52 | "43": 40,
53 | "44": 93,
54 | "45": 107,
55 | "46": 120,
56 | "47": 96,
57 | "48": 85,
58 | "49": 86,
59 | "50": 194,
60 | "51": 40,
61 | "52": 61,
62 | "53": 61,
63 | "54": 67,
64 | "55": 92,
65 | "56": 76,
66 | "57": 54,
67 | "58": 49,
68 | "59": 78,
69 | "60": 166,
70 | "61": 58,
71 | "62": 66,
72 | "63": 59,
73 | "64": 46,
74 | "65": 63,
75 | "66": 40,
76 | "67": 49,
77 | "68": 56,
78 | "69": 41,
79 | "70": 114,
80 | "71": 45,
81 | "72": 41,
82 | "73": 85,
83 | "74": 91,
84 | "75": 41,
85 | "76": 67,
86 | "77": 93,
87 | "78": 109,
88 | "79": 67,
89 | "80": 55,
90 | "81": 112,
91 | "82": 131,
92 | "83": 58,
93 | "84": 66,
94 | "85": 48,
95 | "86": 65,
96 | "87": 46,
97 | "88": 49,
98 | "89": 49,
99 | "90": 43,
100 | "91": 67,
101 | "92": 127,
102 | "93": 59,
103 | "94": 40,
104 | "95": 196,
105 | "96": 102,
106 | "97": 63,
107 | "98": 184,
108 | "99": 162,
109 | "100": 91,
110 | "101": 82
111 | },
112 | "class_names": {
113 | "0": "090.canna lily",
114 | "1": "038.great masterwort",
115 | "2": "080.anthurium",
116 | "3": "030.sweet william",
117 | "4": "029.artichoke",
118 | "5": "012.colt's foot",
119 | "6": "043.sword lily",
120 | "7": "027.prince of wales feathers",
121 | "8": "004.sweet pea",
122 | "9": "064.silverbush",
123 | "10": "031.carnation",
124 | "11": "099.bromelia",
125 | "12": "008.bird of paradise",
126 | "13": "067.spring crocus",
127 | "14": "095.bougainvillea",
128 | "15": "077.passion flower",
129 | "16": "078.lotus",
130 | "17": "061.cautleya spicata",
131 | "18": "088.cyclamen",
132 | "19": "074.rose",
133 | "20": "055.pelargonium",
134 | "21": "032.garden phlox",
135 | "22": "021.fire lily",
136 | "23": "013.king protea",
137 | "24": "079.toad lily",
138 | "25": "070.tree poppy",
139 | "26": "051.petunia",
140 | "27": "069.windflower",
141 | "28": "014.spear thistle",
142 | "29": "060.pink-yellow dahlia?",
143 | "30": "011.snapdragon",
144 | "31": "039.siam tulip",
145 | "32": "063.black-eyed susan",
146 | "33": "037.cape flower",
147 | "34": "036.ruby-lipped cattleya",
148 | "35": "028.stemless gentian",
149 | "36": "048.buttercup",
150 | "37": "007.moon orchid",
151 | "38": "093.ball moss",
152 | "39": "002.hard-leaved pocket orchid",
153 | "40": "018.peruvian lily",
154 | "41": "024.red ginger",
155 | "42": "006.tiger lily",
156 | "43": "003.canterbury bells",
157 | "44": "044.poinsettia",
158 | "45": "076.morning glory",
159 | "46": "075.thorn apple",
160 | "47": "072.azalea",
161 | "48": "052.wild pansy",
162 | "49": "084.columbine",
163 | "50": "073.water lily",
164 | "51": "034.mexican aster",
165 | "52": "054.sunflower",
166 | "53": "066.osteospermum",
167 | "54": "059.orange dahlia",
168 | "55": "050.common dandelion",
169 | "56": "091.hippeastrum",
170 | "57": "068.bearded iris",
171 | "58": "100.blanket flower",
172 | "59": "071.gazania",
173 | "60": "081.frangipani",
174 | "61": "101.trumpet creeper",
175 | "62": "092.bee balm",
176 | "63": "022.pincushion flower",
177 | "64": "033.love in the mist",
178 | "65": "087.magnolia",
179 | "66": "001.pink primrose",
180 | "67": "049.oxeye daisy",
181 | "68": "020.giant white arum lily",
182 | "69": "025.grape hyacinth",
183 | "70": "058.geranium",
184 | "71": "010.globe thistle",
185 | "72": "016.globe-flower",
186 | "73": "017.purple coneflower",
187 | "74": "023.fritillary",
188 | "75": "026.corn poppy",
189 | "76": "047.marigold",
190 | "77": "053.primula",
191 | "78": "056.bishop of llandaff",
192 | "79": "057.gaura",
193 | "80": "062.japanese anemone",
194 | "81": "082.clematis",
195 | "82": "083.hibiscus",
196 | "83": "086.tree mallow",
197 | "84": "097.mallow",
198 | "85": "102.blackberry lily",
199 | "86": "005.english marigold",
200 | "87": "009.monkshood",
201 | "88": "015.yellow iris",
202 | "89": "019.balloon flower",
203 | "90": "035.alpine sea holly",
204 | "91": "040.lenten rose",
205 | "92": "041.barbeton daisy",
206 | "93": "042.daffodil",
207 | "94": "045.bolero deep blue",
208 | "95": "046.wallflower",
209 | "96": "065.californian poppy",
210 | "97": "085.desert-rose",
211 | "98": "089.watercress",
212 | "99": "094.foxglove",
213 | "100": "096.camellia",
214 | "101": "098.mexican petunia"
215 | },
216 | "path": "/path/to/records/vgg_flower",
217 | "file_pattern": "{}.tfrecords",
218 | "__class__": "DatasetSpecification"
219 | }
220 |
--------------------------------------------------------------------------------
/meta_dataset/dataset_conversion/splits/aircraft_splits.json:
--------------------------------------------------------------------------------
1 | {
2 | "train": [
3 | "A340-300",
4 | "A318",
5 | "Falcon 2000",
6 | "F-16A/B",
7 | "F/A-18",
8 | "C-130",
9 | "MD-80",
10 | "BAE 146-200",
11 | "777-200",
12 | "747-400",
13 | "Cessna 172",
14 | "An-12",
15 | "A330-300",
16 | "A321",
17 | "Fokker 100",
18 | "Fokker 50",
19 | "DHC-1",
20 | "Fokker 70",
21 | "A340-200",
22 | "DC-6",
23 | "747-200",
24 | "Il-76",
25 | "747-300",
26 | "Model B200",
27 | "Saab 340",
28 | "Cessna 560",
29 | "Dornier 328",
30 | "E-195",
31 | "ERJ 135",
32 | "747-100",
33 | "737-600",
34 | "C-47",
35 | "DR-400",
36 | "ATR-72",
37 | "A330-200",
38 | "727-200",
39 | "737-700",
40 | "PA-28",
41 | "ERJ 145",
42 | "737-300",
43 | "767-300",
44 | "737-500",
45 | "737-200",
46 | "DHC-6",
47 | "Falcon 900",
48 | "DC-3",
49 | "Eurofighter Typhoon",
50 | "Challenger 600",
51 | "Hawk T1",
52 | "A380",
53 | "777-300",
54 | "E-190",
55 | "DHC-8-100",
56 | "Cessna 525",
57 | "Metroliner",
58 | "EMB-120",
59 | "Tu-134",
60 | "Embraer Legacy 600",
61 | "Gulfstream IV",
62 | "Tu-154",
63 | "MD-87",
64 | "A300B4",
65 | "A340-600",
66 | "A340-500",
67 | "MD-11",
68 | "707-320",
69 | "Cessna 208",
70 | "Global Express",
71 | "A319",
72 | "DH-82"
73 | ],
74 | "test": [
75 | "737-400",
76 | "737-800",
77 | "757-200",
78 | "767-400",
79 | "ATR-42",
80 | "BAE-125",
81 | "Beechcraft 1900",
82 | "Boeing 717",
83 | "CRJ-200",
84 | "CRJ-700",
85 | "E-170",
86 | "L-1011",
87 | "MD-90",
88 | "Saab 2000",
89 | "Spitfire"
90 | ],
91 | "valid": [
92 | "737-900",
93 | "757-300",
94 | "767-200",
95 | "A310",
96 | "A320",
97 | "BAE 146-300",
98 | "CRJ-900",
99 | "DC-10",
100 | "DC-8",
101 | "DC-9-30",
102 | "DHC-8-300",
103 | "Gulfstream V",
104 | "SR-20",
105 | "Tornado",
106 | "Yak-42"
107 | ]
108 | }
--------------------------------------------------------------------------------
/meta_dataset/dataset_conversion/splits/dtd_splits.json:
--------------------------------------------------------------------------------
1 | {
2 | "train": [
3 | "chequered",
4 | "braided",
5 | "interlaced",
6 | "matted",
7 | "honeycombed",
8 | "marbled",
9 | "veined",
10 | "frilly",
11 | "zigzagged",
12 | "cobwebbed",
13 | "pitted",
14 | "waffled",
15 | "fibrous",
16 | "flecked",
17 | "grooved",
18 | "potholed",
19 | "blotchy",
20 | "stained",
21 | "crystalline",
22 | "dotted",
23 | "striped",
24 | "swirly",
25 | "meshed",
26 | "bubbly",
27 | "studded",
28 | "pleated",
29 | "lacelike",
30 | "polka-dotted",
31 | "perforated",
32 | "freckled",
33 | "smeared",
34 | "cracked",
35 | "wrinkled"
36 | ],
37 | "test": [
38 | "banded",
39 | "bumpy",
40 | "crosshatched",
41 | "knitted",
42 | "sprinkled",
43 | "stratified",
44 | "woven"
45 | ],
46 | "valid": [
47 | "gauzy",
48 | "grid",
49 | "lined",
50 | "paisley",
51 | "porous",
52 | "scaly",
53 | "spiralled"
54 | ]
55 | }
--------------------------------------------------------------------------------
/meta_dataset/dataset_conversion/splits/mscoco_splits.json:
--------------------------------------------------------------------------------
1 | {
2 | "train": [],
3 | "test": [
4 | "bicycle",
5 | "car",
6 | "airplane",
7 | "bus",
8 | "truck",
9 | "boat",
10 | "stop sign",
11 | "bench",
12 | "bird",
13 | "horse",
14 | "sheep",
15 | "cow",
16 | "giraffe",
17 | "tie",
18 | "frisbee",
19 | "skis",
20 | "baseball bat",
21 | "tennis racket",
22 | "bottle",
23 | "wine glass",
24 | "cup",
25 | "fork",
26 | "spoon",
27 | "bowl",
28 | "banana",
29 | "apple",
30 | "orange",
31 | "hot dog",
32 | "pizza",
33 | "chair",
34 | "couch",
35 | "potted plant",
36 | "bed",
37 | "mouse",
38 | "keyboard",
39 | "toaster",
40 | "sink",
41 | "refrigerator",
42 | "vase",
43 | "scissors"
44 | ],
45 | "valid": [
46 | "person",
47 | "motorcycle",
48 | "train",
49 | "traffic light",
50 | "fire hydrant",
51 | "parking meter",
52 | "cat",
53 | "dog",
54 | "elephant",
55 | "bear",
56 | "zebra",
57 | "backpack",
58 | "umbrella",
59 | "handbag",
60 | "suitcase",
61 | "snowboard",
62 | "sports ball",
63 | "kite",
64 | "baseball glove",
65 | "skateboard",
66 | "surfboard",
67 | "knife",
68 | "sandwich",
69 | "broccoli",
70 | "carrot",
71 | "donut",
72 | "cake",
73 | "dining table",
74 | "toilet",
75 | "tv",
76 | "laptop",
77 | "remote",
78 | "cell phone",
79 | "microwave",
80 | "oven",
81 | "book",
82 | "clock",
83 | "teddy bear",
84 | "hair drier",
85 | "toothbrush"
86 | ]
87 | }
--------------------------------------------------------------------------------
/meta_dataset/dataset_conversion/splits/traffic_sign_splits.json:
--------------------------------------------------------------------------------
1 | {
2 | "train": [],
3 | "test": [
4 | "00.Speed limit (20km/h)",
5 | "01.Speed limit (30km/h)",
6 | "02.Speed limit (50km/h)",
7 | "03.Speed limit (60km/h)",
8 | "04.Speed limit (70km/h)",
9 | "05.Speed limit (80km/h)",
10 | "06.End of speed limit (80km/h)",
11 | "07.Speed limit (100km/h)",
12 | "08.Speed limit (120km/h)",
13 | "09.No passing",
14 | "10.No passing for vehicles over 3.5 metric tons",
15 | "11.Right-of-way at the next intersection",
16 | "12.Priority road",
17 | "13.Yield",
18 | "14.Stop",
19 | "15.No vehicles",
20 | "16.Vehicles over 3.5 metric tons prohibited",
21 | "17.No entry",
22 | "18.General caution",
23 | "19.Dangerous curve to the left",
24 | "20.Dangerous curve to the right",
25 | "21.Double curve",
26 | "22.Bumpy road",
27 | "23.Slippery road",
28 | "24.Road narrows on the right",
29 | "25.Road work",
30 | "26.Traffic signals",
31 | "27.Pedestrians",
32 | "28.Children crossing",
33 | "29.Bicycles crossing",
34 | "30.Beware of ice/snow",
35 | "31.Wild animals crossing",
36 | "32.End of all speed and passing limits",
37 | "33.Turn right ahead",
38 | "34.Turn left ahead",
39 | "35.Ahead only",
40 | "36.Go straight or right",
41 | "37.Go straight or left",
42 | "38.Keep right",
43 | "39.Keep left",
44 | "40.Roundabout mandatory",
45 | "41.End of no passing",
46 | "42.End of no passing by vehicles over 3.5 metric tons"
47 | ],
48 | "valid": []
49 | }
50 |
--------------------------------------------------------------------------------
/meta_dataset/dataset_conversion/splits/vgg_flower_splits.json:
--------------------------------------------------------------------------------
1 | {
2 | "train": [
3 | "090.canna lily",
4 | "038.great masterwort",
5 | "080.anthurium",
6 | "030.sweet william",
7 | "029.artichoke",
8 | "012.colt's foot",
9 | "043.sword lily",
10 | "027.prince of wales feathers",
11 | "004.sweet pea",
12 | "064.silverbush",
13 | "031.carnation",
14 | "099.bromelia",
15 | "008.bird of paradise",
16 | "067.spring crocus",
17 | "095.bougainvillea",
18 | "077.passion flower",
19 | "078.lotus",
20 | "061.cautleya spicata",
21 | "088.cyclamen",
22 | "074.rose",
23 | "055.pelargonium",
24 | "032.garden phlox",
25 | "021.fire lily",
26 | "013.king protea",
27 | "079.toad lily",
28 | "070.tree poppy",
29 | "051.petunia",
30 | "069.windflower",
31 | "014.spear thistle",
32 | "060.pink-yellow dahlia?",
33 | "011.snapdragon",
34 | "039.siam tulip",
35 | "063.black-eyed susan",
36 | "037.cape flower",
37 | "036.ruby-lipped cattleya",
38 | "028.stemless gentian",
39 | "048.buttercup",
40 | "007.moon orchid",
41 | "093.ball moss",
42 | "002.hard-leaved pocket orchid",
43 | "018.peruvian lily",
44 | "024.red ginger",
45 | "006.tiger lily",
46 | "003.canterbury bells",
47 | "044.poinsettia",
48 | "076.morning glory",
49 | "075.thorn apple",
50 | "072.azalea",
51 | "052.wild pansy",
52 | "084.columbine",
53 | "073.water lily",
54 | "034.mexican aster",
55 | "054.sunflower",
56 | "066.osteospermum",
57 | "059.orange dahlia",
58 | "050.common dandelion",
59 | "091.hippeastrum",
60 | "068.bearded iris",
61 | "100.blanket flower",
62 | "071.gazania",
63 | "081.frangipani",
64 | "101.trumpet creeper",
65 | "092.bee balm",
66 | "022.pincushion flower",
67 | "033.love in the mist",
68 | "087.magnolia",
69 | "001.pink primrose",
70 | "049.oxeye daisy",
71 | "020.giant white arum lily",
72 | "025.grape hyacinth",
73 | "058.geranium"
74 | ],
75 | "valid": [
76 | "010.globe thistle",
77 | "016.globe-flower",
78 | "017.purple coneflower",
79 | "023.fritillary",
80 | "026.corn poppy",
81 | "047.marigold",
82 | "053.primula",
83 | "056.bishop of llandaff",
84 | "057.gaura",
85 | "062.japanese anemone",
86 | "082.clematis",
87 | "083.hibiscus",
88 | "086.tree mallow",
89 | "097.mallow",
90 | "102.blackberry lily"
91 | ],
92 | "test": [
93 | "005.english marigold",
94 | "009.monkshood",
95 | "015.yellow iris",
96 | "019.balloon flower",
97 | "035.alpine sea holly",
98 | "040.lenten rose",
99 | "041.barbeton daisy",
100 | "042.daffodil",
101 | "045.bolero deep blue",
102 | "046.wallflower",
103 | "065.californian poppy",
104 | "085.desert-rose",
105 | "089.watercress",
106 | "094.foxglove",
107 | "096.camellia",
108 | "098.mexican petunia"
109 | ]
110 | }
111 |
--------------------------------------------------------------------------------
/meta_dataset/distribute_utils.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Meta-Dataset Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Utilities for working in a tf.distribute context."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | import tensorflow.compat.v1 as tf
23 |
24 |
25 | def _pad(tensor, total_size, position):
26 | """Pad tensor with zeros along its first axis.
27 |
28 | The output will have first dimension total_size, and
29 | tensor will appear at position position.
30 |
31 | Args:
32 | tensor: tensor to pad.
33 | total_size: the output size.
34 | position: where in the output tensor should appear.
35 |
36 | Returns:
37 | output: The padded tensor.
38 | """
39 | shape_rest = tf.shape(tensor)[1:]
40 | after_dim = total_size - position - tf.shape(tensor)[0]
41 | pad_before = tf.zeros(
42 | tf.concat([[position], shape_rest], axis=0), dtype=tensor.dtype)
43 | pad_after = tf.zeros(
44 | tf.concat([[after_dim], shape_rest], axis=0), dtype=tensor.dtype)
45 | return tf.concat([pad_before, tensor, pad_after], axis=0)
46 |
47 |
48 | def aggregate(tensor):
49 | """Aggregate a tensor across distributed replicas.
50 |
51 | If not running in a distributed context, this just returns the input tensor.
52 |
53 | Args:
54 | tensor: tensor aggregate.
55 |
56 | Returns:
57 | output: A single tensor with all values across different replicas
58 | concatenated along the first axis. The output is in order of gpu index.
59 | """
60 |
61 | replica_ctx = tf.distribute.get_replica_context()
62 | if not replica_ctx:
63 | return tensor
64 | num = tf.shape(tensor)[0:1]
65 | padded_num = _pad(num, replica_ctx.num_replicas_in_sync,
66 | replica_ctx.replica_id_in_sync_group)
67 | all_num = replica_ctx.all_reduce('sum', padded_num)
68 | index_in_output = tf.gather(
69 | tf.cumsum(tf.concat([[0], all_num], axis=0)),
70 | replica_ctx.replica_id_in_sync_group)
71 | total_num = tf.reduce_sum(all_num)
72 | padded_tensor = _pad(tensor, total_num, index_in_output)
73 | return replica_ctx.all_reduce('sum', padded_tensor)
74 |
--------------------------------------------------------------------------------
/meta_dataset/learn/gin/best/baseline_all.gin:
--------------------------------------------------------------------------------
1 | include 'meta_dataset/learn/gin/setups/all.gin'
2 | include 'meta_dataset/learn/gin/learners/baseline_config.gin'
3 | BatchSplitReaderGetReader.add_dataset_offset = True
4 |
5 | # Backbone hypers.
6 | include 'meta_dataset/learn/gin/best_v2/pretrained_wide_resnet.gin'
7 |
8 | # Model hypers.
9 | BaselineLearner.knn_distance = 'cosine'
10 | BaselineLearner.cosine_classifier = True
11 | BaselineLearner.cosine_logits_multiplier = 1
12 | BaselineLearner.use_weight_norm = True
13 |
14 | # Data hypers.
15 | DataConfig.image_height = 126
16 |
17 | # Training hypers (not needed for eval).
18 | Trainer.decay_every = 500
19 | Trainer.decay_rate = 0.8778059962506467
20 | Trainer.learning_rate = 0.000253906846867988
21 | weight_decay = 0.00002393929026012612
22 |
--------------------------------------------------------------------------------
/meta_dataset/learn/gin/best/baseline_imagenet.gin:
--------------------------------------------------------------------------------
1 | include 'meta_dataset/learn/gin/setups/imagenet.gin'
2 | include 'meta_dataset/learn/gin/learners/baseline_config.gin'
3 |
4 | # Backbone hypers.
5 | Learner.embedding_fn = @wide_resnet
6 | weight_decay = 0.000007138118976497546
7 |
8 | Trainer.checkpoint_to_restore = ''
9 | Trainer.pretrained_source = 'scratch'
10 |
11 | # Model hypers.
12 | BaselineLearner.knn_distance = 'cosine'
13 | BaselineLearner.cosine_classifier = False
14 | BaselineLearner.cosine_logits_multiplier = 10
15 | BaselineLearner.use_weight_norm = False
16 |
17 | # Data hypers.
18 | DataConfig.image_height = 126
19 |
20 | # Training hypers (not needed for eval).
21 | Trainer.decay_every = 10000
22 | Trainer.decay_rate = 0.7294597641152971
23 | Trainer.learning_rate = 0.007634189137886614
24 |
--------------------------------------------------------------------------------
/meta_dataset/learn/gin/best/baselinefinetune_all.gin:
--------------------------------------------------------------------------------
1 | include 'meta_dataset/learn/gin/setups/all.gin'
2 | include 'meta_dataset/learn/gin/learners/baselinefinetune_config.gin'
3 | BatchSplitReaderGetReader.add_dataset_offset = True
4 |
5 | # Backbone hypers.
6 | include 'meta_dataset/learn/gin/best_v2/pretrained_wide_resnet.gin'
7 |
8 | # Model hypers.
9 | BaselineLearner.cosine_classifier = False
10 | BaselineLearner.use_weight_norm = True
11 | BaselineLearner.cosine_logits_multiplier = 1
12 | BaselineFinetuneLearner.num_finetune_steps = 200
13 | BaselineFinetuneLearner.finetune_lr = 0.01
14 | BaselineFinetuneLearner.finetune_all_layers = True
15 | BaselineFinetuneLearner.finetune_with_adam = True
16 |
17 | # Data hypers.
18 | DataConfig.image_height = 84
19 |
20 | # Training hypers (not needed for eval).
21 | Trainer.decay_every = 5000
22 | Trainer.decay_rate = 0.5559080744371039
23 | Trainer.learning_rate = 0.0027015533546616804
24 | weight_decay = 0.00002266979856832968
25 |
--------------------------------------------------------------------------------
/meta_dataset/learn/gin/best/baselinefinetune_cosine_mini_imagenet_fiveshot.gin:
--------------------------------------------------------------------------------
1 | include 'meta_dataset/learn/gin/setups/mini_imagenet_five_way_five_shot.gin'
2 | include 'meta_dataset/learn/gin/learners/baselinefinetune_cosine_config.gin'
3 |
4 | Learner.embedding_fn = @wide_resnet
5 | weight_decay = 0
6 |
7 | DataConfig.image_height = 84
8 | Trainer.pretrained_source = 'scratch'
9 |
10 | BaselineFinetuneLearner.num_finetune_steps = 200
11 | BaselineFinetuneLearner.finetune_lr = 0.01
12 | BaselineFinetuneLearner.finetune_with_adam = False
13 | BaselineLearner.use_weight_norm = False
14 | BaselineFinetuneLearner.finetune_all_layers = False
15 | BaselineLearner.cosine_logits_multiplier = 10
16 |
--------------------------------------------------------------------------------
/meta_dataset/learn/gin/best/baselinefinetune_cosine_mini_imagenet_oneshot.gin:
--------------------------------------------------------------------------------
1 | include 'meta_dataset/learn/gin/setups/mini_imagenet_five_way_five_shot.gin'
2 | EpisodeDescriptionConfig.num_support = 1 # Change 5-shot to 1-shot.
3 | include 'meta_dataset/learn/gin/learners/baselinefinetune_cosine_config.gin'
4 |
5 | Learner.embedding_fn = @wide_resnet
6 | weight_decay = 0
7 |
8 | DataConfig.image_height = 84
9 | Trainer.pretrained_source = 'scratch'
10 |
11 | BaselineFinetuneLearner.num_finetune_steps = 200
12 | BaselineFinetuneLearner.finetune_lr = 0.01
13 | BaselineFinetuneLearner.finetune_with_adam = False
14 | BaselineLearner.use_weight_norm = False
15 | BaselineFinetuneLearner.finetune_all_layers = False
16 | BaselineLearner.cosine_logits_multiplier = 10
17 |
--------------------------------------------------------------------------------
/meta_dataset/learn/gin/best/baselinefinetune_imagenet.gin:
--------------------------------------------------------------------------------
1 | include 'meta_dataset/learn/gin/setups/imagenet.gin'
2 | include 'meta_dataset/learn/gin/learners/baselinefinetune_config.gin'
3 |
4 | # Backbone hypers.
5 | include 'meta_dataset/learn/gin/best_v2/pretrained_wide_resnet.gin'
6 |
7 | # Model hypers.
8 | BaselineLearner.cosine_classifier = False
9 | BaselineLearner.use_weight_norm = True
10 | BaselineLearner.cosine_logits_multiplier = 1
11 | BaselineFinetuneLearner.num_finetune_steps = 200
12 | BaselineFinetuneLearner.finetune_lr = 0.01
13 | BaselineFinetuneLearner.finetune_all_layers = True
14 | BaselineFinetuneLearner.finetune_with_adam = True
15 |
16 | # Data hypers.
17 | DataConfig.image_height = 84
18 |
19 | # Training hypers (not needed for eval).
20 | Trainer.decay_every = 5000
21 | Trainer.decay_rate = 0.5559080744371039
22 | Trainer.learning_rate = 0.0027015533546616804
23 | weight_decay = 0.00002266979856832968
24 |
--------------------------------------------------------------------------------
/meta_dataset/learn/gin/best/baselinefinetune_mini_imagenet_fiveshot.gin:
--------------------------------------------------------------------------------
1 | include 'meta_dataset/learn/gin/setups/mini_imagenet.gin'
2 | EpisodeDescriptionConfig.num_support = 5
3 |
4 | include 'meta_dataset/learn/gin/learners/baselinefinetune_config.gin'
5 |
6 | Learner.embedding_fn = @wide_resnet
7 | weight_decay = 0
8 |
9 | DataConfig.image_height = 84
10 | Trainer.pretrained_source = 'scratch'
11 |
12 | BaselineFinetuneLearner.num_finetune_steps = 200
13 | BaselineFinetuneLearner.finetune_lr = 0.01
14 | BaselineFinetuneLearner.finetune_with_adam = False
15 | BaselineLearner.use_weight_norm = False
16 | BaselineFinetuneLearner.finetune_all_layers = False
17 |
--------------------------------------------------------------------------------
/meta_dataset/learn/gin/best/baselinefinetune_mini_imagenet_oneshot.gin:
--------------------------------------------------------------------------------
1 | include 'meta_dataset/learn/gin/setups/mini_imagenet_five_way_five_shot.gin'
2 | EpisodeDescriptionConfig.num_support = 1 # Change 5-shot to 1-shot.
3 | include 'meta_dataset/learn/gin/learners/baselinefinetune_config.gin'
4 |
5 | Learner.embedding_fn = @wide_resnet
6 | weight_decay = 0
7 |
8 | DataConfig.image_height = 84
9 | Trainer.pretrained_source = 'scratch'
10 |
11 | BaselineFinetuneLearner.num_finetune_steps = 200
12 | BaselineFinetuneLearner.finetune_lr = 0.01
13 | BaselineFinetuneLearner.finetune_with_adam = False
14 | BaselineLearner.use_weight_norm = False
15 | BaselineFinetuneLearner.finetune_all_layers = False
16 |
--------------------------------------------------------------------------------
/meta_dataset/learn/gin/best/flute.gin:
--------------------------------------------------------------------------------
1 | include 'meta_dataset/learn/gin/setups/all.gin'
2 | include 'meta_dataset/learn/gin/learners/learner_config.gin'
3 | include 'meta_dataset/learn/gin/setups/trainer_config_flute.gin'
4 | include 'meta_dataset/learn/gin/setups/data_config_flute.gin'
5 |
6 | # Learner to use for evaluation (the `train_learner_class` isn't used now).
7 | Trainer_flute.train_learner_class = @DatasetConditionalBaselineLearner
8 | Trainer_flute.eval_learner_class = @FLUTEFiLMLearner
9 |
10 | # The path to the dataset classifier checkpoint to restore (this is required if
11 | # using the `blender` or `hard blender` as the `film_init` heuristic).
12 | Trainer_flute.dataset_classifier_to_restore = "path/to/trained/dataset/classifier"
13 |
14 | # FLUTE FiLM Learner settings.
15 | FLUTEFiLMLearner.film_init = 'blender'
16 | FLUTEFiLMLearner.num_steps = 6
17 | FLUTEFiLMLearner.lr = 0.005
18 |
19 | # Backbone settings.
20 | FLUTEFiLMLearner.embedding_fn = @flute_resnet
21 | bn_wrapper.batch_norm_fn = @bn_flute_eval
22 | bn_wrapper.num_film_sets = %num_film_sets
23 | dataset_classifier.weight_decay = %weight_decay
24 | dataset_classifier.num_datasets = %num_film_sets
25 | flute_resnet.weight_decay = %weight_decay
26 | weight_decay = 0.
27 | num_film_sets = 8
28 |
--------------------------------------------------------------------------------
/meta_dataset/learn/gin/best/flute_init_from_imagenet.gin:
--------------------------------------------------------------------------------
1 | include 'meta_dataset/learn/gin/best_v2/flute.gin'
2 |
3 | Trainer_flute.dataset_classifier_to_restore = ""
4 | FLUTEFiLMLearner.film_init = 'imagenet'
5 | FLUTEFiLMLearner.num_steps = 2
6 | FLUTEFiLMLearner.lr = 0.001
7 |
--------------------------------------------------------------------------------
/meta_dataset/learn/gin/best/flute_init_from_scratch.gin:
--------------------------------------------------------------------------------
1 | include 'meta_dataset/learn/gin/best_v2/flute.gin'
2 |
3 | Trainer_flute.dataset_classifier_to_restore = ""
4 | FLUTEFiLMLearner.film_init = 'scratch'
5 | FLUTEFiLMLearner.num_steps = 4
6 | FLUTEFiLMLearner.lr = 0.005
7 |
--------------------------------------------------------------------------------
/meta_dataset/learn/gin/best/maml_all.gin:
--------------------------------------------------------------------------------
1 | include 'meta_dataset/learn/gin/setups/all.gin'
2 | include 'meta_dataset/learn/gin/learners/maml_config.gin'
3 |
4 | # Backbone hypers.
5 | include 'meta_dataset/learn/gin/best_v2/pretrained_resnet.gin'
6 |
7 | # Model hypers. (new results)
8 | MAMLLearner.first_order = True
9 | MAMLLearner.alpha = 0.09915203576061224
10 | MAMLLearner.additional_evaluation_update_steps = 5
11 | MAMLLearner.num_update_steps = 10
12 |
13 | # Data hypers.
14 | DataConfig.image_height = 126
15 |
16 | # Training hypers (not needed for eval).
17 | Trainer.decay_learning_rate = True
18 | Trainer.decay_every = 10000
19 | Trainer.decay_rate = 0.8354967819706964
20 | Trainer.learning_rate = 0.00023702989294963628
21 | weight_decay = 0.00030572401935585053
22 |
--------------------------------------------------------------------------------
/meta_dataset/learn/gin/best/maml_all_from_scratch.gin:
--------------------------------------------------------------------------------
1 | include 'meta_dataset/learn/gin/setups/all.gin'
2 | include 'meta_dataset/learn/gin/learners/maml_config.gin'
3 |
4 | # Backbone hypers.
5 | Learner.embedding_fn = @resnet
6 |
7 | Trainer.pretrained_source = 'scratch'
8 | Trainer.checkpoint_to_restore = ''
9 |
10 | # Model hypers.
11 | MAMLLearner.first_order = True
12 | MAMLLearner.alpha = 0.01406233807768908
13 | MAMLLearner.additional_evaluation_update_steps = 5
14 | MAMLLearner.num_update_steps = 6
15 |
16 | # Data hypers.
17 | DataConfig.image_height = 126
18 |
19 | # Training hypers (not needed for eval).
20 | Trainer.decay_learning_rate = True
21 | Trainer.decay_every = 10000
22 | Trainer.decay_rate = 0.5575524279100452
23 | Trainer.learning_rate = 0.0004067024273450244
24 | weight_decay = 0.0000033242172150107816
25 |
--------------------------------------------------------------------------------
/meta_dataset/learn/gin/best/maml_imagenet.gin:
--------------------------------------------------------------------------------
1 | include 'meta_dataset/learn/gin/setups/imagenet.gin'
2 | include 'meta_dataset/learn/gin/learners/maml_config.gin'
3 |
4 | # Backbone hypers.
5 | include 'meta_dataset/learn/gin/best_v2/pretrained_resnet.gin'
6 |
7 | # Model hypers.
8 | MAMLLearner.first_order = True
9 | MAMLLearner.alpha = 0.09915203576061224
10 | MAMLLearner.additional_evaluation_update_steps = 5
11 | MAMLLearner.num_update_steps = 10
12 |
13 | # Data hypers.
14 | DataConfig.image_height = 126
15 |
16 | # Training hypers (not needed for eval).
17 | Trainer.decay_learning_rate = True
18 | Trainer.decay_every = 10000
19 | Trainer.decay_rate = 0.8354967819706964
20 | Trainer.learning_rate = 0.0003601303690709516
21 | weight_decay = 0.00030572401935585053
22 |
--------------------------------------------------------------------------------
/meta_dataset/learn/gin/best/maml_imagenet_from_scratch.gin:
--------------------------------------------------------------------------------
1 | include 'meta_dataset/learn/gin/setups/imagenet.gin'
2 | include 'meta_dataset/learn/gin/learners/maml_config.gin'
3 |
4 | # Backbone hypers.
5 | Learner.embedding_fn = @four_layer_convnet
6 |
7 | Trainer.pretrained_source = 'scratch'
8 | Trainer.checkpoint_to_restore = ''
9 |
10 | # Model hypers.
11 | MAMLLearner.first_order = True
12 | MAMLLearner.alpha = 0.05602373320602529
13 | MAMLLearner.additional_evaluation_update_steps = 5
14 | MAMLLearner.num_update_steps = 10
15 |
16 | # Data hypers.
17 | DataConfig.image_height = 84
18 |
19 | # Training hypers (not needed for eval).
20 | Trainer.decay_learning_rate = True
21 | Trainer.decay_every = 5000
22 | Trainer.decay_rate = 0.9153310958769965
23 | Trainer.learning_rate = 0.00012385665511953882
24 | weight_decay = 0.0006858703351826797
25 |
--------------------------------------------------------------------------------
/meta_dataset/learn/gin/best/maml_init_with_proto_all.gin:
--------------------------------------------------------------------------------
1 | include 'meta_dataset/learn/gin/setups/all.gin'
2 | include 'meta_dataset/learn/gin/learners/maml_init_with_proto_config.gin'
3 |
4 | # Backbone hypers.
5 | include 'meta_dataset/learn/gin/best_v2/pretrained_resnet.gin'
6 |
7 | # Model hypers.
8 | MAMLLearner.first_order = True
9 | MAMLLearner.alpha = 0.01988024603751885
10 | MAMLLearner.additional_evaluation_update_steps = 5
11 | MAMLLearner.num_update_steps = 1
12 |
13 | # Data hypers.
14 | DataConfig.image_height = 126
15 |
16 | # Training hypers (not needed for eval).
17 | Trainer.decay_learning_rate = True
18 | Trainer.decay_every = 500
19 | Trainer.decay_rate = 0.7659662681067452
20 | Trainer.learning_rate = 0.00029809808680211807
21 | weight_decay = 0.0007159940362690606
22 |
--------------------------------------------------------------------------------
/meta_dataset/learn/gin/best/maml_init_with_proto_all_from_scratch.gin:
--------------------------------------------------------------------------------
1 | include 'meta_dataset/learn/gin/setups/all.gin'
2 | include 'meta_dataset/learn/gin/learners/maml_init_with_proto_config.gin'
3 |
4 | # Backbone hypers.
5 | Learner.embedding_fn = @resnet
6 |
7 | Trainer.pretrained_source = 'scratch'
8 | Trainer.checkpoint_to_restore = ''
9 |
10 | # Model hypers.
11 | MAMLLearner.first_order = True
12 | MAMLLearner.alpha = 0.01139772733084185
13 | MAMLLearner.additional_evaluation_update_steps = 5
14 | MAMLLearner.num_update_steps = 1
15 |
16 | # Data hypers.
17 | DataConfig.image_height = 126
18 |
19 | # Training hypers (not needed for eval).
20 | Trainer.decay_learning_rate = True
21 | Trainer.decay_every = 10000
22 | Trainer.decay_rate = 0.6731268939317062
23 | Trainer.learning_rate = 0.0007252750256444753
24 | weight_decay = 0.00005323094246436839
25 |
--------------------------------------------------------------------------------
/meta_dataset/learn/gin/best/maml_init_with_proto_imagenet.gin:
--------------------------------------------------------------------------------
1 | include 'meta_dataset/learn/gin/setups/imagenet.gin'
2 | include 'meta_dataset/learn/gin/learners/maml_init_with_proto_config.gin'
3 |
4 | # Backbone hypers.
5 | include 'meta_dataset/learn/gin/best_v2/pretrained_resnet.gin'
6 |
7 | # Model hypers.
8 | MAMLLearner.first_order = True
9 | MAMLLearner.alpha = 0.005435022808033229
10 | MAMLLearner.additional_evaluation_update_steps = 0
11 | MAMLLearner.num_update_steps = 6
12 |
13 | # Data hypers.
14 | DataConfig.image_height = 126
15 |
16 | # Training hypers (not needed for eval).
17 | Trainer.decay_learning_rate = True
18 | Trainer.decay_every = 2500
19 | Trainer.decay_rate = 0.6477898086638092
20 | Trainer.learning_rate = 0.00036339913514891586
21 | weight_decay = 0.000013656044331235537
22 |
--------------------------------------------------------------------------------
/meta_dataset/learn/gin/best/maml_init_with_proto_imagenet_from_scratch.gin:
--------------------------------------------------------------------------------
1 | include 'meta_dataset/learn/gin/setups/imagenet.gin'
2 | include 'meta_dataset/learn/gin/learners/maml_init_with_proto_config.gin'
3 |
4 | # Backbone hypers.
5 |
6 | Learner.embedding_fn = @resnet
7 |
8 | Trainer.pretrained_source = 'scratch'
9 | Trainer.checkpoint_to_restore = ''
10 |
11 | # Model hypers.
12 | MAMLLearner.first_order = True
13 | MAMLLearner.alpha = 0.31106175977182243
14 | MAMLLearner.additional_evaluation_update_steps = 5
15 | MAMLLearner.num_update_steps = 6
16 |
17 | # Data hypers.
18 | DataConfig.image_height = 126
19 |
20 | # Training hypers (not needed for eval).
21 | Trainer.decay_learning_rate = True
22 | Trainer.decay_every = 5000
23 | Trainer.decay_rate = 0.6431136271727287
24 | Trainer.learning_rate = 0.0007181155997029211
25 | weight_decay = 0.00003630199690303937
26 |
--------------------------------------------------------------------------------
/meta_dataset/learn/gin/best/maml_init_with_proto_inference_all.gin:
--------------------------------------------------------------------------------
1 | include 'meta_dataset/learn/gin/setups/all.gin'
2 | include 'meta_dataset/learn/gin/learners/maml_init_with_proto_config.gin'
3 |
4 | # Backbone hypers.
5 | include 'meta_dataset/learn/gin/best_v2/pretrained_resnet.gin'
6 |
7 | # Model hypers.
8 | MAMLLearner.first_order = True
9 | MAMLLearner.alpha = 0.061445396411177286
10 | MAMLLearner.additional_evaluation_update_steps = 0
11 | MAMLLearner.num_update_steps = 1
12 |
13 | # Data hypers.
14 | DataConfig.image_height = 126
15 |
16 | # Training hypers (not needed for eval).
17 | Trainer.decay_learning_rate = True
18 | Trainer.decay_every = 5000
19 | Trainer.decay_rate = 0.8449806376151254
20 | Trainer.learning_rate = 0.0008118174317224871
21 | weight_decay = 0.0000038819970639091496
22 |
23 | # Baseline hypers (just for the record).
24 | BaselineLearner.cosine_logits_multiplier = 1
25 | BaselineLearner.use_weight_norm = True
26 | BaselineLearner.knn_distance = 'l2'
27 | BaselineLearner.cosine_classifier = False
28 |
--------------------------------------------------------------------------------
/meta_dataset/learn/gin/best/maml_init_with_proto_inference_imagenet.gin:
--------------------------------------------------------------------------------
1 | include 'meta_dataset/learn/gin/setups/imagenet.gin'
2 | include 'meta_dataset/learn/gin/learners/maml_init_with_proto_config.gin'
3 |
4 | # Backbone hypers.
5 | Learner.embedding_fn = @resnet
6 |
7 | Trainer.pretrained_source = 'scratch'
8 | Trainer.checkpoint_to_restore = ''
9 |
10 | # Model hypers.
11 | MAMLLearner.first_order = True
12 | MAMLLearner.alpha = 0.00577257710776957
13 | MAMLLearner.additional_evaluation_update_steps = 0
14 | MAMLLearner.num_update_steps = 1
15 |
16 | # Data hypers.
17 | DataConfig.image_height = 126
18 |
19 | # Training hypers (not needed for eval).
20 | Trainer.decay_learning_rate = True
21 | Trainer.decay_every = 500
22 | Trainer.decay_rate = 0.885662482266546
23 | Trainer.learning_rate = 0.00025036275525430426
24 | weight_decay = 0.003952085241872012
25 |
26 | # Baseline hypers (just for the record).
27 | BaselineLearner.cosine_logits_multiplier = 10
28 | BaselineLearner.use_weight_norm = True
29 | BaselineLearner.knn_distance = 'l2'
30 | BaselineLearner.cosine_classifier = True
31 |
--------------------------------------------------------------------------------
/meta_dataset/learn/gin/best/matching_all.gin:
--------------------------------------------------------------------------------
1 | include 'meta_dataset/learn/gin/setups/all.gin'
2 | include 'meta_dataset/learn/gin/learners/matching_config.gin'
3 |
4 | # Backbone hypers.
5 | include 'meta_dataset/learn/gin/best_v2/pretrained_convnet.gin'
6 |
7 | # Data hypers.
8 | DataConfig.image_height = 84
9 |
10 | # Training hypers (not needed for eval).
11 | Trainer.decay_learning_rate = True
12 | Trainer.decay_every = 500
13 | Trainer.decay_rate = 0.5937020776489796
14 | Trainer.learning_rate = 0.005052178216688174
15 | weight_decay = 0.00000392173790384195
16 |
--------------------------------------------------------------------------------
/meta_dataset/learn/gin/best/matching_all_from_scratch.gin:
--------------------------------------------------------------------------------
1 | include 'meta_dataset/learn/gin/setups/all.gin'
2 | include 'meta_dataset/learn/gin/learners/matching_config.gin'
3 |
4 | # Backbone hypers.
5 | Learner.embedding_fn = @four_layer_convnet
6 |
7 | Trainer.pretrained_source = 'scratch'
8 | Trainer.checkpoint_to_restore = ''
9 |
10 | # Data hypers.
11 | DataConfig.image_height = 126
12 |
13 | # Training hypers (not needed for eval).
14 | Trainer.decay_learning_rate = True
15 | Trainer.decay_every = 2500
16 | Trainer.decay_rate = 0.9197652570309498
17 | Trainer.learning_rate = 0.002748105034397538
18 | weight_decay = 0.0000013254789476822292
19 |
--------------------------------------------------------------------------------
/meta_dataset/learn/gin/best/matching_imagenet.gin:
--------------------------------------------------------------------------------
1 | include 'meta_dataset/learn/gin/setups/imagenet.gin'
2 | include 'meta_dataset/learn/gin/learners/matching_config.gin'
3 |
4 | # Backbone hypers.
5 | include 'meta_dataset/learn/gin/best_v2/pretrained_resnet.gin'
6 |
7 | # Data hypers.
8 | DataConfig.image_height = 126
9 |
10 | # Training hypers (not needed for eval).
11 | Trainer.decay_learning_rate = True
12 | Trainer.decay_every = 500
13 | Trainer.decay_rate = 0.915193802145601
14 | Trainer.learning_rate = 0.0012064626897259694
15 | weight_decay = 0.0000885787420909229
16 |
--------------------------------------------------------------------------------
/meta_dataset/learn/gin/best/matching_imagenet_from_scratch.gin:
--------------------------------------------------------------------------------
1 | include 'meta_dataset/learn/gin/setups/imagenet.gin'
2 | include 'meta_dataset/learn/gin/learners/matching_config.gin'
3 |
4 | # Backbone hypers.
5 | Learner.embedding_fn = @resnet
6 |
7 | Trainer.pretrained_source = 'scratch'
8 | Trainer.checkpoint_to_restore = ''
9 |
10 | # Data hypers.
11 | DataConfig.image_height = 126
12 |
13 | # Training hypers (not needed for eval).
14 | Trainer.decay_learning_rate = True
15 | Trainer.decay_every = 2500
16 | Trainer.decay_rate = 0.8333411536286996
17 | Trainer.learning_rate = 0.0006229660387662655
18 | weight_decay = 0.00018036259587809225
19 |
--------------------------------------------------------------------------------
/meta_dataset/learn/gin/best/matching_inference_all.gin:
--------------------------------------------------------------------------------
1 | include 'meta_dataset/learn/gin/setups/all.gin'
2 | include 'meta_dataset/learn/gin/learners/matching_config.gin'
3 |
4 | # Backbone hypers.
5 | include 'meta_dataset/learn/gin/best_v2/pretrained_resnet.gin'
6 |
7 | # Data hypers.
8 | DataConfig.image_height = 84
9 |
10 | # Training hypers (not needed for eval).
11 | Trainer.decay_learning_rate = True
12 | Trainer.decay_every = 1000
13 | Trainer.decay_rate = 0.5568532448398684
14 | Trainer.learning_rate = 0.0009946451213635535
15 | weight_decay = 0.000006779214221539446
16 |
17 | # Baseline hypers (just for the record).
18 | BaselineLearner.cosine_logits_multiplier = 100
19 | BaselineLearner.use_weight_norm = False
20 | BaselineLearner.knn_distance = 'cosine'
21 | BaselineLearner.cosine_classifier = False
22 |
--------------------------------------------------------------------------------
/meta_dataset/learn/gin/best/matching_inference_imagenet.gin:
--------------------------------------------------------------------------------
1 | include 'meta_dataset/learn/gin/setups/imagenet.gin'
2 | include 'meta_dataset/learn/gin/learners/matching_config.gin'
3 |
4 | # Backbone hypers.
5 | include 'meta_dataset/learn/gin/best_v2/pretrained_resnet.gin'
6 |
7 | # Data hypers.
8 | DataConfig.image_height = 84
9 |
10 | # Training hypers (not needed for eval).
11 | Trainer.decay_learning_rate = True
12 | Trainer.decay_every = 10000
13 | Trainer.decay_rate = 0.9714872833376629
14 | Trainer.learning_rate = 0.002670412664551181
15 | weight_decay = 0.000013646011503835137
16 |
17 | # Baseline hypers (just for the record).
18 | BaselineLearner.cosine_logits_multiplier = 1
19 | BaselineLearner.use_weight_norm = True
20 | BaselineLearner.knn_distance = 'cosine'
21 | BaselineLearner.cosine_classifier = False
22 |
--------------------------------------------------------------------------------
/meta_dataset/learn/gin/best/pretrain_imagenet_convnet.gin:
--------------------------------------------------------------------------------
1 | include 'meta_dataset/learn/gin/setups/imagenet.gin'
2 | include 'meta_dataset/learn/gin/learners/baseline_config.gin'
3 |
4 | # Backbone hypers.
5 | Learner.embedding_fn = @four_layer_convnet
6 |
7 | Trainer.checkpoint_to_restore = ''
8 | Trainer.pretrained_source = 'scratch'
9 |
10 | # Model hypers.
11 | BaselineLearner.knn_distance = 'cosine'
12 | BaselineLearner.cosine_classifier = True
13 | BaselineLearner.cosine_logits_multiplier = 1
14 | BaselineLearner.use_weight_norm = True
15 |
16 | # Data hypers.
17 | DataConfig.image_height = 84
18 |
19 | # Training hypers (not needed for eval).
20 | Trainer.num_updates = 50000
21 | Trainer.decay_every = 1000
22 | Trainer.decay_rate = 0.9105573818947892
23 | Trainer.learning_rate = 0.008644114436633987
24 | weight_decay = 0.000005171477829794739
25 |
--------------------------------------------------------------------------------
/meta_dataset/learn/gin/best/pretrain_imagenet_resnet.gin:
--------------------------------------------------------------------------------
1 | include 'meta_dataset/learn/gin/setups/imagenet.gin'
2 | include 'meta_dataset/learn/gin/learners/baseline_config.gin'
3 |
4 | # Backbone hypers.
5 | Learner.embedding_fn = @resnet
6 |
7 | Trainer.checkpoint_to_restore = ''
8 | Trainer.pretrained_source = 'scratch'
9 |
10 | # Model hypers.
11 | BaselineLearner.knn_distance = 'cosine'
12 | BaselineLearner.cosine_classifier = False
13 | BaselineLearner.cosine_logits_multiplier = 2
14 | BaselineLearner.use_weight_norm = False
15 |
16 | # Data hypers.
17 | DataConfig.image_height = 126
18 |
19 | # Training hypers (not needed for eval).
20 | Trainer.num_updates = 50000
21 | Trainer.decay_every = 1000
22 | Trainer.decay_rate = 0.9967524905880909
23 | Trainer.learning_rate = 0.00375640851370052
24 | weight_decay = 0.00002628042826116842
25 |
--------------------------------------------------------------------------------
/meta_dataset/learn/gin/best/pretrain_imagenet_wide_resnet.gin:
--------------------------------------------------------------------------------
1 | include 'meta_dataset/learn/gin/setups/imagenet.gin'
2 | include 'meta_dataset/learn/gin/learners/baseline_config.gin'
3 |
4 | # Backbone hypers.
5 | Learner.embedding_fn = @wide_resnet
6 |
7 | Trainer.checkpoint_to_restore = ''
8 | Trainer.pretrained_source = 'scratch'
9 |
10 | # Model hypers.
11 | BaselineLearner.knn_distance = 'cosine'
12 | BaselineLearner.cosine_classifier = True
13 | BaselineLearner.cosine_logits_multiplier = 1
14 | BaselineLearner.use_weight_norm = True
15 |
16 | # Data hypers.
17 | DataConfig.image_height = 126
18 |
19 | # Training hypers (not needed for eval).
20 | Trainer.num_updates = 50000
21 | Trainer.decay_every = 100
22 | Trainer.decay_rate = 0.5082121576573064
23 | Trainer.learning_rate = 0.007084776688116927
24 | weight_decay = 0.000005078192976503067
25 |
--------------------------------------------------------------------------------
/meta_dataset/learn/gin/best/pretrained_convnet.gin:
--------------------------------------------------------------------------------
1 | # Best pre-trained convnet. Update the path to the checkpoint.
2 | Learner.embedding_fn = @four_layer_convnet
3 |
4 | Trainer.checkpoint_to_restore = '/path/to/checkpoints/pretrain_imagenet_convnet/model_42500.ckpt'
5 | Trainer.pretrained_source = 'imagenet'
6 |
--------------------------------------------------------------------------------
/meta_dataset/learn/gin/best/pretrained_resnet.gin:
--------------------------------------------------------------------------------
1 | # Best pre-trained resnet. Update the path to the checkpoint.
2 | Learner.embedding_fn = @resnet
3 |
4 | Trainer.checkpoint_to_restore = '/path/to/checkpoints/pretrain_imagenet_resnet/model_27500.ckpt'
5 | Trainer.pretrained_source = 'imagenet'
6 |
--------------------------------------------------------------------------------
/meta_dataset/learn/gin/best/pretrained_wide_resnet.gin:
--------------------------------------------------------------------------------
1 | # Best pre-trained wide resnet. Update the path to the checkpoint.
2 | Learner.embedding_fn = @wide_resnet
3 |
4 | Trainer.checkpoint_to_restore = '/path/to/checkpoints/pretrain_imagenet_wide_resnet/model_46000.ckpt'
5 | Trainer.pretrained_source = 'imagenet'
6 |
--------------------------------------------------------------------------------
/meta_dataset/learn/gin/best/prototypical_all.gin:
--------------------------------------------------------------------------------
1 | include 'meta_dataset/learn/gin/setups/all.gin'
2 | include 'meta_dataset/learn/gin/learners/prototypical_config.gin'
3 |
4 | # Backbone hypers.
5 | include 'meta_dataset/learn/gin/best_v2/pretrained_resnet.gin'
6 |
7 | # Data hypers.
8 | DataConfig.image_height = 126
9 |
10 | # Training hypers (not needed for eval).
11 | Trainer.decay_learning_rate = True
12 | Trainer.decay_every = 100
13 | Trainer.decay_rate = 0.5475646651850319
14 | Trainer.learning_rate = 0.0018231563858704268
15 | weight_decay = 0.0074472606534577565
16 |
--------------------------------------------------------------------------------
/meta_dataset/learn/gin/best/prototypical_all_from_scratch.gin:
--------------------------------------------------------------------------------
1 | include 'meta_dataset/learn/gin/setups/all.gin'
2 | include 'meta_dataset/learn/gin/learners/prototypical_config.gin'
3 |
4 | # Backbone hypers.
5 |
6 | Learner.embedding_fn = @resnet
7 |
8 | Trainer.pretrained_source = 'scratch'
9 | Trainer.checkpoint_to_restore = ''
10 |
11 | # Data hypers.
12 | DataConfig.image_height = 126
13 |
14 | # Training hypers (not needed for eval).
15 | Trainer.decay_learning_rate = True
16 | Trainer.decay_every = 2500
17 | Trainer.decay_rate = 0.8333411536286996
18 | Trainer.learning_rate = 0.0006229660387662655
19 | weight_decay = 0.00018036259587809225
20 |
--------------------------------------------------------------------------------
/meta_dataset/learn/gin/best/prototypical_imagenet.gin:
--------------------------------------------------------------------------------
1 | include 'meta_dataset/learn/gin/setups/imagenet.gin'
2 | include 'meta_dataset/learn/gin/learners/prototypical_config.gin'
3 |
4 | # Backbone hypers.
5 | include 'meta_dataset/learn/gin/best_v2/pretrained_resnet.gin'
6 |
7 | # Data hypers.
8 | DataConfig.image_height = 126
9 |
10 | # Training hypers (not needed for eval).
11 | Trainer.decay_learning_rate = True
12 | Trainer.decay_every = 500
13 | Trainer.decay_rate = 0.915193802145601
14 | Trainer.learning_rate = 0.0012064626897259694
15 | weight_decay = 0.0000885787420909229
16 |
--------------------------------------------------------------------------------
/meta_dataset/learn/gin/best/prototypical_imagenet_from_scratch.gin:
--------------------------------------------------------------------------------
1 | include 'meta_dataset/learn/gin/setups/imagenet.gin'
2 | include 'meta_dataset/learn/gin/learners/prototypical_config.gin'
3 |
4 | # Backbone hypers.
5 | Learner.embedding_fn = @resnet
6 |
7 | Trainer.pretrained_source = 'scratch'
8 | Trainer.checkpoint_to_restore = ''
9 |
10 | # Data hypers.
11 | DataConfig.image_height = 126
12 |
13 | # Training hypers (not needed for eval).
14 | Trainer.decay_learning_rate = True
15 | Trainer.decay_every = 2500
16 | Trainer.decay_rate = 0.8333411536286996
17 | Trainer.learning_rate = 0.0006229660387662655
18 | weight_decay = 0.00018036259587809225
19 |
--------------------------------------------------------------------------------
/meta_dataset/learn/gin/best/prototypical_inference_all.gin:
--------------------------------------------------------------------------------
1 | include 'meta_dataset/learn/gin/setups/all.gin'
2 | include 'meta_dataset/learn/gin/learners/prototypical_config.gin'
3 |
4 | # Backbone hypers.
5 | include 'meta_dataset/learn/gin/best_v2/pretrained_wide_resnet.gin'
6 |
7 | # Data hypers.
8 | DataConfig.image_height = 126
9 |
10 | # Training hypers (not needed for eval).
11 | Trainer.decay_learning_rate = True
12 | Trainer.decay_every = 100
13 | Trainer.decay_rate = 0.7426809516243701
14 | Trainer.learning_rate = 0.0009325756201058525
15 | weight_decay = 0.00003386806355382518
16 |
17 | # Baseline hypers (just for the record).
18 | BaselineLearner.cosine_logits_multiplier = 2
19 | BaselineLearner.use_weight_norm = True
20 | BaselineLearner.knn_distance = 'l2'
21 | BaselineLearner.cosine_classifier = False
22 |
--------------------------------------------------------------------------------
/meta_dataset/learn/gin/best/prototypical_inference_imagenet.gin:
--------------------------------------------------------------------------------
1 | include 'meta_dataset/learn/gin/setups/imagenet.gin'
2 | include 'meta_dataset/learn/gin/learners/prototypical_config.gin'
3 |
4 | # Backbone hypers.
5 | include 'meta_dataset/learn/gin/best_v2/pretrained_wide_resnet.gin'
6 |
7 | # Data hypers.
8 | DataConfig.image_height = 126
9 |
10 | # Training hypers (not needed for eval).
11 | Trainer.decay_learning_rate = True
12 | Trainer.decay_every = 100
13 | Trainer.decay_rate = 0.7426809516243701
14 | Trainer.learning_rate = 0.0009325756201058525
15 | weight_decay = 0.00003386806355382518
16 |
17 | # Baseline hypers (just for the record).
18 | BaselineLearner.cosine_logits_multiplier = 2
19 | BaselineLearner.use_weight_norm = True
20 | BaselineLearner.knn_distance = 'l2'
21 | BaselineLearner.cosine_classifier = False
22 |
--------------------------------------------------------------------------------
/meta_dataset/learn/gin/best/relationnet_all.gin:
--------------------------------------------------------------------------------
1 | include 'meta_dataset/learn/gin/setups/all.gin'
2 | include 'meta_dataset/learn/gin/learners/relationnet_config.gin'
3 |
4 | # Backbone hypers.
5 | include 'meta_dataset/learn/gin/best_v2/pretrained_convnet.gin'
6 |
7 | # Data hypers.
8 | DataConfig.image_height = 126
9 |
10 | # Training hypers (not needed for eval).
11 | Trainer.decay_learning_rate = True
12 | Trainer.decay_every = 5000
13 | Trainer.decay_rate = 0.8707355191010226
14 | Trainer.learning_rate = 0.000906783323100297
15 | weight_decay = 0.0000617576791048944
16 |
--------------------------------------------------------------------------------
/meta_dataset/learn/gin/best/relationnet_all_from_scratch.gin:
--------------------------------------------------------------------------------
1 | include 'meta_dataset/learn/gin/setups/all.gin'
2 | include 'meta_dataset/learn/gin/learners/relationnet_config.gin'
3 |
4 | # Backbone hypers.
5 | Learner.embedding_fn = @four_layer_convnet
6 |
7 | Trainer.pretrained_source = 'scratch'
8 | Trainer.checkpoint_to_restore = ''
9 |
10 | # Data hypers.
11 | DataConfig.image_height = 126
12 |
13 | # Training hypers (not needed for eval).
14 | Trainer.decay_learning_rate = True
15 | Trainer.decay_every = 2500
16 | Trainer.decay_rate = 0.9197652570309498
17 | Trainer.learning_rate = 0.002748105034397538
18 | weight_decay = 0.0000013254789476822292
19 |
--------------------------------------------------------------------------------
/meta_dataset/learn/gin/best/relationnet_imagenet.gin:
--------------------------------------------------------------------------------
1 | include 'meta_dataset/learn/gin/setups/imagenet.gin'
2 | include 'meta_dataset/learn/gin/learners/relationnet_config.gin'
3 |
4 | # Backbone hypers.
5 | include 'meta_dataset/learn/gin/best_v2/pretrained_convnet.gin'
6 |
7 | # Data hypers.
8 | DataConfig.image_height = 126
9 |
10 | # Training hypers (not needed for eval).
11 | Trainer.decay_learning_rate = True
12 | Trainer.decay_every = 5000
13 | Trainer.decay_rate = 0.8707355191010226
14 | Trainer.learning_rate = 0.000906783323100297
15 | weight_decay = 0.0000617576791048944
16 |
--------------------------------------------------------------------------------
/meta_dataset/learn/gin/best/relationnet_imagenet_from_scratch.gin:
--------------------------------------------------------------------------------
1 | include 'meta_dataset/learn/gin/setups/imagenet.gin'
2 | include 'meta_dataset/learn/gin/learners/relationnet_config.gin'
3 |
4 | # Backbone hypers.
5 | Learner.embedding_fn = @four_layer_convnet
6 |
7 | Trainer.pretrained_source = 'scratch'
8 | Trainer.checkpoint_to_restore = ''
9 |
10 | # Data hypers.
11 | DataConfig.image_height = 126
12 |
13 | # Training hypers (not needed for eval).
14 | Trainer.decay_learning_rate = True
15 | Trainer.decay_every = 2500
16 | Trainer.decay_rate = 0.9197652570309498
17 | Trainer.learning_rate = 0.002748105034397538
18 | weight_decay = 0.0000013254789476822292
19 |
--------------------------------------------------------------------------------
/meta_dataset/learn/gin/best/relationnet_mini_imagenet_fiveshot.gin:
--------------------------------------------------------------------------------
1 | include 'meta_dataset/learn/gin/setups/mini_imagenet.gin'
2 | include 'meta_dataset/learn/gin/learners/relationnet_config.gin'
3 |
4 | ## Hyperparameters used to train the best model
5 | Trainer.checkpoint_to_restore = ''
6 | Trainer.pretrained_source = 'scratch'
7 | weight_decay = 1e-6
8 |
--------------------------------------------------------------------------------
/meta_dataset/learn/gin/default/baseline_all.gin:
--------------------------------------------------------------------------------
1 | include 'meta_dataset/learn/gin/setups/all.gin'
2 |
3 | include 'meta_dataset/learn/gin/learners/baseline_config.gin'
4 | Trainer.learning_rate = 1e-5
5 | BatchSplitReaderGetReader.add_dataset_offset = True
6 |
--------------------------------------------------------------------------------
/meta_dataset/learn/gin/default/baseline_cosine_imagenet.gin:
--------------------------------------------------------------------------------
1 | include 'meta_dataset/learn/gin/setups/imagenet.gin'
2 |
3 | include 'meta_dataset/learn/gin/learners/baseline_cosine_config.gin'
4 | Trainer.learning_rate = 1e-5
5 |
--------------------------------------------------------------------------------
/meta_dataset/learn/gin/default/baseline_imagenet.gin:
--------------------------------------------------------------------------------
1 | include 'meta_dataset/learn/gin/setups/imagenet.gin'
2 |
3 | include 'meta_dataset/learn/gin/learners/baseline_config.gin'
4 | Trainer.learning_rate = 1e-5
5 |
--------------------------------------------------------------------------------
/meta_dataset/learn/gin/default/baseline_mini_imagenet_fiveshot.gin:
--------------------------------------------------------------------------------
1 | include 'meta_dataset/learn/gin/setups/mini_imagenet.gin'
2 | EpisodeDescriptionConfig.num_support = 5
3 |
4 | include 'meta_dataset/learn/gin/learners/baseline_config.gin'
5 | Trainer.learning_rate = 1e-4
6 | Trainer.decay_learning_rate = True
7 |
--------------------------------------------------------------------------------
/meta_dataset/learn/gin/default/baseline_mini_imagenet_oneshot.gin:
--------------------------------------------------------------------------------
1 | include 'meta_dataset/learn/gin/setups/mini_imagenet.gin'
2 | EpisodeDescriptionConfig.num_support = 1
3 |
4 | include 'meta_dataset/learn/gin/learners/baseline_config.gin'
5 | Trainer.learning_rate = 1e-4
6 | Trainer.decay_learning_rate = True
7 |
--------------------------------------------------------------------------------
/meta_dataset/learn/gin/default/baselinefinetune_all.gin:
--------------------------------------------------------------------------------
1 | include 'meta_dataset/learn/gin/setups/all.gin'
2 |
3 | include 'meta_dataset/learn/gin/learners/baselinefinetune_config.gin'
4 | Trainer.learning_rate = 1e-5
5 | BatchSplitReaderGetReader.add_dataset_offset = True
6 |
--------------------------------------------------------------------------------
/meta_dataset/learn/gin/default/baselinefinetune_cosine_imagenet.gin:
--------------------------------------------------------------------------------
1 | include 'meta_dataset/learn/gin/setups/imagenet.gin'
2 |
3 | include 'meta_dataset/learn/gin/learners/baselinefinetune_cosine_config.gin'
4 | Trainer.learning_rate = 1e-5
5 |
--------------------------------------------------------------------------------
/meta_dataset/learn/gin/default/baselinefinetune_cosine_mini_imagenet_fiveshot.gin:
--------------------------------------------------------------------------------
1 | include 'meta_dataset/learn/gin/setups/mini_imagenet.gin'
2 | EpisodeDescriptionConfig.num_support = 5
3 |
4 | include 'meta_dataset/learn/gin/learners/baselinefinetune_cosine_config.gin'
5 | Trainer.learning_rate = 1e-4
6 | Trainer.decay_learning_rate = True
7 |
--------------------------------------------------------------------------------
/meta_dataset/learn/gin/default/baselinefinetune_cosine_mini_imagenet_oneshot.gin:
--------------------------------------------------------------------------------
1 | include 'meta_dataset/learn/gin/setups/mini_imagenet.gin'
2 | EpisodeDescriptionConfig.num_support = 1
3 |
4 | include 'meta_dataset/learn/gin/learners/baselinefinetune_cosine_config.gin'
5 | Trainer.learning_rate = 1e-4
6 | Trainer.decay_learning_rate = True
7 |
--------------------------------------------------------------------------------
/meta_dataset/learn/gin/default/baselinefinetune_imagenet.gin:
--------------------------------------------------------------------------------
1 | include 'meta_dataset/learn/gin/setups/imagenet.gin'
2 |
3 | include 'meta_dataset/learn/gin/learners/baselinefinetune_config.gin'
4 | Trainer.learning_rate = 1e-5
5 |
--------------------------------------------------------------------------------
/meta_dataset/learn/gin/default/baselinefinetune_mini_imagenet_fiveshot.gin:
--------------------------------------------------------------------------------
1 | include 'meta_dataset/learn/gin/setups/mini_imagenet.gin'
2 | EpisodeDescriptionConfig.num_support = 5
3 |
4 | include 'meta_dataset/learn/gin/learners/baselinefinetune_config.gin'
5 | Trainer.learning_rate = 1e-4
6 | Trainer.decay_learning_rate = True
7 |
--------------------------------------------------------------------------------
/meta_dataset/learn/gin/default/baselinefinetune_mini_imagenet_oneshot.gin:
--------------------------------------------------------------------------------
1 | include 'meta_dataset/learn/gin/setups/mini_imagenet.gin'
2 | EpisodeDescriptionConfig.num_support = 1
3 |
4 | include 'meta_dataset/learn/gin/learners/baselinefinetune_config.gin'
5 | Trainer.learning_rate = 1e-4
6 | Trainer.decay_learning_rate = True
7 |
--------------------------------------------------------------------------------
/meta_dataset/learn/gin/default/crosstransformer_imagenet.gin:
--------------------------------------------------------------------------------
1 | include 'meta_dataset/learn/gin/setups/imagenet.gin'
2 | include 'meta_dataset/learn/gin/learners/crosstransformer_config.gin'
3 |
4 | # Backbone hypers.
5 | include 'meta_dataset/learn/gin/default/resnet34_stride16.gin'
6 |
7 | # Data hypers.
8 | DataConfig.image_height = 224
9 | train/EpisodeDescriptionConfig.ignore_hierarchy_probability = 0.5
10 | resnet34.deeplab_alignment = False
11 |
12 | # Training hypers (not needed for eval).
13 | Trainer.num_updates = 100000
14 | Trainer.normalized_gradient_descent = True
15 | Trainer.decay_learning_rate = True
16 | Trainer.decay_every = 2000
17 | Trainer.decay_rate = 0.915193802145601
18 | Trainer.learning_rate = 0.0006
19 | Trainer.num_updates = 100000
20 | weight_decay = 0.0000885787420909229
21 |
--------------------------------------------------------------------------------
/meta_dataset/learn/gin/default/crosstransformer_simclreps_imagenet.gin:
--------------------------------------------------------------------------------
1 | include 'meta_dataset/learn/gin/setups/imagenet.gin'
2 | include 'meta_dataset/learn/gin/learners/crosstransformer_config.gin'
3 |
4 | # Backbone hypers.
5 | include 'meta_dataset/learn/gin/default/resnet34_stride16.gin'
6 |
7 | # Data hypers.
8 | DataConfig.image_height = 224
9 | train/EpisodeDescriptionConfig.ignore_hierarchy_probability = 0.5
10 | resnet34.deeplab_alignment = False
11 |
12 | # Training hypers (not needed for eval).
13 | Trainer.normalized_gradient_descent = True
14 | Trainer.decay_learning_rate = True
15 | Trainer.decay_every = 4000
16 | Trainer.decay_rate = 0.915193802145601
17 | Trainer.learning_rate = 0.0006
18 | Trainer.num_updates = 400000
19 | Trainer.enable_tf_optimizations = False
20 | weight_decay = 0.0000885787420909229
21 |
22 | train/EpisodeDescriptionConfig.simclr_episode_fraction = 0.5
23 |
--------------------------------------------------------------------------------
/meta_dataset/learn/gin/default/debug_proto_fungi.gin:
--------------------------------------------------------------------------------
1 | # 'fungi' has many classes (> 1000, more than ImageNet) and a wide spread of
2 | # class sizes (from 6 to 442), while being overall not too big (13GB on disk),
3 | # which makes it a good candidate to profile imbalanced datasets with many
4 | # classes locally.
5 | benchmark.datasets = 'fungi'
6 | include 'meta_dataset/learn/gin/setups/data_config.gin'
7 | include 'meta_dataset/learn/gin/setups/trainer_config_debug.gin'
8 | include 'meta_dataset/learn/gin/setups/variable_way_and_shot.gin'
9 | include 'meta_dataset/learn/gin/learners/prototypical_config.gin'
10 | Trainer.data_config = @DataConfig()
11 |
12 | # Total number of updates is 100, do not checkpoint or validate during profiling.
13 | Trainer.checkpoint_every = 1000
14 | Trainer.validate_every = 1000
15 | Trainer.log_every = 10
16 |
17 | Trainer.train_episode_config = @train/EpisodeDescriptionConfig()
18 | Trainer.eval_episode_config = @EpisodeDescriptionConfig()
19 |
--------------------------------------------------------------------------------
/meta_dataset/learn/gin/default/debug_proto_mini_imagenet.gin:
--------------------------------------------------------------------------------
1 | benchmark.datasets = 'mini_imagenet'
2 | include 'meta_dataset/learn/gin/setups/data_config.gin'
3 | include 'meta_dataset/learn/gin/setups/trainer_config_debug.gin'
4 | include 'meta_dataset/learn/gin/setups/variable_way_and_shot.gin'
5 | include 'meta_dataset/learn/gin/learners/prototypical_config.gin'
6 | Trainer.data_config = @DataConfig()
7 | Trainer.learning_rate = 1e-4
8 |
9 | Trainer.train_episode_config = @train/EpisodeDescriptionConfig()
10 | Trainer.eval_episode_config = @EpisodeDescriptionConfig()
11 |
--------------------------------------------------------------------------------
/meta_dataset/learn/gin/default/flute.gin:
--------------------------------------------------------------------------------
1 | include 'meta_dataset/learn/gin/setups/all.gin'
2 | include 'meta_dataset/learn/gin/learners/learner_config.gin'
3 | include 'meta_dataset/learn/gin/setups/trainer_config_flute.gin'
4 | include 'meta_dataset/learn/gin/setups/data_config_flute.gin'
5 |
6 | # Learners to use at train and validation.
7 | Trainer_flute.train_learner_class = @DatasetConditionalBaselineLearner
8 | Trainer_flute.eval_learner_class = @DatasetConditionalPrototypicalNetworkLearner
9 | BaselineLearner.cosine_classifier = True
10 | BaselineLearner.cosine_logits_multiplier = 10.
11 | BaselineLearner.use_weight_norm = False
12 | BatchSplitReaderGetReader.add_dataset_offset = True
13 | BaselineLearner.knn_in_fc = False
14 | BaselineLearner.knn_distance = 'l2' # Does not matter.
15 |
16 | # Optimization.
17 | Trainer_flute.optimizer_type = 'momentum'
18 | Trainer_flute.learn_rate_scheduler = 'cosine_decay_restarts'
19 | Trainer_flute.decay_learning_rate = True
20 | Trainer_flute.sample_half_from_imagenet = True
21 | Trainer_flute.meta_batch_size = 8
22 | Trainer_flute.batch_size = 16
23 | Trainer_flute.learning_rate = 0.01
24 | Trainer_flute.decay_every = 5000
25 | Trainer_flute.num_updates = 640000
26 | Trainer_flute.validate_every = 1000
27 | Trainer_flute.checkpoint_every = 2000
28 |
29 | # Weight decay and learner settings.
30 | separate_head_linear_classifier.learnable_scale = True
31 | separate_head_linear_classifier.weight_decay = %weight_decay
32 | flute_resnet.weight_decay = %weight_decay
33 | weight_decay = 7e-4
34 | Learner.transductive_batch_norm = False
35 | Learner.backprop_through_moments = True
36 |
37 | # Backbone settings.
38 | Learner.embedding_fn = @flute_resnet
39 | bn_wrapper.batch_norm_fn = @bn_flute_train
40 | bn_flute_train.film_weight_decay = 0.0001
41 | DatasetConditionalBaselineLearner.num_sets = %num_film_sets
42 | DatasetConditionalPrototypicalNetworkLearner.num_sets = %num_film_sets
43 | bn_wrapper.num_film_sets = %num_film_sets
44 | num_film_sets = 8
45 |
46 | # Data settings.
47 | # Validate on the same datasets as training happens, and in the same order, so
48 | # that the ground truth source ID can be used for validation-time forward passes
49 | # without remapping.
50 | benchmark.train_datasets = 'ilsvrc_2012,aircraft,cu_birds,omniglot,quickdraw,vgg_flower,dtd,fungi'
51 | benchmark.eval_datasets = 'ilsvrc_2012,aircraft,cu_birds,omniglot,quickdraw,vgg_flower,dtd,fungi'
52 |
--------------------------------------------------------------------------------
/meta_dataset/learn/gin/default/flute_dataset_classifier.gin:
--------------------------------------------------------------------------------
1 | include 'meta_dataset/learn/gin/setups/all.gin'
2 | include 'meta_dataset/learn/gin/learners/learner_config.gin'
3 | include 'meta_dataset/learn/gin/setups/trainer_config_flute.gin'
4 | include 'meta_dataset/learn/gin/setups/data_config_flute.gin'
5 |
6 | # Learners to use at train and validation.
7 | Trainer_flute.train_learner_class = @DatasetLearner
8 | Trainer_flute.eval_learner_class = @DatasetLearner
9 | BatchSplitReaderGetReader.add_dataset_offset = True
10 | Learner.transductive_batch_norm = False
11 | Learner.backprop_through_moments = True
12 |
13 | # Optimization.
14 | Trainer_flute.optimizer_type = 'adam'
15 | Trainer_flute.learn_rate_scheduler = 'cosine_decay'
16 | Trainer_flute.decay_learning_rate = True
17 | Trainer_flute.sample_half_from_imagenet = False
18 | Trainer_flute.meta_batch_size = 8
19 | Trainer_flute.batch_size = 16
20 | Trainer_flute.learning_rate = 0.001
21 | Trainer_flute.decay_every = 3000
22 | Trainer_flute.num_updates = 14000
23 |
24 | # Backbone settings.
25 | Learner.embedding_fn = @dataset_classifier
26 | dataset_classifier.weight_decay = %weight_decay
27 | dataset_classifier.num_datasets = %num_film_sets
28 | num_film_sets = 8
29 | weight_decay = 7e-4
30 |
31 | # Data settings.
32 | # Validate on the same datasets as training happens, and in the same order, so
33 | # that the ground truth source ID can be used for validation-time forward passes
34 | # without remapping.
35 | benchmark.train_datasets = 'ilsvrc_2012,aircraft,cu_birds,omniglot,quickdraw,vgg_flower,dtd,fungi'
36 | benchmark.eval_datasets = 'ilsvrc_2012,aircraft,cu_birds,omniglot,quickdraw,vgg_flower,dtd,fungi'
37 |
--------------------------------------------------------------------------------
/meta_dataset/learn/gin/default/maml_all.gin:
--------------------------------------------------------------------------------
1 | include 'meta_dataset/learn/gin/setups/all.gin'
2 |
3 | include 'meta_dataset/learn/gin/learners/maml_config.gin'
4 | Trainer.learning_rate = 1e-3
5 |
--------------------------------------------------------------------------------
/meta_dataset/learn/gin/default/maml_imagenet.gin:
--------------------------------------------------------------------------------
1 | include 'meta_dataset/learn/gin/setups/imagenet.gin'
2 |
3 | include 'meta_dataset/learn/gin/learners/maml_config.gin'
4 | Trainer.learning_rate = 1e-3
5 |
--------------------------------------------------------------------------------
/meta_dataset/learn/gin/default/maml_init_with_proto_all.gin:
--------------------------------------------------------------------------------
1 | include 'meta_dataset/learn/gin/setups/all.gin'
2 |
3 | include 'meta_dataset/learn/gin/learners/maml_config.gin'
4 | Trainer.learning_rate = 1e-3
5 |
--------------------------------------------------------------------------------
/meta_dataset/learn/gin/default/maml_init_with_proto_imagenet.gin:
--------------------------------------------------------------------------------
1 | include 'meta_dataset/learn/gin/setups/imagenet.gin'
2 |
3 | include 'meta_dataset/learn/gin/learners/maml_config.gin'
4 | Trainer.learning_rate = 1e-3
5 |
--------------------------------------------------------------------------------
/meta_dataset/learn/gin/default/maml_init_with_proto_inference_all.gin:
--------------------------------------------------------------------------------
1 | include 'meta_dataset/learn/gin/setups/all.gin'
2 | include 'meta_dataset/learn/gin/learners/learner_config.gin'
3 | BatchSplitReaderGetReader.add_dataset_offset = True
4 | Trainer.train_learner_class = @BaselineLearner
5 | Trainer.eval_learner_class = @MAMLLearner
6 | # The following line is what makes this proto-MAML.
7 | MAMLLearner.proto_maml_fc_layer_init = True
8 | weight_decay = 1e-4
9 | BaselineLearner.knn_in_fc = False
10 | MAMLLearner.debug = False
11 |
12 | BaselineLearner.transductive_batch_norm = False
13 | BaselineLearner.backprop_through_moments = True
14 |
15 | MAMLLearner.transductive_batch_norm = False
16 | MAMLLearner.backprop_through_moments = True
17 |
--------------------------------------------------------------------------------
/meta_dataset/learn/gin/default/maml_init_with_proto_inference_imagenet.gin:
--------------------------------------------------------------------------------
1 | include 'meta_dataset/learn/gin/setups/imagenet.gin'
2 | include 'meta_dataset/learn/gin/learners/learner_config.gin'
3 | Trainer.train_learner_class = @BaselineLearner
4 | Trainer.eval_learner_class = @MAMLLearner
5 | # The following line is what makes this proto-MAML.
6 | MAMLLearner.proto_maml_fc_layer_init = True
7 | weight_decay = 1e-4
8 | BaselineLearner.knn_in_fc = False
9 | MAMLLearner.debug = False
10 |
11 | BaselineLearner.transductive_batch_norm = False
12 | BaselineLearner.backprop_through_moments = True
13 |
14 | MAMLLearner.transductive_batch_norm = False
15 | MAMLLearner.backprop_through_moments = True
16 |
--------------------------------------------------------------------------------
/meta_dataset/learn/gin/default/maml_init_with_proto_mini_imagenet_fiveshot.gin:
--------------------------------------------------------------------------------
1 | include 'meta_dataset/learn/gin/setups/mini_imagenet.gin'
2 | EpisodeDescriptionConfig.num_support = 5
3 |
4 | include 'meta_dataset/learn/gin/learners/maml_config.gin'
5 | Trainer.learning_rate = 1e-4
6 |
--------------------------------------------------------------------------------
/meta_dataset/learn/gin/default/maml_init_with_proto_mini_imagenet_oneshot.gin:
--------------------------------------------------------------------------------
1 | include 'meta_dataset/learn/gin/setups/mini_imagenet.gin'
2 | EpisodeDescriptionConfig.num_support = 1
3 |
4 | include 'meta_dataset/learn/gin/learners/maml_config.gin'
5 | Trainer.learning_rate = 1e-4
6 |
--------------------------------------------------------------------------------
/meta_dataset/learn/gin/default/maml_mini_imagenet_fiveshot.gin:
--------------------------------------------------------------------------------
1 | include 'meta_dataset/learn/gin/setups/mini_imagenet.gin'
2 | EpisodeDescriptionConfig.num_support = 5
3 |
4 | include 'meta_dataset/learn/gin/learners/maml_config.gin'
5 | Trainer.learning_rate = 1e-4
6 |
--------------------------------------------------------------------------------
/meta_dataset/learn/gin/default/maml_mini_imagenet_oneshot.gin:
--------------------------------------------------------------------------------
1 | include 'meta_dataset/learn/gin/setups/mini_imagenet.gin'
2 | EpisodeDescriptionConfig.num_support = 1
3 |
4 | include 'meta_dataset/learn/gin/learners/maml_config.gin'
5 | Trainer.learning_rate = 1e-4
6 |
--------------------------------------------------------------------------------
/meta_dataset/learn/gin/default/maml_protonet_all.gin:
--------------------------------------------------------------------------------
1 | include 'meta_dataset/learn/gin/setups/all.gin'
2 |
3 | include 'meta_dataset/learn/gin/learners/maml_config.gin'
4 | Trainer.learning_rate = 1e-3
5 |
--------------------------------------------------------------------------------
/meta_dataset/learn/gin/default/matching_all.gin:
--------------------------------------------------------------------------------
1 | include 'meta_dataset/learn/gin/setups/all.gin'
2 |
3 | include 'meta_dataset/learn/gin/learners/matching_config.gin'
4 | Trainer.learning_rate = 1e-3
5 |
--------------------------------------------------------------------------------
/meta_dataset/learn/gin/default/matching_imagenet.gin:
--------------------------------------------------------------------------------
1 | include 'meta_dataset/learn/gin/setups/imagenet.gin'
2 |
3 | include 'meta_dataset/learn/gin/learners/matching_config.gin'
4 | Trainer.learning_rate = 1e-3
5 |
--------------------------------------------------------------------------------
/meta_dataset/learn/gin/default/matching_inference_all.gin:
--------------------------------------------------------------------------------
1 | include 'meta_dataset/learn/gin/setups/all.gin'
2 | include 'meta_dataset/learn/gin/learners/learner_config.gin'
3 | BatchSplitReaderGetReader.add_dataset_offset = True
4 | Trainer.train_learner_class = @BaselineLearner
5 | Trainer.eval_learner_class = @MatchingNetworkLearner
6 | weight_decay = 1e-4
7 | BaselineLearner.knn_in_fc = False
8 | MatchingNetworkLearner.exact_cosine_distance = False
9 |
10 | BaselineLearner.transductive_batch_norm = False
11 | BaselineLearner.backprop_through_moments = True
12 |
13 | MatchingNetworkLearner.transductive_batch_norm = False
14 | MatchingNetworkLearner.backprop_through_moments = True
15 |
--------------------------------------------------------------------------------
/meta_dataset/learn/gin/default/matching_inference_imagenet.gin:
--------------------------------------------------------------------------------
1 | include 'meta_dataset/learn/gin/setups/imagenet.gin'
2 | include 'meta_dataset/learn/gin/learners/learner_config.gin'
3 | Trainer.train_learner_class = @BaselineLearner
4 | Trainer.eval_learner_class = @MatchingNetworkLearner
5 | weight_decay = 1e-4
6 | BaselineLearner.knn_in_fc = False
7 | MatchingNetworkLearner.exact_cosine_distance = False
8 |
9 | BaselineLearner.transductive_batch_norm = False
10 | BaselineLearner.backprop_through_moments = True
11 |
12 | MatchingNetworkLearner.transductive_batch_norm = False
13 | MatchingNetworkLearner.backprop_through_moments = True
14 |
--------------------------------------------------------------------------------
/meta_dataset/learn/gin/default/matching_mini_imagenet_fiveshot.gin:
--------------------------------------------------------------------------------
1 | include 'meta_dataset/learn/gin/setups/mini_imagenet.gin'
2 | EpisodeDescriptionConfig.num_support = 5
3 |
4 | include 'meta_dataset/learn/gin/learners/matching_config.gin'
5 | Trainer.learning_rate = 1e-4
6 |
--------------------------------------------------------------------------------
/meta_dataset/learn/gin/default/matching_mini_imagenet_oneshot.gin:
--------------------------------------------------------------------------------
1 | include 'meta_dataset/learn/gin/setups/mini_imagenet.gin'
2 | EpisodeDescriptionConfig.num_support = 1
3 |
4 | include 'meta_dataset/learn/gin/learners/matching_config.gin'
5 | Trainer.learning_rate = 1e-4
6 |
--------------------------------------------------------------------------------
/meta_dataset/learn/gin/default/pretrained_resnet34_224.gin:
--------------------------------------------------------------------------------
1 | # Best pre-trained resnet. Update the path to the checkpoint.
2 | Learner.embedding_fn = @resnet34
3 | Trainer.checkpoint_to_restore = '/path/to/checkpoints/pretrain_imagenet_resnet/model_27500.ckpt'
4 | Trainer.pretrained_source = 'imagenet'
5 |
--------------------------------------------------------------------------------
/meta_dataset/learn/gin/default/prototypical_all.gin:
--------------------------------------------------------------------------------
1 | include 'meta_dataset/learn/gin/setups/all.gin'
2 |
3 | include 'meta_dataset/learn/gin/learners/prototypical_config.gin'
4 | Trainer.learning_rate = 1e-3
5 |
--------------------------------------------------------------------------------
/meta_dataset/learn/gin/default/prototypical_imagenet.gin:
--------------------------------------------------------------------------------
1 | include 'meta_dataset/learn/gin/setups/imagenet.gin'
2 |
3 | include 'meta_dataset/learn/gin/learners/prototypical_config.gin'
4 | Trainer.learning_rate = 1e-3
5 |
--------------------------------------------------------------------------------
/meta_dataset/learn/gin/default/prototypical_inference_all.gin:
--------------------------------------------------------------------------------
1 | include 'meta_dataset/learn/gin/setups/all.gin'
2 | include 'meta_dataset/learn/gin/learners/learner_config.gin'
3 | BatchSplitReaderGetReader.add_dataset_offset = True
4 | Trainer.train_learner_class = @BaselineLearner
5 | Trainer.eval_learner_class = @PrototypicalNetworkLearner
6 | weight_decay = 1e-4
7 | BaselineLearner.knn_in_fc = False
8 |
9 | BaselineLearner.transductive_batch_norm = False
10 | BaselineLearner.backprop_through_moments = True
11 |
12 | PrototypicalNetworkLearner.transductive_batch_norm = False
13 | PrototypicalNetworkLearner.backprop_through_moments = True
14 |
--------------------------------------------------------------------------------
/meta_dataset/learn/gin/default/prototypical_inference_imagenet.gin:
--------------------------------------------------------------------------------
1 | include 'meta_dataset/learn/gin/setups/imagenet.gin'
2 | include 'meta_dataset/learn/gin/learners/learner_config.gin'
3 | Trainer.train_learner_class = @BaselineLearner
4 | Trainer.eval_learner_class = @PrototypicalNetworkLearner
5 | weight_decay = 1e-4
6 | BaselineLearner.knn_in_fc = False
7 |
8 | BaselineLearner.transductive_batch_norm = False
9 | BaselineLearner.backprop_through_moments = True
10 |
11 | PrototypicalNetworkLearner.transductive_batch_norm = False
12 | PrototypicalNetworkLearner.backprop_through_moments = True
13 |
--------------------------------------------------------------------------------
/meta_dataset/learn/gin/default/prototypical_mini_imagenet_fiveshot.gin:
--------------------------------------------------------------------------------
1 | include 'meta_dataset/learn/gin/setups/mini_imagenet.gin'
2 | EpisodeDescriptionConfig.num_support = 5
3 | train/EpisodeDescriptionConfig.num_ways = 20
4 |
5 | include 'meta_dataset/learn/gin/learners/prototypical_config.gin'
6 | Trainer.learning_rate = 1e-4
7 |
--------------------------------------------------------------------------------
/meta_dataset/learn/gin/default/prototypical_mini_imagenet_oneshot.gin:
--------------------------------------------------------------------------------
1 | include 'meta_dataset/learn/gin/setups/mini_imagenet.gin'
2 | EpisodeDescriptionConfig.num_support = 1
3 | train/EpisodeDescriptionConfig.num_ways = 30
4 |
5 | include 'meta_dataset/learn/gin/learners/prototypical_config.gin'
6 | Trainer.learning_rate = 1e-4
7 |
--------------------------------------------------------------------------------
/meta_dataset/learn/gin/default/relationnet_all.gin:
--------------------------------------------------------------------------------
1 | include 'meta_dataset/learn/gin/setups/all.gin'
2 |
3 | include 'meta_dataset/learn/gin/learners/relationnet_config.gin'
4 |
5 | # There are more parameters than usual, so allow more updates.
6 | Trainer.num_updates = 100000
7 |
8 | Trainer.learning_rate = 1e-3
9 |
--------------------------------------------------------------------------------
/meta_dataset/learn/gin/default/relationnet_imagenet.gin:
--------------------------------------------------------------------------------
1 | include 'meta_dataset/learn/gin/setups/imagenet.gin'
2 | include 'meta_dataset/learn/gin/learners/relationnet_config.gin'
3 |
4 | # There are more parameters than usual, so allow more updates.
5 | Trainer.num_updates = 100000
6 |
7 | Trainer.learning_rate = 1e-3
8 |
--------------------------------------------------------------------------------
/meta_dataset/learn/gin/default/relationnet_mini_imagenet_fiveshot.gin:
--------------------------------------------------------------------------------
1 | include 'meta_dataset/learn/gin/setups/mini_imagenet.gin'
2 | include 'meta_dataset/learn/gin/learners/relationnet_config.gin'
3 |
4 | Trainer.learning_rate = 1e-3
5 | Trainer.decay_every = 100000
6 | Trainer.decay_rate = 0.5
7 |
--------------------------------------------------------------------------------
/meta_dataset/learn/gin/default/relationnet_mini_imagenet_oneshot.gin:
--------------------------------------------------------------------------------
1 | include 'meta_dataset/learn/gin/setups/mini_imagenet.gin'
2 | EpisodeDescriptionConfig.num_support = 1
3 | include 'meta_dataset/learn/gin/learners/relationnet_config.gin'
4 |
5 | Trainer.learning_rate = 1e-3
6 | Trainer.decay_every = 100000
7 | Trainer.decay_rate = 0.5
8 |
--------------------------------------------------------------------------------
/meta_dataset/learn/gin/default/resnet34_stride16.gin:
--------------------------------------------------------------------------------
1 | Learner.embedding_fn = @resnet34
2 | Trainer.pretrained_source = 'imagenet'
3 | resnet34.max_stride = 16
4 |
--------------------------------------------------------------------------------
/meta_dataset/learn/gin/learners/baseline_config.gin:
--------------------------------------------------------------------------------
1 | include 'meta_dataset/learn/gin/learners/learner_config.gin'
2 | Trainer.train_learner_class = @BaselineLearner
3 | Trainer.eval_learner_class = @BaselineLearner
4 |
5 | Learner.embedding_fn = @four_layer_convnet
6 | weight_decay = 1e-4
7 |
8 | BaselineLearner.knn_in_fc = False
9 | BaselineLearner.knn_distance = 'l2'
10 | BaselineLearner.cosine_classifier = False
11 | BaselineLearner.cosine_logits_multiplier = None
12 | BaselineLearner.use_weight_norm = False
13 |
--------------------------------------------------------------------------------
/meta_dataset/learn/gin/learners/baseline_cosine_config.gin:
--------------------------------------------------------------------------------
1 | include 'meta_dataset/learn/gin/learners/baseline_config.gin'
2 |
3 | BaselineLearner.cosine_classifier = True
4 | BaselineLearner.cosine_logits_multiplier = 1
5 |
--------------------------------------------------------------------------------
/meta_dataset/learn/gin/learners/baselinefinetune_config.gin:
--------------------------------------------------------------------------------
1 | include 'meta_dataset/learn/gin/learners/baseline_config.gin'
2 | Trainer.train_learner_class = @BaselineFinetuneLearner
3 | Trainer.eval_learner_class = @BaselineFinetuneLearner
4 | weight_decay = 1e-4
5 | BaselineFinetuneLearner.num_finetune_steps = 5
6 | BaselineFinetuneLearner.finetune_lr = 1e-4
7 | BaselineFinetuneLearner.finetune_all_layers = False
8 | BaselineFinetuneLearner.finetune_with_adam = False
9 |
--------------------------------------------------------------------------------
/meta_dataset/learn/gin/learners/baselinefinetune_cosine_config.gin:
--------------------------------------------------------------------------------
1 | include 'meta_dataset/learn/gin/learners/baselinefinetune_config.gin'
2 |
3 | BaselineLearner.cosine_classifier = True
4 | BaselineLearner.cosine_logits_multiplier = 1
5 |
--------------------------------------------------------------------------------
/meta_dataset/learn/gin/learners/crosstransformer_config.gin:
--------------------------------------------------------------------------------
1 | include 'meta_dataset/learn/gin/learners/learner_config.gin'
2 | Trainer.train_learner_class = @CrossTransformerLearner
3 | Trainer.eval_learner_class = @CrossTransformerLearner
4 | Trainer.normalized_gradient_descent = True
5 | Learner.embedding_fn = 'four_layer_convnet'
6 | weight_decay = 1e-4
7 | CrossTransformerLearner.tformer_weight_decay = %weight_decay
8 | bn.use_ema=True
9 | Trainer.distribute = True
10 |
--------------------------------------------------------------------------------
/meta_dataset/learn/gin/learners/learner_config.gin:
--------------------------------------------------------------------------------
1 | # Backbone regularization settings.
2 | linear_classifier.weight_decay = %weight_decay
3 | fully_connected_network.weight_decay = %weight_decay
4 | four_layer_convnet.weight_decay = %weight_decay
5 | relationnet_convnet.weight_decay = %weight_decay
6 | relation_module.weight_decay = %weight_decay
7 | resnet.weight_decay = %weight_decay
8 | resnet34.weight_decay = %weight_decay
9 | wide_resnet.weight_decay = %weight_decay
10 | MAMLLearner.classifier_weight_decay = %weight_decay
11 | weight_decay = 0.001
12 |
13 | # Training settings.
14 | Trainer.checkpoint_to_restore = ''
15 | Trainer.learning_rate = 1e-4
16 | Trainer.decay_learning_rate = False
17 | Trainer.decay_every = 5000
18 | Trainer.decay_rate = 0.5
19 | Trainer.normalized_gradient_descent = False
20 | Trainer.pretrained_source = ''
21 |
--------------------------------------------------------------------------------
/meta_dataset/learn/gin/learners/maml_config.gin:
--------------------------------------------------------------------------------
1 | include 'meta_dataset/learn/gin/learners/learner_config.gin'
2 | Trainer.train_learner_class = @MAMLLearner
3 | Trainer.eval_learner_class = @MAMLLearner
4 | Trainer.decay_learning_rate = True
5 |
6 | Learner.embedding_fn = @four_layer_convnet
7 | weight_decay = 1e-4
8 |
9 | MAMLLearner.num_update_steps = 5
10 | MAMLLearner.additional_evaluation_update_steps = 0
11 | MAMLLearner.first_order = True
12 | MAMLLearner.alpha = 0.01
13 | MAMLLearner.adapt_batch_norm = False
14 | MAMLLearner.debug = False
15 | MAMLLearner.zero_fc_layer = True
16 | MAMLLearner.proto_maml_fc_layer_init = False
17 |
--------------------------------------------------------------------------------
/meta_dataset/learn/gin/learners/maml_init_with_proto_config.gin:
--------------------------------------------------------------------------------
1 | include 'meta_dataset/learn/gin/learners/learner_config.gin'
2 | Trainer.train_learner_class = @MAMLLearner
3 | Trainer.decay_learning_rate = True
4 | Trainer.eval_learner_class = @MAMLLearner
5 |
6 | Learner.embedding_fn = @four_layer_convnet
7 | weight_decay = 1e-4
8 |
9 | MAMLLearner.num_update_steps = 5
10 | MAMLLearner.additional_evaluation_update_steps = 0
11 | MAMLLearner.first_order = True
12 | MAMLLearner.alpha = 0.01
13 | MAMLLearner.adapt_batch_norm = False
14 | MAMLLearner.debug = False
15 | MAMLLearner.zero_fc_layer = True
16 | MAMLLearner.proto_maml_fc_layer_init = True
17 |
--------------------------------------------------------------------------------
/meta_dataset/learn/gin/learners/matching_config.gin:
--------------------------------------------------------------------------------
1 | include 'meta_dataset/learn/gin/learners/learner_config.gin'
2 | Trainer.train_learner_class = @MatchingNetworkLearner
3 | Trainer.eval_learner_class = @MatchingNetworkLearner
4 |
5 | Trainer.decay_learning_rate = True
6 | Learner.embedding_fn = @four_layer_convnet
7 | weight_decay = 1e-4
8 |
9 | MatchingNetworkLearner.exact_cosine_distance = False
10 |
--------------------------------------------------------------------------------
/meta_dataset/learn/gin/learners/prototypical_config.gin:
--------------------------------------------------------------------------------
1 | include 'meta_dataset/learn/gin/learners/learner_config.gin'
2 | Trainer.train_learner_class = @PrototypicalNetworkLearner
3 | Trainer.eval_learner_class = @PrototypicalNetworkLearner
4 | Trainer.decay_learning_rate = True
5 |
6 | Learner.embedding_fn = @four_layer_convnet
7 | weight_decay = 1e-4
8 |
--------------------------------------------------------------------------------
/meta_dataset/learn/gin/learners/relationnet_config.gin:
--------------------------------------------------------------------------------
1 | include 'meta_dataset/learn/gin/learners/learner_config.gin'
2 | Trainer.train_learner_class = @RelationNetworkLearner
3 | Trainer.eval_learner_class = @RelationNetworkLearner
4 | Trainer.decay_learning_rate = True
5 |
6 | Learner.embedding_fn = @relationnet_convnet
7 | Learner.transductive_batch_norm = True
8 | weight_decay = 1e-6
9 |
10 | # Allow transductive batch norm to be faithful with the original Relation Net
11 | # implementation, even though this gives it an advantage over the rest of the
12 | # models we used on Meta-Dataset.
13 | Learner.transductive_batch_norm = True
14 |
--------------------------------------------------------------------------------
/meta_dataset/learn/gin/metadataset_v2/baseline_all.gin:
--------------------------------------------------------------------------------
1 | include 'meta_dataset/learn/gin/setups/all_v2.gin'
2 | include 'meta_dataset/learn/gin/learners/baseline_config.gin'
3 | Trainer.learning_rate = 1e-5
4 | BatchSplitReaderGetReader.add_dataset_offset = True
5 |
--------------------------------------------------------------------------------
/meta_dataset/learn/gin/metadataset_v2/baseline_imagenet.gin:
--------------------------------------------------------------------------------
1 | include 'meta_dataset/learn/gin/setups/imagenet_v2.gin'
2 | include 'meta_dataset/learn/gin/learners/baseline_config.gin'
3 | Trainer.learning_rate = 1e-5
4 |
--------------------------------------------------------------------------------
/meta_dataset/learn/gin/metadataset_v2/baselinefinetune_all.gin:
--------------------------------------------------------------------------------
1 | include 'meta_dataset/learn/gin/setups/all_v2.gin'
2 | include 'meta_dataset/learn/gin/learners/baselinefinetune_config.gin'
3 | Trainer.learning_rate = 1e-5
4 | BatchSplitReaderGetReader.add_dataset_offset = True
5 |
--------------------------------------------------------------------------------
/meta_dataset/learn/gin/metadataset_v2/baselinefinetune_imagenet.gin:
--------------------------------------------------------------------------------
1 | include 'meta_dataset/learn/gin/setups/imagenet_v2.gin'
2 | include 'meta_dataset/learn/gin/learners/baselinefinetune_config.gin'
3 | Trainer.learning_rate = 1e-5
4 |
--------------------------------------------------------------------------------
/meta_dataset/learn/gin/metadataset_v2/best/baseline_imagenet.gin:
--------------------------------------------------------------------------------
1 | include 'meta_dataset/learn/gin/metadataset_v2/baseline_imagenet.gin'
2 | include 'meta_dataset/learn/gin/learners/baseline_config.gin'
3 |
4 | BatchSplitReaderGetReader.add_dataset_offset = True
5 |
6 | # Backbone hypers.
7 | Learner.embedding_fn = @resnet
8 | Trainer.checkpoint_to_restore = ''
9 | Trainer.pretrained_source = ''
10 |
11 | # Model hypers.
12 | BaselineLearner.knn_distance = 'cosine'
13 | BaselineLearner.cosine_classifier = False
14 | BaselineLearner.cosine_logits_multiplier = 2
15 | BaselineLearner.use_weight_norm = False
16 |
17 | Trainer.decay_every = 1000
18 | Trainer.decay_rate = 0.5979159492081371
19 | Trainer.learning_rate = 0.00047244647904730503
20 | weight_decay = 0.026388517138594258
21 | DataConfig.image_height = 126
22 | DataAugmentation.jitter_amount = 6
23 | DataAugmentation.gaussian_noise_std = 0.17564536824131866
24 |
--------------------------------------------------------------------------------
/meta_dataset/learn/gin/metadataset_v2/best/baselinefinetune_all.gin:
--------------------------------------------------------------------------------
1 | include 'meta_dataset/learn/gin/setups/all_v2.gin'
2 | include 'meta_dataset/learn/gin/learners/baselinefinetune_config.gin'
3 | # Backbone hypers.
4 | Learner.embedding_fn = @resnet
5 | weight_decay = 0.0001
6 |
7 | # Model hypers.
8 | BaselineLearner.cosine_classifier = True
9 | BaselineLearner.cosine_logits_multiplier = 10
10 | BaselineLearner.knn_distance = 'l2'
11 | BaselineLearner.knn_in_fc = False
12 | BaselineLearner.use_weight_norm = True
13 | BaselineFinetuneLearner.finetune_all_layers = True
14 | BaselineFinetuneLearner.finetune_lr = 0.01
15 | BaselineFinetuneLearner.finetune_with_adam = True
16 | BaselineFinetuneLearner.num_finetune_steps = 100
17 |
18 | # Training hypers (not needed for eval).
19 | Trainer.decay_every = 500
20 | Trainer.decay_learning_rate = False
21 | Trainer.decay_rate = 0.5278384940678894
22 | Trainer.learning_rate = 3.4293725734843445e-06
23 | Trainer.pretrained_source = 'imagenet'
24 | Trainer.checkpoint_to_restore = 'path/to/pretrained_checkpoint'
25 |
26 |
27 | DataConfig.image_height = 126
28 | batch/DataAugmentation.gaussian_noise_std = 0.026413512951864337
29 | batch/DataAugmentation.jitter_amount = 5
30 |
--------------------------------------------------------------------------------
/meta_dataset/learn/gin/metadataset_v2/best/baselinefinetune_imagenet.gin:
--------------------------------------------------------------------------------
1 | include 'meta_dataset/learn/gin/setups/imagenet_v2.gin'
2 | include 'meta_dataset/learn/gin/learners/baselinefinetune_config.gin'
3 | # Backbone hypers.
4 | Learner.embedding_fn = @resnet
5 | weight_decay = 0.0001
6 |
7 | # Model hypers.
8 | BaselineLearner.cosine_classifier = True
9 | BaselineLearner.cosine_logits_multiplier = 10
10 | BaselineLearner.knn_distance = 'l2'
11 | BaselineLearner.knn_in_fc = False
12 | BaselineLearner.use_weight_norm = True
13 | BaselineFinetuneLearner.finetune_all_layers = True
14 | BaselineFinetuneLearner.finetune_lr = 0.01
15 | BaselineFinetuneLearner.finetune_with_adam = True
16 | BaselineFinetuneLearner.num_finetune_steps = 100
17 |
18 | # Training hypers (not needed for eval).
19 | Trainer.decay_every = 500
20 | Trainer.decay_learning_rate = False
21 | Trainer.decay_rate = 0.5278384940678894
22 | Trainer.learning_rate = 3.4293725734843445e-06
23 | Trainer.pretrained_source = 'imagenet'
24 | Trainer.checkpoint_to_restore = 'path/to/pretrained_checkpoint'
25 |
26 | DataConfig.image_height = 126
27 | batch/DataAugmentation.gaussian_noise_std = 0.026413512951864337
28 | batch/DataAugmentation.jitter_amount = 5
29 |
--------------------------------------------------------------------------------
/meta_dataset/learn/gin/metadataset_v2/best/maml_init_with_proto_all.gin:
--------------------------------------------------------------------------------
1 | include 'meta_dataset/learn/gin/setups/all_v2.gin'
2 | include 'meta_dataset/learn/gin/learners/maml_init_with_proto_config.gin'
3 | # Backbone hypers.
4 | Learner.embedding_fn = @resnet
5 | weight_decay = 0.0001
6 | Trainer.pretrained_source = 'imagenet'
7 | Trainer.checkpoint_to_restore = 'path/to/pretrained_checkpoint'
8 |
9 | # Model hypers.
10 | MAMLLearner.adapt_batch_norm = False
11 | MAMLLearner.additional_evaluation_update_steps = 0
12 | MAMLLearner.alpha = 0.005435022808033229
13 | MAMLLearner.first_order = True
14 | MAMLLearner.num_update_steps = 10
15 | MAMLLearner.proto_maml_fc_layer_init = True
16 | MAMLLearner.zero_fc_layer = True
17 | # Training hypers (not needed for eval).
18 | Trainer.decay_every = 1000
19 | Trainer.decay_learning_rate = True
20 | Trainer.decay_rate = 0.6477898086638092
21 | Trainer.learning_rate = 0.00036339913514891586
22 |
23 | DataConfig.image_height = 126
24 | support/DataAugmentation.gaussian_noise_std = 0.4658549336962272
25 | support/DataAugmentation.jitter_amount = 0
26 |
--------------------------------------------------------------------------------
/meta_dataset/learn/gin/metadataset_v2/best/maml_init_with_proto_imagenet.gin:
--------------------------------------------------------------------------------
1 | include 'meta_dataset/learn/gin/setups/imagenet_v2.gin'
2 | include 'meta_dataset/learn/gin/learners/maml_init_with_proto_config.gin'
3 | # Backbone hypers.
4 | Learner.embedding_fn = @resnet
5 | weight_decay = 0.0001
6 | Trainer.pretrained_source = 'imagenet'
7 | Trainer.checkpoint_to_restore = 'path/to/pretrained_checkpoint'
8 |
9 | # Model hypers.
10 | MAMLLearner.adapt_batch_norm = False
11 | MAMLLearner.additional_evaluation_update_steps = 90
12 | MAMLLearner.alpha = 0.01
13 | MAMLLearner.first_order = True
14 | MAMLLearner.num_update_steps = 10
15 | MAMLLearner.proto_maml_fc_layer_init = True
16 | MAMLLearner.zero_fc_layer = True
17 | # Training hypers (not needed for eval).
18 | Trainer.decay_every = 1000
19 | Trainer.decay_learning_rate = True
20 | Trainer.decay_rate = 0.6477898086638092
21 | Trainer.learning_rate = 0.00036339913514891586
22 |
23 | DataConfig.image_height = 126
24 | support/DataAugmentation.gaussian_noise_std = 0.4658549336962272
25 | support/DataAugmentation.jitter_amount = 0
26 |
--------------------------------------------------------------------------------
/meta_dataset/learn/gin/metadataset_v2/best/prototypical_all.gin:
--------------------------------------------------------------------------------
1 | include 'meta_dataset/learn/gin/setups/all_v2.gin'
2 | include 'meta_dataset/learn/gin/learners/prototypical_config.gin'
3 | # Backbone hypers.
4 | Learner.embedding_fn = @resnet
5 | Trainer.pretrained_source = 'imagenet'
6 | Trainer.checkpoint_to_restore = 'path/to/pretrained_checkpoint'
7 |
8 | # Training hypers (not needed for eval).
9 | Trainer.decay_every = 500
10 | Trainer.decay_learning_rate = True
11 | Trainer.decay_rate = 0.885662482266546
12 | Trainer.learning_rate = 0.00025036275525430426
13 | Learner.backprop_through_moments = True
14 | Learner.transductive_batch_norm = False
15 | weight_decay = 0.0001
16 |
17 | DataConfig.image_height = 126
18 | support/DataAugmentation.gaussian_noise_std = 0.15335348868374565
19 | support/DataAugmentation.jitter_amount = 5
20 |
--------------------------------------------------------------------------------
/meta_dataset/learn/gin/metadataset_v2/best/prototypical_imagenet.gin:
--------------------------------------------------------------------------------
1 | include 'meta_dataset/learn/gin/setups/imagenet_v2.gin'
2 | include 'meta_dataset/learn/gin/learners/prototypical_config.gin'
3 | # Backbone hypers.
4 | Learner.embedding_fn = @resnet
5 | Trainer.pretrained_source = 'imagenet'
6 | Trainer.checkpoint_to_restore = 'path/to/pretrained_checkpoint'
7 |
8 | # Training hypers (not needed for eval).
9 | Trainer.decay_every = 500
10 | Trainer.decay_learning_rate = True
11 | Trainer.decay_rate = 0.885662482266546
12 | Trainer.learning_rate = 0.00025036275525430426
13 | Learner.backprop_through_moments = True
14 | Learner.transductive_batch_norm = False
15 | weight_decay = 0.0001
16 |
17 | DataConfig.image_height = 126
18 | support/DataAugmentation.gaussian_noise_std = 0.15335348868374565
19 | support/DataAugmentation.jitter_amount = 5
20 |
--------------------------------------------------------------------------------
/meta_dataset/learn/gin/metadataset_v2/maml_init_with_proto_all.gin:
--------------------------------------------------------------------------------
1 | include 'meta_dataset/learn/gin/setups/all_v2.gin'
2 | include 'meta_dataset/learn/gin/learners/maml_config.gin'
3 | Trainer.learning_rate = 1e-3
4 |
--------------------------------------------------------------------------------
/meta_dataset/learn/gin/metadataset_v2/maml_init_with_proto_imagenet.gin:
--------------------------------------------------------------------------------
1 | include 'meta_dataset/learn/gin/setups/imagenet_v2.gin'
2 | include 'meta_dataset/learn/gin/learners/maml_config.gin'
3 | Trainer.learning_rate = 1e-3
4 |
--------------------------------------------------------------------------------
/meta_dataset/learn/gin/metadataset_v2/prototypical_all.gin:
--------------------------------------------------------------------------------
1 | include 'meta_dataset/learn/gin/setups/all_v2.gin'
2 | include 'meta_dataset/learn/gin/learners/prototypical_config.gin'
3 | Trainer.learning_rate = 1e-3
4 |
--------------------------------------------------------------------------------
/meta_dataset/learn/gin/setups/all.gin:
--------------------------------------------------------------------------------
1 | include 'meta_dataset/learn/gin/setups/all_datasets.gin'
2 | include 'meta_dataset/learn/gin/setups/data_config.gin'
3 | include 'meta_dataset/learn/gin/setups/trainer_config.gin'
4 | include 'meta_dataset/learn/gin/setups/variable_way_and_shot.gin'
5 | Trainer.data_config = @DataConfig()
6 | Trainer.train_episode_config = @train/EpisodeDescriptionConfig()
7 | Trainer.eval_episode_config = @EpisodeDescriptionConfig()
8 |
--------------------------------------------------------------------------------
/meta_dataset/learn/gin/setups/all_datasets.gin:
--------------------------------------------------------------------------------
1 | benchmark.train_datasets = 'ilsvrc_2012,aircraft,cu_birds,omniglot,quickdraw,vgg_flower,dtd,fungi'
2 | benchmark.eval_datasets = 'ilsvrc_2012'
3 |
--------------------------------------------------------------------------------
/meta_dataset/learn/gin/setups/all_v2.gin:
--------------------------------------------------------------------------------
1 | benchmark.train_datasets = 'ilsvrc_2012_v2,aircraft,cu_birds,omniglot,quickdraw,dtd,fungi'
2 | benchmark.eval_datasets = 'fungi,aircraft,quickdraw,omniglot,mscoco,cu_birds,dtd'
3 | include 'meta_dataset/learn/gin/setups/data_config.gin'
4 | include 'meta_dataset/learn/gin/setups/trainer_config.gin'
5 | include 'meta_dataset/learn/gin/setups/variable_way_and_shot.gin'
6 | Trainer.data_config = @DataConfig()
7 | Trainer.train_episode_config = @train/EpisodeDescriptionConfig()
8 | Trainer.eval_episode_config = @EpisodeDescriptionConfig()
9 |
--------------------------------------------------------------------------------
/meta_dataset/learn/gin/setups/data_config.gin:
--------------------------------------------------------------------------------
1 | import meta_dataset.data.config
2 | import meta_dataset.data.decoder
3 | include 'meta_dataset/learn/gin/setups/data_config_common.gin'
4 |
5 | # If we decode features then change the lines below to use FeatureDecoder.
6 | process_dumped_episode.support_decoder = @support/ImageDecoder()
7 | process_episode.support_decoder = @support/ImageDecoder()
8 | support/ImageDecoder.data_augmentation = @support/DataAugmentation()
9 | support/DataAugmentation.enable_jitter = True
10 | support/DataAugmentation.jitter_amount = 0
11 | support/DataAugmentation.enable_gaussian_noise = True
12 | support/DataAugmentation.gaussian_noise_std = 0.0
13 |
14 | process_dumped_episode.query_decoder = @query/ImageDecoder()
15 | process_episode.query_decoder = @query/ImageDecoder()
16 | query/ImageDecoder.data_augmentation = @query/DataAugmentation()
17 | query/DataAugmentation.enable_jitter = False
18 | query/DataAugmentation.jitter_amount = 0
19 | query/DataAugmentation.enable_gaussian_noise = False
20 | query/DataAugmentation.gaussian_noise_std = 0.0
21 |
22 | # It is possible to override the support and query decoders for the meta-train
23 | # and / or meta-evaluation phases using the scopes 'train' or 'evaluation'
24 | # (thanks to trainer.py that will appropriately pick the scope for each phase).
25 | # The following commented-out lines show an example:
26 | # train/process_episode.support_decoder = @train/support/ImageDecoder()
27 | # train/support/ImageDecoder.image_size = ...
28 |
29 | process_batch.batch_decoder = @batch/ImageDecoder()
30 | batch/ImageDecoder.data_augmentation = @batch/DataAugmentation()
31 | batch/DataAugmentation.enable_jitter = True
32 | batch/DataAugmentation.jitter_amount = 0
33 | batch/DataAugmentation.enable_gaussian_noise = True
34 | batch/DataAugmentation.gaussian_noise_std = 0.0
35 |
36 |
--------------------------------------------------------------------------------
/meta_dataset/learn/gin/setups/data_config_common.gin:
--------------------------------------------------------------------------------
1 | # This gin file doesn't set process_episode.support_decoder and .query_decoder
2 | # therefore the image strings are not decoded. This gin file can be used directly
3 | # if the episodes are needed without decoding.
4 | import meta_dataset.data.config
5 | # Default values for sampling variable shots / ways.
6 | EpisodeDescriptionConfig.min_ways = 5
7 | EpisodeDescriptionConfig.max_ways_upper_bound = 50
8 | EpisodeDescriptionConfig.max_num_query = 10
9 | # For weak shot experiments where we have missing class data.
10 | # This should not affect any other experiements.
11 | EpisodeDescriptionConfig.min_examples_in_class = 0
12 | EpisodeDescriptionConfig.max_support_set_size = 500
13 | EpisodeDescriptionConfig.max_support_size_contrib_per_class = 100
14 | EpisodeDescriptionConfig.min_log_weight = -0.69314718055994529 # np.log(0.5)
15 | EpisodeDescriptionConfig.max_log_weight = 0.69314718055994529 # np.log(2)
16 | EpisodeDescriptionConfig.ignore_dag_ontology = False
17 | EpisodeDescriptionConfig.ignore_bilevel_ontology = False
18 | EpisodeDescriptionConfig.ignore_hierarchy_probability = 0.0
19 |
20 | # By default don't use SimCLR Episodes.
21 | EpisodeDescriptionConfig.simclr_episode_fraction = 0.0
22 | # It is possible to override some of the above defaults only for meta-training.
23 | # An example is shown in the following two commented-out lines.
24 | # train/EpisodeDescriptionConfig.min_ways = 5
25 | # train/EpisodeDescriptionConfig.max_ways_upper_bound = 50
26 |
27 | # Other default values for the data pipeline.
28 | DataConfig.image_height = 84
29 | DataConfig.shuffle_buffer_size = 1000
30 | DataConfig.read_buffer_size_bytes = 1048576 # 1 MB (1024**2)
31 | DataConfig.num_prefetch = 64
32 |
--------------------------------------------------------------------------------
/meta_dataset/learn/gin/setups/data_config_feature.gin:
--------------------------------------------------------------------------------
1 | import meta_dataset.data.config
2 | import meta_dataset.data.decoder
3 | import meta_dataset.data.pipeline
4 | include 'meta_dataset/learn/gin/setups/data_config_common.gin'
5 |
6 | process_episode.support_decoder = @FeatureDecoder()
7 | process_episode.query_decoder = @FeatureDecoder()
8 | process_batch.batch_decoder = @FeatureDecoder()
9 | FeatureDecoder.feat_len = 64
10 |
--------------------------------------------------------------------------------
/meta_dataset/learn/gin/setups/data_config_flute.gin:
--------------------------------------------------------------------------------
1 | include 'meta_dataset/learn/gin/setups/data_config_common.gin'
2 | # Do SUR-like image processing and data augmentation. This differs from the
3 | # default Meta-Dataset image processing in that more augmentations are added
4 | # (e.g. random flips, random brightness, contrast, saturation etc)
5 | include 'meta_dataset/learn/gin/setups/sur_data_config.gin'
6 |
--------------------------------------------------------------------------------
/meta_dataset/learn/gin/setups/data_config_no_decoder.gin:
--------------------------------------------------------------------------------
1 | import meta_dataset.data.config
2 | import meta_dataset.data.decoder
3 | import meta_dataset.data.pipeline
4 | include 'meta_dataset/learn/gin/setups/data_config_common.gin'
5 |
6 | process_episode.support_decoder = None
7 | process_episode.query_decoder = None
8 | process_batch.batch_decoder = None
9 |
--------------------------------------------------------------------------------
/meta_dataset/learn/gin/setups/data_config_string.gin:
--------------------------------------------------------------------------------
1 | import meta_dataset.data.config
2 | import meta_dataset.data.decoder
3 | import meta_dataset.data.pipeline
4 | include 'meta_dataset/learn/gin/setups/data_config_common.gin'
5 |
6 | process_episode.support_decoder = @StringDecoder()
7 | process_episode.query_decoder = @StringDecoder()
8 | process_batch.batch_decoder = @StringDecoder()
9 |
--------------------------------------------------------------------------------
/meta_dataset/learn/gin/setups/data_config_tfds.gin:
--------------------------------------------------------------------------------
1 | import meta_dataset.data.config
2 | import meta_dataset.data.decoder
3 | import meta_dataset.data.pipeline
4 |
5 | # Variable shots / ways.
6 | EpisodeDescriptionConfig.num_ways = None
7 | EpisodeDescriptionConfig.num_support = None
8 | EpisodeDescriptionConfig.num_query = None
9 |
10 | # Default values for sampling variable shots / ways.
11 | EpisodeDescriptionConfig.min_ways = 5
12 | EpisodeDescriptionConfig.max_ways_upper_bound = 50
13 | EpisodeDescriptionConfig.max_num_query = 10
14 |
15 | EpisodeDescriptionConfig.min_examples_in_class = 0
16 | EpisodeDescriptionConfig.max_support_set_size = 500
17 | EpisodeDescriptionConfig.max_support_size_contrib_per_class = 100
18 | EpisodeDescriptionConfig.min_log_weight = -0.69314718055994529 # np.log(0.5)
19 | EpisodeDescriptionConfig.max_log_weight = 0.69314718055994529 # np.log(2)
20 | EpisodeDescriptionConfig.ignore_dag_ontology = False
21 | EpisodeDescriptionConfig.ignore_bilevel_ontology = False
22 | EpisodeDescriptionConfig.ignore_hierarchy_probability = 0.0
23 |
24 | # By default don't use SimCLR Episodes.
25 | EpisodeDescriptionConfig.simclr_episode_fraction = 0.0
26 |
27 | # Other default values for the data pipeline.
28 | DataConfig.image_height = 84
29 | DataConfig.shuffle_buffer_size = 1000
30 | DataConfig.read_buffer_size_bytes = 1048576 # 1 MB (1024**2)
31 | DataConfig.num_prefetch = 64
32 |
33 | process_dumped_episode.support_decoder = @support/ImageDecoder()
34 | process_episode.support_decoder = @support/ImageDecoder()
35 | support/ImageDecoder.data_augmentation = @support/DataAugmentation()
36 | support/DataAugmentation.enable_jitter = True
37 | support/DataAugmentation.jitter_amount = 0
38 | support/DataAugmentation.enable_gaussian_noise = True
39 | support/DataAugmentation.gaussian_noise_std = 0.0
40 |
41 | process_dumped_episode.query_decoder = @query/ImageDecoder()
42 | process_episode.query_decoder = @query/ImageDecoder()
43 | query/ImageDecoder.data_augmentation = @query/DataAugmentation()
44 | query/DataAugmentation.enable_jitter = False
45 | query/DataAugmentation.jitter_amount = 0
46 | query/DataAugmentation.enable_gaussian_noise = False
47 | query/DataAugmentation.gaussian_noise_std = 0.0
48 |
49 |
50 | ImageDecoder.skip_tfexample_decoding = True
51 |
--------------------------------------------------------------------------------
/meta_dataset/learn/gin/setups/fixed_way_and_shot.gin:
--------------------------------------------------------------------------------
1 | EpisodeDescriptionConfig.num_ways = 5
2 | EpisodeDescriptionConfig.num_support = 5
3 | EpisodeDescriptionConfig.num_query = 15
4 |
--------------------------------------------------------------------------------
/meta_dataset/learn/gin/setups/imagenet.gin:
--------------------------------------------------------------------------------
1 | benchmark.train_datasets = 'ilsvrc_2012'
2 | benchmark.eval_datasets = 'ilsvrc_2012'
3 | include 'meta_dataset/learn/gin/setups/data_config.gin'
4 | include 'meta_dataset/learn/gin/setups/trainer_config.gin'
5 | include 'meta_dataset/learn/gin/setups/variable_way_and_shot.gin'
6 | Trainer.data_config = @DataConfig()
7 | Trainer.train_episode_config = @train/EpisodeDescriptionConfig()
8 | Trainer.eval_episode_config = @EpisodeDescriptionConfig()
9 |
--------------------------------------------------------------------------------
/meta_dataset/learn/gin/setups/imagenet_v2.gin:
--------------------------------------------------------------------------------
1 | benchmark.train_datasets = 'ilsvrc_2012_v2'
2 | benchmark.eval_datasets = 'fungi,aircraft,quickdraw,omniglot,mscoco,cu_birds,dtd'
3 | include 'meta_dataset/learn/gin/setups/data_config.gin'
4 | include 'meta_dataset/learn/gin/setups/trainer_config.gin'
5 | include 'meta_dataset/learn/gin/setups/variable_way_and_shot.gin'
6 | Trainer.data_config = @DataConfig()
7 | Trainer.train_episode_config = @train/EpisodeDescriptionConfig()
8 | Trainer.eval_episode_config = @EpisodeDescriptionConfig()
9 |
--------------------------------------------------------------------------------
/meta_dataset/learn/gin/setups/mini_imagenet.gin:
--------------------------------------------------------------------------------
1 | benchmark.train_datasets = 'mini_imagenet'
2 | benchmark.eval_datasets = 'mini_imagenet'
3 | include 'meta_dataset/learn/gin/setups/data_config.gin'
4 | include 'meta_dataset/learn/gin/setups/trainer_config.gin'
5 | include 'meta_dataset/learn/gin/setups/fixed_way_and_shot.gin'
6 | Trainer.data_config = @DataConfig()
7 | Trainer.train_episode_config = @train/EpisodeDescriptionConfig()
8 | Trainer.eval_episode_config = @EpisodeDescriptionConfig()
9 | DataConfig.image_height = 84
10 |
--------------------------------------------------------------------------------
/meta_dataset/learn/gin/setups/sur_data_config.gin:
--------------------------------------------------------------------------------
1 | import meta_dataset.data.sur_decoder
2 | import meta_dataset.data.config
3 | import meta_dataset.data.decoder
4 |
5 | # Other default values for the data pipeline.
6 | DataConfig.image_height = 84
7 | DataConfig.shuffle_buffer_size = 1000
8 | DataConfig.read_buffer_size_bytes = 1048576 # 1 MB (1024**2)
9 | DataConfig.num_prefetch = 400
10 | SURImageDecoder.image_size = 84
11 |
12 | # If we decode features then change the lines below to use FeatureDecoder.
13 | process_episode.support_decoder = @support/SURImageDecoder()
14 | support/SURImageDecoder.data_augmentation = @support/SURDataAugmentation()
15 | support/SURDataAugmentation.enable_jitter = True
16 | support/SURDataAugmentation.jitter_amount = 0
17 | support/SURDataAugmentation.enable_gaussian_noise = True
18 | support/SURDataAugmentation.gaussian_noise_std = 0.0
19 | support/SURDataAugmentation.enable_random_flip = False
20 | support/SURDataAugmentation.enable_random_brightness = False
21 | support/SURDataAugmentation.random_brightness_delta = 0
22 | support/SURDataAugmentation.enable_random_contrast = False
23 | support/SURDataAugmentation.random_contrast_delta = 0
24 | support/SURDataAugmentation.enable_random_hue = False
25 | support/SURDataAugmentation.random_hue_delta = 0
26 | support/SURDataAugmentation.enable_random_saturation = False
27 | support/SURDataAugmentation.random_saturation_delta = 0
28 |
29 | process_episode.query_decoder = @query/SURImageDecoder()
30 | query/SURImageDecoder.data_augmentation = @query/SURDataAugmentation()
31 | query/SURDataAugmentation.enable_jitter = False
32 | query/SURDataAugmentation.jitter_amount = 0
33 | query/SURDataAugmentation.enable_gaussian_noise = False
34 | query/SURDataAugmentation.gaussian_noise_std = 0.0
35 | query/SURDataAugmentation.enable_random_flip = False
36 | query/SURDataAugmentation.enable_random_brightness = False
37 | query/SURDataAugmentation.random_brightness_delta = 0
38 | query/SURDataAugmentation.enable_random_contrast = False
39 | query/SURDataAugmentation.random_contrast_delta = 0
40 | query/SURDataAugmentation.enable_random_hue = False
41 | query/SURDataAugmentation.random_hue_delta = 0
42 | query/SURDataAugmentation.enable_random_saturation = False
43 | query/SURDataAugmentation.random_saturation_delta = 0
44 |
45 | process_batch.batch_decoder = @batch/SURImageDecoder()
46 | batch/SURImageDecoder.data_augmentation = @batch/SURDataAugmentation()
47 | batch/SURDataAugmentation.enable_jitter = True
48 | batch/SURDataAugmentation.jitter_amount = 8
49 | batch/SURDataAugmentation.enable_gaussian_noise = True
50 | batch/SURDataAugmentation.gaussian_noise_std = 0.0
51 | batch/SURDataAugmentation.enable_random_flip = False
52 | batch/SURDataAugmentation.enable_random_brightness = True
53 | batch/SURDataAugmentation.random_brightness_delta = 0.125
54 | batch/SURDataAugmentation.enable_random_contrast = True
55 | batch/SURDataAugmentation.random_contrast_delta = 0.2
56 | batch/SURDataAugmentation.enable_random_hue = True
57 | batch/SURDataAugmentation.random_hue_delta = 0.03
58 | batch/SURDataAugmentation.enable_random_saturation = True
59 | batch/SURDataAugmentation.random_saturation_delta = 0.2
60 |
61 | # For dumped episodes:
62 | process_dumped_episode.support_decoder = @support/ImageDecoder()
63 | support/ImageDecoder.data_augmentation = @support/DataAugmentation()
64 | support/DataAugmentation.enable_jitter = True
65 | support/DataAugmentation.jitter_amount = 0
66 | support/DataAugmentation.enable_gaussian_noise = True
67 | support/DataAugmentation.gaussian_noise_std = 0.0
68 | process_dumped_episode.query_decoder = @query/ImageDecoder()
69 | query/ImageDecoder.data_augmentation = @query/DataAugmentation()
70 | query/DataAugmentation.enable_jitter = False
71 | query/DataAugmentation.jitter_amount = 0
72 | query/DataAugmentation.enable_gaussian_noise = False
73 | query/DataAugmentation.gaussian_noise_std = 0.0
74 |
--------------------------------------------------------------------------------
/meta_dataset/learn/gin/setups/trainer_config.gin:
--------------------------------------------------------------------------------
1 | Trainer.num_updates = 75000
2 | Trainer.batch_size = 256 # Only applicable to non-episodic models.
3 | Trainer.num_eval_episodes = 600
4 | Trainer.checkpoint_every = 500
5 | Trainer.validate_every = 500
6 | Trainer.log_every = 100
7 | Trainer.distribute = False
8 | # Enable TensorFlow optimizations. It can add a few minutes to the first
9 | # calls to session.run(), but decrease memory usage.
10 | Trainer.enable_tf_optimizations = True
11 |
12 | Learner.transductive_batch_norm = False
13 | Learner.backprop_through_moments = True
14 |
15 |
16 |
--------------------------------------------------------------------------------
/meta_dataset/learn/gin/setups/trainer_config_debug.gin:
--------------------------------------------------------------------------------
1 | Trainer.num_updates = 100
2 | Trainer.batch_size = 8 # Only applicable to non-episodic models.
3 | Trainer.num_eval_episodes = 10
4 | Trainer.checkpoint_every = 10
5 | Trainer.validate_every = 5
6 | Trainer.log_every = 1
7 | Learner.transductive_batch_norm = False
8 | Learner.backprop_through_moments = True
9 |
--------------------------------------------------------------------------------
/meta_dataset/learn/gin/setups/trainer_config_flute.gin:
--------------------------------------------------------------------------------
1 | Trainer_flute.num_updates = 640000
2 | Trainer_flute.batch_size = 16
3 | Trainer_flute.num_eval_episodes = 600
4 | Trainer_flute.checkpoint_every = 2000
5 | Trainer_flute.validate_every = 1000
6 | Trainer_flute.log_every = 100
7 | Trainer_flute.decay_learning_rate = True
8 | Trainer_flute.checkpoint_to_restore = ''
9 | Trainer_flute.dataset_classifier_to_restore = ''
10 | Trainer_flute.learning_rate = 0.01
11 | Trainer_flute.decay_every = 5000
12 | Trainer_flute.decay_rate = 0.5
13 | Trainer_flute.learn_rate_scheduler = 'cosine_decay_restarts'
14 | Trainer_flute.optimizer_type = 'momentum'
15 | Trainer_flute.meta_batch_size = 8
16 | Trainer_flute.sample_half_from_imagenet = True
17 |
18 | Trainer_flute.distribute = False
19 | # Enable TensorFlow optimizations. It can add a few minutes to the first
20 | # calls to session.run(), but decrease memory usage.
21 | Trainer_flute.enable_tf_optimizations = True
22 | Trainer_flute.normalized_gradient_descent = False
23 | Trainer_flute.pretrained_source = ''
24 |
25 | Trainer_flute.data_config = @DataConfig()
26 | Trainer_flute.train_episode_config = @train/EpisodeDescriptionConfig()
27 | Trainer_flute.eval_episode_config = @EpisodeDescriptionConfig()
28 |
--------------------------------------------------------------------------------
/meta_dataset/learn/gin/setups/variable_way_and_shot.gin:
--------------------------------------------------------------------------------
1 | EpisodeDescriptionConfig.num_ways = None
2 | EpisodeDescriptionConfig.num_support = None
3 | EpisodeDescriptionConfig.num_query = None
4 |
--------------------------------------------------------------------------------
/meta_dataset/learners/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Meta-Dataset Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Module containing (meta-)Learners."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | from meta_dataset.learners.base import BatchLearner
23 | from meta_dataset.learners.base import EpisodicLearner
24 | from meta_dataset.learners.base import Learner
25 | from meta_dataset.learners.baseline_learners import BaselineLearner
26 | from meta_dataset.learners.metric_learners import MatchingNetworkLearner
27 | from meta_dataset.learners.metric_learners import MetricLearner
28 | from meta_dataset.learners.metric_learners import PrototypicalNetworkLearner
29 | from meta_dataset.learners.metric_learners import RelationNetworkLearner
30 | from meta_dataset.learners.optimization_learners import BaselineFinetuneLearner
31 | from meta_dataset.learners.optimization_learners import FLUTEFiLMLearner
32 | from meta_dataset.learners.optimization_learners import MAMLLearner
33 | from meta_dataset.learners.optimization_learners import OptimizationLearner
34 |
35 | __all__ = [
36 | 'BaselineFinetuneLearner',
37 | 'BatchLearner',
38 | 'BaselineLearner',
39 | 'EpisodicLearner',
40 | 'Learner',
41 | 'MAMLLearner',
42 | 'MatchingNetworkLearner',
43 | 'MetricLearner',
44 | 'OptimizationLearner',
45 | 'PrototypicalNetworkLearner',
46 | 'RelationNetworkLearner',
47 | 'FLUTEFiLMLearner',
48 | ]
49 |
--------------------------------------------------------------------------------
/meta_dataset/learners/base.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Meta-Dataset Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Abstract learners."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | import gin.tf
23 | import tensorflow.compat.v1 as tf
24 |
25 |
26 | @gin.configurable
27 | class Learner(object):
28 | """A Learner."""
29 |
30 | def __init__(
31 | self,
32 | is_training,
33 | logit_dim,
34 | transductive_batch_norm,
35 | backprop_through_moments,
36 | embedding_fn,
37 | input_shape,
38 | ):
39 | """Initializes a Learner.
40 |
41 | Note that Gin configuration of subclasses of `Learner` will override any
42 | corresponding Gin configurations of `Learner`, since parameters are passed
43 | to the `Learner` base class's constructor (See
44 | https://github.com/google/gin-config/blob/master/README.md) for more
45 | details).
46 |
47 | Args:
48 | is_training: Whether the `Learner` is in training (as opposed to
49 | evaluation) mode.
50 | logit_dim: An integer; the maximum dimensionality of output predictions or
51 | a list of ints, as required for each subclass of Learner.
52 | transductive_batch_norm: Whether to batch-normalize in the transductive
53 | setting, where means and variances for normalization are computed from
54 | each of the support and query sets (rather than using the support set
55 | statistics for normalization of both the support and query set).
56 | backprop_through_moments: Whether to allow gradients to flow through the
57 | support set moments; only applies to non-transductive batch norm.
58 | embedding_fn: A function that embeds examples.
59 | input_shape: A Tensor corresponding to `[batch_size] + example_shape`.
60 | """
61 | self.is_training = is_training
62 | self.logit_dim = logit_dim
63 | self.transductive_batch_norm = transductive_batch_norm
64 | self.backprop_through_moments = backprop_through_moments
65 | self.embedding_fn = embedding_fn
66 | self.input_shape = input_shape
67 |
68 | def build(self):
69 | """Additional build functionality for subclasses of `Learner`."""
70 |
71 | def compute_regularizer(self):
72 | """Computes a regularizer, independent of the data."""
73 | return tf.reduce_sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))
74 |
75 | def compute_loss(self, onehot_labels, predictions):
76 | """Computes the CE loss of `predictions` with respect to `onehot_labels`.
77 |
78 | Args:
79 | onehot_labels: A `tf.Tensor` containing the the class labels; each vector
80 | along the (last) class dimension should represent a valid probability
81 | distribution.
82 | predictions: A `tf.Tensor` containing the the class predictions,
83 | interpreted as unnormalized log probabilities.
84 |
85 | Returns:
86 | A `tf.Tensor` representing the loss per example.
87 | """
88 | cross_entropy_loss = tf.losses.softmax_cross_entropy(
89 | onehot_labels=onehot_labels,
90 | logits=predictions,
91 | reduction=tf.losses.Reduction.NONE)
92 | return cross_entropy_loss
93 |
94 | def compute_accuracy(self, onehot_labels, predictions):
95 | """Computes the accuracy of `predictions` with respect to `onehot_labels`.
96 |
97 | Args:
98 | onehot_labels: A `tf.Tensor` containing the the class labels; each vector
99 | along the (last) class dimension is expected to contain only a single
100 | `1`.
101 | predictions: A `tf.Tensor` containing the the class predictions
102 | represented as unnormalized log probabilities.
103 |
104 | Returns:
105 | A `tf.Tensor` of ones and zeros representing the correctness of
106 | individual predictions; use `tf.reduce_mean(...)` to obtain the average
107 | accuracy.
108 | """
109 | correct = tf.equal(tf.argmax(onehot_labels, -1), tf.argmax(predictions, -1))
110 | return tf.cast(correct, tf.float32)
111 |
112 | def forward_pass(self, data):
113 | """Returns the (query if episodic) logits."""
114 | raise NotImplementedError('Abstract method.')
115 |
116 |
117 | class EpisodicLearner(Learner):
118 | """An episodic learner."""
119 |
120 | pass
121 |
122 |
123 | class BatchLearner(Learner):
124 | """A batch learner."""
125 |
126 | pass
127 |
--------------------------------------------------------------------------------
/meta_dataset/learners/baseline_learners_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Meta-Dataset Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for `meta_dataset.learners.baseline_learners`."""
17 |
18 | import gin.tf
19 |
20 | from meta_dataset.learners import base_test
21 | from meta_dataset.learners import baseline_learners
22 |
23 | import tensorflow.compat.v1 as tf
24 | tf.disable_eager_execution()
25 |
26 | BASELINE_ARGS = {
27 | **base_test.VALID_LEARNER_INIT_ARGS,
28 | 'cosine_logits_multiplier': 1.0,
29 | 'knn_in_fc': False,
30 | 'knn_distance': 'l2',
31 | }
32 |
33 | gin.bind_parameter('linear_classifier.weight_decay', 0.01)
34 |
35 |
36 | class BaselineTest(base_test.TestBatchLearner):
37 | learner_cls = baseline_learners.BaselineLearner
38 | learner_kwargs = {
39 | **BASELINE_ARGS,
40 | 'cosine_classifier': False,
41 | 'use_weight_norm': False,
42 | }
43 |
44 |
45 | class WeightNormalizedBaselineTest(base_test.TestBatchLearner):
46 | learner_cls = baseline_learners.BaselineLearner
47 | learner_kwargs = {
48 | **BASELINE_ARGS,
49 | 'cosine_classifier': False,
50 | 'use_weight_norm': True,
51 | }
52 |
53 |
54 | class CosineClassifierBaselineTest(base_test.TestBatchLearner):
55 | learner_cls = baseline_learners.BaselineLearner
56 | learner_kwargs = {
57 | **BASELINE_ARGS,
58 | 'cosine_classifier': True,
59 | 'use_weight_norm': False,
60 | }
61 |
62 |
63 | class WeightNormalizedCosineClassifierBaselineTest(base_test.TestBatchLearner):
64 | learner_cls = baseline_learners.BaselineLearner
65 | learner_kwargs = {
66 | **BASELINE_ARGS,
67 | 'cosine_classifier': True,
68 | 'use_weight_norm': True,
69 | }
70 |
71 |
72 | if __name__ == '__main__':
73 | tf.test.main()
74 |
--------------------------------------------------------------------------------
/meta_dataset/learners/experimental/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Meta-Dataset Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Module containing (meta-)Learners."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | from meta_dataset.learners.experimental.base import ExperimentalBatchLearner
23 | from meta_dataset.learners.experimental.base import ExperimentalEpisodicLearner
24 | from meta_dataset.learners.experimental.base import ExperimentalLearner
25 | from meta_dataset.learners.experimental.optimization_learners import ANIL
26 | from meta_dataset.learners.experimental.optimization_learners import ExperimentalOptimizationLearner
27 | from meta_dataset.learners.experimental.optimization_learners import HeadAndBackboneLearner
28 | from meta_dataset.learners.experimental.optimization_learners import MAML
29 |
30 | __all__ = [
31 | 'ExperimentalBatchLearner',
32 | 'ExperimentalEpisodicLearner',
33 | 'ExperimentalLearner',
34 | 'ExperimentalOptimizationLearner',
35 | 'HeadAndBackboneLearner',
36 | 'MAML',
37 | 'ANIL',
38 | ]
39 |
--------------------------------------------------------------------------------
/meta_dataset/learners/experimental/base.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Meta-Dataset Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Abstract experimental learners that use `ReparameterizableModule`s."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | import gin.tf
23 |
24 | from meta_dataset.learners import base as learner_base
25 | import tensorflow as tf
26 |
27 |
28 | class NotBuiltError(RuntimeError):
29 |
30 | def __init__(self):
31 | super(NotBuiltError, self).__init__(
32 | 'The `build` method of `ExperimentalLearner` must be called before '
33 | 'accessing its variables.')
34 |
35 |
36 | def class_specific_data(onehot_labels, data, num_classes, axis=0):
37 | # TODO(eringrant): Deal with case of no data for a class in [1...num_classes].
38 | data_shape = [s for i, s in enumerate(data.shape) if i != axis]
39 | labels = tf.argmax(onehot_labels, axis=-1)
40 | class_idx = [tf.where(tf.equal(labels, i)) for i in range(num_classes)]
41 | return [
42 | tf.reshape(tf.gather(data, idx, axis=axis), [-1] + data_shape)
43 | for idx in class_idx
44 | ]
45 |
46 |
47 | @gin.configurable
48 | class ExperimentalLearner(learner_base.Learner):
49 | """An experimental learner."""
50 |
51 | def __init__(self, **kwargs):
52 | """Constructs an `ExperimentalLearner`.
53 |
54 | Args:
55 | **kwargs: Keyword arguments common to all `Learner`s.
56 |
57 | Raises:
58 | ValueError: If the `embedding_fn` provided is not an instance of
59 | `tf.Module`.
60 | """
61 | super(ExperimentalLearner, self).__init__(**kwargs)
62 |
63 | if not isinstance(self.embedding_fn, tf.Module):
64 | raise ValueError('The `embedding_fn` provided to `ExperimentalLearner`s '
65 | 'must be an instance of `tf.Module`.')
66 |
67 | self._built = False
68 |
69 | def compute_regularizer(self, onehot_labels, predictions):
70 | """Computes a regularizer, maybe using `predictions` and `onehot_labels`."""
71 | del onehot_labels
72 | del predictions
73 | return tf.reduce_sum(input_tensor=self.embedding_fn.losses)
74 |
75 | def build(self):
76 | """Instantiate the parameters belonging to this `ExperimentalLearner`."""
77 | if not self.embedding_fn.built:
78 | self.embedding_fn.build([None] + self.input_shape)
79 | self.embedding_shape = self.embedding_fn.compute_output_shape(
80 | [None] + self.input_shape)
81 | self._built = True
82 |
83 | @property
84 | def variables(self):
85 | """Returns a list of this `ExperimentalLearner`'s variables."""
86 | if not self._built:
87 | raise NotBuiltError
88 | return self.embedding_fn.variables
89 |
90 | @property
91 | def trainable_variables(self):
92 | """Returns a list of this `ExperimentalLearner`'s trainable variables."""
93 | if not self._built:
94 | raise NotBuiltError
95 | return self.embedding_fn.trainable_variables
96 |
97 |
98 | class ExperimentalEpisodicLearner(ExperimentalLearner,
99 | learner_base.EpisodicLearner):
100 | """An experimental episodic learner."""
101 |
102 | pass
103 |
104 |
105 | class ExperimentalBatchLearner(ExperimentalLearner, learner_base.BatchLearner):
106 | """An experimental batch learner."""
107 |
108 | pass
109 |
--------------------------------------------------------------------------------
/meta_dataset/learners/experimental/metric_learners_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Meta-Dataset Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for meta_dataset.learners.experimental.metric_learners."""
17 |
18 | from meta_dataset.learners import base_test
19 | from meta_dataset.learners.experimental import metric_learners
20 | from meta_dataset.models.experimental import reparameterizable_backbones
21 | import tensorflow.compat.v1 as tf
22 |
23 | tf.compat.v1.disable_eager_execution()
24 | tf.compat.v1.experimental.output_all_intermediates(True)
25 |
26 | metric_learner_kwargs = {
27 | 'backprop_through_moments': True,
28 | 'transductive_batch_norm': True,
29 | 'input_shape': [84, 84, 3],
30 | 'logit_dim': 5,
31 | 'is_training': True,
32 | 'distance_metric': metric_learners.euclidean_distance,
33 | }
34 |
35 |
36 | class PrototypicalNetworkTest(base_test.TestEpisodicLearner):
37 | learner_cls = metric_learners.PrototypicalNetwork
38 | learner_kwargs = metric_learner_kwargs
39 |
40 |
41 | class MatchingNetworkTest(base_test.TestEpisodicLearner):
42 | learner_cls = metric_learners.MatchingNetwork
43 | learner_kwargs = metric_learner_kwargs
44 |
45 |
46 | class RelationNetworkTest(base_test.TestEpisodicLearner):
47 | learner_cls = metric_learners.RelationNetwork
48 | learner_kwargs = metric_learner_kwargs
49 |
50 | def set_up_learner(self):
51 | """Set up a `reparameterizable_backbones.RelationNetConvNet` backbone."""
52 | learner_kwargs = self.learner_kwargs
53 | learner_kwargs['embedding_fn'] = (
54 | reparameterizable_backbones.RelationNetConvNet(keep_spatial_dims=True))
55 | data = self.random_data()
56 | learner = self.learner_cls(**learner_kwargs)
57 | return data, learner
58 |
59 |
60 | if __name__ == '__main__':
61 | tf.test.main()
62 |
--------------------------------------------------------------------------------
/meta_dataset/learners/experimental/optimization_learners_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Meta-Dataset Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for meta_dataset.learners.experimental.metric_learners."""
17 |
18 | import gin.tf
19 | from meta_dataset.learners import base_test
20 | from meta_dataset.learners.experimental import optimization_learners
21 | import tensorflow.compat.v1 as tf
22 |
23 | tf.compat.v1.disable_eager_execution()
24 | tf.compat.v1.experimental.output_all_intermediates(True)
25 |
26 |
27 | def mock_sgd():
28 |
29 | def init(x0):
30 | return x0
31 |
32 | def update(i, grad, state):
33 | del i
34 | x = state
35 | return x - 0.01 * grad
36 |
37 | def get_params(state):
38 | x = state
39 | return x
40 |
41 | return init, update, get_params
42 |
43 |
44 | optimization_learner_kwargs = {
45 | 'backprop_through_moments': True,
46 | 'input_shape': [84, 84, 3],
47 | 'logit_dim': 5,
48 | 'is_training': True,
49 | 'update_fn': mock_sgd,
50 | 'additional_evaluation_update_steps': 5,
51 | 'clip_grad_norm': 10.0,
52 | 'num_update_steps': 5,
53 | }
54 |
55 |
56 | class FirstOrderMAMLTest(base_test.TestEpisodicLearner):
57 | learner_cls = optimization_learners.MAML
58 | learner_kwargs = dict(
59 | **optimization_learner_kwargs, **{
60 | 'transductive_batch_norm': False,
61 | 'proto_maml_fc_layer_init': False,
62 | 'zero_fc_layer_init': False,
63 | 'first_order': False,
64 | 'adapt_batch_norm': True,
65 | })
66 |
67 |
68 | class VanillaMAMLTest(base_test.TestEpisodicLearner):
69 | learner_cls = optimization_learners.MAML
70 | learner_kwargs = dict(
71 | **optimization_learner_kwargs, **{
72 | 'transductive_batch_norm': True,
73 | 'proto_maml_fc_layer_init': False,
74 | 'zero_fc_layer_init': False,
75 | 'first_order': False,
76 | 'adapt_batch_norm': False,
77 | })
78 |
79 |
80 | class ProtoMAMLTest(base_test.TestEpisodicLearner):
81 |
82 | def setUp(self):
83 | super().setUp()
84 | gin.bind_parameter('proto_maml_fc_layer_init_fn.prototype_multiplier', 1.0)
85 |
86 | def tearDown(self):
87 | gin.clear_config()
88 | super().tearDown()
89 |
90 | learner_cls = optimization_learners.MAML
91 | learner_kwargs = dict(
92 | **optimization_learner_kwargs, **{
93 | 'transductive_batch_norm': False,
94 | 'proto_maml_fc_layer_init': True,
95 | 'zero_fc_layer_init': False,
96 | 'first_order': False,
97 | 'adapt_batch_norm': True,
98 | })
99 |
100 |
101 | class FirstOrderANILTest(base_test.TestEpisodicLearner):
102 | learner_cls = optimization_learners.ANIL
103 | learner_kwargs = dict(
104 | **optimization_learner_kwargs, **{
105 | 'transductive_batch_norm': False,
106 | 'proto_maml_fc_layer_init': False,
107 | 'zero_fc_layer_init': False,
108 | 'first_order': False,
109 | 'adapt_batch_norm': True,
110 | })
111 |
112 |
113 | class VanillaANILTest(base_test.TestEpisodicLearner):
114 | learner_cls = optimization_learners.ANIL
115 | learner_kwargs = dict(
116 | **optimization_learner_kwargs, **{
117 | 'transductive_batch_norm': True,
118 | 'proto_maml_fc_layer_init': False,
119 | 'zero_fc_layer_init': False,
120 | 'first_order': False,
121 | 'adapt_batch_norm': False,
122 | })
123 |
124 |
125 | class ProtoANILTest(base_test.TestEpisodicLearner):
126 |
127 | def setUp(self):
128 | super().setUp()
129 | gin.bind_parameter('proto_maml_fc_layer_init_fn.prototype_multiplier', 1.0)
130 |
131 | def tearDown(self):
132 | gin.clear_config()
133 | super().tearDown()
134 |
135 | learner_cls = optimization_learners.ANIL
136 | learner_kwargs = dict(
137 | **optimization_learner_kwargs, **{
138 | 'transductive_batch_norm': False,
139 | 'proto_maml_fc_layer_init': True,
140 | 'zero_fc_layer_init': False,
141 | 'first_order': False,
142 | 'adapt_batch_norm': True,
143 | })
144 |
145 |
146 | if __name__ == '__main__':
147 | tf.test.main()
148 |
--------------------------------------------------------------------------------
/meta_dataset/learners/metric_learners_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Meta-Dataset Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for `meta_dataset.learners.metric_learners`."""
17 |
18 | import gin
19 |
20 | from meta_dataset.learners import base_test
21 | from meta_dataset.learners import metric_learners
22 |
23 | import tensorflow.compat.v1 as tf
24 | tf.disable_eager_execution()
25 |
26 | metric_learner_kwargs = {
27 | **base_test.VALID_LEARNER_INIT_ARGS,
28 | }
29 |
30 | gin.bind_parameter('relationnet_convnet.weight_decay', 0.01)
31 | gin.bind_parameter('relation_module.weight_decay', 0.01)
32 |
33 |
34 | class CosineMatchingNetworkLearnerTest(base_test.TestEpisodicLearner):
35 | learner_cls = metric_learners.MatchingNetworkLearner
36 | learner_kwargs = {
37 | **metric_learner_kwargs,
38 | 'exact_cosine_distance': False,
39 | }
40 |
41 |
42 | class ExactCosineMatchingNetworkLearnerTest(base_test.TestEpisodicLearner):
43 | learner_cls = metric_learners.MatchingNetworkLearner
44 | learner_kwargs = {
45 | **metric_learner_kwargs,
46 | 'exact_cosine_distance': True,
47 | }
48 |
49 |
50 | class PrototypicalNetworkLearnerTest(base_test.TestEpisodicLearner):
51 | learner_cls = metric_learners.PrototypicalNetworkLearner
52 | learner_kwargs = {
53 | **metric_learner_kwargs,
54 | }
55 |
56 |
57 | class RelationNetworkLearnerTest(base_test.TestEpisodicLearner):
58 | learner_cls = metric_learners.RelationNetworkLearner
59 | learner_kwargs = {
60 | **metric_learner_kwargs,
61 | }
62 |
63 |
64 | if __name__ == '__main__':
65 | tf.test.main()
66 |
--------------------------------------------------------------------------------
/meta_dataset/learners/optimization_learners_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Meta-Dataset Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for `meta_dataset.learners.optimization_learners`."""
17 |
18 | import gin.tf
19 |
20 | from meta_dataset.learners import base_test
21 | from meta_dataset.learners import optimization_learners
22 |
23 | import tensorflow.compat.v1 as tf
24 | tf.disable_eager_execution()
25 |
26 | BASELINE_FINETUNE_ARGS = {
27 | **base_test.VALID_LEARNER_INIT_ARGS,
28 | 'knn_in_fc': False,
29 | 'knn_distance': False,
30 | 'cosine_classifier': False,
31 | 'cosine_logits_multiplier': 1.0,
32 | 'use_weight_norm': False,
33 | 'is_training': False,
34 | }
35 |
36 | gin.bind_parameter('MAMLLearner.classifier_weight_decay', 0.01)
37 |
38 |
39 | class BaselineFinetuneTest():
40 |
41 | def testLearnerConvergence(self):
42 | # `BaselineFinetuneLearner` differs from `BaselineLearner` only at
43 | # evaluation time.
44 | pass
45 |
46 | def testLearnerImprovement(self):
47 | # `BaselineFinetuneLearner` differs from `BaselineLearner` only at
48 | # evaluation time.
49 | pass
50 |
51 |
52 | class BaselineFinetuneAllLayersAdamTest(BaselineFinetuneTest,
53 | base_test.TestEpisodicLearner):
54 | learner_cls = optimization_learners.BaselineFinetuneLearner
55 | learner_kwargs = {
56 | **BASELINE_FINETUNE_ARGS,
57 | 'num_finetune_steps': 5,
58 | 'finetune_lr': 0.01,
59 | 'finetune_all_layers': True,
60 | 'finetune_with_adam': True,
61 | }
62 |
63 |
64 | class BaselineFinetuneAllLayersGDTest(BaselineFinetuneTest,
65 | base_test.TestEpisodicLearner):
66 | learner_cls = optimization_learners.BaselineFinetuneLearner
67 | learner_kwargs = {
68 | **BASELINE_FINETUNE_ARGS,
69 | 'num_finetune_steps': 5,
70 | 'finetune_lr': 0.1,
71 | 'finetune_all_layers': True,
72 | 'finetune_with_adam': False,
73 | }
74 |
75 |
76 | class BaselineFinetuneLastLayerGDTest(BaselineFinetuneTest,
77 | base_test.TestEpisodicLearner):
78 | learner_cls = optimization_learners.BaselineFinetuneLearner
79 | learner_kwargs = {
80 | **BASELINE_FINETUNE_ARGS,
81 | 'num_finetune_steps': 10,
82 | 'finetune_lr': 0.1,
83 | 'finetune_all_layers': False,
84 | 'finetune_with_adam': False,
85 | }
86 |
87 |
88 | MAML_KWARGS = {
89 | **base_test.VALID_LEARNER_INIT_ARGS,
90 | 'additional_evaluation_update_steps':
91 | 0,
92 | 'first_order':
93 | True,
94 | 'adapt_batch_norm':
95 | True,
96 | }
97 |
98 |
99 | class MAMLLearnerTest(base_test.TestEpisodicLearner):
100 | learner_cls = optimization_learners.MAMLLearner
101 | learner_kwargs = {
102 | **MAML_KWARGS,
103 | 'num_update_steps': 5,
104 | 'alpha': 0.01,
105 | 'zero_fc_layer': False,
106 | 'proto_maml_fc_layer_init': False,
107 | }
108 |
109 |
110 | class ZeroInitMAMLLearnerTest(base_test.TestEpisodicLearner):
111 | learner_cls = optimization_learners.MAMLLearner
112 | learner_kwargs = {
113 | **MAML_KWARGS,
114 | 'num_update_steps': 10,
115 | 'alpha': 0.01,
116 | 'zero_fc_layer': True,
117 | 'proto_maml_fc_layer_init': False,
118 | }
119 |
120 |
121 | class ProtoMAMLLearnerTest(base_test.TestEpisodicLearner):
122 | learner_cls = optimization_learners.MAMLLearner
123 | learner_kwargs = {
124 | **MAML_KWARGS,
125 | 'num_update_steps': 5,
126 | 'alpha': 0.01,
127 | 'zero_fc_layer': False,
128 | 'proto_maml_fc_layer_init': True,
129 | }
130 |
131 |
132 | if __name__ == '__main__':
133 | tf.test.main()
134 |
--------------------------------------------------------------------------------
/meta_dataset/models/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Meta-Dataset Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 |
17 |
--------------------------------------------------------------------------------
/meta_dataset/models/experimental/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Meta-Dataset Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 |
17 |
--------------------------------------------------------------------------------
/meta_dataset/test_utils.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Meta-Dataset Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Testing utilities for meta-dataset.
17 |
18 | This module includes functions that inspect class APIs in order to recursively
19 | enumerate arguments to a method of the class as well as arguments passed
20 | implicitly (via *args and/or **kwargs) to the overriden method of any ancestor
21 | classes. These functions are useful for concisely defining tests for classes
22 | that take potentially-overlapping but non-identical arguments.
23 |
24 | For examples, see
25 | `meta_dataset.models.experimental.reparameterizable_base_test`.
26 | """
27 |
28 | from __future__ import absolute_import
29 | from __future__ import division
30 | from __future__ import print_function
31 |
32 | import functools
33 | import inspect
34 | import itertools
35 | import six
36 |
37 |
38 | def get_argspec(fn):
39 | """Return the `ArgSpec` namespace for this (potentially wrapped) `fn`."""
40 | while hasattr(fn, '__wrapped__'):
41 | fn = fn.__wrapped__
42 | if six.PY3:
43 | return inspect.getfullargspec(fn)
44 | else:
45 | try:
46 | return inspect.getargspec(fn) # pylint: disable=deprecated-method
47 | except TypeError:
48 | # Cannot inspect C variables (https://stackoverflow.com/a/7628130), so
49 | # return an empty `ArgSpec`.
50 | return inspect.ArgSpec([], None, None, ())
51 |
52 |
53 | def get_inherited_args(cls, cls_method):
54 | """Return arguments to `cls_method` of `cls` and all parents of `cls`."""
55 | # TODO(eringrant): This slicing operation will not work for static methods
56 | # (for which the first argument is not the object instance).
57 | args = get_argspec(getattr(cls, cls_method)).args[1:]
58 |
59 | # Inspect all parent classes.
60 | for parent_cls in cls.__mro__:
61 | if cls_method in parent_cls.__dict__:
62 | args += get_argspec(getattr(parent_cls, cls_method)).args[1:]
63 |
64 | return set(args)
65 |
66 |
67 | get_inherited_init_args = functools.partial(
68 | get_inherited_args, cls_method='__init__')
69 | get_inherited_call_args = functools.partial(
70 | get_inherited_args, cls_method='__call__')
71 |
72 |
73 | def get_valid_kwargs(module_cls, valid_module_init_args,
74 | valid_module_call_args):
75 | """Return all valid kwarg configurations for `module_cls`."""
76 | init_args = get_inherited_init_args(module_cls)
77 | call_args = get_inherited_call_args(module_cls)
78 |
79 | valid_init_kwargs = (
80 | zip(itertools.repeat(init_arg), valid_module_init_args[init_arg])
81 | for init_arg in init_args
82 | if init_arg in valid_module_init_args)
83 | valid_call_kwargs = (
84 | zip(itertools.repeat(call_arg), valid_module_call_args[call_arg])
85 | for call_arg in call_args
86 | if call_arg in valid_module_call_args)
87 |
88 | init_kwargs_combos = (
89 | dict(combo) for combo in itertools.product(*valid_init_kwargs))
90 | call_kwargs_combos = (
91 | dict(combo) for combo in itertools.product(*valid_call_kwargs))
92 |
93 | return itertools.product(init_kwargs_combos, call_kwargs_combos)
94 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | absl-py >= 0.7
2 | etils >= 0.4.0
3 | gin-config>=0.1.2
4 | numpy>=1.13.3
5 | scipy>=1.0.0
6 | six >= 1.10
7 | tensorflow-gpu
8 | sklearn
9 | tensorflow_probability <= 0.7
10 | tf-models-official
11 | tensorflow-datasets
12 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | """Install script for setuptools."""
2 |
3 | from distutils import cmd
4 | import os
5 | import urllib.request
6 |
7 | from setuptools import find_packages
8 | from setuptools import setup
9 | from setuptools.command import install
10 |
11 | SIMCLR_DIR = 'simclr'
12 | DATA_UTILS_URL = 'https://raw.githubusercontent.com/google-research/simclr/master/data_util.py'
13 |
14 |
15 | class DownloadSimCLRAugmentationCommand(cmd.Command):
16 | """Downloads SimCLR data_utils.py as it's not built into an egg."""
17 | description = __doc__
18 | user_options = []
19 |
20 | def initialize_options(self):
21 | pass
22 |
23 | def finalize_options(self):
24 | pass
25 |
26 | def run(self):
27 | build_cmd = self.get_finalized_command('build')
28 | dist_root = os.path.realpath(build_cmd.build_lib)
29 | output_dir = os.path.join(dist_root, SIMCLR_DIR)
30 | if not os.path.exists(output_dir):
31 | os.makedirs(output_dir)
32 | output_path = os.path.join(output_dir, 'data_util.py')
33 | downloader = urllib.request.URLopener()
34 | downloader.retrieve(DATA_UTILS_URL, output_path)
35 |
36 |
37 | class InstallCommand(install.install):
38 |
39 | def run(self):
40 | self.run_command('simclr_download')
41 | install.install.run(self)
42 |
43 | setup(
44 | name='meta_dataset',
45 | version='0.2.0',
46 | description='A benchmark for few-shot classification.',
47 | author='Google LLC',
48 | license='Apache License, Version 2.0',
49 | python_requires='>=2.7, <3.10',
50 | packages=find_packages(),
51 | include_package_data=True,
52 | install_requires=[
53 | 'absl-py>=0.7.0',
54 | 'etils>=0.4.0',
55 | 'gin-config>=0.1.2',
56 | 'numpy>=1.13.3',
57 | 'scipy>=1.0.0',
58 | 'setuptools',
59 | 'six>=1.12.0',
60 | # Note that this will install tf 2.0, even though this is a tf 1.0
61 | # project. This is necessary because we rely on augmentation from
62 | # tf-models-official that wasn't added until after tf 2.0 was released.
63 | 'tensorflow-gpu',
64 | 'sklearn',
65 | 'tensorflow_probability<=0.7',
66 | 'tf-models-official',
67 | 'tensorflow-datasets',
68 | ],
69 | cmdclass={
70 | 'simclr_download': DownloadSimCLRAugmentationCommand,
71 | 'install': InstallCommand,
72 | })
73 |
--------------------------------------------------------------------------------