├── .gitignore ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── download.sh ├── example_visualization.py ├── n_digit_mnist.py └── standard_datasets ├── dataset_mnist_2_instance_test.npz ├── dataset_mnist_2_instance_train.npz ├── dataset_mnist_2_number_test_seen.npz ├── dataset_mnist_2_number_test_unseen.npz ├── dataset_mnist_2_number_train.npz ├── dataset_mnist_3_instance_test.npz ├── dataset_mnist_3_instance_train.npz ├── dataset_mnist_3_number_test_seen.npz ├── dataset_mnist_3_number_test_unseen.npz └── dataset_mnist_3_number_train.npz /.gitignore: -------------------------------------------------------------------------------- 1 | # Project specific 2 | data/ 3 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to Contribute 2 | 3 | We'd love to accept your patches and contributions to this project. There are 4 | just a few small guidelines you need to follow. 5 | 6 | ## Contributor License Agreement 7 | 8 | Contributions to this project must be accompanied by a Contributor License 9 | Agreement. You (or your employer) retain the copyright to your contribution; 10 | this simply gives us permission to use and redistribute your contributions as 11 | part of the project. Head over to to see 12 | your current agreements on file or to sign a new one. 13 | 14 | You generally only need to submit a CLA once, so if you've already submitted one 15 | (even if it was for a different project), you probably don't need to do it 16 | again. 17 | 18 | ## Code reviews 19 | 20 | All submissions, including submissions by project members, require review. We 21 | use GitHub pull requests for this purpose. Consult 22 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more 23 | information on using pull requests. 24 | 25 | ## Community Guidelines 26 | 27 | This project follows [Google's Open Source Community 28 | Guidelines](https://opensource.google.com/conduct/). 29 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # n-digit MNIST 2 | 3 | MNIST handwritten digits have been arguably the most popular dataset for machine learning research. 4 | Although the state-of-the-art learned models have long ago reached possibly the best achievable performances on this benchmark, 5 | the dataset itself remains useful to the research community, providing a simple sanity check for new methods: 6 | if it doesn't work on MNIST, it doesn't work anywhere! 7 | 8 | We introduce n-digit variants of MNIST here. 9 | By adding more digits per data point, one can exponentially increase the number of classes for the dataset. 10 | Nonetheless, they still take advantage of the simpleness and light-weighted nature of data. 11 | These datasets provide a simple and useful toy examples for e.g. face embedding. 12 | One can furthermore draw an analogy between individual digits and e.g. face attributes. 13 | In this case, the dataset serves to provide quick insights into the embedding algorithm to be scaled up to more realistic, slow-to-train problems. 14 | 15 | Due to potential proprietarity issues and greater flexibility, we release the code for _generating_ the dataset from the original MNIST dataset, 16 | rather than releasing images themselves. 17 | For benchmarking purposes, we release four _standard_ datasets which are, again, generated via code, but deterministically. 18 | 19 | # Dataset protocols 20 | 21 | Given `n`, the number of digits per sample, we generate data samples which are horizontal concatenations of the original MNIST digit images. 22 | We introduce training and test sets, each of which are built from individual digit images from original training and test sets, respectively. 23 | In both training and test splits, each n-digit class has exactly the same number of examples. 24 | 25 | ## Generating dataset 26 | 27 | ### Dependencies 28 | 29 | Only `numpy` is required. 30 | 31 | ### Download the original MNIST dataset 32 | 33 | Running 34 | ``` shell 35 | ./download.sh 36 | ``` 37 | will download the original MNIST dataset from [official MNIST website](http://yann.lecun.com/exdb/mnist/) 38 | and unzip the files in the `data/` folder: 39 | ``` shell 40 | data/train-images.idx3-ubyte 41 | data/train-labels.idx1-ubyte 42 | data/t10k-images.idx3-ubyte 43 | data/t10k-labels.idx1-ubyte 44 | ``` 45 | 46 | ### Creating the standard n-digit MNIST datasets 47 | 48 | We have four _standard n-digit MNIST_ datasets ready: *mnist_2_instance*, *mnist_2_number*, *mnist_3_instance*, *mnist_3_number*. 49 | Unlike custom-built datasets, they are deterministically generated from pre-computed random arrays. 50 | These datasets are suitable for benchmarking model performances. 51 | 52 | Above four datasets can be created by attaching the `--use_standard_dataset` flag. 53 | 54 | 55 | ``` shell 56 | python n_digit_mnist.py --num_digits 2 --domain_gap instance --use_standard_dataset 57 | python n_digit_mnist.py --num_digits 2 --domain_gap number --use_standard_dataset 58 | python n_digit_mnist.py --num_digits 3 --domain_gap instance --use_standard_dataset 59 | python n_digit_mnist.py --num_digits 3 --domain_gap number --use_standard_dataset 60 | ``` 61 | 62 | To optionally check samples from the dataset, run the following command (requires `pillow` package): 63 | 64 | ``` shell 65 | python example_visualization.py --num_digits 2 --domain_gap instance --num_visualize 10 --mnist_split train 66 | python example_visualization.py --num_digits 2 --domain_gap instance --num_visualize 10 --mnist_split test 67 | ``` 68 | 69 | They extract 20 random samples of the 2-digit instance-gap dataset, 10 from train and 10 from test split, in the visualization subfolder (e.g. `data/dataset_mnist_2_instance/visualization`). 70 | 71 | ### Create your own dataset 72 | 73 | See `n_digit_mnist.py` argument options and configure a new dataset yourself. 74 | Example of 4-digit MNIST with `number` domain gap: 75 | 76 | ``` shell 77 | python n_digit_mnist.py --num_digits 4 --domain_gap number 78 | ``` 79 | 80 | ## Citing the dataset 81 | 82 | The dataset is introduced in the following publication. Use the following bibtex for citing the dataset: 83 | 84 | ``` 85 | @inproceedings{joon2019iclr, 86 | title = {Modeling Uncertainty with Hedged Instance Embedding}, 87 | author = {Oh, Seong Joon and Murphy, Kevin and Pan, Jiyan and Roth, Joseph and Schroff, Florian and Gallagher, Andrew}, 88 | year = {2018}, 89 | booktitle = {International Conference on Learning Representations (ICLR)}, 90 | } 91 | ``` 92 | 93 | ## License 94 | 95 | This project is licensed under the Apache License - see the [LICENSE](LICENSE) file for details. 96 | 97 | This is not an officially supported Google product. 98 | -------------------------------------------------------------------------------- /download.sh: -------------------------------------------------------------------------------- 1 | mkdir -p data 2 | 3 | if [[ "$OSTYPE" == "linux-gnu" ]]; then 4 | wget http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz -O data/train-images-idx3-ubyte.gz 5 | wget http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz -O data/train-labels-idx1-ubyte.gz 6 | wget http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz -O data/t10k-images-idx3-ubyte.gz 7 | wget http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz -O data/t10k-labels-idx1-ubyte.gz 8 | 9 | elif [[ "$OSTYPE" == "darwin"* ]]; then 10 | curl http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz -o data/train-images-idx3-ubyte.gz 11 | curl http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz -o data/train-labels-idx1-ubyte.gz 12 | curl http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz -o data/t10k-images-idx3-ubyte.gz 13 | curl http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz -o data/t10k-labels-idx1-ubyte.gz 14 | 15 | fi 16 | 17 | gzip -d data/train-images-idx3-ubyte.gz 18 | gzip -d data/train-labels-idx1-ubyte.gz 19 | gzip -d data/t10k-images-idx3-ubyte.gz 20 | gzip -d data/t10k-labels-idx1-ubyte.gz 21 | 22 | mv data/train-images-idx3-ubyte data/train-images.idx3-ubyte 23 | mv data/train-labels-idx1-ubyte data/train-labels.idx1-ubyte 24 | mv data/t10k-images-idx3-ubyte data/t10k-images.idx3-ubyte 25 | mv data/t10k-labels-idx1-ubyte data/t10k-labels.idx1-ubyte 26 | 27 | -------------------------------------------------------------------------------- /example_visualization.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Load MNIST dataset, and generate its n-digit version.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import os 22 | 23 | import numpy as np 24 | from PIL import Image 25 | 26 | 27 | def main(): 28 | parser = argparse.ArgumentParser( 29 | description='Visualize created n-digit MNIST dataset.') 30 | 31 | parser.add_argument('--num_digits', default=1, type=int, 32 | help='Number of concatenated digits per data point.') 33 | 34 | parser.add_argument('--domain_gap', default='instance', type=str, 35 | choices=['instance', 'number'], 36 | help='How to split training and test sets' 37 | 'of the n-digit mnist.') 38 | 39 | parser.add_argument('--num_visualize', default=10, type=int, 40 | help='How many samples to visualize.') 41 | 42 | parser.add_argument('--output_dir', default='data', type=str, 43 | help='Directory to write the dataset.') 44 | 45 | parser.add_argument('--mnist_split', default='train', type=str, 46 | help='Which MNIST split was used to generate data?' 47 | '(train or test)') 48 | 49 | args = parser.parse_args() 50 | 51 | # Construct dataset directory. 52 | dataset_dir = os.path.join(args.output_dir, 53 | 'dataset_mnist_%d_%s' % (args.num_digits, 54 | args.domain_gap)) 55 | 56 | path = os.path.join(args.dataset_dir, '%s.npz' % args.mnist_split) 57 | with open(path, 'r') as f: 58 | data = np.load(path) 59 | labels = data['labels'] 60 | images = data['images'] 61 | 62 | visualize_dir = os.path.join(args.dataset_dir, 'visualization') 63 | if not os.path.exists(visualize_dir): 64 | os.makedirs(visualize_dir) 65 | 66 | visualize_indices = np.random.choice(range(len(labels)), 67 | args.num_visualize, replace=False) 68 | 69 | for i in visualize_indices: 70 | im = Image.fromarray(images[i]) 71 | im.save(os.path.join(args.visualize_dir, 72 | 'sample_image_%d_label_%d.jpg' % (i, labels[i]))) 73 | 74 | if __name__ == '__main__': 75 | main() 76 | 77 | -------------------------------------------------------------------------------- /n_digit_mnist.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Load MNIST dataset, and generate its n-digit version.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import argparse 22 | import logging 23 | import os 24 | import struct 25 | 26 | import numpy as np 27 | 28 | 29 | class NDigitMnist(object): 30 | """Construct and write n-digit MNIST into .npz files.""" 31 | 32 | def __init__(self, args): 33 | self.args = args 34 | np.random.seed(args.seed) 35 | 36 | def _load_mnist(self, mnist_split): 37 | """Loads the training or test MNIST data and returns a (images, labels).""" 38 | mnist_map = { 39 | 'train': { 40 | 'images': 'train-images.idx3-ubyte', 41 | 'labels': 'train-labels.idx1-ubyte', 42 | }, 43 | 'test': { 44 | 'images': 't10k-images.idx3-ubyte', 45 | 'labels': 't10k-labels.idx1-ubyte', 46 | }, 47 | }[mnist_split] 48 | 49 | with open(os.path.join(self.args.mnist_dir, mnist_map['images'])) as f: 50 | images = self._decode_mnist(f, 'images') 51 | with open(os.path.join(self.args.mnist_dir, mnist_map['labels'])) as f: 52 | labels = self._decode_mnist(f, 'labels') 53 | 54 | return images, labels 55 | 56 | def _decode_mnist(self, f, which_data): 57 | """Decode raw MNIST dataset into numpy array.""" 58 | 59 | if which_data == 'images': 60 | magic, size, rows, cols = struct.unpack('>IIII', f.read(16)) 61 | if magic != 2051: 62 | raise ValueError('Machic number mismatch.') 63 | return np.frombuffer(f.read(), dtype=np.uint8).reshape(size, rows, cols) 64 | 65 | elif which_data == 'labels': 66 | magic, size = struct.unpack('>II', f.read(8)) 67 | if magic != 2049: 68 | raise ValueError('Machic number mismatch.') 69 | return np.frombuffer(f.read(), dtype=np.uint8).reshape(size) 70 | 71 | def _print_mnist_images(self, images, labels, num_print=10): 72 | shuffle_indices = range(len(labels)) 73 | np.random.shuffle(shuffle_indices) 74 | for i in xrange(num_print): 75 | self._print_mnist_image(images[shuffle_indices[i]], 76 | labels[shuffle_indices[i]]) 77 | 78 | def _print_mnist_image(self, image, label): 79 | """Given an image, prints it as a string array.""" 80 | mapper = np.array(['.', '@']) 81 | for h_ind in xrange(image.shape[0]): 82 | for w_ind in xrange(image.shape[1]): 83 | print(mapper[int(image[h_ind, w_ind] > 0.5)], end='') 84 | print('') 85 | print(label) 86 | 87 | def _write_num_classes(self, dataset_dir, num_classes, mnist_split): 88 | with open( 89 | os.path.join(dataset_dir, 'num_classes_%s' % mnist_split), 'w') as f: 90 | f.write(str(num_classes)) 91 | 92 | def _write_num_samples(self, dataset_dir, num_samples, mnist_split): 93 | with open( 94 | os.path.join(dataset_dir, 'num_samples_%s' % mnist_split), 'w') as f: 95 | f.write(str(num_samples)) 96 | 97 | def _choose_numbers_to_include(self, mnist_split): 98 | """Recipe for contructing n-digit MNIST. 99 | 100 | Strategy for contructing n-digit MNIST: (1) Make a list of numbers 101 | to be included in the {train,test} split, (2) Given the list of numbers, 102 | sample digit images from the original MNIST {train,test} splits to compile 103 | sample_per_number n-digit images per number - see 104 | self._compile_number_images_and_labels. 105 | 106 | When domain_gap is 'number', we assure that no number is shared across 107 | train-test. number_ratio_train percent of numbers will be assigned to train 108 | split, the rest to the test split. 109 | 110 | Args: 111 | mnist_split: ['train', 'test'] MNIST split where digit images are sampled. 112 | 113 | Returns: 114 | sample_per_number: {int} Number of samples per class. 115 | chosen_numbers: {list} List of numbers to include in the current split. 116 | 117 | Raises: 118 | ValueError: if chosen domain_gap parameter is not supported or the 119 | minimal number of samples per class condition is not satisfiable. 120 | """ 121 | 122 | all_numbers = xrange(10 ** self.args.num_digits) 123 | total_num_sample = (self.args.total_num_train 124 | if mnist_split == 'train' 125 | else self.args.total_num_test) 126 | 127 | if self.args.domain_gap == 'instance': 128 | chosen_numbers = all_numbers 129 | sample_per_number = int(total_num_sample / len(chosen_numbers)) 130 | if (sample_per_number < self.args.min_num_instance_per_number and 131 | mnist_split == 'train'): 132 | raise ValueError('Cannot guarantee to have minimal %d samples ' 133 | 'for each class under current configuration' % 134 | self.args.min_num_instance_per_number) 135 | 136 | elif self.args.domain_gap == 'number': 137 | if mnist_split == 'train': 138 | chosen_numbers = np.random.choice( 139 | all_numbers, 140 | int(self.args.number_ratio_train * len(all_numbers) / 100), 141 | replace=False) 142 | self._train_numbers = chosen_numbers 143 | 144 | elif mnist_split == 'test_seen': 145 | chosen_numbers = self._train_numbers.copy() 146 | else: # test_unseen 147 | chosen_numbers = list( 148 | set(all_numbers).difference(set(self._train_numbers))) 149 | 150 | (sample_per_number, chosen_numbers 151 | ) = self._trim_num_classes_for_min_instance_constraint( 152 | total_num_sample, chosen_numbers) 153 | 154 | else: 155 | raise ValueError('domain_gap should be one of [instance, number]') 156 | 157 | logging.info('split %s total number of samples: %d', 158 | mnist_split, total_num_sample) 159 | logging.info('split %s number of samples per number: %d', 160 | mnist_split, sample_per_number) 161 | 162 | return sample_per_number, chosen_numbers 163 | 164 | def _trim_num_classes_for_min_instance_constraint(self, total_num_sample, 165 | chosen_numbers): 166 | """Subsample the classes to meet the min_instance_per_class constraint.""" 167 | 168 | sample_per_number = int(total_num_sample / len(chosen_numbers)) 169 | 170 | if sample_per_number < self.args.min_num_instance_per_number: 171 | sample_per_number = self.args.min_num_instance_per_number 172 | target_num_chosen_numbers = int(total_num_sample / sample_per_number) 173 | chosen_numbers = np.random.choice(chosen_numbers, 174 | target_num_chosen_numbers, 175 | replace=False) 176 | return sample_per_number, chosen_numbers 177 | 178 | def _collect_image_ids(self, sample_per_number, 179 | chosen_numbers, imageset_indices_by_label): 180 | """Given a list of numbers, construct n-digit images and labels. 181 | 182 | Args: 183 | sample_per_number: {int} Number of samples per class. 184 | chosen_numbers: {list} List of integer labels that construct the dataset. 185 | imageset_indices_by_label: 186 | 187 | Returns: 188 | image_ids: {list(len(chosen_numbers))} List of original MNIST index arrays 189 | for each n-digit class. Ingredient for compiling the final dataset. 190 | """ 191 | image_ids = [] 192 | 193 | for number in chosen_numbers: 194 | number_string = self._number_into_string(number) 195 | 196 | # For each position (digit) in the number, sample (sample_per_number) 197 | # digit image indices to compile the n-digit images. 198 | digit_image_indices_all_samples = [] 199 | for digit_char in number_string: 200 | digit_image_indices_all_samples.append(np.expand_dims( 201 | np.random.choice( 202 | imageset_indices_by_label[int(digit_char)], 203 | sample_per_number, replace=False), 204 | -1)) 205 | digit_image_indices_all_samples = np.concatenate( 206 | digit_image_indices_all_samples, 1) 207 | image_ids.append(digit_image_indices_all_samples) 208 | 209 | return image_ids 210 | 211 | def _compile_number_images_and_labels(self, sample_per_number, chosen_numbers, 212 | image_ids, images): 213 | """Given a list of numbers, construct n-digit images and labels. 214 | 215 | Args: 216 | sample_per_number: {int} Number of samples per class. 217 | chosen_numbers: {list} List of integer labels that construct the dataset. 218 | image_ids: {list(len(chosen_numbers))} List of original MNIST index arrays 219 | for each class. Ingredient for compiling the final dataset. 220 | images: {numpy.darray(shape=(?, 28, 28, 1))} Original MNIST images. 221 | 222 | Returns: 223 | n_digit_images: {numpy.darray(shape=(?, 28, 28*num_digits, 1))} 224 | batch of n-digit images. 225 | n_digit_labels: {numpy.darray(shape=(?))} batch of n-digit labels in the 226 | range [0, 10**num_digits-1]. 227 | """ 228 | 229 | n_digit_images = [] 230 | n_digit_labels = [] 231 | 232 | for number, digit_image_indices_all_samples in zip(chosen_numbers, 233 | image_ids): 234 | # For each sample in sample_per_number indices samples, construct the 235 | # n-digit image and label. 236 | for sample_index in xrange(sample_per_number): 237 | digit_image_indices = digit_image_indices_all_samples[sample_index] 238 | 239 | number_image = [] 240 | for digit_image_index in zip(digit_image_indices): 241 | digit_image = images[digit_image_index] 242 | number_image.append(digit_image) 243 | 244 | n_digit_images.append(np.expand_dims( 245 | np.concatenate(number_image, 1), 246 | 0)) 247 | n_digit_labels.append(number) 248 | 249 | return np.concatenate(n_digit_images, 0), np.array(n_digit_labels) 250 | 251 | def _number_into_string(self, number): 252 | """Make e.g. 93 into '093' when num_digits=3.""" 253 | number_string = str(number) 254 | if len(number_string) < self.args.num_digits: 255 | number_string = ('0' * (self.args.num_digits - len(number_string)) 256 | + number_string) 257 | return number_string 258 | 259 | def _save_mnist_to_npz(self, dataset_dir, mnist_split, images, labels): 260 | path = os.path.join(dataset_dir, '%s.npz' % mnist_split) 261 | with open(path, 'w') as f: 262 | np.savez(f, images=images, labels=labels) 263 | 264 | def load_and_write_mnist(self, mnist_split): 265 | """Loads MNIST onto memory, transforms it, and writes on disk.""" 266 | 267 | # Construct dataset directory. 268 | dataset_dir = os.path.join(self.args.output_dir, 269 | 'dataset_mnist_%d_%s' % (self.args.num_digits, 270 | self.args.domain_gap)) 271 | 272 | if not os.path.exists(dataset_dir): 273 | os.makedirs(dataset_dir) 274 | 275 | # We construct the new {train,test} split from original MNIST {train,test} 276 | # split respectively, to ensure that no instance of digit is shared. 277 | original_mnist_split = mnist_split if mnist_split == 'train' else 'test' 278 | images, labels = self._load_mnist(original_mnist_split) 279 | 280 | if self.args.use_standard_dataset: 281 | (sample_per_number, chosen_numbers, 282 | image_ids) = self._load_standard_dataset(mnist_split) 283 | 284 | else: 285 | sample_per_number, chosen_numbers = self._choose_numbers_to_include( 286 | mnist_split) 287 | 288 | imageset_indices_by_label = {digit: np.where(labels == digit)[0] 289 | for digit in xrange(10)} 290 | for digit in xrange(10): 291 | if len(imageset_indices_by_label[digit]) < sample_per_number: 292 | raise ValueError('We lack enough number of digit examples to build' 293 | 'the number images.') 294 | 295 | image_ids = self._collect_image_ids( 296 | sample_per_number, chosen_numbers, imageset_indices_by_label) 297 | 298 | self._save_standard_dataset(sample_per_number, chosen_numbers, 299 | image_ids, mnist_split) 300 | 301 | n_digit_images, n_digit_labels = self._compile_number_images_and_labels( 302 | sample_per_number, chosen_numbers, image_ids, images) 303 | 304 | self._write_num_classes(dataset_dir, len(chosen_numbers), mnist_split) 305 | self._write_num_samples(dataset_dir, len(n_digit_labels), mnist_split) 306 | self._save_mnist_to_npz(dataset_dir, mnist_split, 307 | n_digit_images, n_digit_labels) 308 | 309 | def _save_standard_dataset(self, sample_per_number, 310 | chosen_numbers, image_ids, mnist_split): 311 | with open(os.path.join(self.args.standard_datasets_dir, 312 | 'dataset_mnist_%d_%s_%s.npz' 313 | % (self.args.num_digits, 314 | self.args.domain_gap, mnist_split)), 'w') as f: 315 | np.savez(f, sample_per_number=sample_per_number, 316 | chosen_numbers=chosen_numbers, image_ids=image_ids) 317 | 318 | def _load_standard_dataset(self, mnist_split): 319 | with np.load(os.path.join(self.args.standard_datasets_dir, 320 | 'dataset_mnist_%d_%s_%s.npz' 321 | % (self.args.num_digits, 322 | self.args.domain_gap, mnist_split))) as d: 323 | sample_per_number = d['sample_per_number'] 324 | chosen_numbers = d['chosen_numbers'] 325 | image_ids = d['image_ids'] 326 | return sample_per_number, chosen_numbers, image_ids 327 | 328 | 329 | def main(): 330 | parser = argparse.ArgumentParser( 331 | description='Create an n-digit MNIST dataset.') 332 | 333 | parser.add_argument('--num_digits', default=1, type=int, 334 | help='Number of concatenated digits per data point.') 335 | 336 | parser.add_argument('--domain_gap', default='instance', type=str, 337 | choices=['instance', 'number'], 338 | help='How to split training and test sets' 339 | 'of the n-digit mnist.') 340 | 341 | parser.add_argument('--number_ratio_train', default=70, type=int, 342 | help='When domain_gap is number{m}, we decide' 343 | 'the ratio between the size of the set of numbers' 344 | 'for train vs test') 345 | 346 | parser.add_argument('--total_num_train', default=100000, type=int, 347 | help='Number of training samples per number.') 348 | 349 | parser.add_argument('--total_num_test', default=10000, type=int, 350 | help='Number of testing samples per number.') 351 | 352 | parser.add_argument('--min_num_instance_per_number', default=100, type=int, 353 | help='Minimal number of instances per n-digit number. ' 354 | 'Needed to ensure minimal positive pairs for metric ' 355 | 'embedding training. If the total_num_train/test is ' 356 | 'too small to ensure min_num_instance_per_number ' 357 | 'for every possible number, subsample the set of ' 358 | 'numbers in the split.') 359 | 360 | parser.add_argument('--use_standard_dataset', dest='use_standard_dataset', 361 | default=False, action='store_true', 362 | help='Standard dataset for reproducibility. Uses ' 363 | 'a predefined random number table.') 364 | 365 | parser.add_argument('--seed', default=714, type=int, 366 | help='Seed for controlling randomness.') 367 | 368 | parser.add_argument('--output_dir', default='data', type=str, 369 | help='Directory to write the dataset.') 370 | 371 | parser.add_argument('--mnist_dir', default='data', type=str, 372 | help='Directory with the original MNIST dataset.') 373 | 374 | parser.add_argument('--standard_datasets_dir', default='standard_datasets', 375 | type=str, help='Standard dataset directory.') 376 | 377 | args = parser.parse_args() 378 | 379 | mnist_writer = NDigitMnist(args) 380 | 381 | if args.domain_gap == 'instance': 382 | splits = ['train', 'test'] 383 | elif args.domain_gap == 'number': 384 | splits = ['train', 'test_seen', 'test_unseen'] 385 | for mnist_split in splits: 386 | mnist_writer.load_and_write_mnist(mnist_split) 387 | 388 | if __name__ == '__main__': 389 | main() 390 | 391 | -------------------------------------------------------------------------------- /standard_datasets/dataset_mnist_2_instance_test.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/n-digit-mnist/20154cafc8acbba8de816af6773db96aecb6f8da/standard_datasets/dataset_mnist_2_instance_test.npz -------------------------------------------------------------------------------- /standard_datasets/dataset_mnist_2_instance_train.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/n-digit-mnist/20154cafc8acbba8de816af6773db96aecb6f8da/standard_datasets/dataset_mnist_2_instance_train.npz -------------------------------------------------------------------------------- /standard_datasets/dataset_mnist_2_number_test_seen.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/n-digit-mnist/20154cafc8acbba8de816af6773db96aecb6f8da/standard_datasets/dataset_mnist_2_number_test_seen.npz -------------------------------------------------------------------------------- /standard_datasets/dataset_mnist_2_number_test_unseen.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/n-digit-mnist/20154cafc8acbba8de816af6773db96aecb6f8da/standard_datasets/dataset_mnist_2_number_test_unseen.npz -------------------------------------------------------------------------------- /standard_datasets/dataset_mnist_2_number_train.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/n-digit-mnist/20154cafc8acbba8de816af6773db96aecb6f8da/standard_datasets/dataset_mnist_2_number_train.npz -------------------------------------------------------------------------------- /standard_datasets/dataset_mnist_3_instance_test.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/n-digit-mnist/20154cafc8acbba8de816af6773db96aecb6f8da/standard_datasets/dataset_mnist_3_instance_test.npz -------------------------------------------------------------------------------- /standard_datasets/dataset_mnist_3_instance_train.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/n-digit-mnist/20154cafc8acbba8de816af6773db96aecb6f8da/standard_datasets/dataset_mnist_3_instance_train.npz -------------------------------------------------------------------------------- /standard_datasets/dataset_mnist_3_number_test_seen.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/n-digit-mnist/20154cafc8acbba8de816af6773db96aecb6f8da/standard_datasets/dataset_mnist_3_number_test_seen.npz -------------------------------------------------------------------------------- /standard_datasets/dataset_mnist_3_number_test_unseen.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/n-digit-mnist/20154cafc8acbba8de816af6773db96aecb6f8da/standard_datasets/dataset_mnist_3_number_test_unseen.npz -------------------------------------------------------------------------------- /standard_datasets/dataset_mnist_3_number_train.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/n-digit-mnist/20154cafc8acbba8de816af6773db96aecb6f8da/standard_datasets/dataset_mnist_3_number_train.npz --------------------------------------------------------------------------------