├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── ablation ├── ab_cta_remixmatch.py ├── ab_cta_remixmatch_noweak.py └── ab_remixmatch.py ├── cta ├── cta_fsmixup.py ├── cta_mixmatch.py ├── cta_remixmatch.py └── lib │ ├── __init__.py │ └── train.py ├── fully_supervised ├── fs_baseline.py ├── fs_mixup.py ├── lib │ ├── __init__.py │ ├── data.py │ └── train.py └── runs │ └── all.sh ├── ict.py ├── libml ├── __init__.py ├── augment.py ├── ctaugment.py ├── data.py ├── layers.py ├── models.py ├── train.py └── utils.py ├── mean_teacher.py ├── mixmatch.py ├── mixup.py ├── pi_model.py ├── pseudo_label.py ├── remixmatch_no_cta.py ├── requirements.txt ├── runs └── ssl │ ├── ablation.sh │ ├── cifar10.sh │ ├── stl10.sh │ └── svhn.sh ├── scripts ├── aggregate_accuracy.py ├── check_split.py ├── create_datasets.py ├── create_split.py ├── create_unlabeled.py ├── extract_accuracy.py └── inspect_dataset.py ├── third_party ├── LICENSE ├── __init__.py ├── auto_augment │ ├── __init__.py │ ├── augmentations.py │ ├── custom_ops.py │ ├── policies.py │ ├── shake_drop.py │ ├── shake_shake.py │ └── wrn.py └── vat_utils.py └── vat.py /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to Contribute 2 | 3 | We'd love to accept your patches and contributions to this project. There are 4 | just a few small guidelines you need to follow. 5 | 6 | ## Contributor License Agreement 7 | 8 | Contributions to this project must be accompanied by a Contributor License 9 | Agreement. You (or your employer) retain the copyright to your contribution; 10 | this simply gives us permission to use and redistribute your contributions as 11 | part of the project. Head over to to see 12 | your current agreements on file or to sign a new one. 13 | 14 | You generally only need to submit a CLA once, so if you've already submitted one 15 | (even if it was for a different project), you probably don't need to do it 16 | again. 17 | 18 | ## Code reviews 19 | 20 | All submissions, including submissions by project members, require review. We 21 | use GitHub pull requests for this purpose. Consult 22 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more 23 | information on using pull requests. 24 | 25 | ## Community Guidelines 26 | 27 | This project follows [Google's Open Source Community 28 | Guidelines](https://opensource.google.com/conduct/). 29 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ReMixMatch 2 | 3 | Code for the paper: "[ReMixMatch: Semi-Supervised Learning with Distribution Alignment and Augmentation Anchoring](https://arxiv.org/abs/1911.09785)" by David Berthelot, Nicholas Carlini, Ekin D. Cubuk, Alex Kurakin, Kihyuk Sohn, Han Zhang, and Colin Raffel. 4 | 5 | 6 | This is not an officially supported Google product. 7 | 8 | ## Setup 9 | 10 | **Important**: `ML_DATA` is a shell environment variable that should point to the location where the datasets are installed. See the *Install datasets* section for more details. 11 | 12 | ### Install dependencies 13 | 14 | ```bash 15 | sudo apt install python3-dev python3-virtualenv python3-tk imagemagick 16 | virtualenv -p python3 --system-site-packages env3 17 | . env3/bin/activate 18 | pip install -r requirements.txt 19 | ``` 20 | 21 | ### Install datasets 22 | 23 | ```bash 24 | export ML_DATA="path to where you want the datasets saved" 25 | # Download datasets 26 | CUDA_VISIBLE_DEVICES= ./scripts/create_datasets.py 27 | cp $ML_DATA/svhn-test.tfrecord $ML_DATA/svhn_noextra-test.tfrecord 28 | 29 | # Create unlabeled datasets 30 | CUDA_VISIBLE_DEVICES= scripts/create_unlabeled.py $ML_DATA/SSL2/svhn $ML_DATA/svhn-train.tfrecord $ML_DATA/svhn-extra.tfrecord & 31 | CUDA_VISIBLE_DEVICES= scripts/create_unlabeled.py $ML_DATA/SSL2/svhn_noextra $ML_DATA/svhn-train.tfrecord & 32 | CUDA_VISIBLE_DEVICES= scripts/create_unlabeled.py $ML_DATA/SSL2/cifar10 $ML_DATA/cifar10-train.tfrecord & 33 | CUDA_VISIBLE_DEVICES= scripts/create_unlabeled.py $ML_DATA/SSL2/cifar100 $ML_DATA/cifar100-train.tfrecord & 34 | CUDA_VISIBLE_DEVICES= scripts/create_unlabeled.py $ML_DATA/SSL2/stl10 $ML_DATA/stl10-train.tfrecord $ML_DATA/stl10-unlabeled.tfrecord & 35 | wait 36 | 37 | # Create semi-supervised subsets 38 | for seed in 0 1 2 3 4 5; do 39 | for size in 40 250 1000 4000; do 40 | CUDA_VISIBLE_DEVICES= scripts/create_split.py --seed=$seed --size=$size $ML_DATA/SSL2/svhn $ML_DATA/svhn-train.tfrecord $ML_DATA/svhn-extra.tfrecord & 41 | CUDA_VISIBLE_DEVICES= scripts/create_split.py --seed=$seed --size=$size $ML_DATA/SSL2/svhn_noextra $ML_DATA/svhn-train.tfrecord & 42 | CUDA_VISIBLE_DEVICES= scripts/create_split.py --seed=$seed --size=$size $ML_DATA/SSL2/cifar10 $ML_DATA/cifar10-train.tfrecord & 43 | done 44 | CUDA_VISIBLE_DEVICES= scripts/create_split.py --seed=$seed --size=10000 $ML_DATA/SSL2/cifar100 $ML_DATA/cifar100-train.tfrecord & 45 | CUDA_VISIBLE_DEVICES= scripts/create_split.py --seed=$seed --size=2500 $ML_DATA/SSL2/cifar100 $ML_DATA/cifar100-train.tfrecord & 46 | CUDA_VISIBLE_DEVICES= scripts/create_split.py --seed=$seed --size=1000 $ML_DATA/SSL2/stl10 $ML_DATA/stl10-train.tfrecord $ML_DATA/stl10-unlabeled.tfrecord & 47 | wait 48 | done 49 | CUDA_VISIBLE_DEVICES= scripts/create_split.py --seed=1 --size=5000 $ML_DATA/SSL2/stl10 $ML_DATA/stl10-train.tfrecord $ML_DATA/stl10-unlabeled.tfrecord 50 | ``` 51 | 52 | ## Running 53 | 54 | ### Setup 55 | 56 | All commands must be ran from the project root. The following environment variables must be defined: 57 | ```bash 58 | export ML_DATA="path to where you want the datasets saved" 59 | export PYTHONPATH=$PYTHONPATH:. 60 | ``` 61 | 62 | ### Example 63 | 64 | For example, training a remixmatch with 32 filters and 4 augmentations on cifar10 shuffled with `seed=3`, 250 labeled samples and 5000 65 | validation samples: 66 | ```bash 67 | CUDA_VISIBLE_DEVICES=0 python cta/cta_remixmatch.py --filters=32 --K=4 --dataset=cifar10.3@250-5000 --w_match=1.5 --beta=0.75 --train_dir ./experiments/remixmatch 68 | ``` 69 | 70 | Available labelled sizes are 40, 100, 250, 1000, 4000. 71 | For validation, available sizes are 1, 5000. 72 | Possible shuffling seeds are 1, 2, 3, 4, 5 and 0 for no shuffling (0 is not used in practiced since data requires to be 73 | shuffled for gradient descent to work properly). 74 | 75 | 76 | #### Multi-GPU training 77 | Just pass more GPUs and remixmatch automatically scales to them, here we assign GPUs 4-7 to the program: 78 | ```bash 79 | CUDA_VISIBLE_DEVICES=4,5,6,7 python cta/cta_remixmatch.py --filters=32 --K=4 --dataset=cifar10.3@250-5000 --w_match=1.5 --beta=0.75 --train_dir ./experiments/remixmatch 80 | ``` 81 | 82 | ### Valid dataset names 83 | ```bash 84 | for dataset in cifar10 svhn svhn_noextra; do 85 | for seed in 0 1 2 3 4 5; do 86 | for valid in 1 5000; do 87 | for size in 40 250 1000 4000; do 88 | echo "${dataset}.${seed}@${size}-${valid}" 89 | done; done; done; done 90 | 91 | for seed in 0 1 2 3 4 5; do 92 | for valid in 1 5000; do 93 | echo "cifar100.${seed}@10000-${valid}" 94 | done; done 95 | 96 | for seed in 1 2 3 4 5; do 97 | for valid in 1 5000; do 98 | echo "stl10.${seed}@1000-${valid}" 99 | done; done 100 | echo "stl10.1@5000-1" 101 | ``` 102 | 103 | 104 | ## Monitoring training progress 105 | 106 | You can point tensorboard to the training folder (by default it is `--train_dir=./experiments`) to monitor the training 107 | process: 108 | 109 | ```bash 110 | tensorboard.sh --port 6007 --logdir experiments 111 | ``` 112 | 113 | ## Checkpoint accuracy 114 | 115 | We compute the median accuracy of the last 20 checkpoints in the paper, this is done through this code: 116 | 117 | ```bash 118 | # Following the previous example in which we trained cifar10.3@250-5000, extracting accuracy: 119 | ./scripts/extract_accuracy.py experiments/cifar10.d.d.d.3\@250-5000/CTAugment_depth2_th0.80_decay0.990/CTAReMixMatch_K4_archresnet_batch64_beta0.75_filters32_lr0.002_nclass10_redux1st_repeat4_scales3_use_dmTrue_use_xeTrue_w_kl0.5_w_match1.5_w_rot0.5_warmup_kimg1024_wd0.02/ 120 | # The command above will create a stats/accuracy.json file in the model folder. 121 | # The format is JSON so you can either see its content as a text file or process it to your liking. 122 | ``` 123 | 124 | ## Reproducing tables from the paper 125 | 126 | Check the contents of the `runs/*.sh` files, these will give you the commands (and the hyper-parameters) to reproduce the results from the paper. 127 | 128 | ## Citing this work 129 | 130 | ``` 131 | @article{berthelot2019remixmatch, 132 | title={ReMixMatch: Semi-Supervised Learning with Distribution Alignment and Augmentation Anchoring}, 133 | author={David Berthelot and Nicholas Carlini and Ekin D. Cubuk and Alex Kurakin and Kihyuk Sohn and Han Zhang and Colin Raffel}, 134 | journal={arXiv preprint arXiv:1911.09785}, 135 | year={2019}, 136 | } 137 | ``` 138 | -------------------------------------------------------------------------------- /ablation/ab_cta_remixmatch.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ReMixMatch training, changes from MixMatch are: 15 | - Add distribution matching. 16 | """ 17 | 18 | import os 19 | 20 | from absl import app 21 | from absl import flags 22 | 23 | from cta.cta_remixmatch import CTAReMixMatch 24 | from libml import utils, data 25 | 26 | FLAGS = flags.FLAGS 27 | 28 | 29 | class ABCTAReMixMatch(CTAReMixMatch): 30 | pass 31 | 32 | 33 | def main(argv): 34 | utils.setup_main() 35 | del argv # Unused. 36 | dataset = data.MANY_DATASETS()[FLAGS.dataset]() 37 | log_width = utils.ilog2(dataset.width) 38 | model = ABCTAReMixMatch( 39 | os.path.join(FLAGS.train_dir, dataset.name, ABCTAReMixMatch.cta_name()), 40 | dataset, 41 | lr=FLAGS.lr, 42 | wd=FLAGS.wd, 43 | arch=FLAGS.arch, 44 | batch=FLAGS.batch, 45 | nclass=dataset.nclass, 46 | 47 | K=FLAGS.K, 48 | beta=FLAGS.beta, 49 | w_kl=FLAGS.w_kl, 50 | w_match=FLAGS.w_match, 51 | w_rot=FLAGS.w_rot, 52 | redux=FLAGS.redux, 53 | use_dm=FLAGS.use_dm, 54 | use_xe=FLAGS.use_xe, 55 | warmup_kimg=FLAGS.warmup_kimg, 56 | 57 | scales=FLAGS.scales or (log_width - 2), 58 | filters=FLAGS.filters, 59 | repeat=FLAGS.repeat) 60 | model.train(FLAGS.train_kimg << 10, FLAGS.report_kimg << 10) 61 | 62 | 63 | if __name__ == '__main__': 64 | utils.setup_tf() 65 | flags.DEFINE_float('wd', 0.02, 'Weight decay.') 66 | flags.DEFINE_float('beta', 0.75, 'Mixup beta distribution.') 67 | flags.DEFINE_float('w_kl', 0.5, 'Weight for KL loss.') 68 | flags.DEFINE_float('w_match', 1.5, 'Weight for distribution matching loss.') 69 | flags.DEFINE_float('w_rot', 0.5, 'Weight for rotation loss.') 70 | flags.DEFINE_integer('scales', 0, 'Number of 2x2 downscalings in the classifier.') 71 | flags.DEFINE_integer('filters', 32, 'Filter size of convolutions.') 72 | flags.DEFINE_integer('repeat', 4, 'Number of residual layers per stage.') 73 | flags.DEFINE_integer('warmup_kimg', 1024, 'Unannealing duration for SSL loss.') 74 | flags.DEFINE_enum('redux', '1st', 'swap mean 1st'.split(), 'Logit selection.') 75 | flags.DEFINE_bool('use_dm', True, 'Whether to use distribution matching.') 76 | flags.DEFINE_bool('use_xe', True, 'Whether to use cross-entropy or Brier.') 77 | FLAGS.set_default('augment', 'd.d.d') 78 | FLAGS.set_default('dataset', 'cifar10.3@250-5000') 79 | FLAGS.set_default('batch', 64) 80 | FLAGS.set_default('lr', 0.002) 81 | FLAGS.set_default('train_kimg', 1 << 16) 82 | app.run(main) 83 | -------------------------------------------------------------------------------- /ablation/ab_cta_remixmatch_noweak.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ReMixMatch training, changes from MixMatch are: 15 | - Add distribution matching. 16 | """ 17 | 18 | import os 19 | 20 | import numpy as np 21 | from absl import app 22 | from absl import flags 23 | 24 | from ablation.ab_cta_remixmatch import ABCTAReMixMatch 25 | from libml import utils, data, ctaugment 26 | from libml.augment import AugmentPoolCTA 27 | 28 | FLAGS = flags.FLAGS 29 | 30 | 31 | class AugmentPoolCTANoWeak(AugmentPoolCTA): 32 | @staticmethod 33 | def numpy_apply_policies(arglist): 34 | x, cta, probe = arglist 35 | if x.ndim == 3: 36 | assert probe 37 | policy = cta.policy(probe=True) 38 | return dict(policy=policy, 39 | probe=ctaugment.apply(x, policy), 40 | image=ctaugment.apply(x, cta.policy(probe=False))) 41 | assert not probe 42 | return dict(image=np.stack([ctaugment.apply(y, cta.policy(probe=False)) for y in x]).astype('f')) 43 | 44 | 45 | class ABCTAReMixMatchNoWeak(ABCTAReMixMatch): 46 | AUGMENT_POOL_CLASS = AugmentPoolCTANoWeak 47 | 48 | 49 | def main(argv): 50 | utils.setup_main() 51 | del argv # Unused. 52 | dataset = data.MANY_DATASETS()[FLAGS.dataset]() 53 | log_width = utils.ilog2(dataset.width) 54 | model = ABCTAReMixMatchNoWeak( 55 | os.path.join(FLAGS.train_dir, dataset.name, ABCTAReMixMatchNoWeak.cta_name()), 56 | dataset, 57 | lr=FLAGS.lr, 58 | wd=FLAGS.wd, 59 | arch=FLAGS.arch, 60 | batch=FLAGS.batch, 61 | nclass=dataset.nclass, 62 | 63 | K=FLAGS.K, 64 | beta=FLAGS.beta, 65 | w_kl=FLAGS.w_kl, 66 | w_match=FLAGS.w_match, 67 | w_rot=FLAGS.w_rot, 68 | redux=FLAGS.redux, 69 | use_dm=FLAGS.use_dm, 70 | use_xe=FLAGS.use_xe, 71 | warmup_kimg=FLAGS.warmup_kimg, 72 | 73 | scales=FLAGS.scales or (log_width - 2), 74 | filters=FLAGS.filters, 75 | repeat=FLAGS.repeat) 76 | model.train(FLAGS.train_kimg << 10, FLAGS.report_kimg << 10) 77 | 78 | 79 | if __name__ == '__main__': 80 | utils.setup_tf() 81 | flags.DEFINE_float('wd', 0.02, 'Weight decay.') 82 | flags.DEFINE_float('beta', 0.75, 'Mixup beta distribution.') 83 | flags.DEFINE_float('w_kl', 0.5, 'Weight for KL loss.') 84 | flags.DEFINE_float('w_match', 1.5, 'Weight for distribution matching loss.') 85 | flags.DEFINE_float('w_rot', 0.5, 'Weight for rotation loss.') 86 | flags.DEFINE_integer('scales', 0, 'Number of 2x2 downscalings in the classifier.') 87 | flags.DEFINE_integer('filters', 32, 'Filter size of convolutions.') 88 | flags.DEFINE_integer('repeat', 4, 'Number of residual layers per stage.') 89 | flags.DEFINE_integer('warmup_kimg', 1024, 'Unannealing duration for SSL loss.') 90 | flags.DEFINE_enum('redux', '1st', 'swap mean 1st'.split(), 'Logit selection.') 91 | flags.DEFINE_bool('use_dm', True, 'Whether to use distribution matching.') 92 | flags.DEFINE_bool('use_xe', True, 'Whether to use cross-entropy or Brier.') 93 | FLAGS.set_default('augment', 'd.d.d') 94 | FLAGS.set_default('dataset', 'cifar10.3@250-5000') 95 | FLAGS.set_default('batch', 64) 96 | FLAGS.set_default('lr', 0.002) 97 | FLAGS.set_default('train_kimg', 1 << 16) 98 | app.run(main) 99 | -------------------------------------------------------------------------------- /ablation/ab_remixmatch.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ReMixMatch training, changes from MixMatch are: 15 | - Add distribution matching. 16 | """ 17 | 18 | import os 19 | 20 | from absl import app 21 | from absl import flags 22 | 23 | from libml import data, utils 24 | from remixmatch_no_cta import ReMixMatch 25 | 26 | FLAGS = flags.FLAGS 27 | 28 | 29 | class ABReMixMatch(ReMixMatch): 30 | pass 31 | 32 | 33 | def main(argv): 34 | utils.setup_main() 35 | del argv # Unused. 36 | dataset = data.MANY_DATASETS()[FLAGS.dataset]() 37 | log_width = utils.ilog2(dataset.width) 38 | model = ABReMixMatch( 39 | os.path.join(FLAGS.train_dir, dataset.name), 40 | dataset, 41 | lr=FLAGS.lr, 42 | wd=FLAGS.wd, 43 | arch=FLAGS.arch, 44 | batch=FLAGS.batch, 45 | nclass=dataset.nclass, 46 | 47 | K=FLAGS.K, 48 | beta=FLAGS.beta, 49 | w_kl=FLAGS.w_kl, 50 | w_match=FLAGS.w_match, 51 | w_rot=FLAGS.w_rot, 52 | redux=FLAGS.redux, 53 | use_dm=FLAGS.use_dm, 54 | use_xe=FLAGS.use_xe, 55 | warmup_kimg=FLAGS.warmup_kimg, 56 | 57 | scales=FLAGS.scales or (log_width - 2), 58 | filters=FLAGS.filters, 59 | repeat=FLAGS.repeat) 60 | model.train(FLAGS.train_kimg << 10, FLAGS.report_kimg << 10) 61 | 62 | 63 | if __name__ == '__main__': 64 | utils.setup_tf() 65 | flags.DEFINE_float('wd', 0.02, 'Weight decay.') 66 | flags.DEFINE_float('beta', 0.75, 'Mixup beta distribution.') 67 | flags.DEFINE_float('w_kl', 0.5, 'Weight for KL loss.') 68 | flags.DEFINE_float('w_match', 1.5, 'Weight for distribution matching loss.') 69 | flags.DEFINE_float('w_rot', 0.5, 'Weight for rotation loss.') 70 | flags.DEFINE_integer('scales', 0, 'Number of 2x2 downscalings in the classifier.') 71 | flags.DEFINE_integer('filters', 32, 'Filter size of convolutions.') 72 | flags.DEFINE_integer('repeat', 4, 'Number of residual layers per stage.') 73 | flags.DEFINE_integer('warmup_kimg', 1024, 'Unannealing duration for SSL loss.') 74 | flags.DEFINE_enum('redux', 'swap', 'swap mean 1st'.split(), 'Logit selection.') 75 | flags.DEFINE_bool('use_dm', True, 'Whether to use distribution matching.') 76 | flags.DEFINE_bool('use_xe', True, 'Whether to use cross-entropy or Brier.') 77 | FLAGS.set_default('augment', 'd.d.d') 78 | FLAGS.set_default('dataset', 'cifar10.3@250-5000') 79 | FLAGS.set_default('batch', 64) 80 | FLAGS.set_default('lr', 0.002) 81 | FLAGS.set_default('train_kimg', 1 << 16) 82 | app.run(main) 83 | -------------------------------------------------------------------------------- /cta/cta_fsmixup.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ReMixMatch training, changes from MixMatch are: 15 | - Add distribution matching. 16 | """ 17 | import os 18 | 19 | from absl import app 20 | from absl import flags 21 | 22 | from cta.lib.train import CTAClassifyFullySupervised 23 | from fully_supervised.fs_mixup import FSMixup 24 | from fully_supervised.lib import data 25 | from libml import utils 26 | 27 | FLAGS = flags.FLAGS 28 | 29 | 30 | class CTAFSMixup(FSMixup, CTAClassifyFullySupervised): 31 | pass 32 | 33 | 34 | def main(argv): 35 | utils.setup_main() 36 | del argv # Unused. 37 | dataset = data.DATASETS()[FLAGS.dataset]() 38 | log_width = utils.ilog2(dataset.width) 39 | model = CTAFSMixup( 40 | os.path.join(FLAGS.train_dir, dataset.name, CTAFSMixup.cta_name()), 41 | dataset, 42 | lr=FLAGS.lr, 43 | wd=FLAGS.wd, 44 | arch=FLAGS.arch, 45 | batch=FLAGS.batch, 46 | nclass=dataset.nclass, 47 | ema=FLAGS.ema, 48 | beta=FLAGS.beta, 49 | dropout=FLAGS.dropout, 50 | 51 | scales=FLAGS.scales or (log_width - 2), 52 | filters=FLAGS.filters, 53 | repeat=FLAGS.repeat) 54 | model.train(FLAGS.train_kimg << 10, FLAGS.report_kimg << 10) 55 | 56 | 57 | if __name__ == '__main__': 58 | utils.setup_tf() 59 | flags.DEFINE_float('wd', 0.002, 'Weight decay.') 60 | flags.DEFINE_float('ema', 0.999, 'Exponential moving average of params.') 61 | flags.DEFINE_float('beta', 0.5, 'Mixup beta distribution.') 62 | flags.DEFINE_float('dropout', 0, 'Dropout on embedding layer.') 63 | flags.DEFINE_integer('scales', 0, 'Number of 2x2 downscalings in the classifier.') 64 | flags.DEFINE_integer('filters', 32, 'Filter size of convolutions.') 65 | flags.DEFINE_integer('repeat', 4, 'Number of residual layers per stage.') 66 | FLAGS.set_default('augment', 'd.d') 67 | FLAGS.set_default('dataset', 'cifar10-1') 68 | FLAGS.set_default('batch', 64) 69 | FLAGS.set_default('lr', 0.002) 70 | FLAGS.set_default('train_kimg', 1 << 16) 71 | app.run(main) 72 | -------------------------------------------------------------------------------- /cta/cta_mixmatch.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ReMixMatch training, changes from MixMatch are: 15 | - Add distribution matching. 16 | """ 17 | 18 | import os 19 | 20 | from absl import app 21 | from absl import flags 22 | 23 | from cta.lib.train import CTAClassifySemi 24 | from libml import data, utils 25 | from mixmatch import MixMatch 26 | 27 | FLAGS = flags.FLAGS 28 | 29 | 30 | class CTAMixMatch(MixMatch, CTAClassifySemi): 31 | pass 32 | 33 | 34 | def main(argv): 35 | utils.setup_main() 36 | del argv # Unused. 37 | dataset = data.PAIR_DATASETS()[FLAGS.dataset]() 38 | log_width = utils.ilog2(dataset.width) 39 | model = CTAMixMatch( 40 | os.path.join(FLAGS.train_dir, dataset.name, CTAMixMatch.cta_name()), 41 | dataset, 42 | lr=FLAGS.lr, 43 | wd=FLAGS.wd, 44 | arch=FLAGS.arch, 45 | batch=FLAGS.batch, 46 | nclass=dataset.nclass, 47 | ema=FLAGS.ema, 48 | beta=FLAGS.beta, 49 | w_match=FLAGS.w_match, 50 | scales=FLAGS.scales or (log_width - 2), 51 | filters=FLAGS.filters, 52 | repeat=FLAGS.repeat) 53 | model.train(FLAGS.train_kimg << 10, FLAGS.report_kimg << 10) 54 | 55 | 56 | if __name__ == '__main__': 57 | utils.setup_tf() 58 | flags.DEFINE_float('wd', 0.02, 'Weight decay.') 59 | flags.DEFINE_float('ema', 0.999, 'Exponential moving average of params.') 60 | flags.DEFINE_float('beta', 0.75, 'Mixup beta distribution.') 61 | flags.DEFINE_float('w_match', 100, 'Weight for distribution matching loss.') 62 | flags.DEFINE_integer('scales', 0, 'Number of 2x2 downscalings in the classifier.') 63 | flags.DEFINE_integer('filters', 32, 'Filter size of convolutions.') 64 | flags.DEFINE_integer('repeat', 4, 'Number of residual layers per stage.') 65 | FLAGS.set_default('augment', 'd.d.d') 66 | FLAGS.set_default('dataset', 'cifar10.3@250-5000') 67 | FLAGS.set_default('batch', 64) 68 | FLAGS.set_default('lr', 0.002) 69 | FLAGS.set_default('train_kimg', 1 << 16) 70 | app.run(main) 71 | -------------------------------------------------------------------------------- /cta/cta_remixmatch.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | 17 | from absl import app 18 | from absl import flags 19 | 20 | from cta.lib.train import CTAClassifySemi 21 | from libml import utils, data 22 | from remixmatch_no_cta import ReMixMatch 23 | 24 | FLAGS = flags.FLAGS 25 | 26 | 27 | class CTAReMixMatch(ReMixMatch, CTAClassifySemi): 28 | pass 29 | 30 | 31 | def main(argv): 32 | utils.setup_main() 33 | del argv # Unused. 34 | dataset = data.MANY_DATASETS()[FLAGS.dataset]() 35 | log_width = utils.ilog2(dataset.width) 36 | model = CTAReMixMatch( 37 | os.path.join(FLAGS.train_dir, dataset.name, CTAReMixMatch.cta_name()), 38 | dataset, 39 | lr=FLAGS.lr, 40 | wd=FLAGS.wd, 41 | arch=FLAGS.arch, 42 | batch=FLAGS.batch, 43 | nclass=dataset.nclass, 44 | 45 | K=FLAGS.K, 46 | beta=FLAGS.beta, 47 | w_kl=FLAGS.w_kl, 48 | w_match=FLAGS.w_match, 49 | w_rot=FLAGS.w_rot, 50 | redux=FLAGS.redux, 51 | use_dm=FLAGS.use_dm, 52 | use_xe=FLAGS.use_xe, 53 | warmup_kimg=FLAGS.warmup_kimg, 54 | 55 | scales=FLAGS.scales or (log_width - 2), 56 | filters=FLAGS.filters, 57 | repeat=FLAGS.repeat) 58 | model.train(FLAGS.train_kimg << 10, FLAGS.report_kimg << 10) 59 | 60 | 61 | if __name__ == '__main__': 62 | utils.setup_tf() 63 | flags.DEFINE_float('wd', 0.02, 'Weight decay.') 64 | flags.DEFINE_float('beta', 0.75, 'Mixup beta distribution.') 65 | flags.DEFINE_float('w_kl', 0.5, 'Weight for KL loss.') 66 | flags.DEFINE_float('w_match', 1.5, 'Weight for distribution matching loss.') 67 | flags.DEFINE_float('w_rot', 0.5, 'Weight for rotation loss.') 68 | flags.DEFINE_integer('scales', 0, 'Number of 2x2 downscalings in the classifier.') 69 | flags.DEFINE_integer('filters', 32, 'Filter size of convolutions.') 70 | flags.DEFINE_integer('repeat', 4, 'Number of residual layers per stage.') 71 | flags.DEFINE_integer('warmup_kimg', 1024, 'Unannealing duration for SSL loss.') 72 | flags.DEFINE_enum('redux', '1st', 'swap mean 1st'.split(), 'Logit selection.') 73 | flags.DEFINE_bool('use_dm', True, 'Whether to use distribution matching.') 74 | flags.DEFINE_bool('use_xe', True, 'Whether to use cross-entropy or Brier.') 75 | FLAGS.set_default('augment', 'd.d.d') 76 | FLAGS.set_default('dataset', 'cifar10.3@250-5000') 77 | FLAGS.set_default('batch', 64) 78 | FLAGS.set_default('lr', 0.002) 79 | FLAGS.set_default('train_kimg', 1 << 16) 80 | app.run(main) 81 | -------------------------------------------------------------------------------- /cta/lib/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-research/remixmatch/f7061ebf055227cbeb5c6fced1ce054e0ceecfcd/cta/lib/__init__.py -------------------------------------------------------------------------------- /cta/lib/train.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import numpy as np 16 | from absl import flags 17 | 18 | from fully_supervised.lib.train import ClassifyFullySupervised 19 | from libml import data 20 | from libml.augment import AugmentPoolCTA 21 | from libml.ctaugment import CTAugment 22 | from libml.train import ClassifySemi 23 | 24 | FLAGS = flags.FLAGS 25 | 26 | flags.DEFINE_integer('adepth', 2, 'Augmentation depth.') 27 | flags.DEFINE_float('adecay', 0.99, 'Augmentation decay.') 28 | flags.DEFINE_float('ath', 0.80, 'Augmentation threshold.') 29 | 30 | 31 | class CTAClassifySemi(ClassifySemi): 32 | """Semi-supervised classification.""" 33 | AUGMENTER_CLASS = CTAugment 34 | AUGMENT_POOL_CLASS = AugmentPoolCTA 35 | 36 | @classmethod 37 | def cta_name(cls): 38 | return '%s_depth%d_th%.2f_decay%.3f' % (cls.AUGMENTER_CLASS.__name__, 39 | FLAGS.adepth, FLAGS.ath, FLAGS.adecay) 40 | 41 | def __init__(self, train_dir: str, dataset: data.DataSets, nclass: int, **kwargs): 42 | ClassifySemi.__init__(self, train_dir, dataset, nclass, **kwargs) 43 | self.augmenter = self.AUGMENTER_CLASS(FLAGS.adepth, FLAGS.ath, FLAGS.adecay) 44 | 45 | def gen_labeled_fn(self, data_iterator): 46 | def wrap(): 47 | batch = self.session.run(data_iterator) 48 | batch['cta'] = self.augmenter 49 | batch['probe'] = True 50 | return batch 51 | 52 | return self.AUGMENT_POOL_CLASS(wrap) 53 | 54 | def gen_unlabeled_fn(self, data_iterator): 55 | def wrap(): 56 | batch = self.session.run(data_iterator) 57 | batch['cta'] = self.augmenter 58 | batch['probe'] = False 59 | return batch 60 | 61 | return self.AUGMENT_POOL_CLASS(wrap) 62 | 63 | def train_step(self, train_session, gen_labeled, gen_unlabeled): 64 | x, y = gen_labeled(), gen_unlabeled() 65 | v = train_session.run([self.ops.classify_op, self.ops.train_op, self.ops.update_step], 66 | feed_dict={self.ops.y: y['image'], 67 | self.ops.x: x['probe'], 68 | self.ops.xt: x['image'], 69 | self.ops.label: x['label']}) 70 | self.tmp.step = v[-1] 71 | lx = v[0] 72 | for p in range(lx.shape[0]): 73 | error = lx[p] 74 | error[x['label'][p]] -= 1 75 | error = np.abs(error).sum() 76 | self.augmenter.update_rates(x['policy'][p], 1 - 0.5 * error) 77 | 78 | def eval_stats(self, batch=None, feed_extra=None, classify_op=None): 79 | """Evaluate model on train, valid and test.""" 80 | batch = batch or FLAGS.batch 81 | classify_op = self.ops.classify_op if classify_op is None else classify_op 82 | accuracies = [] 83 | for subset in ('train_labeled', 'valid', 'test'): 84 | images, labels = self.tmp.cache[subset] 85 | predicted = [] 86 | 87 | for x in range(0, images.shape[0], batch): 88 | p = self.session.run( 89 | classify_op, 90 | feed_dict={ 91 | self.ops.x: images[x:x + batch], 92 | **(feed_extra or {}) 93 | }) 94 | predicted.append(p) 95 | predicted = np.concatenate(predicted, axis=0) 96 | accuracies.append((predicted.argmax(1) == labels).mean() * 100) 97 | self.train_print('kimg %-5d accuracy train/valid/test %.2f %.2f %.2f' % 98 | tuple([self.tmp.step >> 10] + accuracies)) 99 | self.train_print(self.augmenter.stats()) 100 | return np.array(accuracies, 'f') 101 | 102 | 103 | class CTAClassifyFullySupervised(ClassifyFullySupervised, CTAClassifySemi): 104 | """Fully-supervised classification.""" 105 | 106 | def train_step(self, train_session, gen_labeled): 107 | x = gen_labeled() 108 | v = train_session.run([self.ops.classify_op, self.ops.train_op, self.ops.update_step], 109 | feed_dict={self.ops.x: x['probe'], 110 | self.ops.xt: x['image'], 111 | self.ops.label: x['label']}) 112 | self.tmp.step = v[-1] 113 | lx = v[0] 114 | for p in range(lx.shape[0]): 115 | error = lx[p] 116 | error[x['label'][p]] -= 1 117 | error = np.abs(error).sum() 118 | self.augmenter.update_rates(x['policy'][p], 1 - 0.5 * error) 119 | -------------------------------------------------------------------------------- /fully_supervised/fs_baseline.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Fully supervised training. 15 | """ 16 | 17 | import functools 18 | import os 19 | 20 | import tensorflow as tf 21 | from absl import app 22 | from absl import flags 23 | 24 | from fully_supervised.lib.data import DATASETS 25 | from fully_supervised.lib.train import ClassifyFullySupervised 26 | from libml import models, utils 27 | from libml.utils import EasyDict 28 | 29 | FLAGS = flags.FLAGS 30 | 31 | 32 | class FSBaseline(ClassifyFullySupervised, models.MultiModel): 33 | 34 | def model(self, batch, lr, wd, ema, **kwargs): 35 | hwc = [self.dataset.height, self.dataset.width, self.dataset.colors] 36 | xt_in = tf.placeholder(tf.float32, [batch] + hwc, 'xt') # For training 37 | x_in = tf.placeholder(tf.float32, [None] + hwc, 'x') 38 | l_in = tf.placeholder(tf.int32, [batch], 'labels') 39 | wd *= lr 40 | classifier = lambda x, **kw: self.classifier(x, **kw, **kwargs).logits 41 | 42 | x, labels_x = self.augment(xt_in, tf.one_hot(l_in, self.nclass), **kwargs) 43 | logits_x = classifier(x, training=True) 44 | 45 | loss_xe = tf.nn.softmax_cross_entropy_with_logits_v2(labels=labels_x, logits=logits_x) 46 | loss_xe = tf.reduce_mean(loss_xe) 47 | tf.summary.scalar('losses/xe', loss_xe) 48 | 49 | ema = tf.train.ExponentialMovingAverage(decay=ema) 50 | ema_op = ema.apply(utils.model_vars()) 51 | ema_getter = functools.partial(utils.getter_ema, ema) 52 | post_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) + [ema_op] 53 | post_ops.extend([tf.assign(v, v * (1 - wd)) for v in utils.model_vars('classify') if 'kernel' in v.name]) 54 | 55 | train_op = tf.train.AdamOptimizer(lr).minimize(loss_xe, colocate_gradients_with_ops=True) 56 | with tf.control_dependencies([train_op]): 57 | train_op = tf.group(*post_ops) 58 | 59 | return EasyDict( 60 | xt=xt_in, x=x_in, label=l_in, train_op=train_op, 61 | classify_raw=tf.nn.softmax(classifier(x_in, training=False)), # No EMA, for debugging. 62 | classify_op=tf.nn.softmax(classifier(x_in, getter=ema_getter, training=False))) 63 | 64 | 65 | def main(argv): 66 | utils.setup_main() 67 | del argv # Unused. 68 | dataset = DATASETS()[FLAGS.dataset]() 69 | log_width = utils.ilog2(dataset.width) 70 | model = FSBaseline( 71 | os.path.join(FLAGS.train_dir, dataset.name), 72 | dataset, 73 | lr=FLAGS.lr, 74 | wd=FLAGS.wd, 75 | arch=FLAGS.arch, 76 | batch=FLAGS.batch, 77 | nclass=dataset.nclass, 78 | ema=FLAGS.ema, 79 | dropout=FLAGS.dropout, 80 | smoothing=FLAGS.smoothing, 81 | 82 | scales=FLAGS.scales or (log_width - 2), 83 | filters=FLAGS.filters, 84 | repeat=FLAGS.repeat) 85 | model.train(FLAGS.train_kimg << 10, FLAGS.report_kimg << 10) 86 | 87 | 88 | if __name__ == '__main__': 89 | utils.setup_tf() 90 | flags.DEFINE_float('wd', 0.002, 'Weight decay.') 91 | flags.DEFINE_float('ema', 0.999, 'Exponential moving average of params.') 92 | flags.DEFINE_float('dropout', 0, 'Dropout on embedding layer.') 93 | flags.DEFINE_float('smoothing', 0.001, 'Label smoothing.') 94 | flags.DEFINE_integer('scales', 0, 'Number of 2x2 downscalings in the classifier.') 95 | flags.DEFINE_integer('filters', 32, 'Filter size of convolutions.') 96 | flags.DEFINE_integer('repeat', 4, 'Number of residual layers per stage.') 97 | FLAGS.set_default('dataset', 'cifar10') 98 | FLAGS.set_default('batch', 64) 99 | FLAGS.set_default('lr', 0.002) 100 | FLAGS.set_default('train_kimg', 1 << 16) 101 | app.run(main) 102 | -------------------------------------------------------------------------------- /fully_supervised/fs_mixup.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Mixup fully supervised training. 15 | """ 16 | 17 | import os 18 | 19 | import tensorflow as tf 20 | from absl import app 21 | from absl import flags 22 | 23 | from fully_supervised.fs_baseline import FSBaseline 24 | from fully_supervised.lib.data import DATASETS 25 | from libml import utils 26 | 27 | FLAGS = flags.FLAGS 28 | 29 | 30 | class FSMixup(FSBaseline): 31 | 32 | def augment(self, x, l, beta, **kwargs): 33 | del kwargs 34 | with tf.device('/cpu'): 35 | mix = tf.distributions.Beta(beta, beta).sample([tf.shape(x)[0], 1, 1, 1]) 36 | mix = tf.maximum(mix, 1 - mix) 37 | xmix = x * mix + x[::-1] * (1 - mix) 38 | lmix = l * mix[:, :, 0, 0] + l[::-1] * (1 - mix[:, :, 0, 0]) 39 | return xmix, lmix 40 | 41 | 42 | def main(argv): 43 | utils.setup_main() 44 | del argv # Unused. 45 | dataset = DATASETS()[FLAGS.dataset]() 46 | log_width = utils.ilog2(dataset.width) 47 | model = FSMixup( 48 | os.path.join(FLAGS.train_dir, dataset.name), 49 | dataset, 50 | lr=FLAGS.lr, 51 | wd=FLAGS.wd, 52 | arch=FLAGS.arch, 53 | batch=FLAGS.batch, 54 | nclass=dataset.nclass, 55 | ema=FLAGS.ema, 56 | beta=FLAGS.beta, 57 | dropout=FLAGS.dropout, 58 | 59 | scales=FLAGS.scales or (log_width - 2), 60 | filters=FLAGS.filters, 61 | repeat=FLAGS.repeat) 62 | model.train(FLAGS.train_kimg << 10, FLAGS.report_kimg << 10) 63 | 64 | 65 | if __name__ == '__main__': 66 | utils.setup_tf() 67 | flags.DEFINE_float('wd', 0.002, 'Weight decay.') 68 | flags.DEFINE_float('ema', 0.999, 'Exponential moving average of params.') 69 | flags.DEFINE_float('beta', 0.5, 'Mixup beta distribution.') 70 | flags.DEFINE_float('dropout', 0, 'Dropout on embedding layer.') 71 | flags.DEFINE_integer('scales', 0, 'Number of 2x2 downscalings in the classifier.') 72 | flags.DEFINE_integer('filters', 32, 'Filter size of convolutions.') 73 | flags.DEFINE_integer('repeat', 4, 'Number of residual layers per stage.') 74 | FLAGS.set_default('dataset', 'cifar10-1') 75 | FLAGS.set_default('batch', 64) 76 | FLAGS.set_default('lr', 0.002) 77 | FLAGS.set_default('train_kimg', 1 << 16) 78 | app.run(main) 79 | -------------------------------------------------------------------------------- /fully_supervised/lib/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /fully_supervised/lib/data.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | 17 | import tensorflow as tf 18 | from absl import flags 19 | 20 | from libml import augment as augment_module 21 | from libml import data 22 | 23 | FLAGS = flags.FLAGS 24 | 25 | 26 | class DataSetsFS(data.DataSets): 27 | @classmethod 28 | def creator(cls, name, train_files, test_files, valid, augment, parse_fn=data.record_parse, do_memoize=True, 29 | nclass=10, height=32, width=32, colors=3): 30 | train_files = [os.path.join(data.DATA_DIR, x) for x in train_files] 31 | test_files = [os.path.join(data.DATA_DIR, x) for x in test_files] 32 | if not isinstance(augment, list): 33 | augment = augment(name) 34 | else: 35 | assert len(augment) == 1 36 | augment = augment[0] 37 | 38 | def create(): 39 | image_shape = [height, width, colors] 40 | kwargs = dict(parse_fn=parse_fn, image_shape=image_shape) 41 | train_labeled = data.DataSet.from_files(train_files, augment, **kwargs) 42 | if do_memoize: 43 | train_labeled = train_labeled.memoize() 44 | if FLAGS.whiten: 45 | mean, std = data.compute_mean_std(train_labeled) 46 | else: 47 | mean, std = 0, 1 48 | 49 | valid_data = data.DataSet.from_files(train_files, augment_module.NOAUGMENT, **kwargs).take(valid) 50 | test_data = data.DataSet.from_files(test_files, augment_module.NOAUGMENT, **kwargs) 51 | 52 | return cls(name + '.' + FLAGS.augment.split('.')[0] + '-' + str(valid), 53 | train_labeled=train_labeled.skip(valid), 54 | train_unlabeled=None, 55 | valid=valid_data, 56 | test=test_data, 57 | nclass=nclass, colors=colors, height=height, width=width, mean=mean, std=std) 58 | 59 | return name + '-' + str(valid), create 60 | 61 | 62 | def augment_function(dataset: str): 63 | return augment_module.get_augmentation(dataset, FLAGS.augment.split('.')[0]) 64 | 65 | 66 | def create_datasets(): 67 | d = {} 68 | d.update([DataSetsFS.creator('cifar10', ['cifar10-train.tfrecord'], ['cifar10-test.tfrecord'], valid, 69 | augment_function) for valid in [1, 5000]]) 70 | d.update([DataSetsFS.creator('cifar100', ['cifar100-train.tfrecord'], ['cifar100-test.tfrecord'], valid, 71 | augment_function, nclass=100) for valid in [1, 5000]]) 72 | d.update([DataSetsFS.creator('fashion_mnist', ['fashion_mnist-train.tfrecord'], ['fashion_mnist-test.tfrecord'], 73 | valid, augment_function, height=32, width=32, colors=1, 74 | parse_fn=data.record_parse_mnist) 75 | for valid in [1, 5000]]) 76 | d.update( 77 | [DataSetsFS.creator('stl10', [], [], valid, augment_function, height=96, width=96, do_memoize=False) 78 | for valid in [1, 5000]]) 79 | d.update([DataSetsFS.creator('svhn', ['svhn-train.tfrecord', 'svhn-extra.tfrecord'], ['svhn-test.tfrecord'], 80 | valid, augment_function, do_memoize=False) for valid in [1, 5000]]) 81 | d.update([DataSetsFS.creator('svhn_noextra', ['svhn-train.tfrecord'], ['svhn-test.tfrecord'], 82 | valid, augment_function, do_memoize=False) for valid in [1, 5000]]) 83 | return d 84 | 85 | 86 | DATASETS = create_datasets 87 | -------------------------------------------------------------------------------- /fully_supervised/lib/train.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import tensorflow as tf 16 | from absl import flags 17 | from tqdm import trange 18 | 19 | from libml import utils 20 | from libml.train import ClassifySemi 21 | 22 | FLAGS = flags.FLAGS 23 | 24 | 25 | class ClassifyFullySupervised(ClassifySemi): 26 | """Fully supervised classification. 27 | """ 28 | 29 | def train_step(self, train_session, gen_labeled): 30 | x = gen_labeled() 31 | self.tmp.step = train_session.run([self.ops.train_op, self.ops.update_step], 32 | feed_dict={self.ops.xt: x['image'], 33 | self.ops.label: x['label']})[1] 34 | 35 | def train(self, train_nimg, report_nimg): 36 | if FLAGS.eval_ckpt: 37 | self.eval_checkpoint(FLAGS.eval_ckpt) 38 | return 39 | batch = FLAGS.batch 40 | train_labeled = self.dataset.train_labeled.repeat().shuffle(FLAGS.shuffle).parse().augment() 41 | train_labeled = train_labeled.batch(batch).prefetch(16).make_one_shot_iterator().get_next() 42 | scaffold = tf.train.Scaffold(saver=tf.train.Saver(max_to_keep=FLAGS.keep_ckpt, pad_step_number=10)) 43 | 44 | with tf.Session(config=utils.get_config()) as sess: 45 | self.session = sess 46 | self.cache_eval() 47 | 48 | with tf.train.MonitoredTrainingSession( 49 | scaffold=scaffold, 50 | checkpoint_dir=self.checkpoint_dir, 51 | config=utils.get_config(), 52 | save_checkpoint_steps=FLAGS.save_kimg << 10, 53 | save_summaries_steps=report_nimg - batch) as train_session: 54 | self.session = train_session._tf_sess() 55 | gen_labeled = self.gen_labeled_fn(train_labeled) 56 | self.tmp.step = self.session.run(self.step) 57 | while self.tmp.step < train_nimg: 58 | loop = trange(self.tmp.step % report_nimg, report_nimg, batch, 59 | leave=False, unit='img', unit_scale=batch, 60 | desc='Epoch %d/%d' % (1 + (self.tmp.step // report_nimg), train_nimg // report_nimg)) 61 | for _ in loop: 62 | self.train_step(train_session, gen_labeled) 63 | while self.tmp.print_queue: 64 | loop.write(self.tmp.print_queue.pop(0)) 65 | while self.tmp.print_queue: 66 | print(self.tmp.print_queue.pop(0)) 67 | -------------------------------------------------------------------------------- /fully_supervised/runs/all.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # Copyright 2019 Google LLC 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # https://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | # Fully supervised baseline without mixup (not shown in paper since Mixup is better) 18 | python fully_supervised/fs_baseline.py --train_dir experiments/fs --dataset=cifar10-1 --wd=0.02 --smoothing=0.001 19 | python fully_supervised/fs_baseline.py --train_dir experiments/fs --dataset=cifar100-1 --wd=0.02 --smoothing=0.001 20 | python fully_supervised/fs_baseline.py --train_dir experiments/fs --dataset=svhn-1 --wd=0.002 --smoothing=0.01 21 | python fully_supervised/fs_baseline.py --train_dir experiments/fs --dataset=svhn_noextra-1 --wd=0.002 --smoothing=0.01 22 | 23 | # Fully supervised Mixup baselines (in paper) 24 | # Uses default parameters: --wd=0.002 --beta=0.5 25 | python fully_supervised/fs_mixup.py --train_dir experiments/fs --dataset=cifar10-1 26 | python fully_supervised/fs_mixup.py --train_dir experiments/fs --dataset=svhn-1 27 | python fully_supervised/fs_mixup.py --train_dir experiments/fs --dataset=svhn_noextra-1 28 | 29 | # Fully supervised Mixup baselines on 26M parameter large network (in paper) 30 | # Uses default parameters: --wd=0.002 --beta=0.5 31 | python fully_supervised/fs_mixup.py --train_dir experiments/fs --dataset=cifar10-1 --filters=135 32 | python fully_supervised/fs_mixup.py --train_dir experiments/fs --dataset=cifar100-1 --filters=135 33 | -------------------------------------------------------------------------------- /ict.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Interpolation Consistency Training for Semi-Supervised Learning. 15 | 16 | Reimplementation of https://arxiv.org/abs/1903.03825 17 | """ 18 | 19 | import functools 20 | import os 21 | 22 | import tensorflow as tf 23 | from absl import app 24 | from absl import flags 25 | 26 | from libml import models, utils 27 | from libml.data import PAIR_DATASETS 28 | from libml.utils import EasyDict 29 | 30 | FLAGS = flags.FLAGS 31 | 32 | 33 | class ICT(models.MultiModel): 34 | 35 | def model(self, batch, lr, wd, ema, warmup_pos, consistency_weight, beta, **kwargs): 36 | hwc = [self.dataset.height, self.dataset.width, self.dataset.colors] 37 | xt_in = tf.placeholder(tf.float32, [batch] + hwc, 'xt') # For training 38 | x_in = tf.placeholder(tf.float32, [None] + hwc, 'x') 39 | y_in = tf.placeholder(tf.float32, [batch, 2] + hwc, 'y') 40 | l_in = tf.placeholder(tf.int32, [batch], 'labels') 41 | l = tf.one_hot(l_in, self.nclass) 42 | wd *= lr 43 | warmup = tf.clip_by_value(tf.to_float(self.step) / (warmup_pos * (FLAGS.train_kimg << 10)), 0, 1) 44 | 45 | y = tf.reshape(tf.transpose(y_in, [1, 0, 2, 3, 4]), [-1] + hwc) 46 | y_1, y_2 = tf.split(y, 2) 47 | 48 | mix = tf.distributions.Beta(beta, beta).sample([tf.shape(x_in)[0], 1, 1, 1]) 49 | mix = tf.maximum(mix, 1 - mix) 50 | 51 | classifier = lambda x, **kw: self.classifier(x, **kw, **kwargs).logits 52 | logits_x = classifier(xt_in, training=True) 53 | post_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) # Take only first call to update batch norm. 54 | 55 | ema = tf.train.ExponentialMovingAverage(decay=ema) 56 | ema_op = ema.apply(utils.model_vars()) 57 | ema_getter = functools.partial(utils.getter_ema, ema) 58 | logits_teacher = classifier(y_1, training=True, getter=ema_getter) 59 | labels_teacher = tf.stop_gradient(tf.nn.softmax(logits_teacher)) 60 | labels_teacher = labels_teacher * mix[:, :, 0, 0] + labels_teacher[::-1] * (1 - mix[:, :, 0, 0]) 61 | logits_student = classifier(y_1 * mix + y_1[::-1] * (1 - mix), training=True) 62 | loss_mt = tf.reduce_mean((labels_teacher - tf.nn.softmax(logits_student)) ** 2, -1) 63 | loss_mt = tf.reduce_mean(loss_mt) 64 | 65 | loss = tf.nn.softmax_cross_entropy_with_logits_v2(labels=l, logits=logits_x) 66 | loss = tf.reduce_mean(loss) 67 | tf.summary.scalar('losses/xe', loss) 68 | tf.summary.scalar('losses/mt', loss_mt) 69 | 70 | post_ops.append(ema_op) 71 | post_ops.extend([tf.assign(v, v * (1 - wd)) for v in utils.model_vars('classify') if 'kernel' in v.name]) 72 | 73 | train_op = tf.train.AdamOptimizer(lr).minimize(loss + loss_mt * warmup * consistency_weight, 74 | colocate_gradients_with_ops=True) 75 | with tf.control_dependencies([train_op]): 76 | train_op = tf.group(*post_ops) 77 | 78 | return EasyDict( 79 | xt=xt_in, x=x_in, y=y_in, label=l_in, train_op=train_op, 80 | classify_raw=tf.nn.softmax(classifier(x_in, training=False)), # No EMA, for debugging. 81 | classify_op=tf.nn.softmax(classifier(x_in, getter=ema_getter, training=False))) 82 | 83 | 84 | def main(argv): 85 | utils.setup_main() 86 | del argv # Unused. 87 | dataset = PAIR_DATASETS()[FLAGS.dataset]() 88 | log_width = utils.ilog2(dataset.width) 89 | model = ICT( 90 | os.path.join(FLAGS.train_dir, dataset.name), 91 | dataset, 92 | lr=FLAGS.lr, 93 | wd=FLAGS.wd, 94 | arch=FLAGS.arch, 95 | warmup_pos=FLAGS.warmup_pos, 96 | batch=FLAGS.batch, 97 | nclass=dataset.nclass, 98 | ema=FLAGS.ema, 99 | beta=FLAGS.beta, 100 | consistency_weight=FLAGS.consistency_weight, 101 | 102 | scales=FLAGS.scales or (log_width - 2), 103 | filters=FLAGS.filters, 104 | repeat=FLAGS.repeat) 105 | model.train(FLAGS.train_kimg << 10, FLAGS.report_kimg << 10) 106 | 107 | 108 | if __name__ == '__main__': 109 | utils.setup_tf() 110 | flags.DEFINE_float('consistency_weight', 50., 'Consistency weight.') 111 | flags.DEFINE_float('warmup_pos', 0.4, 'Relative position at which constraint loss warmup ends.') 112 | flags.DEFINE_float('wd', 0.02, 'Weight decay.') 113 | flags.DEFINE_float('ema', 0.999, 'Exponential moving average of params.') 114 | flags.DEFINE_float('beta', 0.5, 'Mixup beta.') 115 | flags.DEFINE_integer('scales', 0, 'Number of 2x2 downscalings in the classifier.') 116 | flags.DEFINE_integer('filters', 32, 'Filter size of convolutions.') 117 | flags.DEFINE_integer('repeat', 4, 'Number of residual layers per stage.') 118 | FLAGS.set_default('augment', 'd.d.d') 119 | FLAGS.set_default('dataset', 'cifar10.3@250-5000') 120 | FLAGS.set_default('batch', 64) 121 | FLAGS.set_default('lr', 0.002) 122 | FLAGS.set_default('train_kimg', 1 << 16) 123 | app.run(main) 124 | -------------------------------------------------------------------------------- /libml/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /libml/ctaugment.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Control Theory based self-augmentation.""" 15 | import random 16 | from collections import namedtuple 17 | 18 | import numpy as np 19 | from PIL import Image, ImageOps, ImageEnhance, ImageFilter 20 | 21 | OPS = {} 22 | OP = namedtuple('OP', ('f', 'bins')) 23 | Sample = namedtuple('Sample', ('train', 'probe')) 24 | 25 | 26 | def register(*bins): 27 | def wrap(f): 28 | OPS[f.__name__] = OP(f, bins) 29 | return f 30 | 31 | return wrap 32 | 33 | 34 | def apply(x, ops): 35 | if ops is None: 36 | return x 37 | y = Image.fromarray(np.round(127.5 * (1 + x)).clip(0, 255).astype('uint8')) 38 | for op, args in ops: 39 | y = OPS[op].f(y, *args) 40 | return np.asarray(y).astype('f') / 127.5 - 1 41 | 42 | 43 | class CTAugment: 44 | def __init__(self, depth=2, th=0.85, decay=0.99): 45 | self.decay = decay 46 | self.depth = depth 47 | self.th = th 48 | self.rates = {} 49 | for k, op in OPS.items(): 50 | self.rates[k] = tuple([np.ones(x, 'f') for x in op.bins]) 51 | 52 | def rate_to_p(self, rate): 53 | p = rate + (1 - self.decay) 54 | p = p / p.max() 55 | p[p < self.th] = 0 56 | return p 57 | 58 | def policy(self, probe): 59 | kl = list(OPS.keys()) 60 | v = [] 61 | if probe: 62 | for _ in range(self.depth): 63 | k = random.choice(kl) 64 | bins = self.rates[k] 65 | rnd = np.random.uniform(0, 1, len(bins)) 66 | v.append(OP(k, rnd.tolist())) 67 | return v 68 | for _ in range(self.depth): 69 | vt = [] 70 | k = random.choice(kl) 71 | bins = self.rates[k] 72 | rnd = np.random.uniform(0, 1, len(bins)) 73 | for r, bin in zip(rnd, bins): 74 | p = self.rate_to_p(bin) 75 | segments = p[1:] + p[:-1] 76 | segment = np.random.choice(segments.shape[0], p=segments / segments.sum()) 77 | vt.append((segment + r) / segments.shape[0]) 78 | v.append(OP(k, vt)) 79 | return v 80 | 81 | def update_rates(self, policy, accuracy): 82 | for k, bins in policy: 83 | for p, rate in zip(bins, self.rates[k]): 84 | p = int(p * len(rate) * 0.999) 85 | rate[p] = rate[p] * self.decay + accuracy * (1 - self.decay) 86 | 87 | def stats(self): 88 | return '\n'.join('%-16s %s' % (k, ' / '.join(' '.join('%.2f' % x for x in self.rate_to_p(rate)) 89 | for rate in self.rates[k])) 90 | for k in sorted(OPS.keys())) 91 | 92 | 93 | def _enhance(x, op, level): 94 | return op(x).enhance(0.1 + 1.9 * level) 95 | 96 | 97 | def _imageop(x, op, level): 98 | return Image.blend(x, op(x), level) 99 | 100 | 101 | def _filter(x, op, level): 102 | return Image.blend(x, x.filter(op), level) 103 | 104 | 105 | @register(17) 106 | def autocontrast(x, level): 107 | return _imageop(x, ImageOps.autocontrast, level) 108 | 109 | 110 | @register(17) 111 | def blur(x, level): 112 | return _filter(x, ImageFilter.BLUR, level) 113 | 114 | 115 | @register(17) 116 | def brightness(x, brightness): 117 | return _enhance(x, ImageEnhance.Brightness, brightness) 118 | 119 | 120 | @register(17) 121 | def color(x, color): 122 | return _enhance(x, ImageEnhance.Color, color) 123 | 124 | 125 | @register(17) 126 | def contrast(x, contrast): 127 | return _enhance(x, ImageEnhance.Contrast, contrast) 128 | 129 | 130 | @register(17) 131 | def cutout(x, level): 132 | """Apply cutout to pil_img at the specified level.""" 133 | size = 1 + int(level * min(x.size) * 0.499) 134 | img_height, img_width = x.size 135 | height_loc = np.random.randint(low=0, high=img_height) 136 | width_loc = np.random.randint(low=0, high=img_width) 137 | upper_coord = (max(0, height_loc - size // 2), max(0, width_loc - size // 2)) 138 | lower_coord = (min(img_height, height_loc + size // 2), min(img_width, width_loc + size // 2)) 139 | pixels = x.load() # create the pixel map 140 | for i in range(upper_coord[0], lower_coord[0]): # for every col: 141 | for j in range(upper_coord[1], lower_coord[1]): # For every row 142 | pixels[i, j] = (127, 127, 127) # set the colour accordingly 143 | return x 144 | 145 | 146 | @register(17) 147 | def equalize(x, level): 148 | return _imageop(x, ImageOps.equalize, level) 149 | 150 | 151 | @register(17) 152 | def invert(x, level): 153 | return _imageop(x, ImageOps.invert, level) 154 | 155 | 156 | @register() 157 | def identity(x): 158 | return x 159 | 160 | 161 | @register(8) 162 | def posterize(x, level): 163 | level = 1 + int(level * 7.999) 164 | return ImageOps.posterize(x, level) 165 | 166 | 167 | @register(17, 6) 168 | def rescale(x, scale, method): 169 | s = x.size 170 | scale *= 0.25 171 | crop = (scale, scale, s[0] - scale, s[1] - scale) 172 | methods = (Image.ANTIALIAS, Image.BICUBIC, Image.BILINEAR, Image.BOX, Image.HAMMING, Image.NEAREST) 173 | method = methods[int(method * 5.99)] 174 | return x.crop(crop).resize(x.size, method) 175 | 176 | 177 | @register(17) 178 | def rotate(x, angle): 179 | angle = int(np.round((2 * angle - 1) * 45)) 180 | return x.rotate(angle) 181 | 182 | 183 | @register(17) 184 | def sharpness(x, sharpness): 185 | return _enhance(x, ImageEnhance.Sharpness, sharpness) 186 | 187 | 188 | @register(17) 189 | def shear_x(x, shear): 190 | shear = (2 * shear - 1) * 0.3 191 | return x.transform(x.size, Image.AFFINE, (1, shear, 0, 0, 1, 0)) 192 | 193 | 194 | @register(17) 195 | def shear_y(x, shear): 196 | shear = (2 * shear - 1) * 0.3 197 | return x.transform(x.size, Image.AFFINE, (1, 0, 0, shear, 1, 0)) 198 | 199 | 200 | @register(17) 201 | def smooth(x, level): 202 | return _filter(x, ImageFilter.SMOOTH, level) 203 | 204 | 205 | @register(17) 206 | def solarize(x, th): 207 | th = int(th * 255.999) 208 | return ImageOps.solarize(x, th) 209 | 210 | 211 | @register(17) 212 | def translate_x(x, delta): 213 | delta = (2 * delta - 1) * 0.3 214 | return x.transform(x.size, Image.AFFINE, (1, 0, delta, 0, 1, 0)) 215 | 216 | 217 | @register(17) 218 | def translate_y(x, delta): 219 | delta = (2 * delta - 1) * 0.3 220 | return x.transform(x.size, Image.AFFINE, (1, 0, 0, 0, 1, delta)) 221 | -------------------------------------------------------------------------------- /libml/layers.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Custom neural network layers and primitives. 15 | """ 16 | import numbers 17 | 18 | import numpy as np 19 | import tensorflow as tf 20 | 21 | from libml.data import DataSets 22 | 23 | 24 | def smart_shape(x): 25 | s, t = x.shape, tf.shape(x) 26 | return [t[i] if s[i].value is None else s[i] for i in range(len(s))] 27 | 28 | 29 | def entropy_from_logits(logits): 30 | """Computes entropy from classifier logits. 31 | 32 | Args: 33 | logits: a tensor of shape (batch_size, class_count) representing the 34 | logits of a classifier. 35 | 36 | Returns: 37 | A tensor of shape (batch_size,) of floats giving the entropies 38 | batchwise. 39 | """ 40 | distribution = tf.contrib.distributions.Categorical(logits=logits) 41 | return distribution.entropy() 42 | 43 | 44 | def entropy_penalty(logits, entropy_penalty_multiplier, mask): 45 | """Computes an entropy penalty using the classifier logits. 46 | 47 | Args: 48 | logits: a tensor of shape (batch_size, class_count) representing the 49 | logits of a classifier. 50 | entropy_penalty_multiplier: A float by which the entropy is multiplied. 51 | mask: A tensor that optionally masks out some of the costs. 52 | 53 | Returns: 54 | The mean entropy penalty 55 | """ 56 | entropy = entropy_from_logits(logits) 57 | losses = entropy * entropy_penalty_multiplier 58 | losses *= tf.cast(mask, tf.float32) 59 | return tf.reduce_mean(losses) 60 | 61 | 62 | def kl_divergence_from_logits(logits_a, logits_b): 63 | """Gets KL divergence from logits parameterizing categorical distributions. 64 | 65 | Args: 66 | logits_a: A tensor of logits parameterizing the first distribution. 67 | logits_b: A tensor of logits parameterizing the second distribution. 68 | 69 | Returns: 70 | The (batch_size,) shaped tensor of KL divergences. 71 | """ 72 | distribution1 = tf.contrib.distributions.Categorical(logits=logits_a) 73 | distribution2 = tf.contrib.distributions.Categorical(logits=logits_b) 74 | return tf.contrib.distributions.kl_divergence(distribution1, distribution2) 75 | 76 | 77 | def mse_from_logits(output_logits, target_logits): 78 | """Computes MSE between predictions associated with logits. 79 | 80 | Args: 81 | output_logits: A tensor of logits from the primary model. 82 | target_logits: A tensor of logits from the secondary model. 83 | 84 | Returns: 85 | The mean MSE 86 | """ 87 | diffs = tf.nn.softmax(output_logits) - tf.nn.softmax(target_logits) 88 | squared_diffs = tf.square(diffs) 89 | return tf.reduce_mean(squared_diffs, -1) 90 | 91 | 92 | def interleave_offsets(batch, nu): 93 | groups = [batch // (nu + 1)] * (nu + 1) 94 | for x in range(batch - sum(groups)): 95 | groups[-x - 1] += 1 96 | offsets = [0] 97 | for g in groups: 98 | offsets.append(offsets[-1] + g) 99 | assert offsets[-1] == batch 100 | return offsets 101 | 102 | 103 | def interleave(xy, batch): 104 | nu = len(xy) - 1 105 | offsets = interleave_offsets(batch, nu) 106 | xy = [[v[offsets[p]:offsets[p + 1]] for p in range(nu + 1)] for v in xy] 107 | for i in range(1, nu + 1): 108 | xy[0][i], xy[i][i] = xy[i][i], xy[0][i] 109 | return [tf.concat(v, axis=0) for v in xy] 110 | 111 | 112 | def renorm(v): 113 | return v / tf.reduce_sum(v, axis=-1, keepdims=True) 114 | 115 | 116 | def shakeshake(a, b, training): 117 | if not training: 118 | return 0.5 * (a + b) 119 | mu = tf.random_uniform([tf.shape(a)[0]] + [1] * (len(a.shape) - 1), 0, 1) 120 | mixf = a + mu * (b - a) 121 | mixb = a + mu[::1] * (b - a) 122 | return tf.stop_gradient(mixf - mixb) + mixb 123 | 124 | 125 | class PMovingAverage: 126 | def __init__(self, name, nclass, buf_size): 127 | # MEAN aggregation is used by DistributionStrategy to aggregate 128 | # variable updates across shards 129 | self.ma = tf.Variable(tf.ones([buf_size, nclass]) / nclass, 130 | trainable=False, 131 | name=name, 132 | aggregation=tf.VariableAggregation.MEAN) 133 | 134 | def __call__(self): 135 | v = tf.reduce_mean(self.ma, axis=0) 136 | return v / tf.reduce_sum(v) 137 | 138 | def update(self, entry): 139 | entry = tf.reduce_mean(entry, axis=0) 140 | return tf.assign(self.ma, tf.concat([self.ma[1:], [entry]], axis=0)) 141 | 142 | 143 | class PData: 144 | def __init__(self, dataset: DataSets): 145 | self.has_update = False 146 | if dataset.p_unlabeled is not None: 147 | self.p_data = tf.constant(dataset.p_unlabeled, name='p_data') 148 | elif dataset.p_labeled is not None: 149 | self.p_data = tf.constant(dataset.p_labeled, name='p_data') 150 | else: 151 | # MEAN aggregation is used by DistributionStrategy to aggregate 152 | # variable updates across shards 153 | self.p_data = tf.Variable(renorm(tf.ones([dataset.nclass])), 154 | trainable=False, 155 | name='p_data', 156 | aggregation=tf.VariableAggregation.MEAN) 157 | self.has_update = True 158 | 159 | def __call__(self): 160 | return self.p_data / tf.reduce_sum(self.p_data) 161 | 162 | def update(self, entry, decay=0.999): 163 | entry = tf.reduce_mean(entry, axis=0) 164 | return tf.assign(self.p_data, self.p_data * decay + entry * (1 - decay)) 165 | 166 | 167 | class MixMode: 168 | # A class for mixing data for various combination of labeled and unlabeled. 169 | # x = labeled example 170 | # y = unlabeled example 171 | # For example "xx.yxy" means: mix x with x, mix y with both x and y. 172 | MODES = 'xx.yy xxy.yxy xx.yxy xx.yx xx. .yy xxy. .yxy .'.split() 173 | 174 | def __init__(self, mode): 175 | assert mode in self.MODES 176 | self.mode = mode 177 | 178 | @staticmethod 179 | def augment_pair(x0, l0, x1, l1, beta, **kwargs): 180 | del kwargs 181 | if isinstance(beta, numbers.Integral) and beta <= 0: 182 | return x0, l0 183 | 184 | def np_beta(s, beta): # TF implementation seems unreliable for beta below 0.2 185 | return np.random.beta(beta, beta, s).astype('f') 186 | 187 | with tf.device('/cpu'): 188 | mix = tf.py_func(np_beta, [tf.shape(x0)[0], beta], tf.float32) 189 | mix = tf.reshape(tf.maximum(mix, 1 - mix), [tf.shape(x0)[0], 1, 1, 1]) 190 | index = tf.random_shuffle(tf.range(tf.shape(x0)[0])) 191 | xs = tf.gather(x1, index) 192 | ls = tf.gather(l1, index) 193 | xmix = x0 * mix + xs * (1 - mix) 194 | lmix = l0 * mix[:, :, 0, 0] + ls * (1 - mix[:, :, 0, 0]) 195 | return xmix, lmix 196 | 197 | @staticmethod 198 | def augment(x, l, beta, **kwargs): 199 | return MixMode.augment_pair(x, l, x, l, beta, **kwargs) 200 | 201 | def __call__(self, xl: list, ll: list, betal: list): 202 | assert len(xl) == len(ll) >= 2 203 | assert len(betal) == 2 204 | if self.mode == '.': 205 | return xl, ll 206 | elif self.mode == 'xx.': 207 | mx0, ml0 = self.augment(xl[0], ll[0], betal[0]) 208 | return [mx0] + xl[1:], [ml0] + ll[1:] 209 | elif self.mode == '.yy': 210 | mx1, ml1 = self.augment( 211 | tf.concat(xl[1:], 0), tf.concat(ll[1:], 0), betal[1]) 212 | return (xl[:1] + tf.split(mx1, len(xl) - 1), 213 | ll[:1] + tf.split(ml1, len(ll) - 1)) 214 | elif self.mode == 'xx.yy': 215 | mx0, ml0 = self.augment(xl[0], ll[0], betal[0]) 216 | mx1, ml1 = self.augment( 217 | tf.concat(xl[1:], 0), tf.concat(ll[1:], 0), betal[1]) 218 | return ([mx0] + tf.split(mx1, len(xl) - 1), 219 | [ml0] + tf.split(ml1, len(ll) - 1)) 220 | elif self.mode == 'xxy.': 221 | mx, ml = self.augment( 222 | tf.concat(xl, 0), tf.concat(ll, 0), 223 | sum(betal) / len(betal)) 224 | return (tf.split(mx, len(xl))[:1] + xl[1:], 225 | tf.split(ml, len(ll))[:1] + ll[1:]) 226 | elif self.mode == '.yxy': 227 | mx, ml = self.augment( 228 | tf.concat(xl, 0), tf.concat(ll, 0), 229 | sum(betal) / len(betal)) 230 | return (xl[:1] + tf.split(mx, len(xl))[1:], 231 | ll[:1] + tf.split(ml, len(ll))[1:]) 232 | elif self.mode == 'xxy.yxy': 233 | mx, ml = self.augment( 234 | tf.concat(xl, 0), tf.concat(ll, 0), 235 | sum(betal) / len(betal)) 236 | return tf.split(mx, len(xl)), tf.split(ml, len(ll)) 237 | elif self.mode == 'xx.yxy': 238 | mx0, ml0 = self.augment(xl[0], ll[0], betal[0]) 239 | mx1, ml1 = self.augment(tf.concat(xl, 0), tf.concat(ll, 0), betal[1]) 240 | mx1, ml1 = [tf.split(m, len(xl))[1:] for m in (mx1, ml1)] 241 | return [mx0] + mx1, [ml0] + ml1 242 | elif self.mode == 'xx.yx': 243 | mx0, ml0 = self.augment(xl[0], ll[0], betal[0]) 244 | mx1, ml1 = zip(*[ 245 | self.augment_pair(xl[i], ll[i], xl[0], ll[0], betal[1]) 246 | for i in range(1, len(xl)) 247 | ]) 248 | return [mx0] + list(mx1), [ml0] + list(ml1) 249 | raise NotImplementedError(self.mode) 250 | -------------------------------------------------------------------------------- /libml/models.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Classifier architectures.""" 15 | import functools 16 | import itertools 17 | 18 | import tensorflow as tf 19 | from absl import flags 20 | 21 | from libml import layers 22 | from libml.train import ClassifySemi 23 | from libml.utils import EasyDict 24 | 25 | 26 | class CNN13(ClassifySemi): 27 | """Simplified reproduction of the Mean Teacher paper network. filters=128 in original implementation. 28 | Removed dropout, Gaussians, forked dense layers, basically all non-standard things.""" 29 | 30 | def classifier(self, x, scales, filters, training, getter=None, **kwargs): 31 | del kwargs 32 | assert scales == 3 # Only specified for 32x32 inputs. 33 | conv_args = dict(kernel_size=3, activation=tf.nn.leaky_relu, padding='same') 34 | bn_args = dict(training=training, momentum=0.999) 35 | 36 | with tf.variable_scope('classify', reuse=tf.AUTO_REUSE, custom_getter=getter): 37 | y = tf.layers.conv2d((x - self.dataset.mean) / self.dataset.std, filters, **conv_args) 38 | y = tf.layers.batch_normalization(y, **bn_args) 39 | y = tf.layers.conv2d(y, filters, **conv_args) 40 | y = tf.layers.batch_normalization(y, **bn_args) 41 | y = tf.layers.conv2d(y, filters, **conv_args) 42 | y = tf.layers.batch_normalization(y, **bn_args) 43 | y = tf.layers.max_pooling2d(y, 2, 2) 44 | y = tf.layers.conv2d(y, 2 * filters, **conv_args) 45 | y = tf.layers.batch_normalization(y, **bn_args) 46 | y = tf.layers.conv2d(y, 2 * filters, **conv_args) 47 | y = tf.layers.batch_normalization(y, **bn_args) 48 | y = tf.layers.conv2d(y, 2 * filters, **conv_args) 49 | y = tf.layers.batch_normalization(y, **bn_args) 50 | y = tf.layers.max_pooling2d(y, 2, 2) 51 | y = tf.layers.conv2d(y, 4 * filters, kernel_size=3, activation=tf.nn.leaky_relu, padding='valid') 52 | y = tf.layers.batch_normalization(y, **bn_args) 53 | y = tf.layers.conv2d(y, 2 * filters, kernel_size=1, activation=tf.nn.leaky_relu, padding='same') 54 | y = tf.layers.batch_normalization(y, **bn_args) 55 | y = tf.layers.conv2d(y, 1 * filters, kernel_size=1, activation=tf.nn.leaky_relu, padding='same') 56 | y = tf.layers.batch_normalization(y, **bn_args) 57 | y = tf.reduce_mean(y, [1, 2]) # (b, 6, 6, 128) -> (b, 128) 58 | logits = tf.layers.dense(y, self.nclass) 59 | return EasyDict(logits=logits, embeds=y) 60 | 61 | 62 | class ResNet(ClassifySemi): 63 | def classifier(self, x, scales, filters, repeat, training, getter=None, dropout=0, **kwargs): 64 | del kwargs 65 | leaky_relu = functools.partial(tf.nn.leaky_relu, alpha=0.1) 66 | bn_args = dict(training=training, momentum=0.999) 67 | 68 | def conv_args(k, f): 69 | return dict(padding='same', 70 | kernel_initializer=tf.random_normal_initializer(stddev=tf.rsqrt(0.5 * k * k * f))) 71 | 72 | def residual(x0, filters, stride=1, activate_before_residual=False): 73 | x = leaky_relu(tf.layers.batch_normalization(x0, **bn_args)) 74 | if activate_before_residual: 75 | x0 = x 76 | 77 | x = tf.layers.conv2d(x, filters, 3, strides=stride, **conv_args(3, filters)) 78 | x = leaky_relu(tf.layers.batch_normalization(x, **bn_args)) 79 | x = tf.layers.conv2d(x, filters, 3, **conv_args(3, filters)) 80 | 81 | if x0.get_shape()[3] != filters: 82 | x0 = tf.layers.conv2d(x0, filters, 1, strides=stride, **conv_args(1, filters)) 83 | 84 | return x0 + x 85 | 86 | with tf.variable_scope('classify', reuse=tf.AUTO_REUSE, custom_getter=getter): 87 | y = tf.layers.conv2d((x - self.dataset.mean) / self.dataset.std, 16, 3, **conv_args(3, 16)) 88 | for scale in range(scales): 89 | y = residual(y, filters << scale, stride=2 if scale else 1, activate_before_residual=scale == 0) 90 | for i in range(repeat - 1): 91 | y = residual(y, filters << scale) 92 | 93 | y = leaky_relu(tf.layers.batch_normalization(y, **bn_args)) 94 | y = embeds = tf.reduce_mean(y, [1, 2]) 95 | if dropout and training: 96 | y = tf.nn.dropout(y, 1 - dropout) 97 | logits = tf.layers.dense(y, self.nclass, kernel_initializer=tf.glorot_normal_initializer()) 98 | return EasyDict(logits=logits, embeds=embeds) 99 | 100 | 101 | class ShakeNet(ClassifySemi): 102 | def classifier(self, x, scales, filters, repeat, training, getter=None, dropout=0, **kwargs): 103 | del kwargs 104 | bn_args = dict(training=training, momentum=0.999) 105 | 106 | def conv_args(k, f): 107 | return dict(padding='same', use_bias=False, 108 | kernel_initializer=tf.random_normal_initializer(stddev=tf.rsqrt(0.5 * k * k * f))) 109 | 110 | def residual(x0, filters, stride=1): 111 | def branch(): 112 | x = tf.nn.relu(x0) 113 | x = tf.layers.conv2d(x, filters, 3, strides=stride, **conv_args(3, filters)) 114 | x = tf.nn.relu(tf.layers.batch_normalization(x, **bn_args)) 115 | x = tf.layers.conv2d(x, filters, 3, **conv_args(3, filters)) 116 | x = tf.layers.batch_normalization(x, **bn_args) 117 | return x 118 | 119 | x = layers.shakeshake(branch(), branch(), training) 120 | 121 | if stride == 2: 122 | x1 = tf.layers.conv2d(tf.nn.relu(x0[:, ::2, ::2]), filters >> 1, 1, **conv_args(1, filters >> 1)) 123 | x2 = tf.layers.conv2d(tf.nn.relu(x0[:, 1::2, 1::2]), filters >> 1, 1, **conv_args(1, filters >> 1)) 124 | x0 = tf.concat([x1, x2], axis=3) 125 | x0 = tf.layers.batch_normalization(x0, **bn_args) 126 | elif x0.get_shape()[3] != filters: 127 | x0 = tf.layers.conv2d(x0, filters, 1, **conv_args(1, filters)) 128 | x0 = tf.layers.batch_normalization(x0, **bn_args) 129 | 130 | return x0 + x 131 | 132 | with tf.variable_scope('classify', reuse=tf.AUTO_REUSE, custom_getter=getter): 133 | y = tf.layers.conv2d((x - self.dataset.mean) / self.dataset.std, 16, 3, **conv_args(3, 16)) 134 | for scale, i in itertools.product(range(scales), range(repeat)): 135 | with tf.variable_scope('layer%d.%d' % (scale + 1, i)): 136 | if i == 0: 137 | y = residual(y, filters << scale, stride=2 if scale else 1) 138 | else: 139 | y = residual(y, filters << scale) 140 | 141 | y = embeds = tf.reduce_mean(y, [1, 2]) 142 | if dropout and training: 143 | y = tf.nn.dropout(y, 1 - dropout) 144 | logits = tf.layers.dense(y, self.nclass, kernel_initializer=tf.glorot_normal_initializer()) 145 | return EasyDict(logits=logits, embeds=embeds) 146 | 147 | 148 | class MultiModel(CNN13, ResNet, ShakeNet): 149 | MODELS = ('cnn13', 'resnet', 'shake') 150 | MODEL_CNN13, MODEL_RESNET, MODEL_SHAKE = MODELS 151 | 152 | def augment(self, x, l, smoothing, **kwargs): 153 | del kwargs 154 | return x, l - smoothing * (l - 1. / self.nclass) 155 | 156 | def classifier(self, x, arch, **kwargs): 157 | if arch == self.MODEL_CNN13: 158 | return CNN13.classifier(self, x, **kwargs) 159 | elif arch == self.MODEL_RESNET: 160 | return ResNet.classifier(self, x, **kwargs) 161 | elif arch == self.MODEL_SHAKE: 162 | return ShakeNet.classifier(self, x, **kwargs) 163 | raise ValueError('Model %s does not exists, available ones are %s' % (arch, self.MODELS)) 164 | 165 | 166 | flags.DEFINE_enum('arch', MultiModel.MODEL_RESNET, MultiModel.MODELS, 'Architecture.') 167 | -------------------------------------------------------------------------------- /libml/train.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Training loop, checkpoint saving and loading, evaluation code.""" 15 | 16 | import json 17 | import os.path 18 | import shutil 19 | 20 | import numpy as np 21 | import tensorflow as tf 22 | from absl import flags 23 | from tqdm import trange, tqdm 24 | 25 | from libml import data, utils 26 | from libml.utils import EasyDict 27 | 28 | FLAGS = flags.FLAGS 29 | flags.DEFINE_string('train_dir', './experiments', 30 | 'Folder where to save training data.') 31 | flags.DEFINE_float('lr', 0.0001, 'Learning rate.') 32 | flags.DEFINE_integer('batch', 64, 'Batch size.') 33 | flags.DEFINE_integer('train_kimg', 1 << 14, 'Training duration in kibi-samples.') 34 | flags.DEFINE_integer('report_kimg', 64, 'Report summary period in kibi-samples.') 35 | flags.DEFINE_integer('save_kimg', 64, 'Save checkpoint period in kibi-samples.') 36 | flags.DEFINE_integer('keep_ckpt', 50, 'Number of checkpoints to keep.') 37 | flags.DEFINE_string('eval_ckpt', '', 'Checkpoint to evaluate. If provided, do not do training, just do eval.') 38 | flags.DEFINE_string('rerun', '', 'A string to identify a run if running multiple ones with same parameters.') 39 | 40 | 41 | class Model: 42 | def __init__(self, train_dir: str, dataset: data.DataSets, **kwargs): 43 | self.train_dir = os.path.join(train_dir, FLAGS.rerun, self.experiment_name(**kwargs)) 44 | self.params = EasyDict(kwargs) 45 | self.dataset = dataset 46 | self.session = None 47 | self.tmp = EasyDict(print_queue=[], cache=EasyDict()) 48 | self.step = tf.train.get_or_create_global_step() 49 | self.ops = self.model(**kwargs) 50 | self.ops.update_step = tf.assign_add(self.step, FLAGS.batch) 51 | self.add_summaries(**kwargs) 52 | 53 | print(' Config '.center(80, '-')) 54 | print('train_dir', self.train_dir) 55 | print('%-32s %s' % ('Model', self.__class__.__name__)) 56 | print('%-32s %s' % ('Dataset', dataset.name)) 57 | for k, v in sorted(kwargs.items()): 58 | print('%-32s %s' % (k, v)) 59 | print(' Model '.center(80, '-')) 60 | to_print = [tuple(['%s' % x for x in (v.name, np.prod(v.shape), v.shape)]) for v in utils.model_vars(None)] 61 | to_print.append(('Total', str(sum(int(x[1]) for x in to_print)), '')) 62 | sizes = [max([len(x[i]) for x in to_print]) for i in range(3)] 63 | fmt = '%%-%ds %%%ds %%%ds' % tuple(sizes) 64 | for x in to_print[:-1]: 65 | print(fmt % x) 66 | print() 67 | print(fmt % to_print[-1]) 68 | print('-' * 80) 69 | self._create_initial_files() 70 | 71 | @property 72 | def arg_dir(self): 73 | return os.path.join(self.train_dir, 'args') 74 | 75 | @property 76 | def checkpoint_dir(self): 77 | return os.path.join(self.train_dir, 'tf') 78 | 79 | def train_print(self, text): 80 | self.tmp.print_queue.append(text) 81 | 82 | def _create_initial_files(self): 83 | for dir in (self.checkpoint_dir, self.arg_dir): 84 | tf.gfile.MakeDirs(dir) 85 | self.save_args() 86 | 87 | def _reset_files(self): 88 | shutil.rmtree(self.train_dir) 89 | self._create_initial_files() 90 | 91 | def save_args(self, **extra_params): 92 | with tf.gfile.Open(os.path.join(self.arg_dir, 'args.json'), 'w') as f: 93 | json.dump({**self.params, **extra_params}, f, sort_keys=True, indent=4) 94 | 95 | @classmethod 96 | def load(cls, train_dir): 97 | with tf.gfile.Open(os.path.join(train_dir, 'args/args.json'), 'r') as f: 98 | params = json.load(f) 99 | instance = cls(train_dir=train_dir, **params) 100 | instance.train_dir = train_dir 101 | return instance 102 | 103 | def experiment_name(self, **kwargs): 104 | args = [x + str(y) for x, y in sorted(kwargs.items())] 105 | return '_'.join([self.__class__.__name__] + args) 106 | 107 | def eval_mode(self, ckpt=None): 108 | self.session = tf.Session(config=utils.get_config()) 109 | saver = tf.train.Saver() 110 | if ckpt is None: 111 | ckpt = utils.find_latest_checkpoint(self.checkpoint_dir) 112 | else: 113 | ckpt = os.path.abspath(ckpt) 114 | saver.restore(self.session, ckpt) 115 | self.tmp.step = self.session.run(self.step) 116 | print('Eval model %s at global_step %d' % (self.__class__.__name__, self.tmp.step)) 117 | return self 118 | 119 | def model(self, **kwargs): 120 | raise NotImplementedError() 121 | 122 | def add_summaries(self, **kwargs): 123 | raise NotImplementedError() 124 | 125 | 126 | class ClassifySemi(Model): 127 | """Semi-supervised classification.""" 128 | 129 | def __init__(self, train_dir: str, dataset: data.DataSets, nclass: int, **kwargs): 130 | self.nclass = nclass 131 | Model.__init__(self, train_dir, dataset, nclass=nclass, **kwargs) 132 | 133 | def train_step(self, train_session, gen_labeled, gen_unlabeled): 134 | x, y = gen_labeled(), gen_unlabeled() 135 | self.tmp.step = train_session.run([self.ops.train_op, self.ops.update_step], 136 | feed_dict={self.ops.y: y['image'], 137 | self.ops.xt: x['image'], 138 | self.ops.label: x['label']})[1] 139 | 140 | def gen_labeled_fn(self, data_iterator): 141 | return self.dataset.train_labeled.numpy_augment(lambda: self.session.run(data_iterator)) 142 | 143 | def gen_unlabeled_fn(self, data_iterator): 144 | return self.dataset.train_unlabeled.numpy_augment(lambda: self.session.run(data_iterator)) 145 | 146 | def train(self, train_nimg, report_nimg): 147 | if FLAGS.eval_ckpt: 148 | self.eval_checkpoint(FLAGS.eval_ckpt) 149 | return 150 | batch = FLAGS.batch 151 | train_labeled = self.dataset.train_labeled.repeat().shuffle(FLAGS.shuffle).parse().augment() 152 | train_labeled = train_labeled.batch(batch).prefetch(16).make_one_shot_iterator().get_next() 153 | train_unlabeled = self.dataset.train_unlabeled.repeat().shuffle(FLAGS.shuffle).parse().augment() 154 | train_unlabeled = train_unlabeled.batch(batch).prefetch(16).make_one_shot_iterator().get_next() 155 | scaffold = tf.train.Scaffold(saver=tf.train.Saver(max_to_keep=FLAGS.keep_ckpt, pad_step_number=10)) 156 | 157 | with tf.Session(config=utils.get_config()) as sess: 158 | self.session = sess 159 | self.cache_eval() 160 | 161 | with tf.train.MonitoredTrainingSession( 162 | scaffold=scaffold, 163 | checkpoint_dir=self.checkpoint_dir, 164 | config=utils.get_config(), 165 | save_checkpoint_steps=FLAGS.save_kimg << 10, 166 | save_summaries_steps=report_nimg - batch) as train_session: 167 | self.session = train_session._tf_sess() 168 | gen_labeled = self.gen_labeled_fn(train_labeled) 169 | gen_unlabeled = self.gen_unlabeled_fn(train_unlabeled) 170 | self.tmp.step = self.session.run(self.step) 171 | while self.tmp.step < train_nimg: 172 | loop = trange(self.tmp.step % report_nimg, report_nimg, batch, 173 | leave=False, unit='img', unit_scale=batch, 174 | desc='Epoch %d/%d' % (1 + (self.tmp.step // report_nimg), train_nimg // report_nimg)) 175 | for _ in loop: 176 | self.train_step(train_session, gen_labeled, gen_unlabeled) 177 | while self.tmp.print_queue: 178 | loop.write(self.tmp.print_queue.pop(0)) 179 | while self.tmp.print_queue: 180 | print(self.tmp.print_queue.pop(0)) 181 | 182 | def eval_checkpoint(self, ckpt=None): 183 | self.eval_mode(ckpt) 184 | self.cache_eval() 185 | raw = self.eval_stats(classify_op=self.ops.classify_raw) 186 | ema = self.eval_stats(classify_op=self.ops.classify_op) 187 | print('%16s %8s %8s %8s' % ('', 'labeled', 'valid', 'test')) 188 | print('%16s %8s %8s %8s' % (('raw',) + tuple('%.2f' % x for x in raw))) 189 | print('%16s %8s %8s %8s' % (('ema',) + tuple('%.2f' % x for x in ema))) 190 | 191 | def cache_eval(self): 192 | """Cache datasets for computing eval stats.""" 193 | 194 | def collect_samples(dataset, name): 195 | """Return numpy arrays of all the samples from a dataset.""" 196 | pbar = tqdm(desc='Caching %s examples' % name) 197 | it = dataset.batch(1).prefetch(16).make_one_shot_iterator().get_next() 198 | images, labels = [], [] 199 | while 1: 200 | try: 201 | v = self.session.run(it) 202 | except tf.errors.OutOfRangeError: 203 | break 204 | images.append(v['image']) 205 | labels.append(v['label']) 206 | pbar.update() 207 | 208 | images = np.concatenate(images, axis=0) 209 | labels = np.concatenate(labels, axis=0) 210 | pbar.close() 211 | return images, labels 212 | 213 | if 'test' not in self.tmp.cache: 214 | self.tmp.cache.test = collect_samples(self.dataset.test.parse(), name='test') 215 | self.tmp.cache.valid = collect_samples(self.dataset.valid.parse(), name='valid') 216 | self.tmp.cache.train_labeled = collect_samples(self.dataset.train_labeled.take(10000).parse(), 217 | name='train_labeled') 218 | 219 | def eval_stats(self, batch=None, feed_extra=None, classify_op=None): 220 | """Evaluate model on train, valid and test.""" 221 | batch = batch or FLAGS.batch 222 | classify_op = self.ops.classify_op if classify_op is None else classify_op 223 | accuracies = [] 224 | for subset in ('train_labeled', 'valid', 'test'): 225 | images, labels = self.tmp.cache[subset] 226 | predicted = [] 227 | 228 | for x in range(0, images.shape[0], batch): 229 | p = self.session.run( 230 | classify_op, 231 | feed_dict={ 232 | self.ops.x: images[x:x + batch], 233 | **(feed_extra or {}) 234 | }) 235 | predicted.append(p) 236 | predicted = np.concatenate(predicted, axis=0) 237 | accuracies.append((predicted.argmax(1) == labels).mean() * 100) 238 | self.train_print('kimg %-5d accuracy train/valid/test %.2f %.2f %.2f' % 239 | tuple([self.tmp.step >> 10] + accuracies)) 240 | return np.array(accuracies, 'f') 241 | 242 | def add_summaries(self, feed_extra=None, **kwargs): 243 | del kwargs 244 | 245 | def gen_stats(): 246 | return self.eval_stats(feed_extra=feed_extra) 247 | 248 | accuracies = tf.py_func(gen_stats, [], tf.float32) 249 | tf.summary.scalar('accuracy/train_labeled', accuracies[0]) 250 | tf.summary.scalar('accuracy/valid', accuracies[1]) 251 | tf.summary.scalar('accuracy', accuracies[2]) 252 | -------------------------------------------------------------------------------- /libml/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Utilities.""" 15 | 16 | import os 17 | import re 18 | 19 | import numpy as np 20 | import tensorflow as tf 21 | from absl import flags, logging 22 | from tensorflow.python.client import device_lib 23 | 24 | _GPUS = None 25 | FLAGS = flags.FLAGS 26 | flags.DEFINE_bool('log_device_placement', False, 'For debugging purpose.') 27 | 28 | 29 | class EasyDict(dict): 30 | def __init__(self, *args, **kwargs): 31 | super(EasyDict, self).__init__(*args, **kwargs) 32 | self.__dict__ = self 33 | 34 | 35 | def get_config(): 36 | config = tf.ConfigProto() 37 | if len(get_available_gpus()) > 1: 38 | config.allow_soft_placement = True 39 | if FLAGS.log_device_placement: 40 | config.log_device_placement = True 41 | config.gpu_options.allow_growth = True 42 | return config 43 | 44 | 45 | def setup_main(): 46 | pass 47 | 48 | 49 | def setup_tf(): 50 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' 51 | logging.set_verbosity(logging.ERROR) 52 | 53 | 54 | def smart_shape(x): 55 | s = x.shape 56 | st = tf.shape(x) 57 | return [s[i] if s[i].value is not None else st[i] for i in range(4)] 58 | 59 | 60 | def ilog2(x): 61 | """Integer log2.""" 62 | return int(np.ceil(np.log2(x))) 63 | 64 | 65 | def find_latest_checkpoint(dir, glob_term='model.ckpt-*.meta'): 66 | """Replacement for tf.train.latest_checkpoint. 67 | 68 | It does not rely on the "checkpoint" file which sometimes contains 69 | absolute path and is generally hard to work with when sharing files 70 | between users / computers. 71 | """ 72 | r_step = re.compile('.*model\.ckpt-(?P\d+)\.meta') 73 | matches = tf.gfile.Glob(os.path.join(dir, glob_term)) 74 | matches = [(int(r_step.match(x).group('step')), x) for x in matches] 75 | ckpt_file = max(matches)[1][:-5] 76 | return ckpt_file 77 | 78 | 79 | def get_latest_global_step(dir): 80 | """Loads the global step from the latest checkpoint in directory. 81 | 82 | Args: 83 | dir: string, path to the checkpoint directory. 84 | 85 | Returns: 86 | int, the global step of the latest checkpoint or 0 if none was found. 87 | """ 88 | try: 89 | checkpoint_reader = tf.train.NewCheckpointReader(find_latest_checkpoint(dir)) 90 | return checkpoint_reader.get_tensor(tf.GraphKeys.GLOBAL_STEP) 91 | except: # pylint: disable=bare-except 92 | return 0 93 | 94 | 95 | def get_latest_global_step_in_subdir(dir): 96 | """Loads the global step from the latest checkpoint in sub-directories. 97 | 98 | Args: 99 | dir: string, parent of the checkpoint directories. 100 | 101 | Returns: 102 | int, the global step of the latest checkpoint or 0 if none was found. 103 | """ 104 | sub_dirs = (x for x in tf.gfile.Glob(os.path.join(dir, '*')) if os.path.isdir(x)) 105 | step = 0 106 | for x in sub_dirs: 107 | step = max(step, get_latest_global_step(x)) 108 | return step 109 | 110 | 111 | def getter_ema(ema, getter, name, *args, **kwargs): 112 | """Exponential moving average getter for variable scopes. 113 | 114 | Args: 115 | ema: ExponentialMovingAverage object, where to get variable moving averages. 116 | getter: default variable scope getter. 117 | name: variable name. 118 | *args: extra args passed to default getter. 119 | **kwargs: extra args passed to default getter. 120 | 121 | Returns: 122 | If found the moving average variable, otherwise the default variable. 123 | """ 124 | var = getter(name, *args, **kwargs) 125 | ema_var = ema.average(var) 126 | return ema_var if ema_var else var 127 | 128 | 129 | def model_vars(scope=None): 130 | return tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope) 131 | 132 | 133 | def gpu(x): 134 | return '/gpu:%d' % (x % max(1, len(get_available_gpus()))) 135 | 136 | 137 | def get_available_gpus(): 138 | global _GPUS 139 | if _GPUS is None: 140 | config = tf.ConfigProto() 141 | config.gpu_options.allow_growth = True 142 | local_device_protos = device_lib.list_local_devices(session_config=config) 143 | _GPUS = tuple([x.name for x in local_device_protos if x.device_type == 'GPU']) 144 | return _GPUS 145 | 146 | 147 | def get_gpu(): 148 | gpus = get_available_gpus() 149 | pos = 0 150 | while 1: 151 | yield gpus[pos] 152 | pos = (pos + 1) % len(gpus) 153 | 154 | 155 | def average_gradients(tower_grads): 156 | # Adapted from: 157 | # https://github.com/tensorflow/models/blob/master/tutorials/image/cifar10/cifar10_multi_gpu_train.py 158 | """Calculate the average gradient for each shared variable across all towers. 159 | Note that this function provides a synchronization point across all towers. 160 | Args: 161 | tower_grads: List of lists of (gradient, variable) tuples. For each tower, a list of its gradients. 162 | Returns: 163 | List of pairs of (gradient, variable) where the gradient has been averaged 164 | across all towers. 165 | """ 166 | if len(tower_grads) <= 1: 167 | return tower_grads[0] 168 | 169 | average_grads = [] 170 | for grads_and_vars in zip(*tower_grads): 171 | grad = tf.reduce_mean([gv[0] for gv in grads_and_vars], 0) 172 | average_grads.append((grad, grads_and_vars[0][1])) 173 | return average_grads 174 | 175 | 176 | def para_list(fn, *args): 177 | """Run on multiple GPUs in parallel and return list of results.""" 178 | gpus = len(get_available_gpus()) 179 | if gpus <= 1: 180 | return zip(*[fn(*args)]) 181 | splitted = [tf.split(x, gpus) for x in args] 182 | outputs = [] 183 | for gpu, x in enumerate(zip(*splitted)): 184 | with tf.name_scope('tower%d' % gpu): 185 | with tf.device(tf.train.replica_device_setter( 186 | worker_device='/gpu:%d' % gpu, ps_device='/cpu:0', ps_tasks=1)): 187 | outputs.append(fn(*x)) 188 | return zip(*outputs) 189 | 190 | 191 | def para_mean(fn, *args): 192 | """Run on multiple GPUs in parallel and return means.""" 193 | gpus = len(get_available_gpus()) 194 | if gpus <= 1: 195 | return fn(*args) 196 | splitted = [tf.split(x, gpus) for x in args] 197 | outputs = [] 198 | for gpu, x in enumerate(zip(*splitted)): 199 | with tf.name_scope('tower%d' % gpu): 200 | with tf.device(tf.train.replica_device_setter( 201 | worker_device='/gpu:%d' % gpu, ps_device='/cpu:0', ps_tasks=1)): 202 | outputs.append(fn(*x)) 203 | if isinstance(outputs[0], (tuple, list)): 204 | return [tf.reduce_mean(x, 0) for x in zip(*outputs)] 205 | return tf.reduce_mean(outputs, 0) 206 | 207 | 208 | def para_cat(fn, *args): 209 | """Run on multiple GPUs in parallel and return concatenated outputs.""" 210 | gpus = len(get_available_gpus()) 211 | if gpus <= 1: 212 | return fn(*args) 213 | splitted = [tf.split(x, gpus) for x in args] 214 | outputs = [] 215 | for gpu, x in enumerate(zip(*splitted)): 216 | with tf.name_scope('tower%d' % gpu): 217 | with tf.device(tf.train.replica_device_setter( 218 | worker_device='/gpu:%d' % gpu, ps_device='/cpu:0', ps_tasks=1)): 219 | outputs.append(fn(*x)) 220 | if isinstance(outputs[0], (tuple, list)): 221 | return [tf.concat(x, axis=0) for x in zip(*outputs)] 222 | return tf.concat(outputs, axis=0) 223 | 224 | 225 | def combine_dicts(*args): 226 | # Python 2 compatible way to combine several dictionaries 227 | # We need it because currently TPU code does not work with python 3 228 | result = {} 229 | for d in args: 230 | result.update(d) 231 | return result 232 | -------------------------------------------------------------------------------- /mean_teacher.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Mean teachers are better role models: 16 | Weight-averaged consistency targets improve semi-supervised deep learning results. 17 | 18 | Reimplementation of https://arxiv.org/abs/1703.01780 19 | """ 20 | import functools 21 | import os 22 | 23 | import tensorflow as tf 24 | from absl import app 25 | from absl import flags 26 | 27 | from libml import models, utils 28 | from libml.data import PAIR_DATASETS 29 | from libml.utils import EasyDict 30 | 31 | FLAGS = flags.FLAGS 32 | 33 | 34 | class MeanTeacher(models.MultiModel): 35 | 36 | def model(self, batch, lr, wd, ema, warmup_pos, consistency_weight, **kwargs): 37 | hwc = [self.dataset.height, self.dataset.width, self.dataset.colors] 38 | xt_in = tf.placeholder(tf.float32, [batch] + hwc, 'xt') # For training 39 | x_in = tf.placeholder(tf.float32, [None] + hwc, 'x') 40 | y_in = tf.placeholder(tf.float32, [batch, 2] + hwc, 'y') 41 | l_in = tf.placeholder(tf.int32, [batch], 'labels') 42 | l = tf.one_hot(l_in, self.nclass) 43 | wd *= lr 44 | warmup = tf.clip_by_value(tf.to_float(self.step) / (warmup_pos * (FLAGS.train_kimg << 10)), 0, 1) 45 | 46 | classifier = lambda x, **kw: self.classifier(x, **kw, **kwargs).logits 47 | logits_x = classifier(xt_in, training=True) 48 | post_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) # Take only first call to update batch norm. 49 | y = tf.reshape(tf.transpose(y_in, [1, 0, 2, 3, 4]), [-1] + hwc) 50 | y_1, y_2 = tf.split(y, 2) 51 | ema = tf.train.ExponentialMovingAverage(decay=ema) 52 | ema_op = ema.apply(utils.model_vars()) 53 | ema_getter = functools.partial(utils.getter_ema, ema) 54 | logits_y = classifier(y_1, training=True, getter=ema_getter) 55 | logits_teacher = tf.stop_gradient(logits_y) 56 | logits_student = classifier(y_2, training=True) 57 | loss_mt = tf.reduce_mean((tf.nn.softmax(logits_teacher) - tf.nn.softmax(logits_student)) ** 2, -1) 58 | loss_mt = tf.reduce_mean(loss_mt) 59 | 60 | loss = tf.nn.softmax_cross_entropy_with_logits_v2(labels=l, logits=logits_x) 61 | loss = tf.reduce_mean(loss) 62 | tf.summary.scalar('losses/xe', loss) 63 | tf.summary.scalar('losses/mt', loss_mt) 64 | 65 | post_ops.append(ema_op) 66 | post_ops.extend([tf.assign(v, v * (1 - wd)) for v in utils.model_vars('classify') if 'kernel' in v.name]) 67 | 68 | train_op = tf.train.AdamOptimizer(lr).minimize(loss + loss_mt * warmup * consistency_weight, 69 | colocate_gradients_with_ops=True) 70 | with tf.control_dependencies([train_op]): 71 | train_op = tf.group(*post_ops) 72 | 73 | return EasyDict( 74 | xt=xt_in, x=x_in, y=y_in, label=l_in, train_op=train_op, 75 | classify_raw=tf.nn.softmax(classifier(x_in, training=False)), # No EMA, for debugging. 76 | classify_op=tf.nn.softmax(classifier(x_in, getter=ema_getter, training=False))) 77 | 78 | 79 | def main(argv): 80 | utils.setup_main() 81 | del argv # Unused. 82 | dataset = PAIR_DATASETS()[FLAGS.dataset]() 83 | log_width = utils.ilog2(dataset.width) 84 | model = MeanTeacher( 85 | os.path.join(FLAGS.train_dir, dataset.name), 86 | dataset, 87 | lr=FLAGS.lr, 88 | wd=FLAGS.wd, 89 | arch=FLAGS.arch, 90 | warmup_pos=FLAGS.warmup_pos, 91 | batch=FLAGS.batch, 92 | nclass=dataset.nclass, 93 | ema=FLAGS.ema, 94 | smoothing=FLAGS.smoothing, 95 | consistency_weight=FLAGS.consistency_weight, 96 | 97 | scales=FLAGS.scales or (log_width - 2), 98 | filters=FLAGS.filters, 99 | repeat=FLAGS.repeat) 100 | model.train(FLAGS.train_kimg << 10, FLAGS.report_kimg << 10) 101 | 102 | 103 | if __name__ == '__main__': 104 | utils.setup_tf() 105 | flags.DEFINE_float('consistency_weight', 50., 'Consistency weight.') 106 | flags.DEFINE_float('warmup_pos', 0.4, 'Relative position at which constraint loss warmup ends.') 107 | flags.DEFINE_float('wd', 0.02, 'Weight decay.') 108 | flags.DEFINE_float('ema', 0.999, 'Exponential moving average of params.') 109 | flags.DEFINE_float('smoothing', 0.001, 'Label smoothing.') 110 | flags.DEFINE_integer('scales', 0, 'Number of 2x2 downscalings in the classifier.') 111 | flags.DEFINE_integer('filters', 32, 'Filter size of convolutions.') 112 | flags.DEFINE_integer('repeat', 4, 'Number of residual layers per stage.') 113 | FLAGS.set_default('augment', 'd.d.d') 114 | FLAGS.set_default('dataset', 'cifar10.3@250-5000') 115 | FLAGS.set_default('batch', 64) 116 | FLAGS.set_default('lr', 0.002) 117 | FLAGS.set_default('train_kimg', 1 << 16) 118 | app.run(main) 119 | -------------------------------------------------------------------------------- /mixmatch.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """MixMatch training. 15 | - Ensure class consistency by producing a group of `nu` augmentations of the same image and guessing the label for the 16 | group. 17 | - Sharpen the target distribution. 18 | - Use the sharpened distribution directly as a smooth label in MixUp. 19 | """ 20 | 21 | import functools 22 | import os 23 | 24 | import tensorflow as tf 25 | from absl import app 26 | from absl import flags 27 | 28 | from libml import layers, utils, models 29 | from libml.data import PAIR_DATASETS 30 | from libml.layers import MixMode 31 | from libml.utils import EasyDict 32 | 33 | FLAGS = flags.FLAGS 34 | 35 | 36 | class MixMatch(models.MultiModel): 37 | 38 | def distribution_summary(self, p_data, p_model, p_target=None): 39 | def kl(p, q): 40 | p /= tf.reduce_sum(p) 41 | q /= tf.reduce_sum(q) 42 | return -tf.reduce_sum(p * tf.log(q / p)) 43 | 44 | tf.summary.scalar('metrics/kld', kl(p_data, p_model)) 45 | if p_target is not None: 46 | tf.summary.scalar('metrics/kld_target', kl(p_data, p_target)) 47 | 48 | for i in range(self.nclass): 49 | tf.summary.scalar('matching/class%d_ratio' % i, p_model[i] / p_data[i]) 50 | for i in range(self.nclass): 51 | tf.summary.scalar('matching/val%d' % i, p_model[i]) 52 | 53 | def augment(self, x, l, beta, **kwargs): 54 | assert 0, 'Do not call.' 55 | 56 | def guess_label(self, y, classifier, T, **kwargs): 57 | del kwargs 58 | logits_y = [classifier(yi, training=True) for yi in y] 59 | logits_y = tf.concat(logits_y, 0) 60 | # Compute predicted probability distribution py. 61 | p_model_y = tf.reshape(tf.nn.softmax(logits_y), [len(y), -1, self.nclass]) 62 | p_model_y = tf.reduce_mean(p_model_y, axis=0) 63 | # Compute the target distribution. 64 | p_target = tf.pow(p_model_y, 1. / T) 65 | p_target /= tf.reduce_sum(p_target, axis=1, keep_dims=True) 66 | return EasyDict(p_target=p_target, p_model=p_model_y) 67 | 68 | def model(self, batch, lr, wd, ema, beta, w_match, warmup_kimg=1024, nu=2, mixmode='xxy.yxy', dbuf=128, **kwargs): 69 | hwc = [self.dataset.height, self.dataset.width, self.dataset.colors] 70 | xt_in = tf.placeholder(tf.float32, [batch] + hwc, 'xt') # For training 71 | x_in = tf.placeholder(tf.float32, [None] + hwc, 'x') 72 | y_in = tf.placeholder(tf.float32, [batch, nu] + hwc, 'y') 73 | l_in = tf.placeholder(tf.int32, [batch], 'labels') 74 | wd *= lr 75 | w_match *= tf.clip_by_value(tf.cast(self.step, tf.float32) / (warmup_kimg << 10), 0, 1) 76 | augment = MixMode(mixmode) 77 | classifier = lambda x, **kw: self.classifier(x, **kw, **kwargs).logits 78 | 79 | # Moving average of the current estimated label distribution 80 | p_model = layers.PMovingAverage('p_model', self.nclass, dbuf) 81 | p_target = layers.PMovingAverage('p_target', self.nclass, dbuf) # Rectified distribution (only for plotting) 82 | 83 | # Known (or inferred) true unlabeled distribution 84 | p_data = layers.PData(self.dataset) 85 | 86 | y = tf.reshape(tf.transpose(y_in, [1, 0, 2, 3, 4]), [-1] + hwc) 87 | guess = self.guess_label(tf.split(y, nu), classifier, T=0.5, **kwargs) 88 | ly = tf.stop_gradient(guess.p_target) 89 | lx = tf.one_hot(l_in, self.nclass) 90 | xy, labels_xy = augment([xt_in] + tf.split(y, nu), [lx] + [ly] * nu, [beta, beta]) 91 | x, y = xy[0], xy[1:] 92 | labels_x, labels_y = labels_xy[0], tf.concat(labels_xy[1:], 0) 93 | del xy, labels_xy 94 | 95 | batches = layers.interleave([x] + y, batch) 96 | skip_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) 97 | logits = [classifier(batches[0], training=True)] 98 | post_ops = [v for v in tf.get_collection(tf.GraphKeys.UPDATE_OPS) if v not in skip_ops] 99 | for batchi in batches[1:]: 100 | logits.append(classifier(batchi, training=True)) 101 | logits = layers.interleave(logits, batch) 102 | logits_x = logits[0] 103 | logits_y = tf.concat(logits[1:], 0) 104 | 105 | loss_xe = tf.nn.softmax_cross_entropy_with_logits_v2(labels=labels_x, logits=logits_x) 106 | loss_xe = tf.reduce_mean(loss_xe) 107 | loss_l2u = tf.square(labels_y - tf.nn.softmax(logits_y)) 108 | loss_l2u = tf.reduce_mean(loss_l2u) 109 | tf.summary.scalar('losses/xe', loss_xe) 110 | tf.summary.scalar('losses/l2u', loss_l2u) 111 | self.distribution_summary(p_data(), p_model(), p_target()) 112 | 113 | ema = tf.train.ExponentialMovingAverage(decay=ema) 114 | ema_op = ema.apply(utils.model_vars()) 115 | ema_getter = functools.partial(utils.getter_ema, ema) 116 | post_ops.extend([ema_op, 117 | p_model.update(guess.p_model), 118 | p_target.update(guess.p_target)]) 119 | if p_data.has_update: 120 | post_ops.append(p_data.update(lx)) 121 | post_ops.extend([tf.assign(v, v * (1 - wd)) for v in utils.model_vars('classify') if 'kernel' in v.name]) 122 | 123 | train_op = tf.train.AdamOptimizer(lr).minimize(loss_xe + w_match * loss_l2u, colocate_gradients_with_ops=True) 124 | with tf.control_dependencies([train_op]): 125 | train_op = tf.group(*post_ops) 126 | 127 | return EasyDict( 128 | xt=xt_in, x=x_in, y=y_in, label=l_in, train_op=train_op, 129 | classify_raw=tf.nn.softmax(classifier(x_in, training=False)), # No EMA, for debugging. 130 | classify_op=tf.nn.softmax(classifier(x_in, getter=ema_getter, training=False))) 131 | 132 | 133 | def main(argv): 134 | utils.setup_main() 135 | del argv # Unused. 136 | dataset = PAIR_DATASETS()[FLAGS.dataset]() 137 | log_width = utils.ilog2(dataset.width) 138 | model = MixMatch( 139 | os.path.join(FLAGS.train_dir, dataset.name), 140 | dataset, 141 | lr=FLAGS.lr, 142 | wd=FLAGS.wd, 143 | arch=FLAGS.arch, 144 | batch=FLAGS.batch, 145 | nclass=dataset.nclass, 146 | ema=FLAGS.ema, 147 | 148 | beta=FLAGS.beta, 149 | w_match=FLAGS.w_match, 150 | 151 | scales=FLAGS.scales or (log_width - 2), 152 | filters=FLAGS.filters, 153 | repeat=FLAGS.repeat) 154 | model.train(FLAGS.train_kimg << 10, FLAGS.report_kimg << 10) 155 | 156 | 157 | if __name__ == '__main__': 158 | utils.setup_tf() 159 | flags.DEFINE_float('wd', 0.02, 'Weight decay.') 160 | flags.DEFINE_float('ema', 0.999, 'Exponential moving average of params.') 161 | flags.DEFINE_float('beta', 0.5, 'Mixup beta distribution.') 162 | flags.DEFINE_float('w_match', 100, 'Weight for distribution matching loss.') 163 | flags.DEFINE_integer('scales', 0, 'Number of 2x2 downscalings in the classifier.') 164 | flags.DEFINE_integer('filters', 32, 'Filter size of convolutions.') 165 | flags.DEFINE_integer('repeat', 4, 'Number of residual layers per stage.') 166 | FLAGS.set_default('augment', 'd.d.d') 167 | FLAGS.set_default('dataset', 'cifar10.3@250-5000') 168 | FLAGS.set_default('batch', 64) 169 | FLAGS.set_default('lr', 0.002) 170 | FLAGS.set_default('train_kimg', 1 << 16) 171 | app.run(main) 172 | -------------------------------------------------------------------------------- /mixup.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """mixup: Beyond Empirical Risk Minimization. 15 | 16 | Adaption to SSL of MixUp: https://arxiv.org/abs/1710.09412 17 | """ 18 | import functools 19 | import os 20 | 21 | import tensorflow as tf 22 | from absl import app 23 | from absl import flags 24 | 25 | from libml import data, utils, models 26 | from libml.utils import EasyDict 27 | 28 | FLAGS = flags.FLAGS 29 | 30 | 31 | class Mixup(models.MultiModel): 32 | 33 | def augment(self, x, l, beta, **kwargs): 34 | del kwargs 35 | mix = tf.distributions.Beta(beta, beta).sample([tf.shape(x)[0], 1, 1, 1]) 36 | mix = tf.maximum(mix, 1 - mix) 37 | xmix = x * mix + x[::-1] * (1 - mix) 38 | lmix = l * mix[:, :, 0, 0] + l[::-1] * (1 - mix[:, :, 0, 0]) 39 | return xmix, lmix 40 | 41 | def model(self, batch, lr, wd, ema, **kwargs): 42 | hwc = [self.dataset.height, self.dataset.width, self.dataset.colors] 43 | xt_in = tf.placeholder(tf.float32, [batch] + hwc, 'xt') # For training 44 | x_in = tf.placeholder(tf.float32, [None] + hwc, 'x') 45 | y_in = tf.placeholder(tf.float32, [batch] + hwc, 'y') 46 | l_in = tf.placeholder(tf.int32, [batch], 'labels') 47 | wd *= lr 48 | classifier = lambda x, **kw: self.classifier(x, **kw, **kwargs).logits 49 | 50 | def get_logits(x): 51 | logits = classifier(x, training=True) 52 | return logits 53 | 54 | x, labels_x = self.augment(xt_in, tf.one_hot(l_in, self.nclass), **kwargs) 55 | logits_x = get_logits(x) 56 | post_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) 57 | y, labels_y = self.augment(y_in, tf.nn.softmax(get_logits(y_in)), **kwargs) 58 | labels_y = tf.stop_gradient(labels_y) 59 | logits_y = get_logits(y) 60 | 61 | loss_xe = tf.nn.softmax_cross_entropy_with_logits_v2(labels=labels_x, logits=logits_x) 62 | loss_xe = tf.reduce_mean(loss_xe) 63 | loss_xeu = tf.nn.softmax_cross_entropy_with_logits_v2(labels=labels_y, logits=logits_y) 64 | loss_xeu = tf.reduce_mean(loss_xeu) 65 | tf.summary.scalar('losses/xe', loss_xe) 66 | tf.summary.scalar('losses/xeu', loss_xeu) 67 | 68 | ema = tf.train.ExponentialMovingAverage(decay=ema) 69 | ema_op = ema.apply(utils.model_vars()) 70 | ema_getter = functools.partial(utils.getter_ema, ema) 71 | post_ops.append(ema_op) 72 | post_ops.extend([tf.assign(v, v * (1 - wd)) for v in utils.model_vars('classify') if 'kernel' in v.name]) 73 | 74 | train_op = tf.train.AdamOptimizer(lr).minimize(loss_xe + loss_xeu, colocate_gradients_with_ops=True) 75 | with tf.control_dependencies([train_op]): 76 | train_op = tf.group(*post_ops) 77 | 78 | return EasyDict( 79 | xt=xt_in, x=x_in, y=y_in, label=l_in, train_op=train_op, 80 | classify_raw=tf.nn.softmax(classifier(x_in, training=False)), # No EMA, for debugging. 81 | classify_op=tf.nn.softmax(classifier(x_in, getter=ema_getter, training=False))) 82 | 83 | 84 | def main(argv): 85 | utils.setup_main() 86 | del argv # Unused. 87 | dataset = data.DATASETS()[FLAGS.dataset]() 88 | log_width = utils.ilog2(dataset.width) 89 | model = Mixup( 90 | os.path.join(FLAGS.train_dir, dataset.name), 91 | dataset, 92 | lr=FLAGS.lr, 93 | wd=FLAGS.wd, 94 | arch=FLAGS.arch, 95 | batch=FLAGS.batch, 96 | nclass=dataset.nclass, 97 | ema=FLAGS.ema, 98 | beta=FLAGS.beta, 99 | 100 | scales=FLAGS.scales or (log_width - 2), 101 | filters=FLAGS.filters, 102 | repeat=FLAGS.repeat) 103 | model.train(FLAGS.train_kimg << 10, FLAGS.report_kimg << 10) 104 | 105 | 106 | if __name__ == '__main__': 107 | utils.setup_tf() 108 | flags.DEFINE_float('wd', 0.02, 'Weight decay.') 109 | flags.DEFINE_float('ema', 0.999, 'Exponential moving average of params.') 110 | flags.DEFINE_float('beta', 0.5, 'Mixup beta distribution.') 111 | flags.DEFINE_integer('scales', 0, 'Number of 2x2 downscalings in the classifier.') 112 | flags.DEFINE_integer('filters', 32, 'Filter size of convolutions.') 113 | flags.DEFINE_integer('repeat', 4, 'Number of residual layers per stage.') 114 | FLAGS.set_default('dataset', 'cifar10.3@250-5000') 115 | FLAGS.set_default('batch', 64) 116 | FLAGS.set_default('lr', 0.002) 117 | FLAGS.set_default('train_kimg', 1 << 16) 118 | app.run(main) 119 | -------------------------------------------------------------------------------- /pi_model.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Temporal Ensembling for Semi-Supervised Learning. 15 | 16 | Pi-model reimplementation of https://arxiv.org/abs/1610.02242 17 | """ 18 | 19 | import functools 20 | import os 21 | 22 | import tensorflow as tf 23 | from absl import app 24 | from absl import flags 25 | 26 | from libml import models, utils 27 | from libml.data import PAIR_DATASETS 28 | from libml.utils import EasyDict 29 | 30 | FLAGS = flags.FLAGS 31 | 32 | 33 | class PiModel(models.MultiModel): 34 | 35 | def model(self, batch, lr, wd, ema, warmup_pos, consistency_weight, **kwargs): 36 | hwc = [self.dataset.height, self.dataset.width, self.dataset.colors] 37 | xt_in = tf.placeholder(tf.float32, [batch] + hwc, 'xt') # For training 38 | x_in = tf.placeholder(tf.float32, [None] + hwc, 'x') 39 | y_in = tf.placeholder(tf.float32, [batch, 2] + hwc, 'y') 40 | l_in = tf.placeholder(tf.int32, [batch], 'labels') 41 | l = tf.one_hot(l_in, self.nclass) 42 | wd *= lr 43 | warmup = tf.clip_by_value(tf.to_float(self.step) / (warmup_pos * (FLAGS.train_kimg << 10)), 0, 1) 44 | 45 | classifier = lambda x, **kw: self.classifier(x, **kw, **kwargs).logits 46 | logits_x = classifier(xt_in, training=True) 47 | post_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) # Take only first call to update batch norm. 48 | y = tf.reshape(tf.transpose(y_in, [1, 0, 2, 3, 4]), [-1] + hwc) 49 | y_1, y_2 = tf.split(y, 2) 50 | logits_y = classifier(y_1, training=True) 51 | logits_teacher = tf.stop_gradient(logits_y) 52 | logits_student = classifier(y_2, training=True) 53 | loss_pm = tf.reduce_mean((tf.nn.softmax(logits_teacher) - tf.nn.softmax(logits_student)) ** 2, -1) 54 | loss_pm = tf.reduce_mean(loss_pm) 55 | 56 | loss = tf.nn.softmax_cross_entropy_with_logits_v2(labels=l, logits=logits_x) 57 | loss = tf.reduce_mean(loss) 58 | tf.summary.scalar('losses/xe', loss) 59 | tf.summary.scalar('losses/pm', loss_pm) 60 | 61 | ema = tf.train.ExponentialMovingAverage(decay=ema) 62 | ema_op = ema.apply(utils.model_vars()) 63 | ema_getter = functools.partial(utils.getter_ema, ema) 64 | post_ops.append(ema_op) 65 | post_ops.extend([tf.assign(v, v * (1 - wd)) for v in utils.model_vars('classify') if 'kernel' in v.name]) 66 | 67 | train_op = tf.train.AdamOptimizer(lr).minimize(loss + loss_pm * warmup * consistency_weight, 68 | colocate_gradients_with_ops=True) 69 | with tf.control_dependencies([train_op]): 70 | train_op = tf.group(*post_ops) 71 | 72 | return EasyDict( 73 | xt=xt_in, x=x_in, y=y_in, label=l_in, train_op=train_op, 74 | classify_raw=tf.nn.softmax(classifier(x_in, training=False)), # No EMA, for debugging. 75 | classify_op=tf.nn.softmax(classifier(x_in, getter=ema_getter, training=False))) 76 | 77 | 78 | def main(argv): 79 | utils.setup_main() 80 | del argv # Unused. 81 | dataset = PAIR_DATASETS()[FLAGS.dataset]() 82 | log_width = utils.ilog2(dataset.width) 83 | model = PiModel( 84 | os.path.join(FLAGS.train_dir, dataset.name), 85 | dataset, 86 | lr=FLAGS.lr, 87 | wd=FLAGS.wd, 88 | arch=FLAGS.arch, 89 | warmup_pos=FLAGS.warmup_pos, 90 | batch=FLAGS.batch, 91 | nclass=dataset.nclass, 92 | ema=FLAGS.ema, 93 | smoothing=FLAGS.smoothing, 94 | consistency_weight=FLAGS.consistency_weight, 95 | 96 | scales=FLAGS.scales or (log_width - 2), 97 | filters=FLAGS.filters, 98 | repeat=FLAGS.repeat) 99 | model.train(FLAGS.train_kimg << 10, FLAGS.report_kimg << 10) 100 | 101 | 102 | if __name__ == '__main__': 103 | utils.setup_tf() 104 | flags.DEFINE_float('consistency_weight', 10., 'Consistency weight.') 105 | flags.DEFINE_float('warmup_pos', 0.4, 'Relative position at which constraint loss warmup ends.') 106 | flags.DEFINE_float('wd', 0.02, 'Weight decay.') 107 | flags.DEFINE_float('ema', 0.999, 'Exponential moving average of params.') 108 | flags.DEFINE_float('smoothing', 0.1, 'Label smoothing.') 109 | flags.DEFINE_integer('scales', 0, 'Number of 2x2 downscalings in the classifier.') 110 | flags.DEFINE_integer('filters', 32, 'Filter size of convolutions.') 111 | flags.DEFINE_integer('repeat', 4, 'Number of residual layers per stage.') 112 | FLAGS.set_default('augment', 'd.d.d') 113 | FLAGS.set_default('dataset', 'cifar10.3@250-5000') 114 | FLAGS.set_default('batch', 64) 115 | FLAGS.set_default('lr', 0.002) 116 | FLAGS.set_default('train_kimg', 1 << 16) 117 | app.run(main) 118 | -------------------------------------------------------------------------------- /pseudo_label.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Pseudo-label: The simple and efficient semi-supervised learning method fordeep neural networks. 15 | 16 | Reimplementation of http://deeplearning.net/wp-content/uploads/2013/03/pseudo_label_final.pdf 17 | """ 18 | 19 | import functools 20 | import os 21 | 22 | import tensorflow as tf 23 | from absl import app 24 | from absl import flags 25 | 26 | from libml import utils, data, models 27 | from libml.utils import EasyDict 28 | 29 | FLAGS = flags.FLAGS 30 | 31 | 32 | class PseudoLabel(models.MultiModel): 33 | 34 | def model(self, batch, lr, wd, ema, warmup_pos, consistency_weight, threshold, **kwargs): 35 | hwc = [self.dataset.height, self.dataset.width, self.dataset.colors] 36 | xt_in = tf.placeholder(tf.float32, [batch] + hwc, 'xt') # For training 37 | x_in = tf.placeholder(tf.float32, [None] + hwc, 'x') 38 | y_in = tf.placeholder(tf.float32, [batch] + hwc, 'y') 39 | l_in = tf.placeholder(tf.int32, [batch], 'labels') 40 | l = tf.one_hot(l_in, self.nclass) 41 | wd *= lr 42 | warmup = tf.clip_by_value(tf.to_float(self.step) / (warmup_pos * (FLAGS.train_kimg << 10)), 0, 1) 43 | 44 | classifier = lambda x, **kw: self.classifier(x, **kw, **kwargs).logits 45 | logits_x = classifier(xt_in, training=True) 46 | post_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) # Take only first call to update batch norm. 47 | logits_y = classifier(y_in, training=True) 48 | # Get the pseudo-label loss 49 | loss_pl = tf.nn.sparse_softmax_cross_entropy_with_logits( 50 | labels=tf.argmax(logits_y, axis=-1), logits=logits_y 51 | ) 52 | # Masks denoting which data points have high-confidence predictions 53 | greater_than_thresh = tf.reduce_any( 54 | tf.greater(tf.nn.softmax(logits_y), threshold), 55 | axis=-1, 56 | keepdims=True, 57 | ) 58 | greater_than_thresh = tf.cast(greater_than_thresh, loss_pl.dtype) 59 | # Only enforce the loss when the model is confident 60 | loss_pl *= greater_than_thresh 61 | # Note that we also average over examples without confident outputs; 62 | # this is consistent with the realistic evaluation codebase 63 | loss_pl = tf.reduce_mean(loss_pl) 64 | 65 | loss = tf.nn.softmax_cross_entropy_with_logits_v2(labels=l, logits=logits_x) 66 | loss = tf.reduce_mean(loss) 67 | tf.summary.scalar('losses/xe', loss) 68 | tf.summary.scalar('losses/pl', loss_pl) 69 | 70 | ema = tf.train.ExponentialMovingAverage(decay=ema) 71 | ema_op = ema.apply(utils.model_vars()) 72 | ema_getter = functools.partial(utils.getter_ema, ema) 73 | post_ops.append(ema_op) 74 | post_ops.extend([tf.assign(v, v * (1 - wd)) for v in utils.model_vars('classify') if 'kernel' in v.name]) 75 | 76 | train_op = tf.train.AdamOptimizer(lr).minimize(loss + loss_pl * warmup * consistency_weight, 77 | colocate_gradients_with_ops=True) 78 | with tf.control_dependencies([train_op]): 79 | train_op = tf.group(*post_ops) 80 | 81 | return EasyDict( 82 | xt=xt_in, x=x_in, y=y_in, label=l_in, train_op=train_op, 83 | classify_raw=tf.nn.softmax(classifier(x_in, training=False)), # No EMA, for debugging. 84 | classify_op=tf.nn.softmax(classifier(x_in, getter=ema_getter, training=False))) 85 | 86 | 87 | def main(argv): 88 | utils.setup_main() 89 | del argv # Unused. 90 | dataset = data.DATASETS()[FLAGS.dataset]() 91 | log_width = utils.ilog2(dataset.width) 92 | model = PseudoLabel( 93 | os.path.join(FLAGS.train_dir, dataset.name), 94 | dataset, 95 | lr=FLAGS.lr, 96 | wd=FLAGS.wd, 97 | arch=FLAGS.arch, 98 | warmup_pos=FLAGS.warmup_pos, 99 | batch=FLAGS.batch, 100 | nclass=dataset.nclass, 101 | ema=FLAGS.ema, 102 | smoothing=FLAGS.smoothing, 103 | consistency_weight=FLAGS.consistency_weight, 104 | threshold=FLAGS.threshold, 105 | 106 | scales=FLAGS.scales or (log_width - 2), 107 | filters=FLAGS.filters, 108 | repeat=FLAGS.repeat) 109 | model.train(FLAGS.train_kimg << 10, FLAGS.report_kimg << 10) 110 | 111 | 112 | if __name__ == '__main__': 113 | utils.setup_tf() 114 | flags.DEFINE_float('wd', 0.02, 'Weight decay.') 115 | flags.DEFINE_float('consistency_weight', 1., 'Consistency weight.') 116 | flags.DEFINE_float('threshold', 0.95, 'Pseudo-label threshold.') 117 | flags.DEFINE_float('warmup_pos', 0.4, 'Relative position at which constraint loss warmup ends.') 118 | flags.DEFINE_float('ema', 0.999, 'Exponential moving average of params.') 119 | flags.DEFINE_float('smoothing', 0.1, 'Label smoothing.') 120 | flags.DEFINE_integer('scales', 0, 'Number of 2x2 downscalings in the classifier.') 121 | flags.DEFINE_integer('filters', 32, 'Filter size of convolutions.') 122 | flags.DEFINE_integer('repeat', 4, 'Number of residual layers per stage.') 123 | FLAGS.set_default('dataset', 'cifar10.3@250-5000') 124 | FLAGS.set_default('batch', 64) 125 | FLAGS.set_default('lr', 0.002) 126 | FLAGS.set_default('train_kimg', 1 << 16) 127 | app.run(main) 128 | -------------------------------------------------------------------------------- /remixmatch_no_cta.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import functools 16 | import os 17 | 18 | import numpy as np 19 | import tensorflow as tf 20 | from absl import app 21 | from absl import flags 22 | 23 | from libml import data, layers, utils 24 | from libml.utils import EasyDict 25 | from mixmatch import MixMatch 26 | 27 | FLAGS = flags.FLAGS 28 | 29 | 30 | class ReMixMatch(MixMatch): 31 | 32 | def classifier_rot(self, x): 33 | with tf.variable_scope('classify_rot', reuse=tf.AUTO_REUSE): 34 | return tf.layers.dense(x, 4, kernel_initializer=tf.glorot_normal_initializer()) 35 | 36 | def guess_label(self, logits_y, p_data, p_model, T, use_dm, redux, **kwargs): 37 | del kwargs 38 | if redux == 'swap': 39 | p_model_y = tf.concat([tf.nn.softmax(x) for x in logits_y[1:] + logits_y[:1]], axis=0) 40 | elif redux == 'mean': 41 | p_model_y = sum(tf.nn.softmax(x) for x in logits_y) / len(logits_y) 42 | p_model_y = tf.tile(p_model_y, [len(logits_y), 1]) 43 | elif redux == '1st': 44 | p_model_y = tf.nn.softmax(logits_y[0]) 45 | p_model_y = tf.tile(p_model_y, [len(logits_y), 1]) 46 | else: 47 | raise NotImplementedError() 48 | 49 | # Compute the target distribution. 50 | # 1. Rectify the distribution or not. 51 | if use_dm: 52 | p_ratio = (1e-6 + p_data) / (1e-6 + p_model) 53 | p_weighted = p_model_y * p_ratio 54 | p_weighted /= tf.reduce_sum(p_weighted, axis=1, keep_dims=True) 55 | else: 56 | p_weighted = p_model_y 57 | # 2. Apply sharpening. 58 | p_target = tf.pow(p_weighted, 1. / T) 59 | p_target /= tf.reduce_sum(p_target, axis=1, keep_dims=True) 60 | return EasyDict(p_target=p_target, p_model=p_model_y) 61 | 62 | def model(self, batch, lr, wd, beta, w_kl, w_match, w_rot, K, use_xe, warmup_kimg=1024, T=0.5, 63 | mixmode='xxy.yxy', dbuf=128, ema=0.999, **kwargs): 64 | hwc = [self.dataset.height, self.dataset.width, self.dataset.colors] 65 | xt_in = tf.placeholder(tf.float32, [batch] + hwc, 'xt') # For training 66 | x_in = tf.placeholder(tf.float32, [None] + hwc, 'x') 67 | y_in = tf.placeholder(tf.float32, [batch, K + 1] + hwc, 'y') 68 | l_in = tf.placeholder(tf.int32, [batch], 'labels') 69 | wd *= lr 70 | w_match *= tf.clip_by_value(tf.cast(self.step, tf.float32) / (warmup_kimg << 10), 0, 1) 71 | augment = layers.MixMode(mixmode) 72 | 73 | gpu = utils.get_gpu() 74 | 75 | def classifier_to_gpu(x, **kw): 76 | with tf.device(next(gpu)): 77 | return self.classifier(x, **kw, **kwargs).logits 78 | 79 | def random_rotate(x): 80 | b4 = batch // 4 81 | x, xt = x[:2 * b4], tf.transpose(x[2 * b4:], [0, 2, 1, 3]) 82 | l = np.zeros(b4, np.int32) 83 | l = tf.constant(np.concatenate([l, l + 1, l + 2, l + 3], axis=0)) 84 | return tf.concat([x[:b4], x[b4:, ::-1, ::-1], xt[:b4, ::-1], xt[b4:, :, ::-1]], axis=0), l 85 | 86 | # Moving average of the current estimated label distribution 87 | p_model = layers.PMovingAverage('p_model', self.nclass, dbuf) 88 | p_target = layers.PMovingAverage('p_target', self.nclass, dbuf) # Rectified distribution (only for plotting) 89 | 90 | # Known (or inferred) true unlabeled distribution 91 | p_data = layers.PData(self.dataset) 92 | 93 | if w_rot > 0: 94 | rot_y, rot_l = random_rotate(y_in[:, 1]) 95 | with tf.device(next(gpu)): 96 | rot_logits = self.classifier_rot(self.classifier(rot_y, training=True, **kwargs).embeds) 97 | loss_rot = tf.nn.softmax_cross_entropy_with_logits_v2(labels=tf.one_hot(rot_l, 4), logits=rot_logits) 98 | loss_rot = tf.reduce_mean(loss_rot) 99 | tf.summary.scalar('losses/rot', loss_rot) 100 | else: 101 | loss_rot = 0 102 | 103 | if kwargs['redux'] == '1st' and w_kl <= 0: 104 | logits_y = [classifier_to_gpu(y_in[:, 0], training=True)] * (K + 1) 105 | elif kwargs['redux'] == '1st': 106 | logits_y = [classifier_to_gpu(y_in[:, i], training=True) for i in range(2)] 107 | logits_y += logits_y[:1] * (K - 1) 108 | else: 109 | logits_y = [classifier_to_gpu(y_in[:, i], training=True) for i in range(K + 1)] 110 | 111 | guess = self.guess_label(logits_y, p_data(), p_model(), T=T, **kwargs) 112 | ly = tf.stop_gradient(guess.p_target) 113 | if w_kl > 0: 114 | w_kl *= tf.clip_by_value(tf.cast(self.step, tf.float32) / (warmup_kimg << 10), 0, 1) 115 | loss_kl = tf.nn.softmax_cross_entropy_with_logits_v2(labels=ly[:batch], logits=logits_y[1]) 116 | loss_kl = tf.reduce_mean(loss_kl) 117 | tf.summary.scalar('losses/kl', loss_kl) 118 | else: 119 | loss_kl = 0 120 | del logits_y 121 | 122 | lx = tf.one_hot(l_in, self.nclass) 123 | xy, labels_xy = augment([xt_in] + [y_in[:, i] for i in range(K + 1)], [lx] + tf.split(ly, K + 1), 124 | [beta, beta]) 125 | x, y = xy[0], xy[1:] 126 | labels_x, labels_y = labels_xy[0], tf.concat(labels_xy[1:], 0) 127 | del xy, labels_xy 128 | 129 | batches = layers.interleave([x] + y, batch) 130 | logits = [classifier_to_gpu(yi, training=True) for yi in batches[:-1]] 131 | skip_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) 132 | logits.append(classifier_to_gpu(batches[-1], training=True)) 133 | post_ops = [v for v in tf.get_collection(tf.GraphKeys.UPDATE_OPS) if v not in skip_ops] 134 | logits = layers.interleave(logits, batch) 135 | logits_x = logits[0] 136 | logits_y = tf.concat(logits[1:], 0) 137 | del batches, logits 138 | 139 | loss_xe = tf.nn.softmax_cross_entropy_with_logits_v2(labels=labels_x, logits=logits_x) 140 | loss_xe = tf.reduce_mean(loss_xe) 141 | if use_xe: 142 | loss_xeu = tf.nn.softmax_cross_entropy_with_logits_v2(labels=labels_y, logits=logits_y) 143 | else: 144 | loss_xeu = tf.square(labels_y - tf.nn.softmax(logits_y)) 145 | loss_xeu = tf.reduce_mean(loss_xeu) 146 | tf.summary.scalar('losses/xe', loss_xe) 147 | tf.summary.scalar('losses/%s' % ('xeu' if use_xe else 'l2u'), loss_xeu) 148 | self.distribution_summary(p_data(), p_model(), p_target()) 149 | 150 | ema = tf.train.ExponentialMovingAverage(decay=ema) 151 | ema_op = ema.apply(utils.model_vars()) 152 | ema_getter = functools.partial(utils.getter_ema, ema) 153 | post_ops.extend([ema_op, 154 | p_model.update(guess.p_model), 155 | p_target.update(guess.p_target)]) 156 | if p_data.has_update: 157 | post_ops.append(p_data.update(lx)) 158 | post_ops.extend([tf.assign(v, v * (1 - wd)) for v in utils.model_vars('classify') if 'kernel' in v.name]) 159 | 160 | train_op = tf.train.AdamOptimizer(lr).minimize(loss_xe 161 | + w_kl * loss_kl 162 | + w_match * loss_xeu 163 | + w_rot * loss_rot, 164 | colocate_gradients_with_ops=True) 165 | with tf.control_dependencies([train_op]): 166 | train_op = tf.group(*post_ops) 167 | 168 | return EasyDict( 169 | xt=xt_in, x=x_in, y=y_in, label=l_in, train_op=train_op, 170 | classify_op=tf.nn.softmax(classifier_to_gpu(x_in, getter=ema_getter, training=False)), 171 | classify_raw=tf.nn.softmax(classifier_to_gpu(x_in, training=False))) # No EMA, for debugging. 172 | 173 | 174 | def main(argv): 175 | utils.setup_main() 176 | del argv # Unused. 177 | dataset = data.MANY_DATASETS()[FLAGS.dataset]() 178 | log_width = utils.ilog2(dataset.width) 179 | model = ReMixMatch( 180 | os.path.join(FLAGS.train_dir, dataset.name), 181 | dataset, 182 | lr=FLAGS.lr, 183 | wd=FLAGS.wd, 184 | arch=FLAGS.arch, 185 | batch=FLAGS.batch, 186 | nclass=dataset.nclass, 187 | 188 | K=FLAGS.K, 189 | beta=FLAGS.beta, 190 | w_kl=FLAGS.w_kl, 191 | w_match=FLAGS.w_match, 192 | w_rot=FLAGS.w_rot, 193 | redux=FLAGS.redux, 194 | use_dm=FLAGS.use_dm, 195 | use_xe=FLAGS.use_xe, 196 | warmup_kimg=FLAGS.warmup_kimg, 197 | 198 | scales=FLAGS.scales or (log_width - 2), 199 | filters=FLAGS.filters, 200 | repeat=FLAGS.repeat) 201 | model.train(FLAGS.train_kimg << 10, FLAGS.report_kimg << 10) 202 | 203 | 204 | if __name__ == '__main__': 205 | utils.setup_tf() 206 | flags.DEFINE_float('wd', 0.02, 'Weight decay.') 207 | flags.DEFINE_float('beta', 0.75, 'Mixup beta distribution.') 208 | flags.DEFINE_float('w_kl', 0.5, 'Weight for KL loss.') 209 | flags.DEFINE_float('w_match', 1.5, 'Weight for distribution matching loss.') 210 | flags.DEFINE_float('w_rot', 0.5, 'Weight for rotation loss.') 211 | flags.DEFINE_integer('scales', 0, 'Number of 2x2 downscalings in the classifier.') 212 | flags.DEFINE_integer('filters', 32, 'Filter size of convolutions.') 213 | flags.DEFINE_integer('repeat', 4, 'Number of residual layers per stage.') 214 | flags.DEFINE_integer('warmup_kimg', 1024, 'Unannealing duration for SSL loss.') 215 | flags.DEFINE_enum('redux', 'swap', 'swap mean 1st'.split(), 'Logit selection.') 216 | flags.DEFINE_bool('use_dm', True, 'Whether to use distribution matching.') 217 | flags.DEFINE_bool('use_xe', True, 'Whether to use cross-entropy or Brier.') 218 | FLAGS.set_default('augment', 'd.d.d') 219 | FLAGS.set_default('dataset', 'cifar10.3@250-5000') 220 | FLAGS.set_default('batch', 64) 221 | FLAGS.set_default('lr', 0.002) 222 | FLAGS.set_default('train_kimg', 1 << 16) 223 | app.run(main) 224 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py 2 | easydict 3 | cython 4 | numpy 5 | tensorflow-gpu==1.14.0 6 | tqdm 7 | -------------------------------------------------------------------------------- /runs/ssl/ablation.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright 2019 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | DEFAULT_ARGS="--dataset=cifar10.3@250-1 --train_dir experiments/Ablation --augment=d.d.d" 17 | 18 | echo "# Vary number of augmentations" 19 | echo "python ablation/ab_cta_remixmatch.py $DEFAULT_ARGS --K=1" 20 | echo "python ablation/ab_cta_remixmatch.py $DEFAULT_ARGS --K=2" 21 | echo "python ablation/ab_cta_remixmatch.py $DEFAULT_ARGS --K=4" 22 | echo "python ablation/ab_cta_remixmatch.py $DEFAULT_ARGS --K=8" 23 | echo "python ablation/ab_cta_remixmatch.py $DEFAULT_ARGS --K=16" 24 | 25 | echo "# No regularizer" 26 | echo "python ablation/ab_cta_remixmatch.py $DEFAULT_ARGS --K=8 --w_kl=0" 27 | echo "# No rotation" 28 | echo "python ablation/ab_cta_remixmatch.py $DEFAULT_ARGS --K=8 --w_rot=0" 29 | echo "# No Distribution matching" 30 | echo "python ablation/ab_cta_remixmatch.py $DEFAULT_ARGS --K=8 --nouse_dm" 31 | echo "# Use L2 loss / no cross-entropy" 32 | echo "python ablation/ab_cta_remixmatch.py $DEFAULT_ARGS --w_match=75 --K=8 --nouse_xe" 33 | 34 | echo "# Only strong augmentations" 35 | echo "python ablation/ab_cta_remixmatch_noweak.py $DEFAULT_ARGS --K=8" 36 | 37 | echo "# Only weak augmentations" 38 | echo "python ablation/ab_remixmatch.py $DEFAULT_ARGS --K=1" 39 | echo "python ablation/ab_remixmatch.py $DEFAULT_ARGS --K=8" 40 | echo "python ablation/ab_remixmatch.py $DEFAULT_ARGS --redux=mean --K=1" 41 | echo "python ablation/ab_remixmatch.py $DEFAULT_ARGS --redux=mean --K=8" 42 | -------------------------------------------------------------------------------- /runs/ssl/cifar10.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright 2019 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | echo 17 | echo "# Standard" 18 | for seed in 1 2 3 4 5; do for size in 250 1000 4000; do 19 | echo "python cta/cta_remixmatch.py --dataset=cifar10.${seed}@{$size}-1 --K=8" 20 | done; done 21 | 22 | for seed in 1 2 3 4 5; do 23 | echo "python cta/cta_remixmatch.py --dataset=cifar10.${seed}@40-1 --K=1 --w_rot=2 --warmup_kimg=8192" 24 | done 25 | -------------------------------------------------------------------------------- /runs/ssl/stl10.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright 2019 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | echo 17 | echo "# Standard" 18 | for seed in 1 2 3 4 5; do 19 | echo "python cta/cta_remixmatch.py --dataset=stl10.${seed}@1000-1 --K=4 --scales=4 --augment=d.x.d" 20 | done 21 | -------------------------------------------------------------------------------- /runs/ssl/svhn.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright 2019 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | echo 17 | echo "# Standard" 18 | for seed in 1 2 3 4 5; do for size in 250 1000 4000; do 19 | echo "python cta/cta_remixmatch.py --dataset=svhn_noextra.${seed}@{$size}-1 --K=8" 20 | done; done 21 | 22 | for seed in 1 2 3 4 5; do 23 | echo "python cta/cta_remixmatch.py --dataset=svhn_noextra.${seed}@40-1 --K=1 --w_rot=3" 24 | done 25 | -------------------------------------------------------------------------------- /scripts/aggregate_accuracy.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # Copyright 2019 Google LLC 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # https://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """Report all 'stats/accuracy.json' into a json file on stdout. 17 | 18 | All the accuracies are summarized. 19 | """ 20 | import json 21 | import sys 22 | import threading 23 | 24 | import tensorflow as tf 25 | import tqdm 26 | from absl import app 27 | from absl import flags 28 | 29 | FLAGS = flags.FLAGS 30 | N_THREADS = 100 31 | 32 | 33 | def add_contents_to_dict(filename: str, target): 34 | with tf.gfile.Open(filename, 'r') as f: 35 | target[filename] = json.load(f) 36 | 37 | 38 | def main(argv): 39 | files = [] 40 | for path in argv[1:]: 41 | files.extend(tf.io.gfile.glob(path)) 42 | assert files, 'No files found' 43 | print('Found %d files.' % len(files), file=sys.stderr) 44 | summary = {} 45 | threads = [] 46 | for x in tqdm.tqdm(files, leave=False, desc='Collating'): 47 | t = threading.Thread( 48 | target=add_contents_to_dict, kwargs=dict(filename=x, target=summary)) 49 | threads.append(t) 50 | t.start() 51 | while len(threads) >= N_THREADS: 52 | dead = [p for p, t in enumerate(threads) if not t.is_alive()] 53 | while dead: 54 | p = dead.pop() 55 | del threads[p] 56 | if x == files[-1]: 57 | for t in threads: 58 | t.join() 59 | 60 | assert len(summary) == len(files) 61 | print(json.dumps(summary, sort_keys=True, indent=4)) 62 | 63 | 64 | if __name__ == '__main__': 65 | app.run(main) 66 | -------------------------------------------------------------------------------- /scripts/check_split.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # Copyright 2019 Google LLC 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # https://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """Script to measure the overlap between data splits. 18 | 19 | There should not be any overlap unless the original dataset has duplicates. 20 | """ 21 | 22 | import hashlib 23 | import os 24 | 25 | import tensorflow as tf 26 | from absl import app 27 | from absl import flags 28 | from tqdm import trange 29 | 30 | from libml import data, utils 31 | 32 | flags.DEFINE_integer('batch', 1024, 'Batch size.') 33 | flags.DEFINE_integer('samples', 1 << 20, 'Number of samples to load.') 34 | 35 | FLAGS = flags.FLAGS 36 | 37 | 38 | def to_byte(d: dict): 39 | return tf.to_int32(tf.round(127.5 * (d['image'] + 1))) 40 | 41 | 42 | def collect_hashes(sess, group, data): 43 | data = data.parse().batch(FLAGS.batch).prefetch(1).make_one_shot_iterator().get_next() 44 | hashes = set() 45 | hasher = hashlib.sha512 46 | for _ in trange(0, FLAGS.samples, FLAGS.batch, desc='Building hashes for %s' % group, leave=False): 47 | try: 48 | batch = sess.run(data) 49 | except tf.errors.OutOfRangeError: 50 | break 51 | for img in batch: 52 | hashes.add(hasher(img).digest()) 53 | return hashes 54 | 55 | 56 | def main(argv): 57 | utils.setup_main() 58 | del argv 59 | utils.setup_tf() 60 | dataset = data.DATASETS()[FLAGS.dataset]() 61 | with tf.Session(config=utils.get_config()) as sess: 62 | hashes = (collect_hashes(sess, 'labeled', dataset.eval_labeled), 63 | collect_hashes(sess, 'unlabeled', dataset.eval_unlabeled), 64 | collect_hashes(sess, 'validation', dataset.valid), 65 | collect_hashes(sess, 'test', dataset.test)) 66 | print('Overlap matrix (should be an almost perfect diagonal matrix with counts).') 67 | groups = 'labeled unlabeled validation test'.split() 68 | fmt = '%-10s %10s %10s %10s %10s' 69 | print(fmt % tuple([''] + groups)) 70 | for p, x in enumerate(hashes): 71 | overlaps = [len(x & y) for y in hashes] 72 | print(fmt % tuple([groups[p]] + overlaps)) 73 | 74 | 75 | if __name__ == '__main__': 76 | os.environ['CUDA_VISIBLE_DEVICES'] = '' 77 | app.run(main) 78 | -------------------------------------------------------------------------------- /scripts/create_datasets.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # Copyright 2018 Google LLC 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # https://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """Script to download all datasets and create .tfrecord files. 18 | """ 19 | 20 | import collections 21 | import gzip 22 | import os 23 | import tarfile 24 | import tempfile 25 | from urllib import request 26 | 27 | import numpy as np 28 | import scipy.io 29 | import tensorflow as tf 30 | from absl import app 31 | from tqdm import trange 32 | 33 | from libml import data as libml_data 34 | from libml.utils import EasyDict 35 | 36 | URLS = { 37 | 'svhn': 'http://ufldl.stanford.edu/housenumbers/{}_32x32.mat', 38 | 'cifar10': 'https://www.cs.toronto.edu/~kriz/cifar-10-matlab.tar.gz', 39 | 'cifar100': 'https://www.cs.toronto.edu/~kriz/cifar-100-matlab.tar.gz', 40 | 'stl10': 'http://ai.stanford.edu/~acoates/stl10/stl10_binary.tar.gz', 41 | } 42 | 43 | 44 | def _encode_png(images): 45 | raw = [] 46 | with tf.Session() as sess, tf.device('cpu:0'): 47 | image_x = tf.placeholder(tf.uint8, [None, None, None], 'image_x') 48 | to_png = tf.image.encode_png(image_x) 49 | for x in trange(images.shape[0], desc='PNG Encoding', leave=False): 50 | raw.append(sess.run(to_png, feed_dict={image_x: images[x]})) 51 | return raw 52 | 53 | 54 | def _load_svhn(): 55 | splits = collections.OrderedDict() 56 | for split in ['train', 'test', 'extra']: 57 | with tempfile.NamedTemporaryFile() as f: 58 | request.urlretrieve(URLS['svhn'].format(split), f.name) 59 | data_dict = scipy.io.loadmat(f.name) 60 | dataset = {} 61 | dataset['images'] = np.transpose(data_dict['X'], [3, 0, 1, 2]) 62 | dataset['images'] = _encode_png(dataset['images']) 63 | dataset['labels'] = data_dict['y'].reshape((-1)) 64 | # SVHN raw data uses labels from 1 to 10; use 0 to 9 instead. 65 | dataset['labels'] -= 1 66 | splits[split] = dataset 67 | return splits 68 | 69 | 70 | def _load_stl10(): 71 | def unflatten(images): 72 | return np.transpose(images.reshape((-1, 3, 96, 96)), 73 | [0, 3, 2, 1]) 74 | 75 | with tempfile.NamedTemporaryFile() as f: 76 | if tf.gfile.Exists('stl10/stl10_binary.tar.gz'): 77 | f = tf.gfile.Open('stl10/stl10_binary.tar.gz', 'rb') 78 | else: 79 | request.urlretrieve(URLS['stl10'], f.name) 80 | tar = tarfile.open(fileobj=f) 81 | train_X = tar.extractfile('stl10_binary/train_X.bin') 82 | train_y = tar.extractfile('stl10_binary/train_y.bin') 83 | 84 | test_X = tar.extractfile('stl10_binary/test_X.bin') 85 | test_y = tar.extractfile('stl10_binary/test_y.bin') 86 | 87 | unlabeled_X = tar.extractfile('stl10_binary/unlabeled_X.bin') 88 | 89 | train_set = {'images': np.frombuffer(train_X.read(), dtype=np.uint8), 90 | 'labels': np.frombuffer(train_y.read(), dtype=np.uint8) - 1} 91 | 92 | test_set = {'images': np.frombuffer(test_X.read(), dtype=np.uint8), 93 | 'labels': np.frombuffer(test_y.read(), dtype=np.uint8) - 1} 94 | 95 | _imgs = np.frombuffer(unlabeled_X.read(), dtype=np.uint8) 96 | unlabeled_set = {'images': _imgs, 97 | 'labels': np.zeros(100000, dtype=np.uint8)} 98 | 99 | fold_indices = tar.extractfile('stl10_binary/fold_indices.txt').read() 100 | 101 | train_set['images'] = _encode_png(unflatten(train_set['images'])) 102 | test_set['images'] = _encode_png(unflatten(test_set['images'])) 103 | unlabeled_set['images'] = _encode_png(unflatten(unlabeled_set['images'])) 104 | return dict(train=train_set, test=test_set, unlabeled=unlabeled_set, 105 | files=[EasyDict(filename="stl10_fold_indices.txt", data=fold_indices)]) 106 | 107 | 108 | def _load_cifar10(): 109 | def unflatten(images): 110 | return np.transpose(images.reshape((images.shape[0], 3, 32, 32)), 111 | [0, 2, 3, 1]) 112 | 113 | with tempfile.NamedTemporaryFile() as f: 114 | request.urlretrieve(URLS['cifar10'], f.name) 115 | tar = tarfile.open(fileobj=f) 116 | train_data_batches, train_data_labels = [], [] 117 | for batch in range(1, 6): 118 | data_dict = scipy.io.loadmat(tar.extractfile( 119 | 'cifar-10-batches-mat/data_batch_{}.mat'.format(batch))) 120 | train_data_batches.append(data_dict['data']) 121 | train_data_labels.append(data_dict['labels'].flatten()) 122 | train_set = {'images': np.concatenate(train_data_batches, axis=0), 123 | 'labels': np.concatenate(train_data_labels, axis=0)} 124 | data_dict = scipy.io.loadmat(tar.extractfile( 125 | 'cifar-10-batches-mat/test_batch.mat')) 126 | test_set = {'images': data_dict['data'], 127 | 'labels': data_dict['labels'].flatten()} 128 | train_set['images'] = _encode_png(unflatten(train_set['images'])) 129 | test_set['images'] = _encode_png(unflatten(test_set['images'])) 130 | return dict(train=train_set, test=test_set) 131 | 132 | 133 | def _load_cifar100(): 134 | def unflatten(images): 135 | return np.transpose(images.reshape((images.shape[0], 3, 32, 32)), 136 | [0, 2, 3, 1]) 137 | 138 | with tempfile.NamedTemporaryFile() as f: 139 | request.urlretrieve(URLS['cifar100'], f.name) 140 | tar = tarfile.open(fileobj=f) 141 | data_dict = scipy.io.loadmat(tar.extractfile('cifar-100-matlab/train.mat')) 142 | train_set = {'images': data_dict['data'], 143 | 'labels': data_dict['fine_labels'].flatten()} 144 | data_dict = scipy.io.loadmat(tar.extractfile('cifar-100-matlab/test.mat')) 145 | test_set = {'images': data_dict['data'], 146 | 'labels': data_dict['fine_labels'].flatten()} 147 | train_set['images'] = _encode_png(unflatten(train_set['images'])) 148 | test_set['images'] = _encode_png(unflatten(test_set['images'])) 149 | return dict(train=train_set, test=test_set) 150 | 151 | 152 | def _load_fashionmnist(): 153 | def _read32(data): 154 | dt = np.dtype(np.uint32).newbyteorder('>') 155 | return np.frombuffer(data.read(4), dtype=dt)[0] 156 | 157 | image_filename = '{}-images-idx3-ubyte' 158 | label_filename = '{}-labels-idx1-ubyte' 159 | split_files = [('train', 'train'), ('test', 't10k')] 160 | splits = {} 161 | for split, split_file in split_files: 162 | with tempfile.NamedTemporaryFile() as f: 163 | request.urlretrieve(URLS['fashion_mnist'].format(image_filename.format(split_file)), f.name) 164 | with gzip.GzipFile(fileobj=f, mode='r') as data: 165 | assert _read32(data) == 2051 166 | n_images = _read32(data) 167 | row = _read32(data) 168 | col = _read32(data) 169 | images = np.frombuffer(data.read(n_images * row * col), dtype=np.uint8) 170 | images = images.reshape((n_images, row, col, 1)) 171 | with tempfile.NamedTemporaryFile() as f: 172 | request.urlretrieve(URLS['fashion_mnist'].format(label_filename.format(split_file)), f.name) 173 | with gzip.GzipFile(fileobj=f, mode='r') as data: 174 | assert _read32(data) == 2049 175 | n_labels = _read32(data) 176 | labels = np.frombuffer(data.read(n_labels), dtype=np.uint8) 177 | splits[split] = {'images': _encode_png(images), 'labels': labels} 178 | return splits 179 | 180 | 181 | def _int64_feature(value): 182 | return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) 183 | 184 | 185 | def _bytes_feature(value): 186 | return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) 187 | 188 | 189 | def _save_as_tfrecord(data, filename): 190 | assert len(data['images']) == len(data['labels']) 191 | filename = os.path.join(libml_data.DATA_DIR, filename + '.tfrecord') 192 | print('Saving dataset:', filename) 193 | with tf.python_io.TFRecordWriter(filename) as writer: 194 | for x in trange(len(data['images']), desc='Building records'): 195 | feat = dict(image=_bytes_feature(data['images'][x]), 196 | label=_int64_feature(data['labels'][x])) 197 | record = tf.train.Example(features=tf.train.Features(feature=feat)) 198 | writer.write(record.SerializeToString()) 199 | print('Saved:', filename) 200 | 201 | 202 | def _is_installed(name, checksums): 203 | for subset, checksum in checksums.items(): 204 | filename = os.path.join(libml_data.DATA_DIR, '%s-%s.tfrecord' % (name, subset)) 205 | if not tf.gfile.Exists(filename): 206 | return False 207 | return True 208 | 209 | 210 | def _save_files(files, *args, **kwargs): 211 | del args, kwargs 212 | for folder in frozenset(os.path.dirname(x) for x in files): 213 | tf.gfile.MakeDirs(os.path.join(libml_data.DATA_DIR, folder)) 214 | for filename, contents in files.items(): 215 | with tf.gfile.Open(os.path.join(libml_data.DATA_DIR, filename), 'w') as f: 216 | f.write(contents) 217 | 218 | 219 | def _is_installed_folder(name, folder): 220 | return tf.gfile.Exists(os.path.join(libml_data.DATA_DIR, name, folder)) 221 | 222 | 223 | CONFIGS = dict( 224 | cifar10=dict(loader=_load_cifar10, checksums=dict(train=None, test=None)), 225 | cifar100=dict(loader=_load_cifar100, checksums=dict(train=None, test=None)), 226 | svhn=dict(loader=_load_svhn, checksums=dict(train=None, test=None, extra=None)), 227 | stl10=dict(loader=_load_stl10, checksums=dict(train=None, test=None)), 228 | # fashion_mnist=dict(loader=_load_fashionmnist, checksums=dict(train=None, test=None)), 229 | ) 230 | 231 | 232 | def main(argv): 233 | if len(argv[1:]): 234 | subset = set(argv[1:]) 235 | else: 236 | subset = set(CONFIGS.keys()) 237 | tf.gfile.MakeDirs(libml_data.DATA_DIR) 238 | for name, config in CONFIGS.items(): 239 | if name not in subset: 240 | continue 241 | if 'is_installed' in config: 242 | if config['is_installed'](): 243 | print('Skipping already installed:', name) 244 | continue 245 | elif _is_installed(name, config['checksums']): 246 | print('Skipping already installed:', name) 247 | continue 248 | print('Preparing', name) 249 | datas = config['loader']() 250 | saver = config.get('saver', _save_as_tfrecord) 251 | for sub_name, data in datas.items(): 252 | if sub_name == 'readme': 253 | filename = os.path.join(libml_data.DATA_DIR, '%s-%s.txt' % (name, sub_name)) 254 | with tf.gfile.Open(filename, 'w') as f: 255 | f.write(data) 256 | elif sub_name == 'files': 257 | for file_and_data in data: 258 | path = os.path.join(libml_data.DATA_DIR, file_and_data.filename) 259 | with tf.gfile.Open(path, "wb") as f: 260 | f.write(file_and_data.data) 261 | else: 262 | saver(data, '%s-%s' % (name, sub_name)) 263 | 264 | 265 | if __name__ == '__main__': 266 | app.run(main) 267 | -------------------------------------------------------------------------------- /scripts/create_split.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # Copyright 2019 Google LLC 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # https://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """Script to create SSL splits from a dataset. 18 | """ 19 | 20 | import json 21 | import os 22 | from collections import defaultdict 23 | 24 | import numpy as np 25 | import tensorflow as tf 26 | from absl import app 27 | from absl import flags 28 | from tqdm import trange, tqdm 29 | 30 | from libml import data as libml_data 31 | from libml import utils 32 | 33 | flags.DEFINE_integer('seed', 0, 'Random seed to use, 0 for no shuffling.') 34 | flags.DEFINE_integer('size', 0, 'Size of labelled set.') 35 | 36 | FLAGS = flags.FLAGS 37 | 38 | 39 | def get_class(serialized_example): 40 | return tf.parse_single_example(serialized_example, features={'label': tf.FixedLenFeature([], tf.int64)})['label'] 41 | 42 | 43 | def main(argv): 44 | assert FLAGS.size 45 | argv.pop(0) 46 | if any(not tf.gfile.Exists(f) for f in argv[1:]): 47 | raise FileNotFoundError(argv[1:]) 48 | target = '%s.%d@%d' % (argv[0], FLAGS.seed, FLAGS.size) 49 | if tf.gfile.Exists(target): 50 | raise FileExistsError('For safety overwriting is not allowed', target) 51 | input_files = argv[1:] 52 | count = 0 53 | id_class = [] 54 | class_id = defaultdict(list) 55 | print('Computing class distribution') 56 | dataset = tf.data.TFRecordDataset(input_files).map(get_class, 4).batch(1 << 10) 57 | it = dataset.make_one_shot_iterator().get_next() 58 | try: 59 | with tf.Session() as session, tqdm(leave=False) as t: 60 | while 1: 61 | old_count = count 62 | for i in session.run(it): 63 | id_class.append(i) 64 | class_id[i].append(count) 65 | count += 1 66 | t.update(count - old_count) 67 | except tf.errors.OutOfRangeError: 68 | pass 69 | print('%d records found' % count) 70 | nclass = len(class_id) 71 | assert min(class_id.keys()) == 0 and max(class_id.keys()) == (nclass - 1) 72 | train_stats = np.array([len(class_id[i]) for i in range(nclass)], np.float64) 73 | train_stats /= train_stats.max() 74 | if 'stl10' in argv[1]: 75 | # All of the unlabeled data is given label 0, but we know that 76 | # STL has equally distributed data among the 10 classes. 77 | train_stats[:] = 1 78 | 79 | print(' Stats', ' '.join(['%.2f' % (100 * x) for x in train_stats])) 80 | assert min(class_id.keys()) == 0 and max(class_id.keys()) == (nclass - 1) 81 | class_id = [np.array(class_id[i], dtype=np.int64) for i in range(nclass)] 82 | if FLAGS.seed: 83 | np.random.seed(FLAGS.seed) 84 | for i in range(nclass): 85 | np.random.shuffle(class_id[i]) 86 | 87 | # Distribute labels to match the input distribution. 88 | npos = np.zeros(nclass, np.int64) 89 | label = [] 90 | for i in range(FLAGS.size): 91 | c = np.argmax(train_stats - npos / max(npos.max(), 1)) 92 | label.append(class_id[c][npos[c]]) 93 | npos[c] += 1 94 | 95 | del npos, class_id 96 | label = frozenset([int(x) for x in label]) 97 | if 'stl10' in argv[1] and FLAGS.size == 1000: 98 | data = tf.gfile.Open(os.path.join(libml_data.DATA_DIR, 'stl10_fold_indices.txt'), 'r').read() 99 | label = frozenset(list(map(int, data.split('\n')[FLAGS.seed].split()))) 100 | 101 | print('Creating split in %s' % target) 102 | tf.gfile.MakeDirs(os.path.dirname(target)) 103 | with tf.python_io.TFRecordWriter(target + '-label.tfrecord') as writer_label: 104 | pos, loop = 0, trange(count, desc='Writing records') 105 | for input_file in input_files: 106 | for record in tf.python_io.tf_record_iterator(input_file): 107 | if pos in label: 108 | writer_label.write(record) 109 | pos += 1 110 | loop.update() 111 | loop.close() 112 | with tf.gfile.Open(target + '-label.json', 'w') as writer: 113 | writer.write(json.dumps(dict(distribution=train_stats.tolist(), label=sorted(label)), indent=2, sort_keys=True)) 114 | 115 | 116 | if __name__ == '__main__': 117 | utils.setup_tf() 118 | app.run(main) 119 | -------------------------------------------------------------------------------- /scripts/create_unlabeled.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # Copyright 2019 Google LLC 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # https://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """Script to create SSL splits from a dataset. 18 | """ 19 | 20 | import json 21 | import os 22 | from collections import defaultdict 23 | 24 | import numpy as np 25 | import tensorflow as tf 26 | from absl import app 27 | from tqdm import trange, tqdm 28 | 29 | from libml import utils 30 | 31 | 32 | def get_class(serialized_example): 33 | return tf.parse_single_example(serialized_example, features={'label': tf.FixedLenFeature([], tf.int64)})['label'] 34 | 35 | 36 | def main(argv): 37 | argv.pop(0) 38 | if any(not tf.gfile.Exists(f) for f in argv[1:]): 39 | raise FileNotFoundError(argv[1:]) 40 | target = argv[0] 41 | input_files = argv[1:] 42 | count = 0 43 | id_class = [] 44 | class_id = defaultdict(list) 45 | print('Computing class distribution') 46 | dataset = tf.data.TFRecordDataset(input_files).map(get_class, 4).batch(1 << 10) 47 | it = dataset.make_one_shot_iterator().get_next() 48 | try: 49 | with tf.Session() as session, tqdm(leave=False) as t: 50 | while 1: 51 | old_count = count 52 | for i in session.run(it): 53 | id_class.append(i) 54 | class_id[i].append(count) 55 | count += 1 56 | t.update(count - old_count) 57 | except tf.errors.OutOfRangeError: 58 | pass 59 | print('%d records found' % count) 60 | nclass = len(class_id) 61 | assert min(class_id.keys()) == 0 and max(class_id.keys()) == (nclass - 1) 62 | train_stats = np.array([len(class_id[i]) for i in range(nclass)], np.float64) 63 | train_stats /= train_stats.max() 64 | if 'stl10' in argv[1]: 65 | # All of the unlabeled data is given label 0, but we know that 66 | # STL has equally distributed data among the 10 classes. 67 | train_stats[:] = 1 68 | 69 | print(' Stats', ' '.join(['%.2f' % (100 * x) for x in train_stats])) 70 | del class_id 71 | 72 | print('Creating unlabeled dataset for in %s' % target) 73 | npos = np.zeros(nclass, np.int64) 74 | class_data = [[] for _ in range(nclass)] 75 | unlabel = [] 76 | tf.gfile.MakeDirs(os.path.dirname(target)) 77 | with tf.python_io.TFRecordWriter(target + '-unlabel.tfrecord') as writer_unlabel: 78 | pos, loop = 0, trange(count, desc='Writing records') 79 | for input_file in input_files: 80 | for record in tf.python_io.tf_record_iterator(input_file): 81 | class_data[id_class[pos]].append((pos, record)) 82 | while True: 83 | c = np.argmax(train_stats - npos / max(npos.max(), 1)) 84 | if class_data[c]: 85 | p, v = class_data[c].pop(0) 86 | unlabel.append(p) 87 | writer_unlabel.write(v) 88 | npos[c] += 1 89 | else: 90 | break 91 | pos += 1 92 | loop.update() 93 | for remain in class_data: 94 | for p, v in remain: 95 | unlabel.append(p) 96 | writer_unlabel.write(v) 97 | loop.close() 98 | with tf.gfile.Open(target + '-unlabel.json', 'w') as writer: 99 | writer.write(json.dumps(dict(distribution=train_stats.tolist(), indexes=unlabel), indent=2, sort_keys=True)) 100 | 101 | 102 | if __name__ == '__main__': 103 | utils.setup_tf() 104 | app.run(main) 105 | -------------------------------------------------------------------------------- /scripts/extract_accuracy.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # Copyright 2019 Google LLC 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # https://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """Extract and save accuracy to 'stats/accuracy.json'. 18 | 19 | The accuracy is extracted from the most recent eventfile. 20 | """ 21 | 22 | import json 23 | import os.path 24 | 25 | import numpy as np 26 | import tensorflow as tf 27 | from absl import app 28 | from absl import flags 29 | 30 | FLAGS = flags.FLAGS 31 | TAG = 'accuracy' 32 | 33 | 34 | def summary_dict(accuracies): 35 | return { 36 | 'last%02d' % x: np.median(accuracies[-x:]) for x in [1, 10, 20, 50] 37 | } 38 | 39 | 40 | def main(argv): 41 | if len(argv) > 2: 42 | raise app.UsageError('Too many command-line arguments.') 43 | folder = argv[1] 44 | matches = sorted(tf.gfile.Glob(os.path.join(folder, 'tf/events.out.tfevents.*'))) 45 | assert matches, 'No events files found' 46 | tags = set() 47 | accuracies = [] 48 | for event_file in matches: 49 | try: 50 | for e in tf.train.summary_iterator(event_file): 51 | for v in e.summary.value: 52 | if v.tag == TAG: 53 | accuracies.append(v.simple_value) 54 | break 55 | elif not accuracies: 56 | tags.add(v.tag) 57 | except tf.errors.DataLossError: 58 | continue 59 | 60 | assert accuracies, 'No "accuracy" tag found. Found tags = %s' % tags 61 | target_dir = os.path.join(folder, 'stats') 62 | target_file = os.path.join(target_dir, 'accuracy.json') 63 | tf.gfile.MakeDirs(target_dir) 64 | 65 | with tf.gfile.Open(target_file, 'w') as f: 66 | json.dump(summary_dict(accuracies), f, sort_keys=True, indent=4) 67 | print('Saved: %s' % target_file) 68 | 69 | 70 | if __name__ == '__main__': 71 | app.run(main) 72 | -------------------------------------------------------------------------------- /scripts/inspect_dataset.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # Copyright 2019 Google LLC 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # https://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """Script to inspect a dataset, in particular label distribution. 18 | """ 19 | 20 | import numpy as np 21 | import tensorflow as tf 22 | from absl import app 23 | from absl import flags 24 | from tqdm import trange 25 | 26 | from libml import data, utils 27 | 28 | flags.DEFINE_integer('batch', 64, 'Batch size.') 29 | flags.DEFINE_integer('samples', 1 << 16, 'Number of samples to load.') 30 | 31 | FLAGS = flags.FLAGS 32 | 33 | 34 | def main(argv): 35 | utils.setup_main() 36 | del argv 37 | utils.setup_tf() 38 | nbatch = FLAGS.samples // FLAGS.batch 39 | dataset = data.DATASETS()[FLAGS.dataset]() 40 | groups = [('labeled', dataset.train_labeled), 41 | ('unlabeled', dataset.train_unlabeled), 42 | ('test', dataset.test.repeat())] 43 | groups = [(name, ds.batch(FLAGS.batch).prefetch(16).make_one_shot_iterator().get_next()) 44 | for name, ds in groups] 45 | with tf.train.MonitoredSession() as sess: 46 | for group, train_data in groups: 47 | stats = np.zeros(dataset.nclass, np.int32) 48 | minmax = [], [] 49 | for _ in trange(nbatch, leave=False, unit='img', unit_scale=FLAGS.batch, desc=group): 50 | v = sess.run(train_data)['label'] 51 | for u in v: 52 | stats[u] += 1 53 | minmax[0].append(v.min()) 54 | minmax[1].append(v.max()) 55 | print(group) 56 | print(' Label range', min(minmax[0]), max(minmax[1])) 57 | print(' Stats', ' '.join(['%.2f' % (100 * x) for x in (stats / stats.max())])) 58 | 59 | 60 | if __name__ == '__main__': 61 | app.run(main) 62 | -------------------------------------------------------------------------------- /third_party/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright 2015 Google Inc. 4 | This code has been modified from the original version at https://github.com/takerum/vat_tf 5 | Original license reproduced below. 6 | 7 | Copyright (c) 2017 Takeru Miyato 8 | 9 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 10 | 11 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 14 | -------------------------------------------------------------------------------- /third_party/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-research/remixmatch/f7061ebf055227cbeb5c6fced1ce054e0ceecfcd/third_party/__init__.py -------------------------------------------------------------------------------- /third_party/auto_augment/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-research/remixmatch/f7061ebf055227cbeb5c6fced1ce054e0ceecfcd/third_party/auto_augment/__init__.py -------------------------------------------------------------------------------- /third_party/auto_augment/custom_ops.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2019 The Google UDA Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Contains convenience wrappers for typical Neural Network TensorFlow layers. 16 | 17 | Ops that have different behavior during training or eval have an is_training 18 | parameter. 19 | 20 | Copied from AutoAugment: https://github.com/tensorflow/models/blob/master/research/autoaugment/ 21 | """ 22 | from __future__ import absolute_import 23 | from __future__ import division 24 | from __future__ import print_function 25 | 26 | import numpy as np 27 | import tensorflow as tf 28 | 29 | arg_scope = tf.contrib.framework.arg_scope 30 | FLAGS = tf.flags.FLAGS 31 | 32 | 33 | def variable(name, shape, dtype, initializer, trainable): 34 | """Returns a TF variable with the passed in specifications.""" 35 | var = tf.get_variable( 36 | name, 37 | shape=shape, 38 | dtype=dtype, 39 | initializer=initializer, 40 | trainable=trainable) 41 | return var 42 | 43 | 44 | def global_avg_pool(x, scope=None): 45 | """Average pools away spatial height and width dimension of 4D tensor.""" 46 | assert x.get_shape().ndims == 4 47 | with tf.name_scope(scope, 'global_avg_pool', [x]): 48 | kernel_size = (1, int(x.shape[1]), int(x.shape[2]), 1) 49 | squeeze_dims = (1, 2) 50 | result = tf.nn.avg_pool( 51 | x, 52 | ksize=kernel_size, 53 | strides=(1, 1, 1, 1), 54 | padding='VALID', 55 | data_format='NHWC') 56 | return tf.squeeze(result, squeeze_dims) 57 | 58 | 59 | def zero_pad(inputs, in_filter, out_filter): 60 | """Zero pads `input` tensor to have `out_filter` number of filters.""" 61 | outputs = tf.pad(inputs, [[0, 0], [0, 0], [0, 0], 62 | [(out_filter - in_filter) // 2, 63 | (out_filter - in_filter) // 2]]) 64 | return outputs 65 | 66 | 67 | @tf.contrib.framework.add_arg_scope 68 | def batch_norm(inputs, 69 | update_stats=True, 70 | decay=0.999, 71 | center=True, 72 | scale=False, 73 | epsilon=0.001, 74 | is_training=True, 75 | reuse=None, 76 | scope=None, 77 | ): 78 | """Small wrapper around tf.contrib.layers.batch_norm.""" 79 | batch_norm_op = tf.layers.batch_normalization( 80 | inputs, 81 | axis=-1, 82 | momentum=decay, 83 | epsilon=epsilon, 84 | center=center, 85 | scale=scale, 86 | training=is_training, 87 | fused=True, 88 | trainable=True, 89 | ) 90 | return batch_norm_op 91 | 92 | 93 | def stride_arr(stride_h, stride_w): 94 | return [1, stride_h, stride_w, 1] 95 | 96 | 97 | @tf.contrib.framework.add_arg_scope 98 | def conv2d(inputs, 99 | num_filters_out, 100 | kernel_size, 101 | stride=1, 102 | scope=None, 103 | reuse=None): 104 | """Adds a 2D convolution. 105 | 106 | conv2d creates a variable called 'weights', representing the convolutional 107 | kernel, that is convolved with the input. 108 | 109 | Args: 110 | inputs: a 4D tensor in NHWC format. 111 | num_filters_out: the number of output filters. 112 | kernel_size: an int specifying the kernel height and width size. 113 | stride: an int specifying the height and width stride. 114 | scope: Optional scope for variable_scope. 115 | reuse: whether or not the layer and its variables should be reused. 116 | Returns: 117 | a tensor that is the result of a convolution being applied to `inputs`. 118 | """ 119 | with tf.variable_scope(scope, 'Conv', [inputs], reuse=reuse): 120 | num_filters_in = int(inputs.shape[3]) 121 | weights_shape = [kernel_size, kernel_size, num_filters_in, num_filters_out] 122 | 123 | # Initialization 124 | n = int(weights_shape[0] * weights_shape[1] * weights_shape[3]) 125 | weights_initializer = tf.random_normal_initializer( 126 | stddev=np.sqrt(2.0 / n)) 127 | 128 | weights = variable( 129 | name='weights', 130 | shape=weights_shape, 131 | dtype=tf.float32, 132 | initializer=weights_initializer, 133 | trainable=True) 134 | strides = stride_arr(stride, stride) 135 | outputs = tf.nn.conv2d( 136 | inputs, weights, strides, padding='SAME', data_format='NHWC') 137 | return outputs 138 | 139 | 140 | @tf.contrib.framework.add_arg_scope 141 | def fc(inputs, 142 | num_units_out, 143 | scope=None, 144 | reuse=None): 145 | """Creates a fully connected layer applied to `inputs`. 146 | 147 | Args: 148 | inputs: a tensor that the fully connected layer will be applied to. It 149 | will be reshaped if it is not 2D. 150 | num_units_out: the number of output units in the layer. 151 | scope: Optional scope for variable_scope. 152 | reuse: whether or not the layer and its variables should be reused. 153 | 154 | Returns: 155 | a tensor that is the result of applying a linear matrix to `inputs`. 156 | """ 157 | if len(inputs.shape) > 2: 158 | inputs = tf.reshape(inputs, [int(inputs.shape[0]), -1]) 159 | 160 | with tf.variable_scope(scope, 'FC', [inputs], reuse=reuse): 161 | num_units_in = inputs.shape[1] 162 | weights_shape = [num_units_in, num_units_out] 163 | unif_init_range = 1.0 / (num_units_out) ** (0.5) 164 | weights_initializer = tf.random_uniform_initializer( 165 | -unif_init_range, unif_init_range) 166 | weights = variable( 167 | name='weights', 168 | shape=weights_shape, 169 | dtype=tf.float32, 170 | initializer=weights_initializer, 171 | trainable=True) 172 | bias_initializer = tf.constant_initializer(0.0) 173 | biases = variable( 174 | name='biases', 175 | shape=[num_units_out, ], 176 | dtype=tf.float32, 177 | initializer=bias_initializer, 178 | trainable=True) 179 | outputs = tf.nn.xw_plus_b(inputs, weights, biases) 180 | return outputs 181 | 182 | 183 | @tf.contrib.framework.add_arg_scope 184 | def avg_pool(inputs, kernel_size, stride=2, padding='VALID', scope=None): 185 | """Wrapper around tf.nn.avg_pool.""" 186 | with tf.name_scope(scope, 'AvgPool', [inputs]): 187 | kernel = stride_arr(kernel_size, kernel_size) 188 | strides = stride_arr(stride, stride) 189 | return tf.nn.avg_pool( 190 | inputs, 191 | ksize=kernel, 192 | strides=strides, 193 | padding=padding, 194 | data_format='NHWC') 195 | -------------------------------------------------------------------------------- /third_party/auto_augment/policies.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The TensorFlow Authors All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | 17 | def cifar10_policies(): 18 | """AutoAugment policies found on CIFAR-10.""" 19 | exp0_0 = [[('Invert', 0.1, 7), ('Contrast', 0.2, 6)], 20 | [('Rotate', 0.7, 2), ('TranslateX', 0.3, 9)], 21 | [('Sharpness', 0.8, 1), ('Sharpness', 0.9, 3)], 22 | [('ShearY', 0.5, 8), ('TranslateY', 0.7, 9)], 23 | [('AutoContrast', 0.5, 8), ('Equalize', 0.9, 2)]] 24 | exp0_1 = [[('Solarize', 0.4, 5), ('AutoContrast', 0.9, 3)], 25 | [('TranslateY', 0.9, 9), ('TranslateY', 0.7, 9)], 26 | [('AutoContrast', 0.9, 2), ('Solarize', 0.8, 3)], 27 | [('Equalize', 0.8, 8), ('Invert', 0.1, 3)], 28 | [('TranslateY', 0.7, 9), ('AutoContrast', 0.9, 1)]] 29 | exp0_2 = [[('Solarize', 0.4, 5), ('AutoContrast', 0.0, 2)], 30 | [('TranslateY', 0.7, 9), ('TranslateY', 0.7, 9)], 31 | [('AutoContrast', 0.9, 0), ('Solarize', 0.4, 3)], 32 | [('Equalize', 0.7, 5), ('Invert', 0.1, 3)], 33 | [('TranslateY', 0.7, 9), ('TranslateY', 0.7, 9)]] 34 | exp0_3 = [[('Solarize', 0.4, 5), ('AutoContrast', 0.9, 1)], 35 | [('TranslateY', 0.8, 9), ('TranslateY', 0.9, 9)], 36 | [('AutoContrast', 0.8, 0), ('TranslateY', 0.7, 9)], 37 | [('TranslateY', 0.2, 7), ('Color', 0.9, 6)], 38 | [('Equalize', 0.7, 6), ('Color', 0.4, 9)]] 39 | exp1_0 = [[('ShearY', 0.2, 7), ('Posterize', 0.3, 7)], 40 | [('Color', 0.4, 3), ('Brightness', 0.6, 7)], 41 | [('Sharpness', 0.3, 9), ('Brightness', 0.7, 9)], 42 | [('Equalize', 0.6, 5), ('Equalize', 0.5, 1)], 43 | [('Contrast', 0.6, 7), ('Sharpness', 0.6, 5)]] 44 | exp1_1 = [[('Brightness', 0.3, 7), ('AutoContrast', 0.5, 8)], 45 | [('AutoContrast', 0.9, 4), ('AutoContrast', 0.5, 6)], 46 | [('Solarize', 0.3, 5), ('Equalize', 0.6, 5)], 47 | [('TranslateY', 0.2, 4), ('Sharpness', 0.3, 3)], 48 | [('Brightness', 0.0, 8), ('Color', 0.8, 8)]] 49 | exp1_2 = [[('Solarize', 0.2, 6), ('Color', 0.8, 6)], 50 | [('Solarize', 0.2, 6), ('AutoContrast', 0.8, 1)], 51 | [('Solarize', 0.4, 1), ('Equalize', 0.6, 5)], 52 | [('Brightness', 0.0, 0), ('Solarize', 0.5, 2)], 53 | [('AutoContrast', 0.9, 5), ('Brightness', 0.5, 3)]] 54 | exp1_3 = [[('Contrast', 0.7, 5), ('Brightness', 0.0, 2)], 55 | [('Solarize', 0.2, 8), ('Solarize', 0.1, 5)], 56 | [('Contrast', 0.5, 1), ('TranslateY', 0.2, 9)], 57 | [('AutoContrast', 0.6, 5), ('TranslateY', 0.0, 9)], 58 | [('AutoContrast', 0.9, 4), ('Equalize', 0.8, 4)]] 59 | exp1_4 = [[('Brightness', 0.0, 7), ('Equalize', 0.4, 7)], 60 | [('Solarize', 0.2, 5), ('Equalize', 0.7, 5)], 61 | [('Equalize', 0.6, 8), ('Color', 0.6, 2)], 62 | [('Color', 0.3, 7), ('Color', 0.2, 4)], 63 | [('AutoContrast', 0.5, 2), ('Solarize', 0.7, 2)]] 64 | exp1_5 = [[('AutoContrast', 0.2, 0), ('Equalize', 0.1, 0)], 65 | [('ShearY', 0.6, 5), ('Equalize', 0.6, 5)], 66 | [('Brightness', 0.9, 3), ('AutoContrast', 0.4, 1)], 67 | [('Equalize', 0.8, 8), ('Equalize', 0.7, 7)], 68 | [('Equalize', 0.7, 7), ('Solarize', 0.5, 0)]] 69 | exp1_6 = [[('Equalize', 0.8, 4), ('TranslateY', 0.8, 9)], 70 | [('TranslateY', 0.8, 9), ('TranslateY', 0.6, 9)], 71 | [('TranslateY', 0.9, 0), ('TranslateY', 0.5, 9)], 72 | [('AutoContrast', 0.5, 3), ('Solarize', 0.3, 4)], 73 | [('Solarize', 0.5, 3), ('Equalize', 0.4, 4)]] 74 | exp2_0 = [[('Color', 0.7, 7), ('TranslateX', 0.5, 8)], 75 | [('Equalize', 0.3, 7), ('AutoContrast', 0.4, 8)], 76 | [('TranslateY', 0.4, 3), ('Sharpness', 0.2, 6)], 77 | [('Brightness', 0.9, 6), ('Color', 0.2, 8)], 78 | [('Solarize', 0.5, 2), ('Invert', 0.0, 3)]] 79 | exp2_1 = [[('AutoContrast', 0.1, 5), ('Brightness', 0.0, 0)], 80 | [('Cutout', 0.2, 4), ('Equalize', 0.1, 1)], 81 | [('Equalize', 0.7, 7), ('AutoContrast', 0.6, 4)], 82 | [('Color', 0.1, 8), ('ShearY', 0.2, 3)], 83 | [('ShearY', 0.4, 2), ('Rotate', 0.7, 0)]] 84 | exp2_2 = [[('ShearY', 0.1, 3), ('AutoContrast', 0.9, 5)], 85 | [('TranslateY', 0.3, 6), ('Cutout', 0.3, 3)], 86 | [('Equalize', 0.5, 0), ('Solarize', 0.6, 6)], 87 | [('AutoContrast', 0.3, 5), ('Rotate', 0.2, 7)], 88 | [('Equalize', 0.8, 2), ('Invert', 0.4, 0)]] 89 | exp2_3 = [[('Equalize', 0.9, 5), ('Color', 0.7, 0)], 90 | [('Equalize', 0.1, 1), ('ShearY', 0.1, 3)], 91 | [('AutoContrast', 0.7, 3), ('Equalize', 0.7, 0)], 92 | [('Brightness', 0.5, 1), ('Contrast', 0.1, 7)], 93 | [('Contrast', 0.1, 4), ('Solarize', 0.6, 5)]] 94 | exp2_4 = [[('Solarize', 0.2, 3), ('ShearX', 0.0, 0)], 95 | [('TranslateX', 0.3, 0), ('TranslateX', 0.6, 0)], 96 | [('Equalize', 0.5, 9), ('TranslateY', 0.6, 7)], 97 | [('ShearX', 0.1, 0), ('Sharpness', 0.5, 1)], 98 | [('Equalize', 0.8, 6), ('Invert', 0.3, 6)]] 99 | exp2_5 = [[('AutoContrast', 0.3, 9), ('Cutout', 0.5, 3)], 100 | [('ShearX', 0.4, 4), ('AutoContrast', 0.9, 2)], 101 | [('ShearX', 0.0, 3), ('Posterize', 0.0, 3)], 102 | [('Solarize', 0.4, 3), ('Color', 0.2, 4)], 103 | [('Equalize', 0.1, 4), ('Equalize', 0.7, 6)]] 104 | exp2_6 = [[('Equalize', 0.3, 8), ('AutoContrast', 0.4, 3)], 105 | [('Solarize', 0.6, 4), ('AutoContrast', 0.7, 6)], 106 | [('AutoContrast', 0.2, 9), ('Brightness', 0.4, 8)], 107 | [('Equalize', 0.1, 0), ('Equalize', 0.0, 6)], 108 | [('Equalize', 0.8, 4), ('Equalize', 0.0, 4)]] 109 | exp2_7 = [[('Equalize', 0.5, 5), ('AutoContrast', 0.1, 2)], 110 | [('Solarize', 0.5, 5), ('AutoContrast', 0.9, 5)], 111 | [('AutoContrast', 0.6, 1), ('AutoContrast', 0.7, 8)], 112 | [('Equalize', 0.2, 0), ('AutoContrast', 0.1, 2)], 113 | [('Equalize', 0.6, 9), ('Equalize', 0.4, 4)]] 114 | exp0s = exp0_0 + exp0_1 + exp0_2 + exp0_3 115 | exp1s = exp1_0 + exp1_1 + exp1_2 + exp1_3 + exp1_4 + exp1_5 + exp1_6 116 | exp2s = exp2_0 + exp2_1 + exp2_2 + exp2_3 + exp2_4 + exp2_5 + exp2_6 + exp2_7 117 | return exp0s + exp1s + exp2s 118 | 119 | 120 | def svhn_policies(): 121 | """AutoAugment policies found on SVHN.""" 122 | policies = [ 123 | [('ShearX', 0.9, 4), ('Invert', 0.2, 3)], 124 | [('ShearY', 0.9, 8), ('Invert', 0.7, 5)], 125 | [('Equalize', 0.6, 5), ('Solarize', 0.6, 6)], 126 | [('Invert', 0.9, 3), ('Equalize', 0.6, 3)], 127 | [('Equalize', 0.6, 1), ('Rotate', 0.9, 3)], 128 | [('ShearX', 0.9, 4), ('AutoContrast', 0.8, 3)], 129 | [('ShearY', 0.9, 8), ('Invert', 0.4, 5)], 130 | [('ShearY', 0.9, 5), ('Solarize', 0.2, 6)], 131 | [('Invert', 0.9, 6), ('AutoContrast', 0.8, 1)], 132 | [('Equalize', 0.6, 3), ('Rotate', 0.9, 3)], 133 | [('ShearX', 0.9, 4), ('Solarize', 0.3, 3)], 134 | [('ShearY', 0.8, 8), ('Invert', 0.7, 4)], 135 | [('Equalize', 0.9, 5), ('TranslateY', 0.6, 6)], 136 | [('Invert', 0.9, 4), ('Equalize', 0.6, 7)], 137 | [('Contrast', 0.3, 3), ('Rotate', 0.8, 4)], 138 | [('ShearX', 0.9, 3), ('Invert', 0.5, 3)], 139 | [('ShearY', 0.9, 8), ('Invert', 0.4, 5)], 140 | [('Equalize', 0.6, 3), ('Solarize', 0.2, 3)], 141 | [('Invert', 0.9, 4), ('Equalize', 0.5, 6)], 142 | [('Equalize', 0.6, 1), ('Rotate', 0.9, 3)], 143 | [('Invert', 0.8, 5), ('TranslateY', 0.0, 2)], 144 | [('ShearY', 0.7, 6), ('Solarize', 0.4, 8)], 145 | [('Invert', 0.6, 4), ('Rotate', 0.8, 4)], 146 | [('ShearY', 0.3, 7), ('TranslateX', 0.9, 3)], 147 | [('ShearX', 0.1, 6), ('Invert', 0.6, 5)], 148 | [('Solarize', 0.7, 2), ('TranslateY', 0.6, 7)], 149 | [('ShearY', 0.8, 4), ('Invert', 0.8, 8)], 150 | [('ShearX', 0.7, 9), ('TranslateY', 0.8, 3)], 151 | [('ShearY', 0.8, 5), ('AutoContrast', 0.7, 3)], 152 | [('ShearX', 0.7, 2), ('Invert', 0.1, 5)], 153 | [('ShearY', 0.8, 9), ('ShearX', 0.7, 7)], 154 | [('ShearY', 0.7, 4), ('Solarize', 0.9, 7)], 155 | [('ShearY', 0.9, 5), ('Invert', 0.0, 4)], 156 | [('TranslateX', 0.8, 3), ('ShearY', 0.7, 7)], 157 | [('Invert', 0.1, 7), ('Solarize', 0.3, 9)], 158 | [('Invert', 0.6, 2), ('Invert', 0.9, 4)], 159 | [('Equalize', 0.5, 2), ('Solarize', 0.9, 7)], 160 | [('ShearY', 0.6, 7), ('Solarize', 0.8, 3)], 161 | [('ShearY', 0.6, 3), ('Invert', 0.6, 1)], 162 | [('ShearX', 0.4, 2), ('Rotate', 0.7, 5)]] 163 | 164 | return policies 165 | 166 | 167 | def imagenet_policies(): 168 | """AutoAugment policies found on ImageNet. 169 | This policy also transfers to five FGVC datasets with image size similar to 170 | ImageNet including Oxford 102 Flowers, Caltech-101, Oxford-IIIT Pets, 171 | FGVC Aircraft and Stanford Cars. 172 | """ 173 | policies = [ 174 | [('Posterize', 0.4, 8), ('Rotate', 0.6, 9)], 175 | [('Solarize', 0.6, 5), ('AutoContrast', 0.6, 5)], 176 | [('Equalize', 0.8, 8), ('Equalize', 0.6, 3)], 177 | [('Posterize', 0.6, 7), ('Posterize', 0.6, 6)], 178 | [('Equalize', 0.4, 7), ('Solarize', 0.2, 4)], 179 | [('Equalize', 0.4, 4), ('Rotate', 0.8, 8)], 180 | [('Solarize', 0.6, 3), ('Equalize', 0.6, 7)], 181 | [('Posterize', 0.8, 5), ('Equalize', 1.0, 2)], 182 | [('Rotate', 0.2, 3), ('Solarize', 0.6, 8)], 183 | [('Equalize', 0.6, 8), ('Posterize', 0.4, 6)], 184 | [('Rotate', 0.8, 8), ('Color', 0.4, 0)], 185 | [('Rotate', 0.4, 9), ('Equalize', 0.6, 2)], 186 | [('Equalize', 0.0, 7), ('Equalize', 0.8, 8)], 187 | [('Invert', 0.6, 4), ('Equalize', 1.0, 8)], 188 | [('Color', 0.6, 4), ('Contrast', 1.0, 8)], 189 | [('Rotate', 0.8, 8), ('Color', 1.0, 2)], 190 | [('Color', 0.8, 8), ('Solarize', 0.8, 7)], 191 | [('Sharpness', 0.4, 7), ('Invert', 0.6, 8)], 192 | [('ShearX', 0.6, 5), ('Equalize', 1.0, 9)], 193 | [('Color', 0.4, 0), ('Equalize', 0.6, 3)] 194 | ] 195 | return policies 196 | -------------------------------------------------------------------------------- /third_party/auto_augment/shake_drop.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2019 The Google UDA Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Builds the Shake-Drop Model. 16 | Copied from AutoAugment: https://github.com/tensorflow/models/blob/master/research/autoaugment/ 17 | """ 18 | 19 | from __future__ import absolute_import 20 | from __future__ import division 21 | from __future__ import print_function 22 | 23 | import math 24 | 25 | import tensorflow as tf 26 | 27 | import third_party.auto_augment.custom_ops as ops 28 | 29 | 30 | def round_int(x): 31 | """Rounds `x` and then converts to an int.""" 32 | return int(math.floor(x + 0.5)) 33 | 34 | 35 | def shortcut(x, output_filters, stride): 36 | """Applies strided avg pool or zero padding to make output_filters match x.""" 37 | num_filters = int(x.shape[3]) 38 | if stride == 2: 39 | x = ops.avg_pool(x, 2, stride=stride, padding='SAME') 40 | if num_filters != output_filters: 41 | diff = output_filters - num_filters 42 | assert diff > 0 43 | # Zero padd diff zeros 44 | padding = [[0, 0], [0, 0], [0, 0], [0, diff]] 45 | x = tf.pad(x, padding) 46 | return x 47 | 48 | 49 | def calc_prob(curr_layer, total_layers, p_l): 50 | """Calculates drop prob depending on the current layer.""" 51 | return 1 - (float(curr_layer) / total_layers) * p_l 52 | 53 | 54 | def bottleneck_layer(x, n, stride, prob, is_training, alpha, beta): 55 | """Bottleneck layer for shake drop model.""" 56 | assert alpha[1] > alpha[0] 57 | assert beta[1] > beta[0] 58 | with tf.variable_scope('bottleneck_{}'.format(prob)): 59 | input_layer = x 60 | x = ops.batch_norm(x, scope='bn_1_pre') 61 | x = ops.conv2d(x, n, 1, scope='1x1_conv_contract') 62 | x = ops.batch_norm(x, scope='bn_1_post') 63 | x = tf.nn.relu(x) 64 | x = ops.conv2d(x, n, 3, stride=stride, scope='3x3') 65 | x = ops.batch_norm(x, scope='bn_2') 66 | x = tf.nn.relu(x) 67 | x = ops.conv2d(x, n * 4, 1, scope='1x1_conv_expand') 68 | x = ops.batch_norm(x, scope='bn_3') 69 | 70 | # Apply regularization here 71 | # Sample bernoulli with prob 72 | if is_training: 73 | batch_size = tf.shape(x)[0] 74 | bern_shape = [batch_size, 1, 1, 1] 75 | random_tensor = prob 76 | random_tensor += tf.random_uniform(bern_shape, dtype=tf.float32) 77 | binary_tensor = tf.floor(random_tensor) 78 | 79 | alpha_values = tf.random_uniform( 80 | [batch_size, 1, 1, 1], minval=alpha[0], maxval=alpha[1], 81 | dtype=tf.float32) 82 | beta_values = tf.random_uniform( 83 | [batch_size, 1, 1, 1], minval=beta[0], maxval=beta[1], 84 | dtype=tf.float32) 85 | rand_forward = ( 86 | binary_tensor + alpha_values - binary_tensor * alpha_values) 87 | rand_backward = ( 88 | binary_tensor + beta_values - binary_tensor * beta_values) 89 | x = x * rand_backward + tf.stop_gradient(x * rand_forward - 90 | x * rand_backward) 91 | else: 92 | expected_alpha = (alpha[1] + alpha[0]) / 2 93 | # prob is the expectation of the bernoulli variable 94 | x = (prob + expected_alpha - prob * expected_alpha) * x 95 | 96 | res = shortcut(input_layer, n * 4, stride) 97 | return x + res 98 | 99 | 100 | def build_shake_drop_model(images, num_classes, is_training): 101 | """Builds the PyramidNet Shake-Drop model. 102 | 103 | Build the PyramidNet Shake-Drop model from https://arxiv.org/abs/1802.02375. 104 | 105 | Args: 106 | images: Tensor of images that will be fed into the Wide ResNet Model. 107 | num_classes: Number of classed that the model needs to predict. 108 | is_training: Is the model training or not. 109 | 110 | Returns: 111 | The logits of the PyramidNet Shake-Drop model. 112 | """ 113 | 114 | is_training = is_training 115 | # ShakeDrop Hparams 116 | p_l = 0.5 117 | alpha_shake = [-1, 1] 118 | beta_shake = [0, 1] 119 | 120 | # PyramidNet Hparams 121 | alpha = 200 122 | depth = 272 123 | # This is for the bottleneck architecture specifically 124 | n = int((depth - 2) / 9) 125 | start_channel = 16 126 | add_channel = alpha / (3 * n) 127 | 128 | # Building the models 129 | x = images 130 | x = ops.conv2d(x, 16, 3, scope='init_conv') 131 | x = ops.batch_norm(x, scope='init_bn') 132 | 133 | layer_num = 1 134 | total_layers = n * 3 135 | start_channel += add_channel 136 | prob = calc_prob(layer_num, total_layers, p_l) 137 | x = bottleneck_layer( 138 | x, round_int(start_channel), 1, prob, is_training, alpha_shake, 139 | beta_shake) 140 | layer_num += 1 141 | for _ in range(1, n): 142 | start_channel += add_channel 143 | prob = calc_prob(layer_num, total_layers, p_l) 144 | x = bottleneck_layer( 145 | x, round_int(start_channel), 1, prob, is_training, alpha_shake, 146 | beta_shake) 147 | layer_num += 1 148 | 149 | start_channel += add_channel 150 | prob = calc_prob(layer_num, total_layers, p_l) 151 | x = bottleneck_layer( 152 | x, round_int(start_channel), 2, prob, is_training, alpha_shake, 153 | beta_shake) 154 | layer_num += 1 155 | for _ in range(1, n): 156 | start_channel += add_channel 157 | prob = calc_prob(layer_num, total_layers, p_l) 158 | x = bottleneck_layer( 159 | x, round_int(start_channel), 1, prob, is_training, alpha_shake, 160 | beta_shake) 161 | layer_num += 1 162 | 163 | start_channel += add_channel 164 | prob = calc_prob(layer_num, total_layers, p_l) 165 | x = bottleneck_layer( 166 | x, round_int(start_channel), 2, prob, is_training, alpha_shake, 167 | beta_shake) 168 | layer_num += 1 169 | for _ in range(1, n): 170 | start_channel += add_channel 171 | prob = calc_prob(layer_num, total_layers, p_l) 172 | x = bottleneck_layer( 173 | x, round_int(start_channel), 1, prob, is_training, alpha_shake, 174 | beta_shake) 175 | layer_num += 1 176 | 177 | assert layer_num - 1 == total_layers 178 | x = ops.batch_norm(x, scope='final_bn') 179 | x = tf.nn.relu(x) 180 | x = ops.global_avg_pool(x) 181 | # Fully connected 182 | logits = ops.fc(x, num_classes) 183 | return logits 184 | -------------------------------------------------------------------------------- /third_party/auto_augment/shake_shake.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2019 The Google UDA Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Builds the Shake-Shake Model. 16 | Copied from AutoAugment: https://github.com/tensorflow/models/blob/master/research/autoaugment/ 17 | """ 18 | 19 | from __future__ import absolute_import 20 | from __future__ import division 21 | from __future__ import print_function 22 | 23 | import tensorflow as tf 24 | 25 | import third_party.auto_augment.custom_ops as ops 26 | 27 | 28 | def _shake_shake_skip_connection(x, output_filters, stride): 29 | """Adds a residual connection to the filter x for the shake-shake model.""" 30 | curr_filters = int(x.shape[3]) 31 | if curr_filters == output_filters: 32 | return x 33 | stride_spec = ops.stride_arr(stride, stride) 34 | # Skip path 1 35 | path1 = tf.nn.avg_pool( 36 | x, [1, 1, 1, 1], stride_spec, 'VALID', data_format='NHWC') 37 | path1 = ops.conv2d(path1, int(output_filters / 2), 1, scope='path1_conv') 38 | 39 | # Skip path 2 40 | # First pad with 0's then crop 41 | pad_arr = [[0, 0], [0, 1], [0, 1], [0, 0]] 42 | path2 = tf.pad(x, pad_arr)[:, 1:, 1:, :] 43 | concat_axis = 3 44 | 45 | path2 = tf.nn.avg_pool( 46 | path2, [1, 1, 1, 1], stride_spec, 'VALID', data_format='NHWC') 47 | path2 = ops.conv2d(path2, int(output_filters / 2), 1, scope='path2_conv') 48 | 49 | # Concat and apply BN 50 | final_path = tf.concat(values=[path1, path2], axis=concat_axis) 51 | final_path = ops.batch_norm(final_path, scope='final_path_bn') 52 | return final_path 53 | 54 | 55 | def _shake_shake_branch(x, output_filters, stride, rand_forward, rand_backward, 56 | is_training): 57 | """Building a 2 branching convnet.""" 58 | x = tf.nn.relu(x) 59 | x = ops.conv2d(x, output_filters, 3, stride=stride, scope='conv1') 60 | x = ops.batch_norm(x, scope='bn1') 61 | x = tf.nn.relu(x) 62 | x = ops.conv2d(x, output_filters, 3, scope='conv2') 63 | x = ops.batch_norm(x, scope='bn2') 64 | if is_training: 65 | x = x * rand_backward + tf.stop_gradient(x * rand_forward - 66 | x * rand_backward) 67 | else: 68 | x *= 1.0 / 2 69 | return x 70 | 71 | 72 | def _shake_shake_block(x, output_filters, stride, is_training): 73 | """Builds a full shake-shake sub layer.""" 74 | batch_size = tf.shape(x)[0] 75 | 76 | # Generate random numbers for scaling the branches 77 | rand_forward = [ 78 | tf.random_uniform( 79 | [batch_size, 1, 1, 1], minval=0, maxval=1, dtype=tf.float32) 80 | for _ in range(2) 81 | ] 82 | rand_backward = [ 83 | tf.random_uniform( 84 | [batch_size, 1, 1, 1], minval=0, maxval=1, dtype=tf.float32) 85 | for _ in range(2) 86 | ] 87 | # Normalize so that all sum to 1 88 | total_forward = tf.add_n(rand_forward) 89 | total_backward = tf.add_n(rand_backward) 90 | rand_forward = [samp / total_forward for samp in rand_forward] 91 | rand_backward = [samp / total_backward for samp in rand_backward] 92 | zipped_rand = zip(rand_forward, rand_backward) 93 | 94 | branches = [] 95 | for branch, (r_forward, r_backward) in enumerate(zipped_rand): 96 | with tf.variable_scope('branch_{}'.format(branch)): 97 | b = _shake_shake_branch(x, output_filters, stride, r_forward, r_backward, 98 | is_training) 99 | branches.append(b) 100 | res = _shake_shake_skip_connection(x, output_filters, stride) 101 | return res + tf.add_n(branches) 102 | 103 | 104 | def _shake_shake_layer(x, output_filters, num_blocks, stride, 105 | is_training): 106 | """Builds many sub layers into one full layer.""" 107 | for block_num in range(num_blocks): 108 | curr_stride = stride if (block_num == 0) else 1 109 | with tf.variable_scope('layer_{}'.format(block_num)): 110 | x = _shake_shake_block(x, output_filters, curr_stride, 111 | is_training) 112 | return x 113 | 114 | 115 | def build_shake_shake_model(images, num_classes, hparams, is_training): 116 | """Builds the Shake-Shake model. 117 | 118 | Build the Shake-Shake model from https://arxiv.org/abs/1705.07485. 119 | 120 | Args: 121 | images: Tensor of images that will be fed into the Wide ResNet Model. 122 | num_classes: Number of classed that the model needs to predict. 123 | hparams: tf.HParams object that contains additional hparams needed to 124 | construct the model. In this case it is the `shake_shake_widen_factor` 125 | that is used to determine how many filters the model has. 126 | is_training: Is the model training or not. 127 | 128 | Returns: 129 | The logits of the Shake-Shake model. 130 | """ 131 | depth = 26 132 | k = hparams.shake_shake_widen_factor # The widen factor 133 | n = int((depth - 2) / 6) 134 | x = images 135 | 136 | x = ops.conv2d(x, 16, 3, scope='init_conv') 137 | x = ops.batch_norm(x, scope='init_bn') 138 | with tf.variable_scope('L1'): 139 | x = _shake_shake_layer(x, 16 * k, n, 1, is_training) 140 | with tf.variable_scope('L2'): 141 | x = _shake_shake_layer(x, 32 * k, n, 2, is_training) 142 | with tf.variable_scope('L3'): 143 | x = _shake_shake_layer(x, 64 * k, n, 2, is_training) 144 | x = tf.nn.relu(x) 145 | x = ops.global_avg_pool(x) 146 | 147 | # Fully connected 148 | logits = ops.fc(x, num_classes) 149 | return logits 150 | -------------------------------------------------------------------------------- /third_party/auto_augment/wrn.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2019 The Google UDA Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Builds the WideResNet Model.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import numpy as np 22 | import tensorflow as tf 23 | 24 | import third_party.auto_augment.custom_ops as ops 25 | 26 | 27 | def residual_block( 28 | x, in_filter, out_filter, stride, update_bn=True, 29 | activate_before_residual=False): 30 | """Adds residual connection to `x` in addition to applying BN->ReLU->3x3 Conv. 31 | 32 | Args: 33 | x: Tensor that is the output of the previous layer in the model. 34 | in_filter: Number of filters `x` has. 35 | out_filter: Number of filters that the output of this layer will have. 36 | stride: Integer that specified what stride should be applied to `x`. 37 | activate_before_residual: Boolean on whether a BN->ReLU should be applied 38 | to x before the convolution is applied. 39 | 40 | Returns: 41 | A Tensor that is the result of applying two sequences of BN->ReLU->3x3 Conv 42 | and then adding that Tensor to `x`. 43 | """ 44 | 45 | if activate_before_residual: # Pass up RELU and BN activation for resnet 46 | with tf.variable_scope('shared_activation'): 47 | x = ops.batch_norm(x, update_stats=update_bn, scope='init_bn') 48 | x = tf.nn.relu(x) 49 | orig_x = x 50 | else: 51 | orig_x = x 52 | 53 | block_x = x 54 | if not activate_before_residual: 55 | with tf.variable_scope('residual_only_activation'): 56 | block_x = ops.batch_norm(block_x, update_stats=update_bn, 57 | scope='init_bn') 58 | block_x = tf.nn.relu(block_x) 59 | 60 | with tf.variable_scope('sub1'): 61 | block_x = ops.conv2d( 62 | block_x, out_filter, 3, stride=stride, scope='conv1') 63 | 64 | with tf.variable_scope('sub2'): 65 | block_x = ops.batch_norm(block_x, update_stats=update_bn, scope='bn2') 66 | block_x = tf.nn.relu(block_x) 67 | block_x = ops.conv2d( 68 | block_x, out_filter, 3, stride=1, scope='conv2') 69 | 70 | with tf.variable_scope( 71 | 'sub_add'): # If number of filters do not agree then zero pad them 72 | if in_filter != out_filter: 73 | orig_x = ops.avg_pool(orig_x, stride, stride) 74 | orig_x = ops.zero_pad(orig_x, in_filter, out_filter) 75 | x = orig_x + block_x 76 | return x 77 | 78 | 79 | def _res_add(in_filter, out_filter, stride, x, orig_x): 80 | """Adds `x` with `orig_x`, both of which are layers in the model. 81 | 82 | Args: 83 | in_filter: Number of filters in `orig_x`. 84 | out_filter: Number of filters in `x`. 85 | stride: Integer specifying the stide that should be applied `orig_x`. 86 | x: Tensor that is the output of the previous layer. 87 | orig_x: Tensor that is the output of an earlier layer in the network. 88 | 89 | Returns: 90 | A Tensor that is the result of `x` and `orig_x` being added after 91 | zero padding and striding are applied to `orig_x` to get the shapes 92 | to match. 93 | """ 94 | if in_filter != out_filter: 95 | orig_x = ops.avg_pool(orig_x, stride, stride) 96 | orig_x = ops.zero_pad(orig_x, in_filter, out_filter) 97 | x = x + orig_x 98 | orig_x = x 99 | return x, orig_x 100 | 101 | 102 | def build_wrn_model(images, num_classes, wrn_size, update_bn=True): 103 | """Builds the WRN model. 104 | 105 | Build the Wide ResNet model from https://arxiv.org/abs/1605.07146. 106 | 107 | Args: 108 | images: Tensor of images that will be fed into the Wide ResNet Model. 109 | num_classes: Number of classed that the model needs to predict. 110 | wrn_size: Parameter that scales the number of filters in the Wide ResNet 111 | model. 112 | 113 | Returns: 114 | The logits of the Wide ResNet model. 115 | """ 116 | # wrn_size = 16 * widening factor k 117 | kernel_size = wrn_size 118 | filter_size = 3 119 | # depth = num_blocks_per_resnet * 6 + 4 = 28 120 | num_blocks_per_resnet = 4 121 | filters = [ 122 | min(kernel_size, 16), kernel_size, kernel_size * 2, kernel_size * 4 123 | ] 124 | strides = [1, 2, 2] # stride for each resblock 125 | 126 | # Run the first conv 127 | with tf.variable_scope('init'): 128 | x = images 129 | output_filters = filters[0] 130 | x = ops.conv2d(x, output_filters, filter_size, scope='init_conv') 131 | 132 | first_x = x # Res from the beginning 133 | orig_x = x # Res from previous block 134 | 135 | for block_num in range(1, 4): 136 | with tf.variable_scope('unit_{}_0'.format(block_num)): 137 | activate_before_residual = True if block_num == 1 else False 138 | x = residual_block( 139 | x, 140 | filters[block_num - 1], 141 | filters[block_num], 142 | strides[block_num - 1], 143 | update_bn=update_bn, 144 | activate_before_residual=activate_before_residual) 145 | for i in range(1, num_blocks_per_resnet): 146 | with tf.variable_scope('unit_{}_{}'.format(block_num, i)): 147 | x = residual_block( 148 | x, 149 | filters[block_num], 150 | filters[block_num], 151 | 1, 152 | update_bn=update_bn, 153 | activate_before_residual=False) 154 | x, orig_x = _res_add(filters[block_num - 1], filters[block_num], 155 | strides[block_num - 1], x, orig_x) 156 | final_stride_val = np.prod(strides) 157 | x, _ = _res_add(filters[0], filters[3], final_stride_val, x, first_x) 158 | with tf.variable_scope('unit_last'): 159 | x = ops.batch_norm(x, scope='final_bn') 160 | x = tf.nn.relu(x) 161 | x = ops.global_avg_pool(x) 162 | logits = ops.fc(x, num_classes) 163 | return logits 164 | -------------------------------------------------------------------------------- /third_party/vat_utils.py: -------------------------------------------------------------------------------- 1 | """Utilities derived from the VAT code.""" 2 | 3 | import tensorflow as tf 4 | 5 | 6 | def generate_perturbation(x, logit, forward, epsilon, xi=1e-6): 7 | """Generate an adversarial perturbation. 8 | 9 | Args: 10 | x: Model inputs. 11 | logit: Original model output without perturbation. 12 | forward: Callable which computs logits given input. 13 | epsilon: Gradient multiplier. 14 | xi: Small constant. 15 | 16 | Returns: 17 | Aversarial perturbation to be applied to x. 18 | """ 19 | d = tf.random_normal(shape=tf.shape(x)) 20 | 21 | for _ in range(1): 22 | d = xi * get_normalized_vector(d) 23 | logit_p = logit 24 | logit_m = forward(x + d) 25 | dist = kl_divergence_with_logit(logit_p, logit_m) 26 | grad = tf.gradients(tf.reduce_mean(dist), [d], aggregation_method=2)[0] 27 | d = tf.stop_gradient(grad) 28 | 29 | return epsilon * get_normalized_vector(d) 30 | 31 | 32 | def kl_divergence_with_logit(q_logit, p_logit): 33 | """Compute the per-element KL-divergence of a batch.""" 34 | q = tf.nn.softmax(q_logit) 35 | qlogq = tf.reduce_sum(q * logsoftmax(q_logit), 1) 36 | qlogp = tf.reduce_sum(q * logsoftmax(p_logit), 1) 37 | return qlogq - qlogp 38 | 39 | 40 | def get_normalized_vector(d): 41 | """Normalize d by infinity and L2 norms.""" 42 | d /= 1e-12 + tf.reduce_max( 43 | tf.abs(d), list(range(1, len(d.get_shape()))), keepdims=True 44 | ) 45 | d /= tf.sqrt( 46 | 1e-6 47 | + tf.reduce_sum( 48 | tf.pow(d, 2.0), list(range(1, len(d.get_shape()))), keepdims=True 49 | ) 50 | ) 51 | return d 52 | 53 | 54 | def logsoftmax(x): 55 | """Compute log-domain softmax of logits.""" 56 | xdev = x - tf.reduce_max(x, 1, keepdims=True) 57 | lsm = xdev - tf.log(tf.reduce_sum(tf.exp(xdev), 1, keepdims=True)) 58 | return lsm 59 | -------------------------------------------------------------------------------- /vat.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Virtual adversarial training:a regularization method for supervised and semi-supervised learning. 15 | 16 | Application to SSL of https://arxiv.org/abs/1704.03976 17 | """ 18 | 19 | import functools 20 | import os 21 | 22 | import tensorflow as tf 23 | from absl import app 24 | from absl import flags 25 | 26 | from libml import utils, data, layers, models 27 | from libml.utils import EasyDict 28 | from third_party import vat_utils 29 | 30 | FLAGS = flags.FLAGS 31 | 32 | 33 | class VAT(models.MultiModel): 34 | 35 | def model(self, batch, lr, wd, ema, warmup_pos, vat, vat_eps, entmin_weight, **kwargs): 36 | hwc = [self.dataset.height, self.dataset.width, self.dataset.colors] 37 | xt_in = tf.placeholder(tf.float32, [batch] + hwc, 'xt') # For training 38 | x_in = tf.placeholder(tf.float32, [None] + hwc, 'x') 39 | y_in = tf.placeholder(tf.float32, [batch] + hwc, 'y') 40 | l_in = tf.placeholder(tf.int32, [batch], 'labels') 41 | wd *= lr 42 | warmup = tf.clip_by_value(tf.to_float(self.step) / (warmup_pos * (FLAGS.train_kimg << 10)), 0, 1) 43 | 44 | classifier = lambda x, **kw: self.classifier(x, **kw, **kwargs).logits 45 | l = tf.one_hot(l_in, self.nclass) 46 | logits_x = classifier(xt_in, training=True) 47 | post_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) # Take only first call to update batch norm. 48 | logits_y = classifier(y_in, training=True) 49 | delta_y = vat_utils.generate_perturbation(y_in, logits_y, lambda x: classifier(x, training=True), vat_eps) 50 | logits_student = classifier(y_in + delta_y, training=True) 51 | logits_teacher = tf.stop_gradient(logits_y) 52 | loss_vat = layers.kl_divergence_from_logits(logits_student, logits_teacher) 53 | loss_vat = tf.reduce_mean(loss_vat) 54 | loss_entmin = tf.reduce_mean(tf.distributions.Categorical(logits=logits_y).entropy()) 55 | 56 | loss = tf.nn.softmax_cross_entropy_with_logits_v2(labels=l, logits=logits_x) 57 | loss = tf.reduce_mean(loss) 58 | tf.summary.scalar('losses/xe', loss) 59 | tf.summary.scalar('losses/vat', loss_vat) 60 | tf.summary.scalar('losses/entmin', loss_entmin) 61 | 62 | ema = tf.train.ExponentialMovingAverage(decay=ema) 63 | ema_op = ema.apply(utils.model_vars()) 64 | ema_getter = functools.partial(utils.getter_ema, ema) 65 | post_ops.append(ema_op) 66 | post_ops.extend([tf.assign(v, v * (1 - wd)) for v in utils.model_vars('classify') if 'kernel' in v.name]) 67 | 68 | train_op = tf.train.AdamOptimizer(lr).minimize(loss + loss_vat * warmup * vat + entmin_weight * loss_entmin, 69 | colocate_gradients_with_ops=True) 70 | with tf.control_dependencies([train_op]): 71 | train_op = tf.group(*post_ops) 72 | 73 | return EasyDict( 74 | xt=xt_in, x=x_in, y=y_in, label=l_in, train_op=train_op, 75 | classify_raw=tf.nn.softmax(classifier(x_in, training=False)), # No EMA, for debugging. 76 | classify_op=tf.nn.softmax(classifier(x_in, getter=ema_getter, training=False))) 77 | 78 | 79 | def main(argv): 80 | utils.setup_main() 81 | del argv # Unused. 82 | dataset = data.DATASETS()[FLAGS.dataset]() 83 | log_width = utils.ilog2(dataset.width) 84 | model = VAT( 85 | os.path.join(FLAGS.train_dir, dataset.name), 86 | dataset, 87 | lr=FLAGS.lr, 88 | wd=FLAGS.wd, 89 | arch=FLAGS.arch, 90 | warmup_pos=FLAGS.warmup_pos, 91 | batch=FLAGS.batch, 92 | nclass=dataset.nclass, 93 | ema=FLAGS.ema, 94 | smoothing=FLAGS.smoothing, 95 | vat=FLAGS.vat, 96 | vat_eps=FLAGS.vat_eps, 97 | entmin_weight=FLAGS.entmin_weight, 98 | 99 | scales=FLAGS.scales or (log_width - 2), 100 | filters=FLAGS.filters, 101 | repeat=FLAGS.repeat) 102 | model.train(FLAGS.train_kimg << 10, FLAGS.report_kimg << 10) 103 | 104 | 105 | if __name__ == '__main__': 106 | utils.setup_tf() 107 | flags.DEFINE_float('wd', 0.02, 'Weight decay.') 108 | flags.DEFINE_float('vat', 0.3, 'VAT weight.') 109 | flags.DEFINE_float('vat_eps', 6, 'VAT perturbation size.') 110 | flags.DEFINE_float('entmin_weight', 0.06, 'Entropy minimization weight.') 111 | flags.DEFINE_float('warmup_pos', 0.4, 'Relative position at which constraint loss warmup ends.') 112 | flags.DEFINE_float('ema', 0.999, 'Exponential moving average of params.') 113 | flags.DEFINE_float('smoothing', 0.1, 'Label smoothing.') 114 | flags.DEFINE_integer('scales', 0, 'Number of 2x2 downscalings in the classifier.') 115 | flags.DEFINE_integer('filters', 32, 'Filter size of convolutions.') 116 | flags.DEFINE_integer('repeat', 4, 'Number of residual layers per stage.') 117 | FLAGS.set_default('dataset', 'cifar10.3@250-5000') 118 | FLAGS.set_default('batch', 64) 119 | FLAGS.set_default('lr', 0.002) 120 | FLAGS.set_default('train_kimg', 1 << 16) 121 | app.run(main) 122 | --------------------------------------------------------------------------------