├── CONTRIBUTING.md
├── LICENSE
├── README.md
├── ablation
├── ab_cta_remixmatch.py
├── ab_cta_remixmatch_noweak.py
└── ab_remixmatch.py
├── cta
├── cta_fsmixup.py
├── cta_mixmatch.py
├── cta_remixmatch.py
└── lib
│ ├── __init__.py
│ └── train.py
├── fully_supervised
├── fs_baseline.py
├── fs_mixup.py
├── lib
│ ├── __init__.py
│ ├── data.py
│ └── train.py
└── runs
│ └── all.sh
├── ict.py
├── libml
├── __init__.py
├── augment.py
├── ctaugment.py
├── data.py
├── layers.py
├── models.py
├── train.py
└── utils.py
├── mean_teacher.py
├── mixmatch.py
├── mixup.py
├── pi_model.py
├── pseudo_label.py
├── remixmatch_no_cta.py
├── requirements.txt
├── runs
└── ssl
│ ├── ablation.sh
│ ├── cifar10.sh
│ ├── stl10.sh
│ └── svhn.sh
├── scripts
├── aggregate_accuracy.py
├── check_split.py
├── create_datasets.py
├── create_split.py
├── create_unlabeled.py
├── extract_accuracy.py
└── inspect_dataset.py
├── third_party
├── LICENSE
├── __init__.py
├── auto_augment
│ ├── __init__.py
│ ├── augmentations.py
│ ├── custom_ops.py
│ ├── policies.py
│ ├── shake_drop.py
│ ├── shake_shake.py
│ └── wrn.py
└── vat_utils.py
└── vat.py
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # How to Contribute
2 |
3 | We'd love to accept your patches and contributions to this project. There are
4 | just a few small guidelines you need to follow.
5 |
6 | ## Contributor License Agreement
7 |
8 | Contributions to this project must be accompanied by a Contributor License
9 | Agreement. You (or your employer) retain the copyright to your contribution;
10 | this simply gives us permission to use and redistribute your contributions as
11 | part of the project. Head over to to see
12 | your current agreements on file or to sign a new one.
13 |
14 | You generally only need to submit a CLA once, so if you've already submitted one
15 | (even if it was for a different project), you probably don't need to do it
16 | again.
17 |
18 | ## Code reviews
19 |
20 | All submissions, including submissions by project members, require review. We
21 | use GitHub pull requests for this purpose. Consult
22 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more
23 | information on using pull requests.
24 |
25 | ## Community Guidelines
26 |
27 | This project follows [Google's Open Source Community
28 | Guidelines](https://opensource.google.com/conduct/).
29 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # ReMixMatch
2 |
3 | Code for the paper: "[ReMixMatch: Semi-Supervised Learning with Distribution Alignment and Augmentation Anchoring](https://arxiv.org/abs/1911.09785)" by David Berthelot, Nicholas Carlini, Ekin D. Cubuk, Alex Kurakin, Kihyuk Sohn, Han Zhang, and Colin Raffel.
4 |
5 |
6 | This is not an officially supported Google product.
7 |
8 | ## Setup
9 |
10 | **Important**: `ML_DATA` is a shell environment variable that should point to the location where the datasets are installed. See the *Install datasets* section for more details.
11 |
12 | ### Install dependencies
13 |
14 | ```bash
15 | sudo apt install python3-dev python3-virtualenv python3-tk imagemagick
16 | virtualenv -p python3 --system-site-packages env3
17 | . env3/bin/activate
18 | pip install -r requirements.txt
19 | ```
20 |
21 | ### Install datasets
22 |
23 | ```bash
24 | export ML_DATA="path to where you want the datasets saved"
25 | # Download datasets
26 | CUDA_VISIBLE_DEVICES= ./scripts/create_datasets.py
27 | cp $ML_DATA/svhn-test.tfrecord $ML_DATA/svhn_noextra-test.tfrecord
28 |
29 | # Create unlabeled datasets
30 | CUDA_VISIBLE_DEVICES= scripts/create_unlabeled.py $ML_DATA/SSL2/svhn $ML_DATA/svhn-train.tfrecord $ML_DATA/svhn-extra.tfrecord &
31 | CUDA_VISIBLE_DEVICES= scripts/create_unlabeled.py $ML_DATA/SSL2/svhn_noextra $ML_DATA/svhn-train.tfrecord &
32 | CUDA_VISIBLE_DEVICES= scripts/create_unlabeled.py $ML_DATA/SSL2/cifar10 $ML_DATA/cifar10-train.tfrecord &
33 | CUDA_VISIBLE_DEVICES= scripts/create_unlabeled.py $ML_DATA/SSL2/cifar100 $ML_DATA/cifar100-train.tfrecord &
34 | CUDA_VISIBLE_DEVICES= scripts/create_unlabeled.py $ML_DATA/SSL2/stl10 $ML_DATA/stl10-train.tfrecord $ML_DATA/stl10-unlabeled.tfrecord &
35 | wait
36 |
37 | # Create semi-supervised subsets
38 | for seed in 0 1 2 3 4 5; do
39 | for size in 40 250 1000 4000; do
40 | CUDA_VISIBLE_DEVICES= scripts/create_split.py --seed=$seed --size=$size $ML_DATA/SSL2/svhn $ML_DATA/svhn-train.tfrecord $ML_DATA/svhn-extra.tfrecord &
41 | CUDA_VISIBLE_DEVICES= scripts/create_split.py --seed=$seed --size=$size $ML_DATA/SSL2/svhn_noextra $ML_DATA/svhn-train.tfrecord &
42 | CUDA_VISIBLE_DEVICES= scripts/create_split.py --seed=$seed --size=$size $ML_DATA/SSL2/cifar10 $ML_DATA/cifar10-train.tfrecord &
43 | done
44 | CUDA_VISIBLE_DEVICES= scripts/create_split.py --seed=$seed --size=10000 $ML_DATA/SSL2/cifar100 $ML_DATA/cifar100-train.tfrecord &
45 | CUDA_VISIBLE_DEVICES= scripts/create_split.py --seed=$seed --size=2500 $ML_DATA/SSL2/cifar100 $ML_DATA/cifar100-train.tfrecord &
46 | CUDA_VISIBLE_DEVICES= scripts/create_split.py --seed=$seed --size=1000 $ML_DATA/SSL2/stl10 $ML_DATA/stl10-train.tfrecord $ML_DATA/stl10-unlabeled.tfrecord &
47 | wait
48 | done
49 | CUDA_VISIBLE_DEVICES= scripts/create_split.py --seed=1 --size=5000 $ML_DATA/SSL2/stl10 $ML_DATA/stl10-train.tfrecord $ML_DATA/stl10-unlabeled.tfrecord
50 | ```
51 |
52 | ## Running
53 |
54 | ### Setup
55 |
56 | All commands must be ran from the project root. The following environment variables must be defined:
57 | ```bash
58 | export ML_DATA="path to where you want the datasets saved"
59 | export PYTHONPATH=$PYTHONPATH:.
60 | ```
61 |
62 | ### Example
63 |
64 | For example, training a remixmatch with 32 filters and 4 augmentations on cifar10 shuffled with `seed=3`, 250 labeled samples and 5000
65 | validation samples:
66 | ```bash
67 | CUDA_VISIBLE_DEVICES=0 python cta/cta_remixmatch.py --filters=32 --K=4 --dataset=cifar10.3@250-5000 --w_match=1.5 --beta=0.75 --train_dir ./experiments/remixmatch
68 | ```
69 |
70 | Available labelled sizes are 40, 100, 250, 1000, 4000.
71 | For validation, available sizes are 1, 5000.
72 | Possible shuffling seeds are 1, 2, 3, 4, 5 and 0 for no shuffling (0 is not used in practiced since data requires to be
73 | shuffled for gradient descent to work properly).
74 |
75 |
76 | #### Multi-GPU training
77 | Just pass more GPUs and remixmatch automatically scales to them, here we assign GPUs 4-7 to the program:
78 | ```bash
79 | CUDA_VISIBLE_DEVICES=4,5,6,7 python cta/cta_remixmatch.py --filters=32 --K=4 --dataset=cifar10.3@250-5000 --w_match=1.5 --beta=0.75 --train_dir ./experiments/remixmatch
80 | ```
81 |
82 | ### Valid dataset names
83 | ```bash
84 | for dataset in cifar10 svhn svhn_noextra; do
85 | for seed in 0 1 2 3 4 5; do
86 | for valid in 1 5000; do
87 | for size in 40 250 1000 4000; do
88 | echo "${dataset}.${seed}@${size}-${valid}"
89 | done; done; done; done
90 |
91 | for seed in 0 1 2 3 4 5; do
92 | for valid in 1 5000; do
93 | echo "cifar100.${seed}@10000-${valid}"
94 | done; done
95 |
96 | for seed in 1 2 3 4 5; do
97 | for valid in 1 5000; do
98 | echo "stl10.${seed}@1000-${valid}"
99 | done; done
100 | echo "stl10.1@5000-1"
101 | ```
102 |
103 |
104 | ## Monitoring training progress
105 |
106 | You can point tensorboard to the training folder (by default it is `--train_dir=./experiments`) to monitor the training
107 | process:
108 |
109 | ```bash
110 | tensorboard.sh --port 6007 --logdir experiments
111 | ```
112 |
113 | ## Checkpoint accuracy
114 |
115 | We compute the median accuracy of the last 20 checkpoints in the paper, this is done through this code:
116 |
117 | ```bash
118 | # Following the previous example in which we trained cifar10.3@250-5000, extracting accuracy:
119 | ./scripts/extract_accuracy.py experiments/cifar10.d.d.d.3\@250-5000/CTAugment_depth2_th0.80_decay0.990/CTAReMixMatch_K4_archresnet_batch64_beta0.75_filters32_lr0.002_nclass10_redux1st_repeat4_scales3_use_dmTrue_use_xeTrue_w_kl0.5_w_match1.5_w_rot0.5_warmup_kimg1024_wd0.02/
120 | # The command above will create a stats/accuracy.json file in the model folder.
121 | # The format is JSON so you can either see its content as a text file or process it to your liking.
122 | ```
123 |
124 | ## Reproducing tables from the paper
125 |
126 | Check the contents of the `runs/*.sh` files, these will give you the commands (and the hyper-parameters) to reproduce the results from the paper.
127 |
128 | ## Citing this work
129 |
130 | ```
131 | @article{berthelot2019remixmatch,
132 | title={ReMixMatch: Semi-Supervised Learning with Distribution Alignment and Augmentation Anchoring},
133 | author={David Berthelot and Nicholas Carlini and Ekin D. Cubuk and Alex Kurakin and Kihyuk Sohn and Han Zhang and Colin Raffel},
134 | journal={arXiv preprint arXiv:1911.09785},
135 | year={2019},
136 | }
137 | ```
138 |
--------------------------------------------------------------------------------
/ablation/ab_cta_remixmatch.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """ReMixMatch training, changes from MixMatch are:
15 | - Add distribution matching.
16 | """
17 |
18 | import os
19 |
20 | from absl import app
21 | from absl import flags
22 |
23 | from cta.cta_remixmatch import CTAReMixMatch
24 | from libml import utils, data
25 |
26 | FLAGS = flags.FLAGS
27 |
28 |
29 | class ABCTAReMixMatch(CTAReMixMatch):
30 | pass
31 |
32 |
33 | def main(argv):
34 | utils.setup_main()
35 | del argv # Unused.
36 | dataset = data.MANY_DATASETS()[FLAGS.dataset]()
37 | log_width = utils.ilog2(dataset.width)
38 | model = ABCTAReMixMatch(
39 | os.path.join(FLAGS.train_dir, dataset.name, ABCTAReMixMatch.cta_name()),
40 | dataset,
41 | lr=FLAGS.lr,
42 | wd=FLAGS.wd,
43 | arch=FLAGS.arch,
44 | batch=FLAGS.batch,
45 | nclass=dataset.nclass,
46 |
47 | K=FLAGS.K,
48 | beta=FLAGS.beta,
49 | w_kl=FLAGS.w_kl,
50 | w_match=FLAGS.w_match,
51 | w_rot=FLAGS.w_rot,
52 | redux=FLAGS.redux,
53 | use_dm=FLAGS.use_dm,
54 | use_xe=FLAGS.use_xe,
55 | warmup_kimg=FLAGS.warmup_kimg,
56 |
57 | scales=FLAGS.scales or (log_width - 2),
58 | filters=FLAGS.filters,
59 | repeat=FLAGS.repeat)
60 | model.train(FLAGS.train_kimg << 10, FLAGS.report_kimg << 10)
61 |
62 |
63 | if __name__ == '__main__':
64 | utils.setup_tf()
65 | flags.DEFINE_float('wd', 0.02, 'Weight decay.')
66 | flags.DEFINE_float('beta', 0.75, 'Mixup beta distribution.')
67 | flags.DEFINE_float('w_kl', 0.5, 'Weight for KL loss.')
68 | flags.DEFINE_float('w_match', 1.5, 'Weight for distribution matching loss.')
69 | flags.DEFINE_float('w_rot', 0.5, 'Weight for rotation loss.')
70 | flags.DEFINE_integer('scales', 0, 'Number of 2x2 downscalings in the classifier.')
71 | flags.DEFINE_integer('filters', 32, 'Filter size of convolutions.')
72 | flags.DEFINE_integer('repeat', 4, 'Number of residual layers per stage.')
73 | flags.DEFINE_integer('warmup_kimg', 1024, 'Unannealing duration for SSL loss.')
74 | flags.DEFINE_enum('redux', '1st', 'swap mean 1st'.split(), 'Logit selection.')
75 | flags.DEFINE_bool('use_dm', True, 'Whether to use distribution matching.')
76 | flags.DEFINE_bool('use_xe', True, 'Whether to use cross-entropy or Brier.')
77 | FLAGS.set_default('augment', 'd.d.d')
78 | FLAGS.set_default('dataset', 'cifar10.3@250-5000')
79 | FLAGS.set_default('batch', 64)
80 | FLAGS.set_default('lr', 0.002)
81 | FLAGS.set_default('train_kimg', 1 << 16)
82 | app.run(main)
83 |
--------------------------------------------------------------------------------
/ablation/ab_cta_remixmatch_noweak.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """ReMixMatch training, changes from MixMatch are:
15 | - Add distribution matching.
16 | """
17 |
18 | import os
19 |
20 | import numpy as np
21 | from absl import app
22 | from absl import flags
23 |
24 | from ablation.ab_cta_remixmatch import ABCTAReMixMatch
25 | from libml import utils, data, ctaugment
26 | from libml.augment import AugmentPoolCTA
27 |
28 | FLAGS = flags.FLAGS
29 |
30 |
31 | class AugmentPoolCTANoWeak(AugmentPoolCTA):
32 | @staticmethod
33 | def numpy_apply_policies(arglist):
34 | x, cta, probe = arglist
35 | if x.ndim == 3:
36 | assert probe
37 | policy = cta.policy(probe=True)
38 | return dict(policy=policy,
39 | probe=ctaugment.apply(x, policy),
40 | image=ctaugment.apply(x, cta.policy(probe=False)))
41 | assert not probe
42 | return dict(image=np.stack([ctaugment.apply(y, cta.policy(probe=False)) for y in x]).astype('f'))
43 |
44 |
45 | class ABCTAReMixMatchNoWeak(ABCTAReMixMatch):
46 | AUGMENT_POOL_CLASS = AugmentPoolCTANoWeak
47 |
48 |
49 | def main(argv):
50 | utils.setup_main()
51 | del argv # Unused.
52 | dataset = data.MANY_DATASETS()[FLAGS.dataset]()
53 | log_width = utils.ilog2(dataset.width)
54 | model = ABCTAReMixMatchNoWeak(
55 | os.path.join(FLAGS.train_dir, dataset.name, ABCTAReMixMatchNoWeak.cta_name()),
56 | dataset,
57 | lr=FLAGS.lr,
58 | wd=FLAGS.wd,
59 | arch=FLAGS.arch,
60 | batch=FLAGS.batch,
61 | nclass=dataset.nclass,
62 |
63 | K=FLAGS.K,
64 | beta=FLAGS.beta,
65 | w_kl=FLAGS.w_kl,
66 | w_match=FLAGS.w_match,
67 | w_rot=FLAGS.w_rot,
68 | redux=FLAGS.redux,
69 | use_dm=FLAGS.use_dm,
70 | use_xe=FLAGS.use_xe,
71 | warmup_kimg=FLAGS.warmup_kimg,
72 |
73 | scales=FLAGS.scales or (log_width - 2),
74 | filters=FLAGS.filters,
75 | repeat=FLAGS.repeat)
76 | model.train(FLAGS.train_kimg << 10, FLAGS.report_kimg << 10)
77 |
78 |
79 | if __name__ == '__main__':
80 | utils.setup_tf()
81 | flags.DEFINE_float('wd', 0.02, 'Weight decay.')
82 | flags.DEFINE_float('beta', 0.75, 'Mixup beta distribution.')
83 | flags.DEFINE_float('w_kl', 0.5, 'Weight for KL loss.')
84 | flags.DEFINE_float('w_match', 1.5, 'Weight for distribution matching loss.')
85 | flags.DEFINE_float('w_rot', 0.5, 'Weight for rotation loss.')
86 | flags.DEFINE_integer('scales', 0, 'Number of 2x2 downscalings in the classifier.')
87 | flags.DEFINE_integer('filters', 32, 'Filter size of convolutions.')
88 | flags.DEFINE_integer('repeat', 4, 'Number of residual layers per stage.')
89 | flags.DEFINE_integer('warmup_kimg', 1024, 'Unannealing duration for SSL loss.')
90 | flags.DEFINE_enum('redux', '1st', 'swap mean 1st'.split(), 'Logit selection.')
91 | flags.DEFINE_bool('use_dm', True, 'Whether to use distribution matching.')
92 | flags.DEFINE_bool('use_xe', True, 'Whether to use cross-entropy or Brier.')
93 | FLAGS.set_default('augment', 'd.d.d')
94 | FLAGS.set_default('dataset', 'cifar10.3@250-5000')
95 | FLAGS.set_default('batch', 64)
96 | FLAGS.set_default('lr', 0.002)
97 | FLAGS.set_default('train_kimg', 1 << 16)
98 | app.run(main)
99 |
--------------------------------------------------------------------------------
/ablation/ab_remixmatch.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """ReMixMatch training, changes from MixMatch are:
15 | - Add distribution matching.
16 | """
17 |
18 | import os
19 |
20 | from absl import app
21 | from absl import flags
22 |
23 | from libml import data, utils
24 | from remixmatch_no_cta import ReMixMatch
25 |
26 | FLAGS = flags.FLAGS
27 |
28 |
29 | class ABReMixMatch(ReMixMatch):
30 | pass
31 |
32 |
33 | def main(argv):
34 | utils.setup_main()
35 | del argv # Unused.
36 | dataset = data.MANY_DATASETS()[FLAGS.dataset]()
37 | log_width = utils.ilog2(dataset.width)
38 | model = ABReMixMatch(
39 | os.path.join(FLAGS.train_dir, dataset.name),
40 | dataset,
41 | lr=FLAGS.lr,
42 | wd=FLAGS.wd,
43 | arch=FLAGS.arch,
44 | batch=FLAGS.batch,
45 | nclass=dataset.nclass,
46 |
47 | K=FLAGS.K,
48 | beta=FLAGS.beta,
49 | w_kl=FLAGS.w_kl,
50 | w_match=FLAGS.w_match,
51 | w_rot=FLAGS.w_rot,
52 | redux=FLAGS.redux,
53 | use_dm=FLAGS.use_dm,
54 | use_xe=FLAGS.use_xe,
55 | warmup_kimg=FLAGS.warmup_kimg,
56 |
57 | scales=FLAGS.scales or (log_width - 2),
58 | filters=FLAGS.filters,
59 | repeat=FLAGS.repeat)
60 | model.train(FLAGS.train_kimg << 10, FLAGS.report_kimg << 10)
61 |
62 |
63 | if __name__ == '__main__':
64 | utils.setup_tf()
65 | flags.DEFINE_float('wd', 0.02, 'Weight decay.')
66 | flags.DEFINE_float('beta', 0.75, 'Mixup beta distribution.')
67 | flags.DEFINE_float('w_kl', 0.5, 'Weight for KL loss.')
68 | flags.DEFINE_float('w_match', 1.5, 'Weight for distribution matching loss.')
69 | flags.DEFINE_float('w_rot', 0.5, 'Weight for rotation loss.')
70 | flags.DEFINE_integer('scales', 0, 'Number of 2x2 downscalings in the classifier.')
71 | flags.DEFINE_integer('filters', 32, 'Filter size of convolutions.')
72 | flags.DEFINE_integer('repeat', 4, 'Number of residual layers per stage.')
73 | flags.DEFINE_integer('warmup_kimg', 1024, 'Unannealing duration for SSL loss.')
74 | flags.DEFINE_enum('redux', 'swap', 'swap mean 1st'.split(), 'Logit selection.')
75 | flags.DEFINE_bool('use_dm', True, 'Whether to use distribution matching.')
76 | flags.DEFINE_bool('use_xe', True, 'Whether to use cross-entropy or Brier.')
77 | FLAGS.set_default('augment', 'd.d.d')
78 | FLAGS.set_default('dataset', 'cifar10.3@250-5000')
79 | FLAGS.set_default('batch', 64)
80 | FLAGS.set_default('lr', 0.002)
81 | FLAGS.set_default('train_kimg', 1 << 16)
82 | app.run(main)
83 |
--------------------------------------------------------------------------------
/cta/cta_fsmixup.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """ReMixMatch training, changes from MixMatch are:
15 | - Add distribution matching.
16 | """
17 | import os
18 |
19 | from absl import app
20 | from absl import flags
21 |
22 | from cta.lib.train import CTAClassifyFullySupervised
23 | from fully_supervised.fs_mixup import FSMixup
24 | from fully_supervised.lib import data
25 | from libml import utils
26 |
27 | FLAGS = flags.FLAGS
28 |
29 |
30 | class CTAFSMixup(FSMixup, CTAClassifyFullySupervised):
31 | pass
32 |
33 |
34 | def main(argv):
35 | utils.setup_main()
36 | del argv # Unused.
37 | dataset = data.DATASETS()[FLAGS.dataset]()
38 | log_width = utils.ilog2(dataset.width)
39 | model = CTAFSMixup(
40 | os.path.join(FLAGS.train_dir, dataset.name, CTAFSMixup.cta_name()),
41 | dataset,
42 | lr=FLAGS.lr,
43 | wd=FLAGS.wd,
44 | arch=FLAGS.arch,
45 | batch=FLAGS.batch,
46 | nclass=dataset.nclass,
47 | ema=FLAGS.ema,
48 | beta=FLAGS.beta,
49 | dropout=FLAGS.dropout,
50 |
51 | scales=FLAGS.scales or (log_width - 2),
52 | filters=FLAGS.filters,
53 | repeat=FLAGS.repeat)
54 | model.train(FLAGS.train_kimg << 10, FLAGS.report_kimg << 10)
55 |
56 |
57 | if __name__ == '__main__':
58 | utils.setup_tf()
59 | flags.DEFINE_float('wd', 0.002, 'Weight decay.')
60 | flags.DEFINE_float('ema', 0.999, 'Exponential moving average of params.')
61 | flags.DEFINE_float('beta', 0.5, 'Mixup beta distribution.')
62 | flags.DEFINE_float('dropout', 0, 'Dropout on embedding layer.')
63 | flags.DEFINE_integer('scales', 0, 'Number of 2x2 downscalings in the classifier.')
64 | flags.DEFINE_integer('filters', 32, 'Filter size of convolutions.')
65 | flags.DEFINE_integer('repeat', 4, 'Number of residual layers per stage.')
66 | FLAGS.set_default('augment', 'd.d')
67 | FLAGS.set_default('dataset', 'cifar10-1')
68 | FLAGS.set_default('batch', 64)
69 | FLAGS.set_default('lr', 0.002)
70 | FLAGS.set_default('train_kimg', 1 << 16)
71 | app.run(main)
72 |
--------------------------------------------------------------------------------
/cta/cta_mixmatch.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """ReMixMatch training, changes from MixMatch are:
15 | - Add distribution matching.
16 | """
17 |
18 | import os
19 |
20 | from absl import app
21 | from absl import flags
22 |
23 | from cta.lib.train import CTAClassifySemi
24 | from libml import data, utils
25 | from mixmatch import MixMatch
26 |
27 | FLAGS = flags.FLAGS
28 |
29 |
30 | class CTAMixMatch(MixMatch, CTAClassifySemi):
31 | pass
32 |
33 |
34 | def main(argv):
35 | utils.setup_main()
36 | del argv # Unused.
37 | dataset = data.PAIR_DATASETS()[FLAGS.dataset]()
38 | log_width = utils.ilog2(dataset.width)
39 | model = CTAMixMatch(
40 | os.path.join(FLAGS.train_dir, dataset.name, CTAMixMatch.cta_name()),
41 | dataset,
42 | lr=FLAGS.lr,
43 | wd=FLAGS.wd,
44 | arch=FLAGS.arch,
45 | batch=FLAGS.batch,
46 | nclass=dataset.nclass,
47 | ema=FLAGS.ema,
48 | beta=FLAGS.beta,
49 | w_match=FLAGS.w_match,
50 | scales=FLAGS.scales or (log_width - 2),
51 | filters=FLAGS.filters,
52 | repeat=FLAGS.repeat)
53 | model.train(FLAGS.train_kimg << 10, FLAGS.report_kimg << 10)
54 |
55 |
56 | if __name__ == '__main__':
57 | utils.setup_tf()
58 | flags.DEFINE_float('wd', 0.02, 'Weight decay.')
59 | flags.DEFINE_float('ema', 0.999, 'Exponential moving average of params.')
60 | flags.DEFINE_float('beta', 0.75, 'Mixup beta distribution.')
61 | flags.DEFINE_float('w_match', 100, 'Weight for distribution matching loss.')
62 | flags.DEFINE_integer('scales', 0, 'Number of 2x2 downscalings in the classifier.')
63 | flags.DEFINE_integer('filters', 32, 'Filter size of convolutions.')
64 | flags.DEFINE_integer('repeat', 4, 'Number of residual layers per stage.')
65 | FLAGS.set_default('augment', 'd.d.d')
66 | FLAGS.set_default('dataset', 'cifar10.3@250-5000')
67 | FLAGS.set_default('batch', 64)
68 | FLAGS.set_default('lr', 0.002)
69 | FLAGS.set_default('train_kimg', 1 << 16)
70 | app.run(main)
71 |
--------------------------------------------------------------------------------
/cta/cta_remixmatch.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import os
16 |
17 | from absl import app
18 | from absl import flags
19 |
20 | from cta.lib.train import CTAClassifySemi
21 | from libml import utils, data
22 | from remixmatch_no_cta import ReMixMatch
23 |
24 | FLAGS = flags.FLAGS
25 |
26 |
27 | class CTAReMixMatch(ReMixMatch, CTAClassifySemi):
28 | pass
29 |
30 |
31 | def main(argv):
32 | utils.setup_main()
33 | del argv # Unused.
34 | dataset = data.MANY_DATASETS()[FLAGS.dataset]()
35 | log_width = utils.ilog2(dataset.width)
36 | model = CTAReMixMatch(
37 | os.path.join(FLAGS.train_dir, dataset.name, CTAReMixMatch.cta_name()),
38 | dataset,
39 | lr=FLAGS.lr,
40 | wd=FLAGS.wd,
41 | arch=FLAGS.arch,
42 | batch=FLAGS.batch,
43 | nclass=dataset.nclass,
44 |
45 | K=FLAGS.K,
46 | beta=FLAGS.beta,
47 | w_kl=FLAGS.w_kl,
48 | w_match=FLAGS.w_match,
49 | w_rot=FLAGS.w_rot,
50 | redux=FLAGS.redux,
51 | use_dm=FLAGS.use_dm,
52 | use_xe=FLAGS.use_xe,
53 | warmup_kimg=FLAGS.warmup_kimg,
54 |
55 | scales=FLAGS.scales or (log_width - 2),
56 | filters=FLAGS.filters,
57 | repeat=FLAGS.repeat)
58 | model.train(FLAGS.train_kimg << 10, FLAGS.report_kimg << 10)
59 |
60 |
61 | if __name__ == '__main__':
62 | utils.setup_tf()
63 | flags.DEFINE_float('wd', 0.02, 'Weight decay.')
64 | flags.DEFINE_float('beta', 0.75, 'Mixup beta distribution.')
65 | flags.DEFINE_float('w_kl', 0.5, 'Weight for KL loss.')
66 | flags.DEFINE_float('w_match', 1.5, 'Weight for distribution matching loss.')
67 | flags.DEFINE_float('w_rot', 0.5, 'Weight for rotation loss.')
68 | flags.DEFINE_integer('scales', 0, 'Number of 2x2 downscalings in the classifier.')
69 | flags.DEFINE_integer('filters', 32, 'Filter size of convolutions.')
70 | flags.DEFINE_integer('repeat', 4, 'Number of residual layers per stage.')
71 | flags.DEFINE_integer('warmup_kimg', 1024, 'Unannealing duration for SSL loss.')
72 | flags.DEFINE_enum('redux', '1st', 'swap mean 1st'.split(), 'Logit selection.')
73 | flags.DEFINE_bool('use_dm', True, 'Whether to use distribution matching.')
74 | flags.DEFINE_bool('use_xe', True, 'Whether to use cross-entropy or Brier.')
75 | FLAGS.set_default('augment', 'd.d.d')
76 | FLAGS.set_default('dataset', 'cifar10.3@250-5000')
77 | FLAGS.set_default('batch', 64)
78 | FLAGS.set_default('lr', 0.002)
79 | FLAGS.set_default('train_kimg', 1 << 16)
80 | app.run(main)
81 |
--------------------------------------------------------------------------------
/cta/lib/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google-research/remixmatch/f7061ebf055227cbeb5c6fced1ce054e0ceecfcd/cta/lib/__init__.py
--------------------------------------------------------------------------------
/cta/lib/train.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import numpy as np
16 | from absl import flags
17 |
18 | from fully_supervised.lib.train import ClassifyFullySupervised
19 | from libml import data
20 | from libml.augment import AugmentPoolCTA
21 | from libml.ctaugment import CTAugment
22 | from libml.train import ClassifySemi
23 |
24 | FLAGS = flags.FLAGS
25 |
26 | flags.DEFINE_integer('adepth', 2, 'Augmentation depth.')
27 | flags.DEFINE_float('adecay', 0.99, 'Augmentation decay.')
28 | flags.DEFINE_float('ath', 0.80, 'Augmentation threshold.')
29 |
30 |
31 | class CTAClassifySemi(ClassifySemi):
32 | """Semi-supervised classification."""
33 | AUGMENTER_CLASS = CTAugment
34 | AUGMENT_POOL_CLASS = AugmentPoolCTA
35 |
36 | @classmethod
37 | def cta_name(cls):
38 | return '%s_depth%d_th%.2f_decay%.3f' % (cls.AUGMENTER_CLASS.__name__,
39 | FLAGS.adepth, FLAGS.ath, FLAGS.adecay)
40 |
41 | def __init__(self, train_dir: str, dataset: data.DataSets, nclass: int, **kwargs):
42 | ClassifySemi.__init__(self, train_dir, dataset, nclass, **kwargs)
43 | self.augmenter = self.AUGMENTER_CLASS(FLAGS.adepth, FLAGS.ath, FLAGS.adecay)
44 |
45 | def gen_labeled_fn(self, data_iterator):
46 | def wrap():
47 | batch = self.session.run(data_iterator)
48 | batch['cta'] = self.augmenter
49 | batch['probe'] = True
50 | return batch
51 |
52 | return self.AUGMENT_POOL_CLASS(wrap)
53 |
54 | def gen_unlabeled_fn(self, data_iterator):
55 | def wrap():
56 | batch = self.session.run(data_iterator)
57 | batch['cta'] = self.augmenter
58 | batch['probe'] = False
59 | return batch
60 |
61 | return self.AUGMENT_POOL_CLASS(wrap)
62 |
63 | def train_step(self, train_session, gen_labeled, gen_unlabeled):
64 | x, y = gen_labeled(), gen_unlabeled()
65 | v = train_session.run([self.ops.classify_op, self.ops.train_op, self.ops.update_step],
66 | feed_dict={self.ops.y: y['image'],
67 | self.ops.x: x['probe'],
68 | self.ops.xt: x['image'],
69 | self.ops.label: x['label']})
70 | self.tmp.step = v[-1]
71 | lx = v[0]
72 | for p in range(lx.shape[0]):
73 | error = lx[p]
74 | error[x['label'][p]] -= 1
75 | error = np.abs(error).sum()
76 | self.augmenter.update_rates(x['policy'][p], 1 - 0.5 * error)
77 |
78 | def eval_stats(self, batch=None, feed_extra=None, classify_op=None):
79 | """Evaluate model on train, valid and test."""
80 | batch = batch or FLAGS.batch
81 | classify_op = self.ops.classify_op if classify_op is None else classify_op
82 | accuracies = []
83 | for subset in ('train_labeled', 'valid', 'test'):
84 | images, labels = self.tmp.cache[subset]
85 | predicted = []
86 |
87 | for x in range(0, images.shape[0], batch):
88 | p = self.session.run(
89 | classify_op,
90 | feed_dict={
91 | self.ops.x: images[x:x + batch],
92 | **(feed_extra or {})
93 | })
94 | predicted.append(p)
95 | predicted = np.concatenate(predicted, axis=0)
96 | accuracies.append((predicted.argmax(1) == labels).mean() * 100)
97 | self.train_print('kimg %-5d accuracy train/valid/test %.2f %.2f %.2f' %
98 | tuple([self.tmp.step >> 10] + accuracies))
99 | self.train_print(self.augmenter.stats())
100 | return np.array(accuracies, 'f')
101 |
102 |
103 | class CTAClassifyFullySupervised(ClassifyFullySupervised, CTAClassifySemi):
104 | """Fully-supervised classification."""
105 |
106 | def train_step(self, train_session, gen_labeled):
107 | x = gen_labeled()
108 | v = train_session.run([self.ops.classify_op, self.ops.train_op, self.ops.update_step],
109 | feed_dict={self.ops.x: x['probe'],
110 | self.ops.xt: x['image'],
111 | self.ops.label: x['label']})
112 | self.tmp.step = v[-1]
113 | lx = v[0]
114 | for p in range(lx.shape[0]):
115 | error = lx[p]
116 | error[x['label'][p]] -= 1
117 | error = np.abs(error).sum()
118 | self.augmenter.update_rates(x['policy'][p], 1 - 0.5 * error)
119 |
--------------------------------------------------------------------------------
/fully_supervised/fs_baseline.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Fully supervised training.
15 | """
16 |
17 | import functools
18 | import os
19 |
20 | import tensorflow as tf
21 | from absl import app
22 | from absl import flags
23 |
24 | from fully_supervised.lib.data import DATASETS
25 | from fully_supervised.lib.train import ClassifyFullySupervised
26 | from libml import models, utils
27 | from libml.utils import EasyDict
28 |
29 | FLAGS = flags.FLAGS
30 |
31 |
32 | class FSBaseline(ClassifyFullySupervised, models.MultiModel):
33 |
34 | def model(self, batch, lr, wd, ema, **kwargs):
35 | hwc = [self.dataset.height, self.dataset.width, self.dataset.colors]
36 | xt_in = tf.placeholder(tf.float32, [batch] + hwc, 'xt') # For training
37 | x_in = tf.placeholder(tf.float32, [None] + hwc, 'x')
38 | l_in = tf.placeholder(tf.int32, [batch], 'labels')
39 | wd *= lr
40 | classifier = lambda x, **kw: self.classifier(x, **kw, **kwargs).logits
41 |
42 | x, labels_x = self.augment(xt_in, tf.one_hot(l_in, self.nclass), **kwargs)
43 | logits_x = classifier(x, training=True)
44 |
45 | loss_xe = tf.nn.softmax_cross_entropy_with_logits_v2(labels=labels_x, logits=logits_x)
46 | loss_xe = tf.reduce_mean(loss_xe)
47 | tf.summary.scalar('losses/xe', loss_xe)
48 |
49 | ema = tf.train.ExponentialMovingAverage(decay=ema)
50 | ema_op = ema.apply(utils.model_vars())
51 | ema_getter = functools.partial(utils.getter_ema, ema)
52 | post_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) + [ema_op]
53 | post_ops.extend([tf.assign(v, v * (1 - wd)) for v in utils.model_vars('classify') if 'kernel' in v.name])
54 |
55 | train_op = tf.train.AdamOptimizer(lr).minimize(loss_xe, colocate_gradients_with_ops=True)
56 | with tf.control_dependencies([train_op]):
57 | train_op = tf.group(*post_ops)
58 |
59 | return EasyDict(
60 | xt=xt_in, x=x_in, label=l_in, train_op=train_op,
61 | classify_raw=tf.nn.softmax(classifier(x_in, training=False)), # No EMA, for debugging.
62 | classify_op=tf.nn.softmax(classifier(x_in, getter=ema_getter, training=False)))
63 |
64 |
65 | def main(argv):
66 | utils.setup_main()
67 | del argv # Unused.
68 | dataset = DATASETS()[FLAGS.dataset]()
69 | log_width = utils.ilog2(dataset.width)
70 | model = FSBaseline(
71 | os.path.join(FLAGS.train_dir, dataset.name),
72 | dataset,
73 | lr=FLAGS.lr,
74 | wd=FLAGS.wd,
75 | arch=FLAGS.arch,
76 | batch=FLAGS.batch,
77 | nclass=dataset.nclass,
78 | ema=FLAGS.ema,
79 | dropout=FLAGS.dropout,
80 | smoothing=FLAGS.smoothing,
81 |
82 | scales=FLAGS.scales or (log_width - 2),
83 | filters=FLAGS.filters,
84 | repeat=FLAGS.repeat)
85 | model.train(FLAGS.train_kimg << 10, FLAGS.report_kimg << 10)
86 |
87 |
88 | if __name__ == '__main__':
89 | utils.setup_tf()
90 | flags.DEFINE_float('wd', 0.002, 'Weight decay.')
91 | flags.DEFINE_float('ema', 0.999, 'Exponential moving average of params.')
92 | flags.DEFINE_float('dropout', 0, 'Dropout on embedding layer.')
93 | flags.DEFINE_float('smoothing', 0.001, 'Label smoothing.')
94 | flags.DEFINE_integer('scales', 0, 'Number of 2x2 downscalings in the classifier.')
95 | flags.DEFINE_integer('filters', 32, 'Filter size of convolutions.')
96 | flags.DEFINE_integer('repeat', 4, 'Number of residual layers per stage.')
97 | FLAGS.set_default('dataset', 'cifar10')
98 | FLAGS.set_default('batch', 64)
99 | FLAGS.set_default('lr', 0.002)
100 | FLAGS.set_default('train_kimg', 1 << 16)
101 | app.run(main)
102 |
--------------------------------------------------------------------------------
/fully_supervised/fs_mixup.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Mixup fully supervised training.
15 | """
16 |
17 | import os
18 |
19 | import tensorflow as tf
20 | from absl import app
21 | from absl import flags
22 |
23 | from fully_supervised.fs_baseline import FSBaseline
24 | from fully_supervised.lib.data import DATASETS
25 | from libml import utils
26 |
27 | FLAGS = flags.FLAGS
28 |
29 |
30 | class FSMixup(FSBaseline):
31 |
32 | def augment(self, x, l, beta, **kwargs):
33 | del kwargs
34 | with tf.device('/cpu'):
35 | mix = tf.distributions.Beta(beta, beta).sample([tf.shape(x)[0], 1, 1, 1])
36 | mix = tf.maximum(mix, 1 - mix)
37 | xmix = x * mix + x[::-1] * (1 - mix)
38 | lmix = l * mix[:, :, 0, 0] + l[::-1] * (1 - mix[:, :, 0, 0])
39 | return xmix, lmix
40 |
41 |
42 | def main(argv):
43 | utils.setup_main()
44 | del argv # Unused.
45 | dataset = DATASETS()[FLAGS.dataset]()
46 | log_width = utils.ilog2(dataset.width)
47 | model = FSMixup(
48 | os.path.join(FLAGS.train_dir, dataset.name),
49 | dataset,
50 | lr=FLAGS.lr,
51 | wd=FLAGS.wd,
52 | arch=FLAGS.arch,
53 | batch=FLAGS.batch,
54 | nclass=dataset.nclass,
55 | ema=FLAGS.ema,
56 | beta=FLAGS.beta,
57 | dropout=FLAGS.dropout,
58 |
59 | scales=FLAGS.scales or (log_width - 2),
60 | filters=FLAGS.filters,
61 | repeat=FLAGS.repeat)
62 | model.train(FLAGS.train_kimg << 10, FLAGS.report_kimg << 10)
63 |
64 |
65 | if __name__ == '__main__':
66 | utils.setup_tf()
67 | flags.DEFINE_float('wd', 0.002, 'Weight decay.')
68 | flags.DEFINE_float('ema', 0.999, 'Exponential moving average of params.')
69 | flags.DEFINE_float('beta', 0.5, 'Mixup beta distribution.')
70 | flags.DEFINE_float('dropout', 0, 'Dropout on embedding layer.')
71 | flags.DEFINE_integer('scales', 0, 'Number of 2x2 downscalings in the classifier.')
72 | flags.DEFINE_integer('filters', 32, 'Filter size of convolutions.')
73 | flags.DEFINE_integer('repeat', 4, 'Number of residual layers per stage.')
74 | FLAGS.set_default('dataset', 'cifar10-1')
75 | FLAGS.set_default('batch', 64)
76 | FLAGS.set_default('lr', 0.002)
77 | FLAGS.set_default('train_kimg', 1 << 16)
78 | app.run(main)
79 |
--------------------------------------------------------------------------------
/fully_supervised/lib/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
--------------------------------------------------------------------------------
/fully_supervised/lib/data.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import os
16 |
17 | import tensorflow as tf
18 | from absl import flags
19 |
20 | from libml import augment as augment_module
21 | from libml import data
22 |
23 | FLAGS = flags.FLAGS
24 |
25 |
26 | class DataSetsFS(data.DataSets):
27 | @classmethod
28 | def creator(cls, name, train_files, test_files, valid, augment, parse_fn=data.record_parse, do_memoize=True,
29 | nclass=10, height=32, width=32, colors=3):
30 | train_files = [os.path.join(data.DATA_DIR, x) for x in train_files]
31 | test_files = [os.path.join(data.DATA_DIR, x) for x in test_files]
32 | if not isinstance(augment, list):
33 | augment = augment(name)
34 | else:
35 | assert len(augment) == 1
36 | augment = augment[0]
37 |
38 | def create():
39 | image_shape = [height, width, colors]
40 | kwargs = dict(parse_fn=parse_fn, image_shape=image_shape)
41 | train_labeled = data.DataSet.from_files(train_files, augment, **kwargs)
42 | if do_memoize:
43 | train_labeled = train_labeled.memoize()
44 | if FLAGS.whiten:
45 | mean, std = data.compute_mean_std(train_labeled)
46 | else:
47 | mean, std = 0, 1
48 |
49 | valid_data = data.DataSet.from_files(train_files, augment_module.NOAUGMENT, **kwargs).take(valid)
50 | test_data = data.DataSet.from_files(test_files, augment_module.NOAUGMENT, **kwargs)
51 |
52 | return cls(name + '.' + FLAGS.augment.split('.')[0] + '-' + str(valid),
53 | train_labeled=train_labeled.skip(valid),
54 | train_unlabeled=None,
55 | valid=valid_data,
56 | test=test_data,
57 | nclass=nclass, colors=colors, height=height, width=width, mean=mean, std=std)
58 |
59 | return name + '-' + str(valid), create
60 |
61 |
62 | def augment_function(dataset: str):
63 | return augment_module.get_augmentation(dataset, FLAGS.augment.split('.')[0])
64 |
65 |
66 | def create_datasets():
67 | d = {}
68 | d.update([DataSetsFS.creator('cifar10', ['cifar10-train.tfrecord'], ['cifar10-test.tfrecord'], valid,
69 | augment_function) for valid in [1, 5000]])
70 | d.update([DataSetsFS.creator('cifar100', ['cifar100-train.tfrecord'], ['cifar100-test.tfrecord'], valid,
71 | augment_function, nclass=100) for valid in [1, 5000]])
72 | d.update([DataSetsFS.creator('fashion_mnist', ['fashion_mnist-train.tfrecord'], ['fashion_mnist-test.tfrecord'],
73 | valid, augment_function, height=32, width=32, colors=1,
74 | parse_fn=data.record_parse_mnist)
75 | for valid in [1, 5000]])
76 | d.update(
77 | [DataSetsFS.creator('stl10', [], [], valid, augment_function, height=96, width=96, do_memoize=False)
78 | for valid in [1, 5000]])
79 | d.update([DataSetsFS.creator('svhn', ['svhn-train.tfrecord', 'svhn-extra.tfrecord'], ['svhn-test.tfrecord'],
80 | valid, augment_function, do_memoize=False) for valid in [1, 5000]])
81 | d.update([DataSetsFS.creator('svhn_noextra', ['svhn-train.tfrecord'], ['svhn-test.tfrecord'],
82 | valid, augment_function, do_memoize=False) for valid in [1, 5000]])
83 | return d
84 |
85 |
86 | DATASETS = create_datasets
87 |
--------------------------------------------------------------------------------
/fully_supervised/lib/train.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import tensorflow as tf
16 | from absl import flags
17 | from tqdm import trange
18 |
19 | from libml import utils
20 | from libml.train import ClassifySemi
21 |
22 | FLAGS = flags.FLAGS
23 |
24 |
25 | class ClassifyFullySupervised(ClassifySemi):
26 | """Fully supervised classification.
27 | """
28 |
29 | def train_step(self, train_session, gen_labeled):
30 | x = gen_labeled()
31 | self.tmp.step = train_session.run([self.ops.train_op, self.ops.update_step],
32 | feed_dict={self.ops.xt: x['image'],
33 | self.ops.label: x['label']})[1]
34 |
35 | def train(self, train_nimg, report_nimg):
36 | if FLAGS.eval_ckpt:
37 | self.eval_checkpoint(FLAGS.eval_ckpt)
38 | return
39 | batch = FLAGS.batch
40 | train_labeled = self.dataset.train_labeled.repeat().shuffle(FLAGS.shuffle).parse().augment()
41 | train_labeled = train_labeled.batch(batch).prefetch(16).make_one_shot_iterator().get_next()
42 | scaffold = tf.train.Scaffold(saver=tf.train.Saver(max_to_keep=FLAGS.keep_ckpt, pad_step_number=10))
43 |
44 | with tf.Session(config=utils.get_config()) as sess:
45 | self.session = sess
46 | self.cache_eval()
47 |
48 | with tf.train.MonitoredTrainingSession(
49 | scaffold=scaffold,
50 | checkpoint_dir=self.checkpoint_dir,
51 | config=utils.get_config(),
52 | save_checkpoint_steps=FLAGS.save_kimg << 10,
53 | save_summaries_steps=report_nimg - batch) as train_session:
54 | self.session = train_session._tf_sess()
55 | gen_labeled = self.gen_labeled_fn(train_labeled)
56 | self.tmp.step = self.session.run(self.step)
57 | while self.tmp.step < train_nimg:
58 | loop = trange(self.tmp.step % report_nimg, report_nimg, batch,
59 | leave=False, unit='img', unit_scale=batch,
60 | desc='Epoch %d/%d' % (1 + (self.tmp.step // report_nimg), train_nimg // report_nimg))
61 | for _ in loop:
62 | self.train_step(train_session, gen_labeled)
63 | while self.tmp.print_queue:
64 | loop.write(self.tmp.print_queue.pop(0))
65 | while self.tmp.print_queue:
66 | print(self.tmp.print_queue.pop(0))
67 |
--------------------------------------------------------------------------------
/fully_supervised/runs/all.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | # Copyright 2019 Google LLC
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # https://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | # Fully supervised baseline without mixup (not shown in paper since Mixup is better)
18 | python fully_supervised/fs_baseline.py --train_dir experiments/fs --dataset=cifar10-1 --wd=0.02 --smoothing=0.001
19 | python fully_supervised/fs_baseline.py --train_dir experiments/fs --dataset=cifar100-1 --wd=0.02 --smoothing=0.001
20 | python fully_supervised/fs_baseline.py --train_dir experiments/fs --dataset=svhn-1 --wd=0.002 --smoothing=0.01
21 | python fully_supervised/fs_baseline.py --train_dir experiments/fs --dataset=svhn_noextra-1 --wd=0.002 --smoothing=0.01
22 |
23 | # Fully supervised Mixup baselines (in paper)
24 | # Uses default parameters: --wd=0.002 --beta=0.5
25 | python fully_supervised/fs_mixup.py --train_dir experiments/fs --dataset=cifar10-1
26 | python fully_supervised/fs_mixup.py --train_dir experiments/fs --dataset=svhn-1
27 | python fully_supervised/fs_mixup.py --train_dir experiments/fs --dataset=svhn_noextra-1
28 |
29 | # Fully supervised Mixup baselines on 26M parameter large network (in paper)
30 | # Uses default parameters: --wd=0.002 --beta=0.5
31 | python fully_supervised/fs_mixup.py --train_dir experiments/fs --dataset=cifar10-1 --filters=135
32 | python fully_supervised/fs_mixup.py --train_dir experiments/fs --dataset=cifar100-1 --filters=135
33 |
--------------------------------------------------------------------------------
/ict.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Interpolation Consistency Training for Semi-Supervised Learning.
15 |
16 | Reimplementation of https://arxiv.org/abs/1903.03825
17 | """
18 |
19 | import functools
20 | import os
21 |
22 | import tensorflow as tf
23 | from absl import app
24 | from absl import flags
25 |
26 | from libml import models, utils
27 | from libml.data import PAIR_DATASETS
28 | from libml.utils import EasyDict
29 |
30 | FLAGS = flags.FLAGS
31 |
32 |
33 | class ICT(models.MultiModel):
34 |
35 | def model(self, batch, lr, wd, ema, warmup_pos, consistency_weight, beta, **kwargs):
36 | hwc = [self.dataset.height, self.dataset.width, self.dataset.colors]
37 | xt_in = tf.placeholder(tf.float32, [batch] + hwc, 'xt') # For training
38 | x_in = tf.placeholder(tf.float32, [None] + hwc, 'x')
39 | y_in = tf.placeholder(tf.float32, [batch, 2] + hwc, 'y')
40 | l_in = tf.placeholder(tf.int32, [batch], 'labels')
41 | l = tf.one_hot(l_in, self.nclass)
42 | wd *= lr
43 | warmup = tf.clip_by_value(tf.to_float(self.step) / (warmup_pos * (FLAGS.train_kimg << 10)), 0, 1)
44 |
45 | y = tf.reshape(tf.transpose(y_in, [1, 0, 2, 3, 4]), [-1] + hwc)
46 | y_1, y_2 = tf.split(y, 2)
47 |
48 | mix = tf.distributions.Beta(beta, beta).sample([tf.shape(x_in)[0], 1, 1, 1])
49 | mix = tf.maximum(mix, 1 - mix)
50 |
51 | classifier = lambda x, **kw: self.classifier(x, **kw, **kwargs).logits
52 | logits_x = classifier(xt_in, training=True)
53 | post_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) # Take only first call to update batch norm.
54 |
55 | ema = tf.train.ExponentialMovingAverage(decay=ema)
56 | ema_op = ema.apply(utils.model_vars())
57 | ema_getter = functools.partial(utils.getter_ema, ema)
58 | logits_teacher = classifier(y_1, training=True, getter=ema_getter)
59 | labels_teacher = tf.stop_gradient(tf.nn.softmax(logits_teacher))
60 | labels_teacher = labels_teacher * mix[:, :, 0, 0] + labels_teacher[::-1] * (1 - mix[:, :, 0, 0])
61 | logits_student = classifier(y_1 * mix + y_1[::-1] * (1 - mix), training=True)
62 | loss_mt = tf.reduce_mean((labels_teacher - tf.nn.softmax(logits_student)) ** 2, -1)
63 | loss_mt = tf.reduce_mean(loss_mt)
64 |
65 | loss = tf.nn.softmax_cross_entropy_with_logits_v2(labels=l, logits=logits_x)
66 | loss = tf.reduce_mean(loss)
67 | tf.summary.scalar('losses/xe', loss)
68 | tf.summary.scalar('losses/mt', loss_mt)
69 |
70 | post_ops.append(ema_op)
71 | post_ops.extend([tf.assign(v, v * (1 - wd)) for v in utils.model_vars('classify') if 'kernel' in v.name])
72 |
73 | train_op = tf.train.AdamOptimizer(lr).minimize(loss + loss_mt * warmup * consistency_weight,
74 | colocate_gradients_with_ops=True)
75 | with tf.control_dependencies([train_op]):
76 | train_op = tf.group(*post_ops)
77 |
78 | return EasyDict(
79 | xt=xt_in, x=x_in, y=y_in, label=l_in, train_op=train_op,
80 | classify_raw=tf.nn.softmax(classifier(x_in, training=False)), # No EMA, for debugging.
81 | classify_op=tf.nn.softmax(classifier(x_in, getter=ema_getter, training=False)))
82 |
83 |
84 | def main(argv):
85 | utils.setup_main()
86 | del argv # Unused.
87 | dataset = PAIR_DATASETS()[FLAGS.dataset]()
88 | log_width = utils.ilog2(dataset.width)
89 | model = ICT(
90 | os.path.join(FLAGS.train_dir, dataset.name),
91 | dataset,
92 | lr=FLAGS.lr,
93 | wd=FLAGS.wd,
94 | arch=FLAGS.arch,
95 | warmup_pos=FLAGS.warmup_pos,
96 | batch=FLAGS.batch,
97 | nclass=dataset.nclass,
98 | ema=FLAGS.ema,
99 | beta=FLAGS.beta,
100 | consistency_weight=FLAGS.consistency_weight,
101 |
102 | scales=FLAGS.scales or (log_width - 2),
103 | filters=FLAGS.filters,
104 | repeat=FLAGS.repeat)
105 | model.train(FLAGS.train_kimg << 10, FLAGS.report_kimg << 10)
106 |
107 |
108 | if __name__ == '__main__':
109 | utils.setup_tf()
110 | flags.DEFINE_float('consistency_weight', 50., 'Consistency weight.')
111 | flags.DEFINE_float('warmup_pos', 0.4, 'Relative position at which constraint loss warmup ends.')
112 | flags.DEFINE_float('wd', 0.02, 'Weight decay.')
113 | flags.DEFINE_float('ema', 0.999, 'Exponential moving average of params.')
114 | flags.DEFINE_float('beta', 0.5, 'Mixup beta.')
115 | flags.DEFINE_integer('scales', 0, 'Number of 2x2 downscalings in the classifier.')
116 | flags.DEFINE_integer('filters', 32, 'Filter size of convolutions.')
117 | flags.DEFINE_integer('repeat', 4, 'Number of residual layers per stage.')
118 | FLAGS.set_default('augment', 'd.d.d')
119 | FLAGS.set_default('dataset', 'cifar10.3@250-5000')
120 | FLAGS.set_default('batch', 64)
121 | FLAGS.set_default('lr', 0.002)
122 | FLAGS.set_default('train_kimg', 1 << 16)
123 | app.run(main)
124 |
--------------------------------------------------------------------------------
/libml/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
--------------------------------------------------------------------------------
/libml/ctaugment.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Control Theory based self-augmentation."""
15 | import random
16 | from collections import namedtuple
17 |
18 | import numpy as np
19 | from PIL import Image, ImageOps, ImageEnhance, ImageFilter
20 |
21 | OPS = {}
22 | OP = namedtuple('OP', ('f', 'bins'))
23 | Sample = namedtuple('Sample', ('train', 'probe'))
24 |
25 |
26 | def register(*bins):
27 | def wrap(f):
28 | OPS[f.__name__] = OP(f, bins)
29 | return f
30 |
31 | return wrap
32 |
33 |
34 | def apply(x, ops):
35 | if ops is None:
36 | return x
37 | y = Image.fromarray(np.round(127.5 * (1 + x)).clip(0, 255).astype('uint8'))
38 | for op, args in ops:
39 | y = OPS[op].f(y, *args)
40 | return np.asarray(y).astype('f') / 127.5 - 1
41 |
42 |
43 | class CTAugment:
44 | def __init__(self, depth=2, th=0.85, decay=0.99):
45 | self.decay = decay
46 | self.depth = depth
47 | self.th = th
48 | self.rates = {}
49 | for k, op in OPS.items():
50 | self.rates[k] = tuple([np.ones(x, 'f') for x in op.bins])
51 |
52 | def rate_to_p(self, rate):
53 | p = rate + (1 - self.decay)
54 | p = p / p.max()
55 | p[p < self.th] = 0
56 | return p
57 |
58 | def policy(self, probe):
59 | kl = list(OPS.keys())
60 | v = []
61 | if probe:
62 | for _ in range(self.depth):
63 | k = random.choice(kl)
64 | bins = self.rates[k]
65 | rnd = np.random.uniform(0, 1, len(bins))
66 | v.append(OP(k, rnd.tolist()))
67 | return v
68 | for _ in range(self.depth):
69 | vt = []
70 | k = random.choice(kl)
71 | bins = self.rates[k]
72 | rnd = np.random.uniform(0, 1, len(bins))
73 | for r, bin in zip(rnd, bins):
74 | p = self.rate_to_p(bin)
75 | segments = p[1:] + p[:-1]
76 | segment = np.random.choice(segments.shape[0], p=segments / segments.sum())
77 | vt.append((segment + r) / segments.shape[0])
78 | v.append(OP(k, vt))
79 | return v
80 |
81 | def update_rates(self, policy, accuracy):
82 | for k, bins in policy:
83 | for p, rate in zip(bins, self.rates[k]):
84 | p = int(p * len(rate) * 0.999)
85 | rate[p] = rate[p] * self.decay + accuracy * (1 - self.decay)
86 |
87 | def stats(self):
88 | return '\n'.join('%-16s %s' % (k, ' / '.join(' '.join('%.2f' % x for x in self.rate_to_p(rate))
89 | for rate in self.rates[k]))
90 | for k in sorted(OPS.keys()))
91 |
92 |
93 | def _enhance(x, op, level):
94 | return op(x).enhance(0.1 + 1.9 * level)
95 |
96 |
97 | def _imageop(x, op, level):
98 | return Image.blend(x, op(x), level)
99 |
100 |
101 | def _filter(x, op, level):
102 | return Image.blend(x, x.filter(op), level)
103 |
104 |
105 | @register(17)
106 | def autocontrast(x, level):
107 | return _imageop(x, ImageOps.autocontrast, level)
108 |
109 |
110 | @register(17)
111 | def blur(x, level):
112 | return _filter(x, ImageFilter.BLUR, level)
113 |
114 |
115 | @register(17)
116 | def brightness(x, brightness):
117 | return _enhance(x, ImageEnhance.Brightness, brightness)
118 |
119 |
120 | @register(17)
121 | def color(x, color):
122 | return _enhance(x, ImageEnhance.Color, color)
123 |
124 |
125 | @register(17)
126 | def contrast(x, contrast):
127 | return _enhance(x, ImageEnhance.Contrast, contrast)
128 |
129 |
130 | @register(17)
131 | def cutout(x, level):
132 | """Apply cutout to pil_img at the specified level."""
133 | size = 1 + int(level * min(x.size) * 0.499)
134 | img_height, img_width = x.size
135 | height_loc = np.random.randint(low=0, high=img_height)
136 | width_loc = np.random.randint(low=0, high=img_width)
137 | upper_coord = (max(0, height_loc - size // 2), max(0, width_loc - size // 2))
138 | lower_coord = (min(img_height, height_loc + size // 2), min(img_width, width_loc + size // 2))
139 | pixels = x.load() # create the pixel map
140 | for i in range(upper_coord[0], lower_coord[0]): # for every col:
141 | for j in range(upper_coord[1], lower_coord[1]): # For every row
142 | pixels[i, j] = (127, 127, 127) # set the colour accordingly
143 | return x
144 |
145 |
146 | @register(17)
147 | def equalize(x, level):
148 | return _imageop(x, ImageOps.equalize, level)
149 |
150 |
151 | @register(17)
152 | def invert(x, level):
153 | return _imageop(x, ImageOps.invert, level)
154 |
155 |
156 | @register()
157 | def identity(x):
158 | return x
159 |
160 |
161 | @register(8)
162 | def posterize(x, level):
163 | level = 1 + int(level * 7.999)
164 | return ImageOps.posterize(x, level)
165 |
166 |
167 | @register(17, 6)
168 | def rescale(x, scale, method):
169 | s = x.size
170 | scale *= 0.25
171 | crop = (scale, scale, s[0] - scale, s[1] - scale)
172 | methods = (Image.ANTIALIAS, Image.BICUBIC, Image.BILINEAR, Image.BOX, Image.HAMMING, Image.NEAREST)
173 | method = methods[int(method * 5.99)]
174 | return x.crop(crop).resize(x.size, method)
175 |
176 |
177 | @register(17)
178 | def rotate(x, angle):
179 | angle = int(np.round((2 * angle - 1) * 45))
180 | return x.rotate(angle)
181 |
182 |
183 | @register(17)
184 | def sharpness(x, sharpness):
185 | return _enhance(x, ImageEnhance.Sharpness, sharpness)
186 |
187 |
188 | @register(17)
189 | def shear_x(x, shear):
190 | shear = (2 * shear - 1) * 0.3
191 | return x.transform(x.size, Image.AFFINE, (1, shear, 0, 0, 1, 0))
192 |
193 |
194 | @register(17)
195 | def shear_y(x, shear):
196 | shear = (2 * shear - 1) * 0.3
197 | return x.transform(x.size, Image.AFFINE, (1, 0, 0, shear, 1, 0))
198 |
199 |
200 | @register(17)
201 | def smooth(x, level):
202 | return _filter(x, ImageFilter.SMOOTH, level)
203 |
204 |
205 | @register(17)
206 | def solarize(x, th):
207 | th = int(th * 255.999)
208 | return ImageOps.solarize(x, th)
209 |
210 |
211 | @register(17)
212 | def translate_x(x, delta):
213 | delta = (2 * delta - 1) * 0.3
214 | return x.transform(x.size, Image.AFFINE, (1, 0, delta, 0, 1, 0))
215 |
216 |
217 | @register(17)
218 | def translate_y(x, delta):
219 | delta = (2 * delta - 1) * 0.3
220 | return x.transform(x.size, Image.AFFINE, (1, 0, 0, 0, 1, delta))
221 |
--------------------------------------------------------------------------------
/libml/layers.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Custom neural network layers and primitives.
15 | """
16 | import numbers
17 |
18 | import numpy as np
19 | import tensorflow as tf
20 |
21 | from libml.data import DataSets
22 |
23 |
24 | def smart_shape(x):
25 | s, t = x.shape, tf.shape(x)
26 | return [t[i] if s[i].value is None else s[i] for i in range(len(s))]
27 |
28 |
29 | def entropy_from_logits(logits):
30 | """Computes entropy from classifier logits.
31 |
32 | Args:
33 | logits: a tensor of shape (batch_size, class_count) representing the
34 | logits of a classifier.
35 |
36 | Returns:
37 | A tensor of shape (batch_size,) of floats giving the entropies
38 | batchwise.
39 | """
40 | distribution = tf.contrib.distributions.Categorical(logits=logits)
41 | return distribution.entropy()
42 |
43 |
44 | def entropy_penalty(logits, entropy_penalty_multiplier, mask):
45 | """Computes an entropy penalty using the classifier logits.
46 |
47 | Args:
48 | logits: a tensor of shape (batch_size, class_count) representing the
49 | logits of a classifier.
50 | entropy_penalty_multiplier: A float by which the entropy is multiplied.
51 | mask: A tensor that optionally masks out some of the costs.
52 |
53 | Returns:
54 | The mean entropy penalty
55 | """
56 | entropy = entropy_from_logits(logits)
57 | losses = entropy * entropy_penalty_multiplier
58 | losses *= tf.cast(mask, tf.float32)
59 | return tf.reduce_mean(losses)
60 |
61 |
62 | def kl_divergence_from_logits(logits_a, logits_b):
63 | """Gets KL divergence from logits parameterizing categorical distributions.
64 |
65 | Args:
66 | logits_a: A tensor of logits parameterizing the first distribution.
67 | logits_b: A tensor of logits parameterizing the second distribution.
68 |
69 | Returns:
70 | The (batch_size,) shaped tensor of KL divergences.
71 | """
72 | distribution1 = tf.contrib.distributions.Categorical(logits=logits_a)
73 | distribution2 = tf.contrib.distributions.Categorical(logits=logits_b)
74 | return tf.contrib.distributions.kl_divergence(distribution1, distribution2)
75 |
76 |
77 | def mse_from_logits(output_logits, target_logits):
78 | """Computes MSE between predictions associated with logits.
79 |
80 | Args:
81 | output_logits: A tensor of logits from the primary model.
82 | target_logits: A tensor of logits from the secondary model.
83 |
84 | Returns:
85 | The mean MSE
86 | """
87 | diffs = tf.nn.softmax(output_logits) - tf.nn.softmax(target_logits)
88 | squared_diffs = tf.square(diffs)
89 | return tf.reduce_mean(squared_diffs, -1)
90 |
91 |
92 | def interleave_offsets(batch, nu):
93 | groups = [batch // (nu + 1)] * (nu + 1)
94 | for x in range(batch - sum(groups)):
95 | groups[-x - 1] += 1
96 | offsets = [0]
97 | for g in groups:
98 | offsets.append(offsets[-1] + g)
99 | assert offsets[-1] == batch
100 | return offsets
101 |
102 |
103 | def interleave(xy, batch):
104 | nu = len(xy) - 1
105 | offsets = interleave_offsets(batch, nu)
106 | xy = [[v[offsets[p]:offsets[p + 1]] for p in range(nu + 1)] for v in xy]
107 | for i in range(1, nu + 1):
108 | xy[0][i], xy[i][i] = xy[i][i], xy[0][i]
109 | return [tf.concat(v, axis=0) for v in xy]
110 |
111 |
112 | def renorm(v):
113 | return v / tf.reduce_sum(v, axis=-1, keepdims=True)
114 |
115 |
116 | def shakeshake(a, b, training):
117 | if not training:
118 | return 0.5 * (a + b)
119 | mu = tf.random_uniform([tf.shape(a)[0]] + [1] * (len(a.shape) - 1), 0, 1)
120 | mixf = a + mu * (b - a)
121 | mixb = a + mu[::1] * (b - a)
122 | return tf.stop_gradient(mixf - mixb) + mixb
123 |
124 |
125 | class PMovingAverage:
126 | def __init__(self, name, nclass, buf_size):
127 | # MEAN aggregation is used by DistributionStrategy to aggregate
128 | # variable updates across shards
129 | self.ma = tf.Variable(tf.ones([buf_size, nclass]) / nclass,
130 | trainable=False,
131 | name=name,
132 | aggregation=tf.VariableAggregation.MEAN)
133 |
134 | def __call__(self):
135 | v = tf.reduce_mean(self.ma, axis=0)
136 | return v / tf.reduce_sum(v)
137 |
138 | def update(self, entry):
139 | entry = tf.reduce_mean(entry, axis=0)
140 | return tf.assign(self.ma, tf.concat([self.ma[1:], [entry]], axis=0))
141 |
142 |
143 | class PData:
144 | def __init__(self, dataset: DataSets):
145 | self.has_update = False
146 | if dataset.p_unlabeled is not None:
147 | self.p_data = tf.constant(dataset.p_unlabeled, name='p_data')
148 | elif dataset.p_labeled is not None:
149 | self.p_data = tf.constant(dataset.p_labeled, name='p_data')
150 | else:
151 | # MEAN aggregation is used by DistributionStrategy to aggregate
152 | # variable updates across shards
153 | self.p_data = tf.Variable(renorm(tf.ones([dataset.nclass])),
154 | trainable=False,
155 | name='p_data',
156 | aggregation=tf.VariableAggregation.MEAN)
157 | self.has_update = True
158 |
159 | def __call__(self):
160 | return self.p_data / tf.reduce_sum(self.p_data)
161 |
162 | def update(self, entry, decay=0.999):
163 | entry = tf.reduce_mean(entry, axis=0)
164 | return tf.assign(self.p_data, self.p_data * decay + entry * (1 - decay))
165 |
166 |
167 | class MixMode:
168 | # A class for mixing data for various combination of labeled and unlabeled.
169 | # x = labeled example
170 | # y = unlabeled example
171 | # For example "xx.yxy" means: mix x with x, mix y with both x and y.
172 | MODES = 'xx.yy xxy.yxy xx.yxy xx.yx xx. .yy xxy. .yxy .'.split()
173 |
174 | def __init__(self, mode):
175 | assert mode in self.MODES
176 | self.mode = mode
177 |
178 | @staticmethod
179 | def augment_pair(x0, l0, x1, l1, beta, **kwargs):
180 | del kwargs
181 | if isinstance(beta, numbers.Integral) and beta <= 0:
182 | return x0, l0
183 |
184 | def np_beta(s, beta): # TF implementation seems unreliable for beta below 0.2
185 | return np.random.beta(beta, beta, s).astype('f')
186 |
187 | with tf.device('/cpu'):
188 | mix = tf.py_func(np_beta, [tf.shape(x0)[0], beta], tf.float32)
189 | mix = tf.reshape(tf.maximum(mix, 1 - mix), [tf.shape(x0)[0], 1, 1, 1])
190 | index = tf.random_shuffle(tf.range(tf.shape(x0)[0]))
191 | xs = tf.gather(x1, index)
192 | ls = tf.gather(l1, index)
193 | xmix = x0 * mix + xs * (1 - mix)
194 | lmix = l0 * mix[:, :, 0, 0] + ls * (1 - mix[:, :, 0, 0])
195 | return xmix, lmix
196 |
197 | @staticmethod
198 | def augment(x, l, beta, **kwargs):
199 | return MixMode.augment_pair(x, l, x, l, beta, **kwargs)
200 |
201 | def __call__(self, xl: list, ll: list, betal: list):
202 | assert len(xl) == len(ll) >= 2
203 | assert len(betal) == 2
204 | if self.mode == '.':
205 | return xl, ll
206 | elif self.mode == 'xx.':
207 | mx0, ml0 = self.augment(xl[0], ll[0], betal[0])
208 | return [mx0] + xl[1:], [ml0] + ll[1:]
209 | elif self.mode == '.yy':
210 | mx1, ml1 = self.augment(
211 | tf.concat(xl[1:], 0), tf.concat(ll[1:], 0), betal[1])
212 | return (xl[:1] + tf.split(mx1, len(xl) - 1),
213 | ll[:1] + tf.split(ml1, len(ll) - 1))
214 | elif self.mode == 'xx.yy':
215 | mx0, ml0 = self.augment(xl[0], ll[0], betal[0])
216 | mx1, ml1 = self.augment(
217 | tf.concat(xl[1:], 0), tf.concat(ll[1:], 0), betal[1])
218 | return ([mx0] + tf.split(mx1, len(xl) - 1),
219 | [ml0] + tf.split(ml1, len(ll) - 1))
220 | elif self.mode == 'xxy.':
221 | mx, ml = self.augment(
222 | tf.concat(xl, 0), tf.concat(ll, 0),
223 | sum(betal) / len(betal))
224 | return (tf.split(mx, len(xl))[:1] + xl[1:],
225 | tf.split(ml, len(ll))[:1] + ll[1:])
226 | elif self.mode == '.yxy':
227 | mx, ml = self.augment(
228 | tf.concat(xl, 0), tf.concat(ll, 0),
229 | sum(betal) / len(betal))
230 | return (xl[:1] + tf.split(mx, len(xl))[1:],
231 | ll[:1] + tf.split(ml, len(ll))[1:])
232 | elif self.mode == 'xxy.yxy':
233 | mx, ml = self.augment(
234 | tf.concat(xl, 0), tf.concat(ll, 0),
235 | sum(betal) / len(betal))
236 | return tf.split(mx, len(xl)), tf.split(ml, len(ll))
237 | elif self.mode == 'xx.yxy':
238 | mx0, ml0 = self.augment(xl[0], ll[0], betal[0])
239 | mx1, ml1 = self.augment(tf.concat(xl, 0), tf.concat(ll, 0), betal[1])
240 | mx1, ml1 = [tf.split(m, len(xl))[1:] for m in (mx1, ml1)]
241 | return [mx0] + mx1, [ml0] + ml1
242 | elif self.mode == 'xx.yx':
243 | mx0, ml0 = self.augment(xl[0], ll[0], betal[0])
244 | mx1, ml1 = zip(*[
245 | self.augment_pair(xl[i], ll[i], xl[0], ll[0], betal[1])
246 | for i in range(1, len(xl))
247 | ])
248 | return [mx0] + list(mx1), [ml0] + list(ml1)
249 | raise NotImplementedError(self.mode)
250 |
--------------------------------------------------------------------------------
/libml/models.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Classifier architectures."""
15 | import functools
16 | import itertools
17 |
18 | import tensorflow as tf
19 | from absl import flags
20 |
21 | from libml import layers
22 | from libml.train import ClassifySemi
23 | from libml.utils import EasyDict
24 |
25 |
26 | class CNN13(ClassifySemi):
27 | """Simplified reproduction of the Mean Teacher paper network. filters=128 in original implementation.
28 | Removed dropout, Gaussians, forked dense layers, basically all non-standard things."""
29 |
30 | def classifier(self, x, scales, filters, training, getter=None, **kwargs):
31 | del kwargs
32 | assert scales == 3 # Only specified for 32x32 inputs.
33 | conv_args = dict(kernel_size=3, activation=tf.nn.leaky_relu, padding='same')
34 | bn_args = dict(training=training, momentum=0.999)
35 |
36 | with tf.variable_scope('classify', reuse=tf.AUTO_REUSE, custom_getter=getter):
37 | y = tf.layers.conv2d((x - self.dataset.mean) / self.dataset.std, filters, **conv_args)
38 | y = tf.layers.batch_normalization(y, **bn_args)
39 | y = tf.layers.conv2d(y, filters, **conv_args)
40 | y = tf.layers.batch_normalization(y, **bn_args)
41 | y = tf.layers.conv2d(y, filters, **conv_args)
42 | y = tf.layers.batch_normalization(y, **bn_args)
43 | y = tf.layers.max_pooling2d(y, 2, 2)
44 | y = tf.layers.conv2d(y, 2 * filters, **conv_args)
45 | y = tf.layers.batch_normalization(y, **bn_args)
46 | y = tf.layers.conv2d(y, 2 * filters, **conv_args)
47 | y = tf.layers.batch_normalization(y, **bn_args)
48 | y = tf.layers.conv2d(y, 2 * filters, **conv_args)
49 | y = tf.layers.batch_normalization(y, **bn_args)
50 | y = tf.layers.max_pooling2d(y, 2, 2)
51 | y = tf.layers.conv2d(y, 4 * filters, kernel_size=3, activation=tf.nn.leaky_relu, padding='valid')
52 | y = tf.layers.batch_normalization(y, **bn_args)
53 | y = tf.layers.conv2d(y, 2 * filters, kernel_size=1, activation=tf.nn.leaky_relu, padding='same')
54 | y = tf.layers.batch_normalization(y, **bn_args)
55 | y = tf.layers.conv2d(y, 1 * filters, kernel_size=1, activation=tf.nn.leaky_relu, padding='same')
56 | y = tf.layers.batch_normalization(y, **bn_args)
57 | y = tf.reduce_mean(y, [1, 2]) # (b, 6, 6, 128) -> (b, 128)
58 | logits = tf.layers.dense(y, self.nclass)
59 | return EasyDict(logits=logits, embeds=y)
60 |
61 |
62 | class ResNet(ClassifySemi):
63 | def classifier(self, x, scales, filters, repeat, training, getter=None, dropout=0, **kwargs):
64 | del kwargs
65 | leaky_relu = functools.partial(tf.nn.leaky_relu, alpha=0.1)
66 | bn_args = dict(training=training, momentum=0.999)
67 |
68 | def conv_args(k, f):
69 | return dict(padding='same',
70 | kernel_initializer=tf.random_normal_initializer(stddev=tf.rsqrt(0.5 * k * k * f)))
71 |
72 | def residual(x0, filters, stride=1, activate_before_residual=False):
73 | x = leaky_relu(tf.layers.batch_normalization(x0, **bn_args))
74 | if activate_before_residual:
75 | x0 = x
76 |
77 | x = tf.layers.conv2d(x, filters, 3, strides=stride, **conv_args(3, filters))
78 | x = leaky_relu(tf.layers.batch_normalization(x, **bn_args))
79 | x = tf.layers.conv2d(x, filters, 3, **conv_args(3, filters))
80 |
81 | if x0.get_shape()[3] != filters:
82 | x0 = tf.layers.conv2d(x0, filters, 1, strides=stride, **conv_args(1, filters))
83 |
84 | return x0 + x
85 |
86 | with tf.variable_scope('classify', reuse=tf.AUTO_REUSE, custom_getter=getter):
87 | y = tf.layers.conv2d((x - self.dataset.mean) / self.dataset.std, 16, 3, **conv_args(3, 16))
88 | for scale in range(scales):
89 | y = residual(y, filters << scale, stride=2 if scale else 1, activate_before_residual=scale == 0)
90 | for i in range(repeat - 1):
91 | y = residual(y, filters << scale)
92 |
93 | y = leaky_relu(tf.layers.batch_normalization(y, **bn_args))
94 | y = embeds = tf.reduce_mean(y, [1, 2])
95 | if dropout and training:
96 | y = tf.nn.dropout(y, 1 - dropout)
97 | logits = tf.layers.dense(y, self.nclass, kernel_initializer=tf.glorot_normal_initializer())
98 | return EasyDict(logits=logits, embeds=embeds)
99 |
100 |
101 | class ShakeNet(ClassifySemi):
102 | def classifier(self, x, scales, filters, repeat, training, getter=None, dropout=0, **kwargs):
103 | del kwargs
104 | bn_args = dict(training=training, momentum=0.999)
105 |
106 | def conv_args(k, f):
107 | return dict(padding='same', use_bias=False,
108 | kernel_initializer=tf.random_normal_initializer(stddev=tf.rsqrt(0.5 * k * k * f)))
109 |
110 | def residual(x0, filters, stride=1):
111 | def branch():
112 | x = tf.nn.relu(x0)
113 | x = tf.layers.conv2d(x, filters, 3, strides=stride, **conv_args(3, filters))
114 | x = tf.nn.relu(tf.layers.batch_normalization(x, **bn_args))
115 | x = tf.layers.conv2d(x, filters, 3, **conv_args(3, filters))
116 | x = tf.layers.batch_normalization(x, **bn_args)
117 | return x
118 |
119 | x = layers.shakeshake(branch(), branch(), training)
120 |
121 | if stride == 2:
122 | x1 = tf.layers.conv2d(tf.nn.relu(x0[:, ::2, ::2]), filters >> 1, 1, **conv_args(1, filters >> 1))
123 | x2 = tf.layers.conv2d(tf.nn.relu(x0[:, 1::2, 1::2]), filters >> 1, 1, **conv_args(1, filters >> 1))
124 | x0 = tf.concat([x1, x2], axis=3)
125 | x0 = tf.layers.batch_normalization(x0, **bn_args)
126 | elif x0.get_shape()[3] != filters:
127 | x0 = tf.layers.conv2d(x0, filters, 1, **conv_args(1, filters))
128 | x0 = tf.layers.batch_normalization(x0, **bn_args)
129 |
130 | return x0 + x
131 |
132 | with tf.variable_scope('classify', reuse=tf.AUTO_REUSE, custom_getter=getter):
133 | y = tf.layers.conv2d((x - self.dataset.mean) / self.dataset.std, 16, 3, **conv_args(3, 16))
134 | for scale, i in itertools.product(range(scales), range(repeat)):
135 | with tf.variable_scope('layer%d.%d' % (scale + 1, i)):
136 | if i == 0:
137 | y = residual(y, filters << scale, stride=2 if scale else 1)
138 | else:
139 | y = residual(y, filters << scale)
140 |
141 | y = embeds = tf.reduce_mean(y, [1, 2])
142 | if dropout and training:
143 | y = tf.nn.dropout(y, 1 - dropout)
144 | logits = tf.layers.dense(y, self.nclass, kernel_initializer=tf.glorot_normal_initializer())
145 | return EasyDict(logits=logits, embeds=embeds)
146 |
147 |
148 | class MultiModel(CNN13, ResNet, ShakeNet):
149 | MODELS = ('cnn13', 'resnet', 'shake')
150 | MODEL_CNN13, MODEL_RESNET, MODEL_SHAKE = MODELS
151 |
152 | def augment(self, x, l, smoothing, **kwargs):
153 | del kwargs
154 | return x, l - smoothing * (l - 1. / self.nclass)
155 |
156 | def classifier(self, x, arch, **kwargs):
157 | if arch == self.MODEL_CNN13:
158 | return CNN13.classifier(self, x, **kwargs)
159 | elif arch == self.MODEL_RESNET:
160 | return ResNet.classifier(self, x, **kwargs)
161 | elif arch == self.MODEL_SHAKE:
162 | return ShakeNet.classifier(self, x, **kwargs)
163 | raise ValueError('Model %s does not exists, available ones are %s' % (arch, self.MODELS))
164 |
165 |
166 | flags.DEFINE_enum('arch', MultiModel.MODEL_RESNET, MultiModel.MODELS, 'Architecture.')
167 |
--------------------------------------------------------------------------------
/libml/train.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Training loop, checkpoint saving and loading, evaluation code."""
15 |
16 | import json
17 | import os.path
18 | import shutil
19 |
20 | import numpy as np
21 | import tensorflow as tf
22 | from absl import flags
23 | from tqdm import trange, tqdm
24 |
25 | from libml import data, utils
26 | from libml.utils import EasyDict
27 |
28 | FLAGS = flags.FLAGS
29 | flags.DEFINE_string('train_dir', './experiments',
30 | 'Folder where to save training data.')
31 | flags.DEFINE_float('lr', 0.0001, 'Learning rate.')
32 | flags.DEFINE_integer('batch', 64, 'Batch size.')
33 | flags.DEFINE_integer('train_kimg', 1 << 14, 'Training duration in kibi-samples.')
34 | flags.DEFINE_integer('report_kimg', 64, 'Report summary period in kibi-samples.')
35 | flags.DEFINE_integer('save_kimg', 64, 'Save checkpoint period in kibi-samples.')
36 | flags.DEFINE_integer('keep_ckpt', 50, 'Number of checkpoints to keep.')
37 | flags.DEFINE_string('eval_ckpt', '', 'Checkpoint to evaluate. If provided, do not do training, just do eval.')
38 | flags.DEFINE_string('rerun', '', 'A string to identify a run if running multiple ones with same parameters.')
39 |
40 |
41 | class Model:
42 | def __init__(self, train_dir: str, dataset: data.DataSets, **kwargs):
43 | self.train_dir = os.path.join(train_dir, FLAGS.rerun, self.experiment_name(**kwargs))
44 | self.params = EasyDict(kwargs)
45 | self.dataset = dataset
46 | self.session = None
47 | self.tmp = EasyDict(print_queue=[], cache=EasyDict())
48 | self.step = tf.train.get_or_create_global_step()
49 | self.ops = self.model(**kwargs)
50 | self.ops.update_step = tf.assign_add(self.step, FLAGS.batch)
51 | self.add_summaries(**kwargs)
52 |
53 | print(' Config '.center(80, '-'))
54 | print('train_dir', self.train_dir)
55 | print('%-32s %s' % ('Model', self.__class__.__name__))
56 | print('%-32s %s' % ('Dataset', dataset.name))
57 | for k, v in sorted(kwargs.items()):
58 | print('%-32s %s' % (k, v))
59 | print(' Model '.center(80, '-'))
60 | to_print = [tuple(['%s' % x for x in (v.name, np.prod(v.shape), v.shape)]) for v in utils.model_vars(None)]
61 | to_print.append(('Total', str(sum(int(x[1]) for x in to_print)), ''))
62 | sizes = [max([len(x[i]) for x in to_print]) for i in range(3)]
63 | fmt = '%%-%ds %%%ds %%%ds' % tuple(sizes)
64 | for x in to_print[:-1]:
65 | print(fmt % x)
66 | print()
67 | print(fmt % to_print[-1])
68 | print('-' * 80)
69 | self._create_initial_files()
70 |
71 | @property
72 | def arg_dir(self):
73 | return os.path.join(self.train_dir, 'args')
74 |
75 | @property
76 | def checkpoint_dir(self):
77 | return os.path.join(self.train_dir, 'tf')
78 |
79 | def train_print(self, text):
80 | self.tmp.print_queue.append(text)
81 |
82 | def _create_initial_files(self):
83 | for dir in (self.checkpoint_dir, self.arg_dir):
84 | tf.gfile.MakeDirs(dir)
85 | self.save_args()
86 |
87 | def _reset_files(self):
88 | shutil.rmtree(self.train_dir)
89 | self._create_initial_files()
90 |
91 | def save_args(self, **extra_params):
92 | with tf.gfile.Open(os.path.join(self.arg_dir, 'args.json'), 'w') as f:
93 | json.dump({**self.params, **extra_params}, f, sort_keys=True, indent=4)
94 |
95 | @classmethod
96 | def load(cls, train_dir):
97 | with tf.gfile.Open(os.path.join(train_dir, 'args/args.json'), 'r') as f:
98 | params = json.load(f)
99 | instance = cls(train_dir=train_dir, **params)
100 | instance.train_dir = train_dir
101 | return instance
102 |
103 | def experiment_name(self, **kwargs):
104 | args = [x + str(y) for x, y in sorted(kwargs.items())]
105 | return '_'.join([self.__class__.__name__] + args)
106 |
107 | def eval_mode(self, ckpt=None):
108 | self.session = tf.Session(config=utils.get_config())
109 | saver = tf.train.Saver()
110 | if ckpt is None:
111 | ckpt = utils.find_latest_checkpoint(self.checkpoint_dir)
112 | else:
113 | ckpt = os.path.abspath(ckpt)
114 | saver.restore(self.session, ckpt)
115 | self.tmp.step = self.session.run(self.step)
116 | print('Eval model %s at global_step %d' % (self.__class__.__name__, self.tmp.step))
117 | return self
118 |
119 | def model(self, **kwargs):
120 | raise NotImplementedError()
121 |
122 | def add_summaries(self, **kwargs):
123 | raise NotImplementedError()
124 |
125 |
126 | class ClassifySemi(Model):
127 | """Semi-supervised classification."""
128 |
129 | def __init__(self, train_dir: str, dataset: data.DataSets, nclass: int, **kwargs):
130 | self.nclass = nclass
131 | Model.__init__(self, train_dir, dataset, nclass=nclass, **kwargs)
132 |
133 | def train_step(self, train_session, gen_labeled, gen_unlabeled):
134 | x, y = gen_labeled(), gen_unlabeled()
135 | self.tmp.step = train_session.run([self.ops.train_op, self.ops.update_step],
136 | feed_dict={self.ops.y: y['image'],
137 | self.ops.xt: x['image'],
138 | self.ops.label: x['label']})[1]
139 |
140 | def gen_labeled_fn(self, data_iterator):
141 | return self.dataset.train_labeled.numpy_augment(lambda: self.session.run(data_iterator))
142 |
143 | def gen_unlabeled_fn(self, data_iterator):
144 | return self.dataset.train_unlabeled.numpy_augment(lambda: self.session.run(data_iterator))
145 |
146 | def train(self, train_nimg, report_nimg):
147 | if FLAGS.eval_ckpt:
148 | self.eval_checkpoint(FLAGS.eval_ckpt)
149 | return
150 | batch = FLAGS.batch
151 | train_labeled = self.dataset.train_labeled.repeat().shuffle(FLAGS.shuffle).parse().augment()
152 | train_labeled = train_labeled.batch(batch).prefetch(16).make_one_shot_iterator().get_next()
153 | train_unlabeled = self.dataset.train_unlabeled.repeat().shuffle(FLAGS.shuffle).parse().augment()
154 | train_unlabeled = train_unlabeled.batch(batch).prefetch(16).make_one_shot_iterator().get_next()
155 | scaffold = tf.train.Scaffold(saver=tf.train.Saver(max_to_keep=FLAGS.keep_ckpt, pad_step_number=10))
156 |
157 | with tf.Session(config=utils.get_config()) as sess:
158 | self.session = sess
159 | self.cache_eval()
160 |
161 | with tf.train.MonitoredTrainingSession(
162 | scaffold=scaffold,
163 | checkpoint_dir=self.checkpoint_dir,
164 | config=utils.get_config(),
165 | save_checkpoint_steps=FLAGS.save_kimg << 10,
166 | save_summaries_steps=report_nimg - batch) as train_session:
167 | self.session = train_session._tf_sess()
168 | gen_labeled = self.gen_labeled_fn(train_labeled)
169 | gen_unlabeled = self.gen_unlabeled_fn(train_unlabeled)
170 | self.tmp.step = self.session.run(self.step)
171 | while self.tmp.step < train_nimg:
172 | loop = trange(self.tmp.step % report_nimg, report_nimg, batch,
173 | leave=False, unit='img', unit_scale=batch,
174 | desc='Epoch %d/%d' % (1 + (self.tmp.step // report_nimg), train_nimg // report_nimg))
175 | for _ in loop:
176 | self.train_step(train_session, gen_labeled, gen_unlabeled)
177 | while self.tmp.print_queue:
178 | loop.write(self.tmp.print_queue.pop(0))
179 | while self.tmp.print_queue:
180 | print(self.tmp.print_queue.pop(0))
181 |
182 | def eval_checkpoint(self, ckpt=None):
183 | self.eval_mode(ckpt)
184 | self.cache_eval()
185 | raw = self.eval_stats(classify_op=self.ops.classify_raw)
186 | ema = self.eval_stats(classify_op=self.ops.classify_op)
187 | print('%16s %8s %8s %8s' % ('', 'labeled', 'valid', 'test'))
188 | print('%16s %8s %8s %8s' % (('raw',) + tuple('%.2f' % x for x in raw)))
189 | print('%16s %8s %8s %8s' % (('ema',) + tuple('%.2f' % x for x in ema)))
190 |
191 | def cache_eval(self):
192 | """Cache datasets for computing eval stats."""
193 |
194 | def collect_samples(dataset, name):
195 | """Return numpy arrays of all the samples from a dataset."""
196 | pbar = tqdm(desc='Caching %s examples' % name)
197 | it = dataset.batch(1).prefetch(16).make_one_shot_iterator().get_next()
198 | images, labels = [], []
199 | while 1:
200 | try:
201 | v = self.session.run(it)
202 | except tf.errors.OutOfRangeError:
203 | break
204 | images.append(v['image'])
205 | labels.append(v['label'])
206 | pbar.update()
207 |
208 | images = np.concatenate(images, axis=0)
209 | labels = np.concatenate(labels, axis=0)
210 | pbar.close()
211 | return images, labels
212 |
213 | if 'test' not in self.tmp.cache:
214 | self.tmp.cache.test = collect_samples(self.dataset.test.parse(), name='test')
215 | self.tmp.cache.valid = collect_samples(self.dataset.valid.parse(), name='valid')
216 | self.tmp.cache.train_labeled = collect_samples(self.dataset.train_labeled.take(10000).parse(),
217 | name='train_labeled')
218 |
219 | def eval_stats(self, batch=None, feed_extra=None, classify_op=None):
220 | """Evaluate model on train, valid and test."""
221 | batch = batch or FLAGS.batch
222 | classify_op = self.ops.classify_op if classify_op is None else classify_op
223 | accuracies = []
224 | for subset in ('train_labeled', 'valid', 'test'):
225 | images, labels = self.tmp.cache[subset]
226 | predicted = []
227 |
228 | for x in range(0, images.shape[0], batch):
229 | p = self.session.run(
230 | classify_op,
231 | feed_dict={
232 | self.ops.x: images[x:x + batch],
233 | **(feed_extra or {})
234 | })
235 | predicted.append(p)
236 | predicted = np.concatenate(predicted, axis=0)
237 | accuracies.append((predicted.argmax(1) == labels).mean() * 100)
238 | self.train_print('kimg %-5d accuracy train/valid/test %.2f %.2f %.2f' %
239 | tuple([self.tmp.step >> 10] + accuracies))
240 | return np.array(accuracies, 'f')
241 |
242 | def add_summaries(self, feed_extra=None, **kwargs):
243 | del kwargs
244 |
245 | def gen_stats():
246 | return self.eval_stats(feed_extra=feed_extra)
247 |
248 | accuracies = tf.py_func(gen_stats, [], tf.float32)
249 | tf.summary.scalar('accuracy/train_labeled', accuracies[0])
250 | tf.summary.scalar('accuracy/valid', accuracies[1])
251 | tf.summary.scalar('accuracy', accuracies[2])
252 |
--------------------------------------------------------------------------------
/libml/utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Utilities."""
15 |
16 | import os
17 | import re
18 |
19 | import numpy as np
20 | import tensorflow as tf
21 | from absl import flags, logging
22 | from tensorflow.python.client import device_lib
23 |
24 | _GPUS = None
25 | FLAGS = flags.FLAGS
26 | flags.DEFINE_bool('log_device_placement', False, 'For debugging purpose.')
27 |
28 |
29 | class EasyDict(dict):
30 | def __init__(self, *args, **kwargs):
31 | super(EasyDict, self).__init__(*args, **kwargs)
32 | self.__dict__ = self
33 |
34 |
35 | def get_config():
36 | config = tf.ConfigProto()
37 | if len(get_available_gpus()) > 1:
38 | config.allow_soft_placement = True
39 | if FLAGS.log_device_placement:
40 | config.log_device_placement = True
41 | config.gpu_options.allow_growth = True
42 | return config
43 |
44 |
45 | def setup_main():
46 | pass
47 |
48 |
49 | def setup_tf():
50 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
51 | logging.set_verbosity(logging.ERROR)
52 |
53 |
54 | def smart_shape(x):
55 | s = x.shape
56 | st = tf.shape(x)
57 | return [s[i] if s[i].value is not None else st[i] for i in range(4)]
58 |
59 |
60 | def ilog2(x):
61 | """Integer log2."""
62 | return int(np.ceil(np.log2(x)))
63 |
64 |
65 | def find_latest_checkpoint(dir, glob_term='model.ckpt-*.meta'):
66 | """Replacement for tf.train.latest_checkpoint.
67 |
68 | It does not rely on the "checkpoint" file which sometimes contains
69 | absolute path and is generally hard to work with when sharing files
70 | between users / computers.
71 | """
72 | r_step = re.compile('.*model\.ckpt-(?P\d+)\.meta')
73 | matches = tf.gfile.Glob(os.path.join(dir, glob_term))
74 | matches = [(int(r_step.match(x).group('step')), x) for x in matches]
75 | ckpt_file = max(matches)[1][:-5]
76 | return ckpt_file
77 |
78 |
79 | def get_latest_global_step(dir):
80 | """Loads the global step from the latest checkpoint in directory.
81 |
82 | Args:
83 | dir: string, path to the checkpoint directory.
84 |
85 | Returns:
86 | int, the global step of the latest checkpoint or 0 if none was found.
87 | """
88 | try:
89 | checkpoint_reader = tf.train.NewCheckpointReader(find_latest_checkpoint(dir))
90 | return checkpoint_reader.get_tensor(tf.GraphKeys.GLOBAL_STEP)
91 | except: # pylint: disable=bare-except
92 | return 0
93 |
94 |
95 | def get_latest_global_step_in_subdir(dir):
96 | """Loads the global step from the latest checkpoint in sub-directories.
97 |
98 | Args:
99 | dir: string, parent of the checkpoint directories.
100 |
101 | Returns:
102 | int, the global step of the latest checkpoint or 0 if none was found.
103 | """
104 | sub_dirs = (x for x in tf.gfile.Glob(os.path.join(dir, '*')) if os.path.isdir(x))
105 | step = 0
106 | for x in sub_dirs:
107 | step = max(step, get_latest_global_step(x))
108 | return step
109 |
110 |
111 | def getter_ema(ema, getter, name, *args, **kwargs):
112 | """Exponential moving average getter for variable scopes.
113 |
114 | Args:
115 | ema: ExponentialMovingAverage object, where to get variable moving averages.
116 | getter: default variable scope getter.
117 | name: variable name.
118 | *args: extra args passed to default getter.
119 | **kwargs: extra args passed to default getter.
120 |
121 | Returns:
122 | If found the moving average variable, otherwise the default variable.
123 | """
124 | var = getter(name, *args, **kwargs)
125 | ema_var = ema.average(var)
126 | return ema_var if ema_var else var
127 |
128 |
129 | def model_vars(scope=None):
130 | return tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope)
131 |
132 |
133 | def gpu(x):
134 | return '/gpu:%d' % (x % max(1, len(get_available_gpus())))
135 |
136 |
137 | def get_available_gpus():
138 | global _GPUS
139 | if _GPUS is None:
140 | config = tf.ConfigProto()
141 | config.gpu_options.allow_growth = True
142 | local_device_protos = device_lib.list_local_devices(session_config=config)
143 | _GPUS = tuple([x.name for x in local_device_protos if x.device_type == 'GPU'])
144 | return _GPUS
145 |
146 |
147 | def get_gpu():
148 | gpus = get_available_gpus()
149 | pos = 0
150 | while 1:
151 | yield gpus[pos]
152 | pos = (pos + 1) % len(gpus)
153 |
154 |
155 | def average_gradients(tower_grads):
156 | # Adapted from:
157 | # https://github.com/tensorflow/models/blob/master/tutorials/image/cifar10/cifar10_multi_gpu_train.py
158 | """Calculate the average gradient for each shared variable across all towers.
159 | Note that this function provides a synchronization point across all towers.
160 | Args:
161 | tower_grads: List of lists of (gradient, variable) tuples. For each tower, a list of its gradients.
162 | Returns:
163 | List of pairs of (gradient, variable) where the gradient has been averaged
164 | across all towers.
165 | """
166 | if len(tower_grads) <= 1:
167 | return tower_grads[0]
168 |
169 | average_grads = []
170 | for grads_and_vars in zip(*tower_grads):
171 | grad = tf.reduce_mean([gv[0] for gv in grads_and_vars], 0)
172 | average_grads.append((grad, grads_and_vars[0][1]))
173 | return average_grads
174 |
175 |
176 | def para_list(fn, *args):
177 | """Run on multiple GPUs in parallel and return list of results."""
178 | gpus = len(get_available_gpus())
179 | if gpus <= 1:
180 | return zip(*[fn(*args)])
181 | splitted = [tf.split(x, gpus) for x in args]
182 | outputs = []
183 | for gpu, x in enumerate(zip(*splitted)):
184 | with tf.name_scope('tower%d' % gpu):
185 | with tf.device(tf.train.replica_device_setter(
186 | worker_device='/gpu:%d' % gpu, ps_device='/cpu:0', ps_tasks=1)):
187 | outputs.append(fn(*x))
188 | return zip(*outputs)
189 |
190 |
191 | def para_mean(fn, *args):
192 | """Run on multiple GPUs in parallel and return means."""
193 | gpus = len(get_available_gpus())
194 | if gpus <= 1:
195 | return fn(*args)
196 | splitted = [tf.split(x, gpus) for x in args]
197 | outputs = []
198 | for gpu, x in enumerate(zip(*splitted)):
199 | with tf.name_scope('tower%d' % gpu):
200 | with tf.device(tf.train.replica_device_setter(
201 | worker_device='/gpu:%d' % gpu, ps_device='/cpu:0', ps_tasks=1)):
202 | outputs.append(fn(*x))
203 | if isinstance(outputs[0], (tuple, list)):
204 | return [tf.reduce_mean(x, 0) for x in zip(*outputs)]
205 | return tf.reduce_mean(outputs, 0)
206 |
207 |
208 | def para_cat(fn, *args):
209 | """Run on multiple GPUs in parallel and return concatenated outputs."""
210 | gpus = len(get_available_gpus())
211 | if gpus <= 1:
212 | return fn(*args)
213 | splitted = [tf.split(x, gpus) for x in args]
214 | outputs = []
215 | for gpu, x in enumerate(zip(*splitted)):
216 | with tf.name_scope('tower%d' % gpu):
217 | with tf.device(tf.train.replica_device_setter(
218 | worker_device='/gpu:%d' % gpu, ps_device='/cpu:0', ps_tasks=1)):
219 | outputs.append(fn(*x))
220 | if isinstance(outputs[0], (tuple, list)):
221 | return [tf.concat(x, axis=0) for x in zip(*outputs)]
222 | return tf.concat(outputs, axis=0)
223 |
224 |
225 | def combine_dicts(*args):
226 | # Python 2 compatible way to combine several dictionaries
227 | # We need it because currently TPU code does not work with python 3
228 | result = {}
229 | for d in args:
230 | result.update(d)
231 | return result
232 |
--------------------------------------------------------------------------------
/mean_teacher.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Mean teachers are better role models:
16 | Weight-averaged consistency targets improve semi-supervised deep learning results.
17 |
18 | Reimplementation of https://arxiv.org/abs/1703.01780
19 | """
20 | import functools
21 | import os
22 |
23 | import tensorflow as tf
24 | from absl import app
25 | from absl import flags
26 |
27 | from libml import models, utils
28 | from libml.data import PAIR_DATASETS
29 | from libml.utils import EasyDict
30 |
31 | FLAGS = flags.FLAGS
32 |
33 |
34 | class MeanTeacher(models.MultiModel):
35 |
36 | def model(self, batch, lr, wd, ema, warmup_pos, consistency_weight, **kwargs):
37 | hwc = [self.dataset.height, self.dataset.width, self.dataset.colors]
38 | xt_in = tf.placeholder(tf.float32, [batch] + hwc, 'xt') # For training
39 | x_in = tf.placeholder(tf.float32, [None] + hwc, 'x')
40 | y_in = tf.placeholder(tf.float32, [batch, 2] + hwc, 'y')
41 | l_in = tf.placeholder(tf.int32, [batch], 'labels')
42 | l = tf.one_hot(l_in, self.nclass)
43 | wd *= lr
44 | warmup = tf.clip_by_value(tf.to_float(self.step) / (warmup_pos * (FLAGS.train_kimg << 10)), 0, 1)
45 |
46 | classifier = lambda x, **kw: self.classifier(x, **kw, **kwargs).logits
47 | logits_x = classifier(xt_in, training=True)
48 | post_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) # Take only first call to update batch norm.
49 | y = tf.reshape(tf.transpose(y_in, [1, 0, 2, 3, 4]), [-1] + hwc)
50 | y_1, y_2 = tf.split(y, 2)
51 | ema = tf.train.ExponentialMovingAverage(decay=ema)
52 | ema_op = ema.apply(utils.model_vars())
53 | ema_getter = functools.partial(utils.getter_ema, ema)
54 | logits_y = classifier(y_1, training=True, getter=ema_getter)
55 | logits_teacher = tf.stop_gradient(logits_y)
56 | logits_student = classifier(y_2, training=True)
57 | loss_mt = tf.reduce_mean((tf.nn.softmax(logits_teacher) - tf.nn.softmax(logits_student)) ** 2, -1)
58 | loss_mt = tf.reduce_mean(loss_mt)
59 |
60 | loss = tf.nn.softmax_cross_entropy_with_logits_v2(labels=l, logits=logits_x)
61 | loss = tf.reduce_mean(loss)
62 | tf.summary.scalar('losses/xe', loss)
63 | tf.summary.scalar('losses/mt', loss_mt)
64 |
65 | post_ops.append(ema_op)
66 | post_ops.extend([tf.assign(v, v * (1 - wd)) for v in utils.model_vars('classify') if 'kernel' in v.name])
67 |
68 | train_op = tf.train.AdamOptimizer(lr).minimize(loss + loss_mt * warmup * consistency_weight,
69 | colocate_gradients_with_ops=True)
70 | with tf.control_dependencies([train_op]):
71 | train_op = tf.group(*post_ops)
72 |
73 | return EasyDict(
74 | xt=xt_in, x=x_in, y=y_in, label=l_in, train_op=train_op,
75 | classify_raw=tf.nn.softmax(classifier(x_in, training=False)), # No EMA, for debugging.
76 | classify_op=tf.nn.softmax(classifier(x_in, getter=ema_getter, training=False)))
77 |
78 |
79 | def main(argv):
80 | utils.setup_main()
81 | del argv # Unused.
82 | dataset = PAIR_DATASETS()[FLAGS.dataset]()
83 | log_width = utils.ilog2(dataset.width)
84 | model = MeanTeacher(
85 | os.path.join(FLAGS.train_dir, dataset.name),
86 | dataset,
87 | lr=FLAGS.lr,
88 | wd=FLAGS.wd,
89 | arch=FLAGS.arch,
90 | warmup_pos=FLAGS.warmup_pos,
91 | batch=FLAGS.batch,
92 | nclass=dataset.nclass,
93 | ema=FLAGS.ema,
94 | smoothing=FLAGS.smoothing,
95 | consistency_weight=FLAGS.consistency_weight,
96 |
97 | scales=FLAGS.scales or (log_width - 2),
98 | filters=FLAGS.filters,
99 | repeat=FLAGS.repeat)
100 | model.train(FLAGS.train_kimg << 10, FLAGS.report_kimg << 10)
101 |
102 |
103 | if __name__ == '__main__':
104 | utils.setup_tf()
105 | flags.DEFINE_float('consistency_weight', 50., 'Consistency weight.')
106 | flags.DEFINE_float('warmup_pos', 0.4, 'Relative position at which constraint loss warmup ends.')
107 | flags.DEFINE_float('wd', 0.02, 'Weight decay.')
108 | flags.DEFINE_float('ema', 0.999, 'Exponential moving average of params.')
109 | flags.DEFINE_float('smoothing', 0.001, 'Label smoothing.')
110 | flags.DEFINE_integer('scales', 0, 'Number of 2x2 downscalings in the classifier.')
111 | flags.DEFINE_integer('filters', 32, 'Filter size of convolutions.')
112 | flags.DEFINE_integer('repeat', 4, 'Number of residual layers per stage.')
113 | FLAGS.set_default('augment', 'd.d.d')
114 | FLAGS.set_default('dataset', 'cifar10.3@250-5000')
115 | FLAGS.set_default('batch', 64)
116 | FLAGS.set_default('lr', 0.002)
117 | FLAGS.set_default('train_kimg', 1 << 16)
118 | app.run(main)
119 |
--------------------------------------------------------------------------------
/mixmatch.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """MixMatch training.
15 | - Ensure class consistency by producing a group of `nu` augmentations of the same image and guessing the label for the
16 | group.
17 | - Sharpen the target distribution.
18 | - Use the sharpened distribution directly as a smooth label in MixUp.
19 | """
20 |
21 | import functools
22 | import os
23 |
24 | import tensorflow as tf
25 | from absl import app
26 | from absl import flags
27 |
28 | from libml import layers, utils, models
29 | from libml.data import PAIR_DATASETS
30 | from libml.layers import MixMode
31 | from libml.utils import EasyDict
32 |
33 | FLAGS = flags.FLAGS
34 |
35 |
36 | class MixMatch(models.MultiModel):
37 |
38 | def distribution_summary(self, p_data, p_model, p_target=None):
39 | def kl(p, q):
40 | p /= tf.reduce_sum(p)
41 | q /= tf.reduce_sum(q)
42 | return -tf.reduce_sum(p * tf.log(q / p))
43 |
44 | tf.summary.scalar('metrics/kld', kl(p_data, p_model))
45 | if p_target is not None:
46 | tf.summary.scalar('metrics/kld_target', kl(p_data, p_target))
47 |
48 | for i in range(self.nclass):
49 | tf.summary.scalar('matching/class%d_ratio' % i, p_model[i] / p_data[i])
50 | for i in range(self.nclass):
51 | tf.summary.scalar('matching/val%d' % i, p_model[i])
52 |
53 | def augment(self, x, l, beta, **kwargs):
54 | assert 0, 'Do not call.'
55 |
56 | def guess_label(self, y, classifier, T, **kwargs):
57 | del kwargs
58 | logits_y = [classifier(yi, training=True) for yi in y]
59 | logits_y = tf.concat(logits_y, 0)
60 | # Compute predicted probability distribution py.
61 | p_model_y = tf.reshape(tf.nn.softmax(logits_y), [len(y), -1, self.nclass])
62 | p_model_y = tf.reduce_mean(p_model_y, axis=0)
63 | # Compute the target distribution.
64 | p_target = tf.pow(p_model_y, 1. / T)
65 | p_target /= tf.reduce_sum(p_target, axis=1, keep_dims=True)
66 | return EasyDict(p_target=p_target, p_model=p_model_y)
67 |
68 | def model(self, batch, lr, wd, ema, beta, w_match, warmup_kimg=1024, nu=2, mixmode='xxy.yxy', dbuf=128, **kwargs):
69 | hwc = [self.dataset.height, self.dataset.width, self.dataset.colors]
70 | xt_in = tf.placeholder(tf.float32, [batch] + hwc, 'xt') # For training
71 | x_in = tf.placeholder(tf.float32, [None] + hwc, 'x')
72 | y_in = tf.placeholder(tf.float32, [batch, nu] + hwc, 'y')
73 | l_in = tf.placeholder(tf.int32, [batch], 'labels')
74 | wd *= lr
75 | w_match *= tf.clip_by_value(tf.cast(self.step, tf.float32) / (warmup_kimg << 10), 0, 1)
76 | augment = MixMode(mixmode)
77 | classifier = lambda x, **kw: self.classifier(x, **kw, **kwargs).logits
78 |
79 | # Moving average of the current estimated label distribution
80 | p_model = layers.PMovingAverage('p_model', self.nclass, dbuf)
81 | p_target = layers.PMovingAverage('p_target', self.nclass, dbuf) # Rectified distribution (only for plotting)
82 |
83 | # Known (or inferred) true unlabeled distribution
84 | p_data = layers.PData(self.dataset)
85 |
86 | y = tf.reshape(tf.transpose(y_in, [1, 0, 2, 3, 4]), [-1] + hwc)
87 | guess = self.guess_label(tf.split(y, nu), classifier, T=0.5, **kwargs)
88 | ly = tf.stop_gradient(guess.p_target)
89 | lx = tf.one_hot(l_in, self.nclass)
90 | xy, labels_xy = augment([xt_in] + tf.split(y, nu), [lx] + [ly] * nu, [beta, beta])
91 | x, y = xy[0], xy[1:]
92 | labels_x, labels_y = labels_xy[0], tf.concat(labels_xy[1:], 0)
93 | del xy, labels_xy
94 |
95 | batches = layers.interleave([x] + y, batch)
96 | skip_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
97 | logits = [classifier(batches[0], training=True)]
98 | post_ops = [v for v in tf.get_collection(tf.GraphKeys.UPDATE_OPS) if v not in skip_ops]
99 | for batchi in batches[1:]:
100 | logits.append(classifier(batchi, training=True))
101 | logits = layers.interleave(logits, batch)
102 | logits_x = logits[0]
103 | logits_y = tf.concat(logits[1:], 0)
104 |
105 | loss_xe = tf.nn.softmax_cross_entropy_with_logits_v2(labels=labels_x, logits=logits_x)
106 | loss_xe = tf.reduce_mean(loss_xe)
107 | loss_l2u = tf.square(labels_y - tf.nn.softmax(logits_y))
108 | loss_l2u = tf.reduce_mean(loss_l2u)
109 | tf.summary.scalar('losses/xe', loss_xe)
110 | tf.summary.scalar('losses/l2u', loss_l2u)
111 | self.distribution_summary(p_data(), p_model(), p_target())
112 |
113 | ema = tf.train.ExponentialMovingAverage(decay=ema)
114 | ema_op = ema.apply(utils.model_vars())
115 | ema_getter = functools.partial(utils.getter_ema, ema)
116 | post_ops.extend([ema_op,
117 | p_model.update(guess.p_model),
118 | p_target.update(guess.p_target)])
119 | if p_data.has_update:
120 | post_ops.append(p_data.update(lx))
121 | post_ops.extend([tf.assign(v, v * (1 - wd)) for v in utils.model_vars('classify') if 'kernel' in v.name])
122 |
123 | train_op = tf.train.AdamOptimizer(lr).minimize(loss_xe + w_match * loss_l2u, colocate_gradients_with_ops=True)
124 | with tf.control_dependencies([train_op]):
125 | train_op = tf.group(*post_ops)
126 |
127 | return EasyDict(
128 | xt=xt_in, x=x_in, y=y_in, label=l_in, train_op=train_op,
129 | classify_raw=tf.nn.softmax(classifier(x_in, training=False)), # No EMA, for debugging.
130 | classify_op=tf.nn.softmax(classifier(x_in, getter=ema_getter, training=False)))
131 |
132 |
133 | def main(argv):
134 | utils.setup_main()
135 | del argv # Unused.
136 | dataset = PAIR_DATASETS()[FLAGS.dataset]()
137 | log_width = utils.ilog2(dataset.width)
138 | model = MixMatch(
139 | os.path.join(FLAGS.train_dir, dataset.name),
140 | dataset,
141 | lr=FLAGS.lr,
142 | wd=FLAGS.wd,
143 | arch=FLAGS.arch,
144 | batch=FLAGS.batch,
145 | nclass=dataset.nclass,
146 | ema=FLAGS.ema,
147 |
148 | beta=FLAGS.beta,
149 | w_match=FLAGS.w_match,
150 |
151 | scales=FLAGS.scales or (log_width - 2),
152 | filters=FLAGS.filters,
153 | repeat=FLAGS.repeat)
154 | model.train(FLAGS.train_kimg << 10, FLAGS.report_kimg << 10)
155 |
156 |
157 | if __name__ == '__main__':
158 | utils.setup_tf()
159 | flags.DEFINE_float('wd', 0.02, 'Weight decay.')
160 | flags.DEFINE_float('ema', 0.999, 'Exponential moving average of params.')
161 | flags.DEFINE_float('beta', 0.5, 'Mixup beta distribution.')
162 | flags.DEFINE_float('w_match', 100, 'Weight for distribution matching loss.')
163 | flags.DEFINE_integer('scales', 0, 'Number of 2x2 downscalings in the classifier.')
164 | flags.DEFINE_integer('filters', 32, 'Filter size of convolutions.')
165 | flags.DEFINE_integer('repeat', 4, 'Number of residual layers per stage.')
166 | FLAGS.set_default('augment', 'd.d.d')
167 | FLAGS.set_default('dataset', 'cifar10.3@250-5000')
168 | FLAGS.set_default('batch', 64)
169 | FLAGS.set_default('lr', 0.002)
170 | FLAGS.set_default('train_kimg', 1 << 16)
171 | app.run(main)
172 |
--------------------------------------------------------------------------------
/mixup.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """mixup: Beyond Empirical Risk Minimization.
15 |
16 | Adaption to SSL of MixUp: https://arxiv.org/abs/1710.09412
17 | """
18 | import functools
19 | import os
20 |
21 | import tensorflow as tf
22 | from absl import app
23 | from absl import flags
24 |
25 | from libml import data, utils, models
26 | from libml.utils import EasyDict
27 |
28 | FLAGS = flags.FLAGS
29 |
30 |
31 | class Mixup(models.MultiModel):
32 |
33 | def augment(self, x, l, beta, **kwargs):
34 | del kwargs
35 | mix = tf.distributions.Beta(beta, beta).sample([tf.shape(x)[0], 1, 1, 1])
36 | mix = tf.maximum(mix, 1 - mix)
37 | xmix = x * mix + x[::-1] * (1 - mix)
38 | lmix = l * mix[:, :, 0, 0] + l[::-1] * (1 - mix[:, :, 0, 0])
39 | return xmix, lmix
40 |
41 | def model(self, batch, lr, wd, ema, **kwargs):
42 | hwc = [self.dataset.height, self.dataset.width, self.dataset.colors]
43 | xt_in = tf.placeholder(tf.float32, [batch] + hwc, 'xt') # For training
44 | x_in = tf.placeholder(tf.float32, [None] + hwc, 'x')
45 | y_in = tf.placeholder(tf.float32, [batch] + hwc, 'y')
46 | l_in = tf.placeholder(tf.int32, [batch], 'labels')
47 | wd *= lr
48 | classifier = lambda x, **kw: self.classifier(x, **kw, **kwargs).logits
49 |
50 | def get_logits(x):
51 | logits = classifier(x, training=True)
52 | return logits
53 |
54 | x, labels_x = self.augment(xt_in, tf.one_hot(l_in, self.nclass), **kwargs)
55 | logits_x = get_logits(x)
56 | post_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
57 | y, labels_y = self.augment(y_in, tf.nn.softmax(get_logits(y_in)), **kwargs)
58 | labels_y = tf.stop_gradient(labels_y)
59 | logits_y = get_logits(y)
60 |
61 | loss_xe = tf.nn.softmax_cross_entropy_with_logits_v2(labels=labels_x, logits=logits_x)
62 | loss_xe = tf.reduce_mean(loss_xe)
63 | loss_xeu = tf.nn.softmax_cross_entropy_with_logits_v2(labels=labels_y, logits=logits_y)
64 | loss_xeu = tf.reduce_mean(loss_xeu)
65 | tf.summary.scalar('losses/xe', loss_xe)
66 | tf.summary.scalar('losses/xeu', loss_xeu)
67 |
68 | ema = tf.train.ExponentialMovingAverage(decay=ema)
69 | ema_op = ema.apply(utils.model_vars())
70 | ema_getter = functools.partial(utils.getter_ema, ema)
71 | post_ops.append(ema_op)
72 | post_ops.extend([tf.assign(v, v * (1 - wd)) for v in utils.model_vars('classify') if 'kernel' in v.name])
73 |
74 | train_op = tf.train.AdamOptimizer(lr).minimize(loss_xe + loss_xeu, colocate_gradients_with_ops=True)
75 | with tf.control_dependencies([train_op]):
76 | train_op = tf.group(*post_ops)
77 |
78 | return EasyDict(
79 | xt=xt_in, x=x_in, y=y_in, label=l_in, train_op=train_op,
80 | classify_raw=tf.nn.softmax(classifier(x_in, training=False)), # No EMA, for debugging.
81 | classify_op=tf.nn.softmax(classifier(x_in, getter=ema_getter, training=False)))
82 |
83 |
84 | def main(argv):
85 | utils.setup_main()
86 | del argv # Unused.
87 | dataset = data.DATASETS()[FLAGS.dataset]()
88 | log_width = utils.ilog2(dataset.width)
89 | model = Mixup(
90 | os.path.join(FLAGS.train_dir, dataset.name),
91 | dataset,
92 | lr=FLAGS.lr,
93 | wd=FLAGS.wd,
94 | arch=FLAGS.arch,
95 | batch=FLAGS.batch,
96 | nclass=dataset.nclass,
97 | ema=FLAGS.ema,
98 | beta=FLAGS.beta,
99 |
100 | scales=FLAGS.scales or (log_width - 2),
101 | filters=FLAGS.filters,
102 | repeat=FLAGS.repeat)
103 | model.train(FLAGS.train_kimg << 10, FLAGS.report_kimg << 10)
104 |
105 |
106 | if __name__ == '__main__':
107 | utils.setup_tf()
108 | flags.DEFINE_float('wd', 0.02, 'Weight decay.')
109 | flags.DEFINE_float('ema', 0.999, 'Exponential moving average of params.')
110 | flags.DEFINE_float('beta', 0.5, 'Mixup beta distribution.')
111 | flags.DEFINE_integer('scales', 0, 'Number of 2x2 downscalings in the classifier.')
112 | flags.DEFINE_integer('filters', 32, 'Filter size of convolutions.')
113 | flags.DEFINE_integer('repeat', 4, 'Number of residual layers per stage.')
114 | FLAGS.set_default('dataset', 'cifar10.3@250-5000')
115 | FLAGS.set_default('batch', 64)
116 | FLAGS.set_default('lr', 0.002)
117 | FLAGS.set_default('train_kimg', 1 << 16)
118 | app.run(main)
119 |
--------------------------------------------------------------------------------
/pi_model.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Temporal Ensembling for Semi-Supervised Learning.
15 |
16 | Pi-model reimplementation of https://arxiv.org/abs/1610.02242
17 | """
18 |
19 | import functools
20 | import os
21 |
22 | import tensorflow as tf
23 | from absl import app
24 | from absl import flags
25 |
26 | from libml import models, utils
27 | from libml.data import PAIR_DATASETS
28 | from libml.utils import EasyDict
29 |
30 | FLAGS = flags.FLAGS
31 |
32 |
33 | class PiModel(models.MultiModel):
34 |
35 | def model(self, batch, lr, wd, ema, warmup_pos, consistency_weight, **kwargs):
36 | hwc = [self.dataset.height, self.dataset.width, self.dataset.colors]
37 | xt_in = tf.placeholder(tf.float32, [batch] + hwc, 'xt') # For training
38 | x_in = tf.placeholder(tf.float32, [None] + hwc, 'x')
39 | y_in = tf.placeholder(tf.float32, [batch, 2] + hwc, 'y')
40 | l_in = tf.placeholder(tf.int32, [batch], 'labels')
41 | l = tf.one_hot(l_in, self.nclass)
42 | wd *= lr
43 | warmup = tf.clip_by_value(tf.to_float(self.step) / (warmup_pos * (FLAGS.train_kimg << 10)), 0, 1)
44 |
45 | classifier = lambda x, **kw: self.classifier(x, **kw, **kwargs).logits
46 | logits_x = classifier(xt_in, training=True)
47 | post_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) # Take only first call to update batch norm.
48 | y = tf.reshape(tf.transpose(y_in, [1, 0, 2, 3, 4]), [-1] + hwc)
49 | y_1, y_2 = tf.split(y, 2)
50 | logits_y = classifier(y_1, training=True)
51 | logits_teacher = tf.stop_gradient(logits_y)
52 | logits_student = classifier(y_2, training=True)
53 | loss_pm = tf.reduce_mean((tf.nn.softmax(logits_teacher) - tf.nn.softmax(logits_student)) ** 2, -1)
54 | loss_pm = tf.reduce_mean(loss_pm)
55 |
56 | loss = tf.nn.softmax_cross_entropy_with_logits_v2(labels=l, logits=logits_x)
57 | loss = tf.reduce_mean(loss)
58 | tf.summary.scalar('losses/xe', loss)
59 | tf.summary.scalar('losses/pm', loss_pm)
60 |
61 | ema = tf.train.ExponentialMovingAverage(decay=ema)
62 | ema_op = ema.apply(utils.model_vars())
63 | ema_getter = functools.partial(utils.getter_ema, ema)
64 | post_ops.append(ema_op)
65 | post_ops.extend([tf.assign(v, v * (1 - wd)) for v in utils.model_vars('classify') if 'kernel' in v.name])
66 |
67 | train_op = tf.train.AdamOptimizer(lr).minimize(loss + loss_pm * warmup * consistency_weight,
68 | colocate_gradients_with_ops=True)
69 | with tf.control_dependencies([train_op]):
70 | train_op = tf.group(*post_ops)
71 |
72 | return EasyDict(
73 | xt=xt_in, x=x_in, y=y_in, label=l_in, train_op=train_op,
74 | classify_raw=tf.nn.softmax(classifier(x_in, training=False)), # No EMA, for debugging.
75 | classify_op=tf.nn.softmax(classifier(x_in, getter=ema_getter, training=False)))
76 |
77 |
78 | def main(argv):
79 | utils.setup_main()
80 | del argv # Unused.
81 | dataset = PAIR_DATASETS()[FLAGS.dataset]()
82 | log_width = utils.ilog2(dataset.width)
83 | model = PiModel(
84 | os.path.join(FLAGS.train_dir, dataset.name),
85 | dataset,
86 | lr=FLAGS.lr,
87 | wd=FLAGS.wd,
88 | arch=FLAGS.arch,
89 | warmup_pos=FLAGS.warmup_pos,
90 | batch=FLAGS.batch,
91 | nclass=dataset.nclass,
92 | ema=FLAGS.ema,
93 | smoothing=FLAGS.smoothing,
94 | consistency_weight=FLAGS.consistency_weight,
95 |
96 | scales=FLAGS.scales or (log_width - 2),
97 | filters=FLAGS.filters,
98 | repeat=FLAGS.repeat)
99 | model.train(FLAGS.train_kimg << 10, FLAGS.report_kimg << 10)
100 |
101 |
102 | if __name__ == '__main__':
103 | utils.setup_tf()
104 | flags.DEFINE_float('consistency_weight', 10., 'Consistency weight.')
105 | flags.DEFINE_float('warmup_pos', 0.4, 'Relative position at which constraint loss warmup ends.')
106 | flags.DEFINE_float('wd', 0.02, 'Weight decay.')
107 | flags.DEFINE_float('ema', 0.999, 'Exponential moving average of params.')
108 | flags.DEFINE_float('smoothing', 0.1, 'Label smoothing.')
109 | flags.DEFINE_integer('scales', 0, 'Number of 2x2 downscalings in the classifier.')
110 | flags.DEFINE_integer('filters', 32, 'Filter size of convolutions.')
111 | flags.DEFINE_integer('repeat', 4, 'Number of residual layers per stage.')
112 | FLAGS.set_default('augment', 'd.d.d')
113 | FLAGS.set_default('dataset', 'cifar10.3@250-5000')
114 | FLAGS.set_default('batch', 64)
115 | FLAGS.set_default('lr', 0.002)
116 | FLAGS.set_default('train_kimg', 1 << 16)
117 | app.run(main)
118 |
--------------------------------------------------------------------------------
/pseudo_label.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Pseudo-label: The simple and efficient semi-supervised learning method fordeep neural networks.
15 |
16 | Reimplementation of http://deeplearning.net/wp-content/uploads/2013/03/pseudo_label_final.pdf
17 | """
18 |
19 | import functools
20 | import os
21 |
22 | import tensorflow as tf
23 | from absl import app
24 | from absl import flags
25 |
26 | from libml import utils, data, models
27 | from libml.utils import EasyDict
28 |
29 | FLAGS = flags.FLAGS
30 |
31 |
32 | class PseudoLabel(models.MultiModel):
33 |
34 | def model(self, batch, lr, wd, ema, warmup_pos, consistency_weight, threshold, **kwargs):
35 | hwc = [self.dataset.height, self.dataset.width, self.dataset.colors]
36 | xt_in = tf.placeholder(tf.float32, [batch] + hwc, 'xt') # For training
37 | x_in = tf.placeholder(tf.float32, [None] + hwc, 'x')
38 | y_in = tf.placeholder(tf.float32, [batch] + hwc, 'y')
39 | l_in = tf.placeholder(tf.int32, [batch], 'labels')
40 | l = tf.one_hot(l_in, self.nclass)
41 | wd *= lr
42 | warmup = tf.clip_by_value(tf.to_float(self.step) / (warmup_pos * (FLAGS.train_kimg << 10)), 0, 1)
43 |
44 | classifier = lambda x, **kw: self.classifier(x, **kw, **kwargs).logits
45 | logits_x = classifier(xt_in, training=True)
46 | post_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) # Take only first call to update batch norm.
47 | logits_y = classifier(y_in, training=True)
48 | # Get the pseudo-label loss
49 | loss_pl = tf.nn.sparse_softmax_cross_entropy_with_logits(
50 | labels=tf.argmax(logits_y, axis=-1), logits=logits_y
51 | )
52 | # Masks denoting which data points have high-confidence predictions
53 | greater_than_thresh = tf.reduce_any(
54 | tf.greater(tf.nn.softmax(logits_y), threshold),
55 | axis=-1,
56 | keepdims=True,
57 | )
58 | greater_than_thresh = tf.cast(greater_than_thresh, loss_pl.dtype)
59 | # Only enforce the loss when the model is confident
60 | loss_pl *= greater_than_thresh
61 | # Note that we also average over examples without confident outputs;
62 | # this is consistent with the realistic evaluation codebase
63 | loss_pl = tf.reduce_mean(loss_pl)
64 |
65 | loss = tf.nn.softmax_cross_entropy_with_logits_v2(labels=l, logits=logits_x)
66 | loss = tf.reduce_mean(loss)
67 | tf.summary.scalar('losses/xe', loss)
68 | tf.summary.scalar('losses/pl', loss_pl)
69 |
70 | ema = tf.train.ExponentialMovingAverage(decay=ema)
71 | ema_op = ema.apply(utils.model_vars())
72 | ema_getter = functools.partial(utils.getter_ema, ema)
73 | post_ops.append(ema_op)
74 | post_ops.extend([tf.assign(v, v * (1 - wd)) for v in utils.model_vars('classify') if 'kernel' in v.name])
75 |
76 | train_op = tf.train.AdamOptimizer(lr).minimize(loss + loss_pl * warmup * consistency_weight,
77 | colocate_gradients_with_ops=True)
78 | with tf.control_dependencies([train_op]):
79 | train_op = tf.group(*post_ops)
80 |
81 | return EasyDict(
82 | xt=xt_in, x=x_in, y=y_in, label=l_in, train_op=train_op,
83 | classify_raw=tf.nn.softmax(classifier(x_in, training=False)), # No EMA, for debugging.
84 | classify_op=tf.nn.softmax(classifier(x_in, getter=ema_getter, training=False)))
85 |
86 |
87 | def main(argv):
88 | utils.setup_main()
89 | del argv # Unused.
90 | dataset = data.DATASETS()[FLAGS.dataset]()
91 | log_width = utils.ilog2(dataset.width)
92 | model = PseudoLabel(
93 | os.path.join(FLAGS.train_dir, dataset.name),
94 | dataset,
95 | lr=FLAGS.lr,
96 | wd=FLAGS.wd,
97 | arch=FLAGS.arch,
98 | warmup_pos=FLAGS.warmup_pos,
99 | batch=FLAGS.batch,
100 | nclass=dataset.nclass,
101 | ema=FLAGS.ema,
102 | smoothing=FLAGS.smoothing,
103 | consistency_weight=FLAGS.consistency_weight,
104 | threshold=FLAGS.threshold,
105 |
106 | scales=FLAGS.scales or (log_width - 2),
107 | filters=FLAGS.filters,
108 | repeat=FLAGS.repeat)
109 | model.train(FLAGS.train_kimg << 10, FLAGS.report_kimg << 10)
110 |
111 |
112 | if __name__ == '__main__':
113 | utils.setup_tf()
114 | flags.DEFINE_float('wd', 0.02, 'Weight decay.')
115 | flags.DEFINE_float('consistency_weight', 1., 'Consistency weight.')
116 | flags.DEFINE_float('threshold', 0.95, 'Pseudo-label threshold.')
117 | flags.DEFINE_float('warmup_pos', 0.4, 'Relative position at which constraint loss warmup ends.')
118 | flags.DEFINE_float('ema', 0.999, 'Exponential moving average of params.')
119 | flags.DEFINE_float('smoothing', 0.1, 'Label smoothing.')
120 | flags.DEFINE_integer('scales', 0, 'Number of 2x2 downscalings in the classifier.')
121 | flags.DEFINE_integer('filters', 32, 'Filter size of convolutions.')
122 | flags.DEFINE_integer('repeat', 4, 'Number of residual layers per stage.')
123 | FLAGS.set_default('dataset', 'cifar10.3@250-5000')
124 | FLAGS.set_default('batch', 64)
125 | FLAGS.set_default('lr', 0.002)
126 | FLAGS.set_default('train_kimg', 1 << 16)
127 | app.run(main)
128 |
--------------------------------------------------------------------------------
/remixmatch_no_cta.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import functools
16 | import os
17 |
18 | import numpy as np
19 | import tensorflow as tf
20 | from absl import app
21 | from absl import flags
22 |
23 | from libml import data, layers, utils
24 | from libml.utils import EasyDict
25 | from mixmatch import MixMatch
26 |
27 | FLAGS = flags.FLAGS
28 |
29 |
30 | class ReMixMatch(MixMatch):
31 |
32 | def classifier_rot(self, x):
33 | with tf.variable_scope('classify_rot', reuse=tf.AUTO_REUSE):
34 | return tf.layers.dense(x, 4, kernel_initializer=tf.glorot_normal_initializer())
35 |
36 | def guess_label(self, logits_y, p_data, p_model, T, use_dm, redux, **kwargs):
37 | del kwargs
38 | if redux == 'swap':
39 | p_model_y = tf.concat([tf.nn.softmax(x) for x in logits_y[1:] + logits_y[:1]], axis=0)
40 | elif redux == 'mean':
41 | p_model_y = sum(tf.nn.softmax(x) for x in logits_y) / len(logits_y)
42 | p_model_y = tf.tile(p_model_y, [len(logits_y), 1])
43 | elif redux == '1st':
44 | p_model_y = tf.nn.softmax(logits_y[0])
45 | p_model_y = tf.tile(p_model_y, [len(logits_y), 1])
46 | else:
47 | raise NotImplementedError()
48 |
49 | # Compute the target distribution.
50 | # 1. Rectify the distribution or not.
51 | if use_dm:
52 | p_ratio = (1e-6 + p_data) / (1e-6 + p_model)
53 | p_weighted = p_model_y * p_ratio
54 | p_weighted /= tf.reduce_sum(p_weighted, axis=1, keep_dims=True)
55 | else:
56 | p_weighted = p_model_y
57 | # 2. Apply sharpening.
58 | p_target = tf.pow(p_weighted, 1. / T)
59 | p_target /= tf.reduce_sum(p_target, axis=1, keep_dims=True)
60 | return EasyDict(p_target=p_target, p_model=p_model_y)
61 |
62 | def model(self, batch, lr, wd, beta, w_kl, w_match, w_rot, K, use_xe, warmup_kimg=1024, T=0.5,
63 | mixmode='xxy.yxy', dbuf=128, ema=0.999, **kwargs):
64 | hwc = [self.dataset.height, self.dataset.width, self.dataset.colors]
65 | xt_in = tf.placeholder(tf.float32, [batch] + hwc, 'xt') # For training
66 | x_in = tf.placeholder(tf.float32, [None] + hwc, 'x')
67 | y_in = tf.placeholder(tf.float32, [batch, K + 1] + hwc, 'y')
68 | l_in = tf.placeholder(tf.int32, [batch], 'labels')
69 | wd *= lr
70 | w_match *= tf.clip_by_value(tf.cast(self.step, tf.float32) / (warmup_kimg << 10), 0, 1)
71 | augment = layers.MixMode(mixmode)
72 |
73 | gpu = utils.get_gpu()
74 |
75 | def classifier_to_gpu(x, **kw):
76 | with tf.device(next(gpu)):
77 | return self.classifier(x, **kw, **kwargs).logits
78 |
79 | def random_rotate(x):
80 | b4 = batch // 4
81 | x, xt = x[:2 * b4], tf.transpose(x[2 * b4:], [0, 2, 1, 3])
82 | l = np.zeros(b4, np.int32)
83 | l = tf.constant(np.concatenate([l, l + 1, l + 2, l + 3], axis=0))
84 | return tf.concat([x[:b4], x[b4:, ::-1, ::-1], xt[:b4, ::-1], xt[b4:, :, ::-1]], axis=0), l
85 |
86 | # Moving average of the current estimated label distribution
87 | p_model = layers.PMovingAverage('p_model', self.nclass, dbuf)
88 | p_target = layers.PMovingAverage('p_target', self.nclass, dbuf) # Rectified distribution (only for plotting)
89 |
90 | # Known (or inferred) true unlabeled distribution
91 | p_data = layers.PData(self.dataset)
92 |
93 | if w_rot > 0:
94 | rot_y, rot_l = random_rotate(y_in[:, 1])
95 | with tf.device(next(gpu)):
96 | rot_logits = self.classifier_rot(self.classifier(rot_y, training=True, **kwargs).embeds)
97 | loss_rot = tf.nn.softmax_cross_entropy_with_logits_v2(labels=tf.one_hot(rot_l, 4), logits=rot_logits)
98 | loss_rot = tf.reduce_mean(loss_rot)
99 | tf.summary.scalar('losses/rot', loss_rot)
100 | else:
101 | loss_rot = 0
102 |
103 | if kwargs['redux'] == '1st' and w_kl <= 0:
104 | logits_y = [classifier_to_gpu(y_in[:, 0], training=True)] * (K + 1)
105 | elif kwargs['redux'] == '1st':
106 | logits_y = [classifier_to_gpu(y_in[:, i], training=True) for i in range(2)]
107 | logits_y += logits_y[:1] * (K - 1)
108 | else:
109 | logits_y = [classifier_to_gpu(y_in[:, i], training=True) for i in range(K + 1)]
110 |
111 | guess = self.guess_label(logits_y, p_data(), p_model(), T=T, **kwargs)
112 | ly = tf.stop_gradient(guess.p_target)
113 | if w_kl > 0:
114 | w_kl *= tf.clip_by_value(tf.cast(self.step, tf.float32) / (warmup_kimg << 10), 0, 1)
115 | loss_kl = tf.nn.softmax_cross_entropy_with_logits_v2(labels=ly[:batch], logits=logits_y[1])
116 | loss_kl = tf.reduce_mean(loss_kl)
117 | tf.summary.scalar('losses/kl', loss_kl)
118 | else:
119 | loss_kl = 0
120 | del logits_y
121 |
122 | lx = tf.one_hot(l_in, self.nclass)
123 | xy, labels_xy = augment([xt_in] + [y_in[:, i] for i in range(K + 1)], [lx] + tf.split(ly, K + 1),
124 | [beta, beta])
125 | x, y = xy[0], xy[1:]
126 | labels_x, labels_y = labels_xy[0], tf.concat(labels_xy[1:], 0)
127 | del xy, labels_xy
128 |
129 | batches = layers.interleave([x] + y, batch)
130 | logits = [classifier_to_gpu(yi, training=True) for yi in batches[:-1]]
131 | skip_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
132 | logits.append(classifier_to_gpu(batches[-1], training=True))
133 | post_ops = [v for v in tf.get_collection(tf.GraphKeys.UPDATE_OPS) if v not in skip_ops]
134 | logits = layers.interleave(logits, batch)
135 | logits_x = logits[0]
136 | logits_y = tf.concat(logits[1:], 0)
137 | del batches, logits
138 |
139 | loss_xe = tf.nn.softmax_cross_entropy_with_logits_v2(labels=labels_x, logits=logits_x)
140 | loss_xe = tf.reduce_mean(loss_xe)
141 | if use_xe:
142 | loss_xeu = tf.nn.softmax_cross_entropy_with_logits_v2(labels=labels_y, logits=logits_y)
143 | else:
144 | loss_xeu = tf.square(labels_y - tf.nn.softmax(logits_y))
145 | loss_xeu = tf.reduce_mean(loss_xeu)
146 | tf.summary.scalar('losses/xe', loss_xe)
147 | tf.summary.scalar('losses/%s' % ('xeu' if use_xe else 'l2u'), loss_xeu)
148 | self.distribution_summary(p_data(), p_model(), p_target())
149 |
150 | ema = tf.train.ExponentialMovingAverage(decay=ema)
151 | ema_op = ema.apply(utils.model_vars())
152 | ema_getter = functools.partial(utils.getter_ema, ema)
153 | post_ops.extend([ema_op,
154 | p_model.update(guess.p_model),
155 | p_target.update(guess.p_target)])
156 | if p_data.has_update:
157 | post_ops.append(p_data.update(lx))
158 | post_ops.extend([tf.assign(v, v * (1 - wd)) for v in utils.model_vars('classify') if 'kernel' in v.name])
159 |
160 | train_op = tf.train.AdamOptimizer(lr).minimize(loss_xe
161 | + w_kl * loss_kl
162 | + w_match * loss_xeu
163 | + w_rot * loss_rot,
164 | colocate_gradients_with_ops=True)
165 | with tf.control_dependencies([train_op]):
166 | train_op = tf.group(*post_ops)
167 |
168 | return EasyDict(
169 | xt=xt_in, x=x_in, y=y_in, label=l_in, train_op=train_op,
170 | classify_op=tf.nn.softmax(classifier_to_gpu(x_in, getter=ema_getter, training=False)),
171 | classify_raw=tf.nn.softmax(classifier_to_gpu(x_in, training=False))) # No EMA, for debugging.
172 |
173 |
174 | def main(argv):
175 | utils.setup_main()
176 | del argv # Unused.
177 | dataset = data.MANY_DATASETS()[FLAGS.dataset]()
178 | log_width = utils.ilog2(dataset.width)
179 | model = ReMixMatch(
180 | os.path.join(FLAGS.train_dir, dataset.name),
181 | dataset,
182 | lr=FLAGS.lr,
183 | wd=FLAGS.wd,
184 | arch=FLAGS.arch,
185 | batch=FLAGS.batch,
186 | nclass=dataset.nclass,
187 |
188 | K=FLAGS.K,
189 | beta=FLAGS.beta,
190 | w_kl=FLAGS.w_kl,
191 | w_match=FLAGS.w_match,
192 | w_rot=FLAGS.w_rot,
193 | redux=FLAGS.redux,
194 | use_dm=FLAGS.use_dm,
195 | use_xe=FLAGS.use_xe,
196 | warmup_kimg=FLAGS.warmup_kimg,
197 |
198 | scales=FLAGS.scales or (log_width - 2),
199 | filters=FLAGS.filters,
200 | repeat=FLAGS.repeat)
201 | model.train(FLAGS.train_kimg << 10, FLAGS.report_kimg << 10)
202 |
203 |
204 | if __name__ == '__main__':
205 | utils.setup_tf()
206 | flags.DEFINE_float('wd', 0.02, 'Weight decay.')
207 | flags.DEFINE_float('beta', 0.75, 'Mixup beta distribution.')
208 | flags.DEFINE_float('w_kl', 0.5, 'Weight for KL loss.')
209 | flags.DEFINE_float('w_match', 1.5, 'Weight for distribution matching loss.')
210 | flags.DEFINE_float('w_rot', 0.5, 'Weight for rotation loss.')
211 | flags.DEFINE_integer('scales', 0, 'Number of 2x2 downscalings in the classifier.')
212 | flags.DEFINE_integer('filters', 32, 'Filter size of convolutions.')
213 | flags.DEFINE_integer('repeat', 4, 'Number of residual layers per stage.')
214 | flags.DEFINE_integer('warmup_kimg', 1024, 'Unannealing duration for SSL loss.')
215 | flags.DEFINE_enum('redux', 'swap', 'swap mean 1st'.split(), 'Logit selection.')
216 | flags.DEFINE_bool('use_dm', True, 'Whether to use distribution matching.')
217 | flags.DEFINE_bool('use_xe', True, 'Whether to use cross-entropy or Brier.')
218 | FLAGS.set_default('augment', 'd.d.d')
219 | FLAGS.set_default('dataset', 'cifar10.3@250-5000')
220 | FLAGS.set_default('batch', 64)
221 | FLAGS.set_default('lr', 0.002)
222 | FLAGS.set_default('train_kimg', 1 << 16)
223 | app.run(main)
224 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | absl-py
2 | easydict
3 | cython
4 | numpy
5 | tensorflow-gpu==1.14.0
6 | tqdm
7 |
--------------------------------------------------------------------------------
/runs/ssl/ablation.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | # Copyright 2019 Google LLC
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # https://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | DEFAULT_ARGS="--dataset=cifar10.3@250-1 --train_dir experiments/Ablation --augment=d.d.d"
17 |
18 | echo "# Vary number of augmentations"
19 | echo "python ablation/ab_cta_remixmatch.py $DEFAULT_ARGS --K=1"
20 | echo "python ablation/ab_cta_remixmatch.py $DEFAULT_ARGS --K=2"
21 | echo "python ablation/ab_cta_remixmatch.py $DEFAULT_ARGS --K=4"
22 | echo "python ablation/ab_cta_remixmatch.py $DEFAULT_ARGS --K=8"
23 | echo "python ablation/ab_cta_remixmatch.py $DEFAULT_ARGS --K=16"
24 |
25 | echo "# No regularizer"
26 | echo "python ablation/ab_cta_remixmatch.py $DEFAULT_ARGS --K=8 --w_kl=0"
27 | echo "# No rotation"
28 | echo "python ablation/ab_cta_remixmatch.py $DEFAULT_ARGS --K=8 --w_rot=0"
29 | echo "# No Distribution matching"
30 | echo "python ablation/ab_cta_remixmatch.py $DEFAULT_ARGS --K=8 --nouse_dm"
31 | echo "# Use L2 loss / no cross-entropy"
32 | echo "python ablation/ab_cta_remixmatch.py $DEFAULT_ARGS --w_match=75 --K=8 --nouse_xe"
33 |
34 | echo "# Only strong augmentations"
35 | echo "python ablation/ab_cta_remixmatch_noweak.py $DEFAULT_ARGS --K=8"
36 |
37 | echo "# Only weak augmentations"
38 | echo "python ablation/ab_remixmatch.py $DEFAULT_ARGS --K=1"
39 | echo "python ablation/ab_remixmatch.py $DEFAULT_ARGS --K=8"
40 | echo "python ablation/ab_remixmatch.py $DEFAULT_ARGS --redux=mean --K=1"
41 | echo "python ablation/ab_remixmatch.py $DEFAULT_ARGS --redux=mean --K=8"
42 |
--------------------------------------------------------------------------------
/runs/ssl/cifar10.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | # Copyright 2019 Google LLC
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # https://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | echo
17 | echo "# Standard"
18 | for seed in 1 2 3 4 5; do for size in 250 1000 4000; do
19 | echo "python cta/cta_remixmatch.py --dataset=cifar10.${seed}@{$size}-1 --K=8"
20 | done; done
21 |
22 | for seed in 1 2 3 4 5; do
23 | echo "python cta/cta_remixmatch.py --dataset=cifar10.${seed}@40-1 --K=1 --w_rot=2 --warmup_kimg=8192"
24 | done
25 |
--------------------------------------------------------------------------------
/runs/ssl/stl10.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | # Copyright 2019 Google LLC
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # https://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | echo
17 | echo "# Standard"
18 | for seed in 1 2 3 4 5; do
19 | echo "python cta/cta_remixmatch.py --dataset=stl10.${seed}@1000-1 --K=4 --scales=4 --augment=d.x.d"
20 | done
21 |
--------------------------------------------------------------------------------
/runs/ssl/svhn.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | # Copyright 2019 Google LLC
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # https://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | echo
17 | echo "# Standard"
18 | for seed in 1 2 3 4 5; do for size in 250 1000 4000; do
19 | echo "python cta/cta_remixmatch.py --dataset=svhn_noextra.${seed}@{$size}-1 --K=8"
20 | done; done
21 |
22 | for seed in 1 2 3 4 5; do
23 | echo "python cta/cta_remixmatch.py --dataset=svhn_noextra.${seed}@40-1 --K=1 --w_rot=3"
24 | done
25 |
--------------------------------------------------------------------------------
/scripts/aggregate_accuracy.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 |
3 | # Copyright 2019 Google LLC
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # https://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | """Report all 'stats/accuracy.json' into a json file on stdout.
17 |
18 | All the accuracies are summarized.
19 | """
20 | import json
21 | import sys
22 | import threading
23 |
24 | import tensorflow as tf
25 | import tqdm
26 | from absl import app
27 | from absl import flags
28 |
29 | FLAGS = flags.FLAGS
30 | N_THREADS = 100
31 |
32 |
33 | def add_contents_to_dict(filename: str, target):
34 | with tf.gfile.Open(filename, 'r') as f:
35 | target[filename] = json.load(f)
36 |
37 |
38 | def main(argv):
39 | files = []
40 | for path in argv[1:]:
41 | files.extend(tf.io.gfile.glob(path))
42 | assert files, 'No files found'
43 | print('Found %d files.' % len(files), file=sys.stderr)
44 | summary = {}
45 | threads = []
46 | for x in tqdm.tqdm(files, leave=False, desc='Collating'):
47 | t = threading.Thread(
48 | target=add_contents_to_dict, kwargs=dict(filename=x, target=summary))
49 | threads.append(t)
50 | t.start()
51 | while len(threads) >= N_THREADS:
52 | dead = [p for p, t in enumerate(threads) if not t.is_alive()]
53 | while dead:
54 | p = dead.pop()
55 | del threads[p]
56 | if x == files[-1]:
57 | for t in threads:
58 | t.join()
59 |
60 | assert len(summary) == len(files)
61 | print(json.dumps(summary, sort_keys=True, indent=4))
62 |
63 |
64 | if __name__ == '__main__':
65 | app.run(main)
66 |
--------------------------------------------------------------------------------
/scripts/check_split.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 |
3 | # Copyright 2019 Google LLC
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # https://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | """Script to measure the overlap between data splits.
18 |
19 | There should not be any overlap unless the original dataset has duplicates.
20 | """
21 |
22 | import hashlib
23 | import os
24 |
25 | import tensorflow as tf
26 | from absl import app
27 | from absl import flags
28 | from tqdm import trange
29 |
30 | from libml import data, utils
31 |
32 | flags.DEFINE_integer('batch', 1024, 'Batch size.')
33 | flags.DEFINE_integer('samples', 1 << 20, 'Number of samples to load.')
34 |
35 | FLAGS = flags.FLAGS
36 |
37 |
38 | def to_byte(d: dict):
39 | return tf.to_int32(tf.round(127.5 * (d['image'] + 1)))
40 |
41 |
42 | def collect_hashes(sess, group, data):
43 | data = data.parse().batch(FLAGS.batch).prefetch(1).make_one_shot_iterator().get_next()
44 | hashes = set()
45 | hasher = hashlib.sha512
46 | for _ in trange(0, FLAGS.samples, FLAGS.batch, desc='Building hashes for %s' % group, leave=False):
47 | try:
48 | batch = sess.run(data)
49 | except tf.errors.OutOfRangeError:
50 | break
51 | for img in batch:
52 | hashes.add(hasher(img).digest())
53 | return hashes
54 |
55 |
56 | def main(argv):
57 | utils.setup_main()
58 | del argv
59 | utils.setup_tf()
60 | dataset = data.DATASETS()[FLAGS.dataset]()
61 | with tf.Session(config=utils.get_config()) as sess:
62 | hashes = (collect_hashes(sess, 'labeled', dataset.eval_labeled),
63 | collect_hashes(sess, 'unlabeled', dataset.eval_unlabeled),
64 | collect_hashes(sess, 'validation', dataset.valid),
65 | collect_hashes(sess, 'test', dataset.test))
66 | print('Overlap matrix (should be an almost perfect diagonal matrix with counts).')
67 | groups = 'labeled unlabeled validation test'.split()
68 | fmt = '%-10s %10s %10s %10s %10s'
69 | print(fmt % tuple([''] + groups))
70 | for p, x in enumerate(hashes):
71 | overlaps = [len(x & y) for y in hashes]
72 | print(fmt % tuple([groups[p]] + overlaps))
73 |
74 |
75 | if __name__ == '__main__':
76 | os.environ['CUDA_VISIBLE_DEVICES'] = ''
77 | app.run(main)
78 |
--------------------------------------------------------------------------------
/scripts/create_datasets.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 |
3 | # Copyright 2018 Google LLC
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # https://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | """Script to download all datasets and create .tfrecord files.
18 | """
19 |
20 | import collections
21 | import gzip
22 | import os
23 | import tarfile
24 | import tempfile
25 | from urllib import request
26 |
27 | import numpy as np
28 | import scipy.io
29 | import tensorflow as tf
30 | from absl import app
31 | from tqdm import trange
32 |
33 | from libml import data as libml_data
34 | from libml.utils import EasyDict
35 |
36 | URLS = {
37 | 'svhn': 'http://ufldl.stanford.edu/housenumbers/{}_32x32.mat',
38 | 'cifar10': 'https://www.cs.toronto.edu/~kriz/cifar-10-matlab.tar.gz',
39 | 'cifar100': 'https://www.cs.toronto.edu/~kriz/cifar-100-matlab.tar.gz',
40 | 'stl10': 'http://ai.stanford.edu/~acoates/stl10/stl10_binary.tar.gz',
41 | }
42 |
43 |
44 | def _encode_png(images):
45 | raw = []
46 | with tf.Session() as sess, tf.device('cpu:0'):
47 | image_x = tf.placeholder(tf.uint8, [None, None, None], 'image_x')
48 | to_png = tf.image.encode_png(image_x)
49 | for x in trange(images.shape[0], desc='PNG Encoding', leave=False):
50 | raw.append(sess.run(to_png, feed_dict={image_x: images[x]}))
51 | return raw
52 |
53 |
54 | def _load_svhn():
55 | splits = collections.OrderedDict()
56 | for split in ['train', 'test', 'extra']:
57 | with tempfile.NamedTemporaryFile() as f:
58 | request.urlretrieve(URLS['svhn'].format(split), f.name)
59 | data_dict = scipy.io.loadmat(f.name)
60 | dataset = {}
61 | dataset['images'] = np.transpose(data_dict['X'], [3, 0, 1, 2])
62 | dataset['images'] = _encode_png(dataset['images'])
63 | dataset['labels'] = data_dict['y'].reshape((-1))
64 | # SVHN raw data uses labels from 1 to 10; use 0 to 9 instead.
65 | dataset['labels'] -= 1
66 | splits[split] = dataset
67 | return splits
68 |
69 |
70 | def _load_stl10():
71 | def unflatten(images):
72 | return np.transpose(images.reshape((-1, 3, 96, 96)),
73 | [0, 3, 2, 1])
74 |
75 | with tempfile.NamedTemporaryFile() as f:
76 | if tf.gfile.Exists('stl10/stl10_binary.tar.gz'):
77 | f = tf.gfile.Open('stl10/stl10_binary.tar.gz', 'rb')
78 | else:
79 | request.urlretrieve(URLS['stl10'], f.name)
80 | tar = tarfile.open(fileobj=f)
81 | train_X = tar.extractfile('stl10_binary/train_X.bin')
82 | train_y = tar.extractfile('stl10_binary/train_y.bin')
83 |
84 | test_X = tar.extractfile('stl10_binary/test_X.bin')
85 | test_y = tar.extractfile('stl10_binary/test_y.bin')
86 |
87 | unlabeled_X = tar.extractfile('stl10_binary/unlabeled_X.bin')
88 |
89 | train_set = {'images': np.frombuffer(train_X.read(), dtype=np.uint8),
90 | 'labels': np.frombuffer(train_y.read(), dtype=np.uint8) - 1}
91 |
92 | test_set = {'images': np.frombuffer(test_X.read(), dtype=np.uint8),
93 | 'labels': np.frombuffer(test_y.read(), dtype=np.uint8) - 1}
94 |
95 | _imgs = np.frombuffer(unlabeled_X.read(), dtype=np.uint8)
96 | unlabeled_set = {'images': _imgs,
97 | 'labels': np.zeros(100000, dtype=np.uint8)}
98 |
99 | fold_indices = tar.extractfile('stl10_binary/fold_indices.txt').read()
100 |
101 | train_set['images'] = _encode_png(unflatten(train_set['images']))
102 | test_set['images'] = _encode_png(unflatten(test_set['images']))
103 | unlabeled_set['images'] = _encode_png(unflatten(unlabeled_set['images']))
104 | return dict(train=train_set, test=test_set, unlabeled=unlabeled_set,
105 | files=[EasyDict(filename="stl10_fold_indices.txt", data=fold_indices)])
106 |
107 |
108 | def _load_cifar10():
109 | def unflatten(images):
110 | return np.transpose(images.reshape((images.shape[0], 3, 32, 32)),
111 | [0, 2, 3, 1])
112 |
113 | with tempfile.NamedTemporaryFile() as f:
114 | request.urlretrieve(URLS['cifar10'], f.name)
115 | tar = tarfile.open(fileobj=f)
116 | train_data_batches, train_data_labels = [], []
117 | for batch in range(1, 6):
118 | data_dict = scipy.io.loadmat(tar.extractfile(
119 | 'cifar-10-batches-mat/data_batch_{}.mat'.format(batch)))
120 | train_data_batches.append(data_dict['data'])
121 | train_data_labels.append(data_dict['labels'].flatten())
122 | train_set = {'images': np.concatenate(train_data_batches, axis=0),
123 | 'labels': np.concatenate(train_data_labels, axis=0)}
124 | data_dict = scipy.io.loadmat(tar.extractfile(
125 | 'cifar-10-batches-mat/test_batch.mat'))
126 | test_set = {'images': data_dict['data'],
127 | 'labels': data_dict['labels'].flatten()}
128 | train_set['images'] = _encode_png(unflatten(train_set['images']))
129 | test_set['images'] = _encode_png(unflatten(test_set['images']))
130 | return dict(train=train_set, test=test_set)
131 |
132 |
133 | def _load_cifar100():
134 | def unflatten(images):
135 | return np.transpose(images.reshape((images.shape[0], 3, 32, 32)),
136 | [0, 2, 3, 1])
137 |
138 | with tempfile.NamedTemporaryFile() as f:
139 | request.urlretrieve(URLS['cifar100'], f.name)
140 | tar = tarfile.open(fileobj=f)
141 | data_dict = scipy.io.loadmat(tar.extractfile('cifar-100-matlab/train.mat'))
142 | train_set = {'images': data_dict['data'],
143 | 'labels': data_dict['fine_labels'].flatten()}
144 | data_dict = scipy.io.loadmat(tar.extractfile('cifar-100-matlab/test.mat'))
145 | test_set = {'images': data_dict['data'],
146 | 'labels': data_dict['fine_labels'].flatten()}
147 | train_set['images'] = _encode_png(unflatten(train_set['images']))
148 | test_set['images'] = _encode_png(unflatten(test_set['images']))
149 | return dict(train=train_set, test=test_set)
150 |
151 |
152 | def _load_fashionmnist():
153 | def _read32(data):
154 | dt = np.dtype(np.uint32).newbyteorder('>')
155 | return np.frombuffer(data.read(4), dtype=dt)[0]
156 |
157 | image_filename = '{}-images-idx3-ubyte'
158 | label_filename = '{}-labels-idx1-ubyte'
159 | split_files = [('train', 'train'), ('test', 't10k')]
160 | splits = {}
161 | for split, split_file in split_files:
162 | with tempfile.NamedTemporaryFile() as f:
163 | request.urlretrieve(URLS['fashion_mnist'].format(image_filename.format(split_file)), f.name)
164 | with gzip.GzipFile(fileobj=f, mode='r') as data:
165 | assert _read32(data) == 2051
166 | n_images = _read32(data)
167 | row = _read32(data)
168 | col = _read32(data)
169 | images = np.frombuffer(data.read(n_images * row * col), dtype=np.uint8)
170 | images = images.reshape((n_images, row, col, 1))
171 | with tempfile.NamedTemporaryFile() as f:
172 | request.urlretrieve(URLS['fashion_mnist'].format(label_filename.format(split_file)), f.name)
173 | with gzip.GzipFile(fileobj=f, mode='r') as data:
174 | assert _read32(data) == 2049
175 | n_labels = _read32(data)
176 | labels = np.frombuffer(data.read(n_labels), dtype=np.uint8)
177 | splits[split] = {'images': _encode_png(images), 'labels': labels}
178 | return splits
179 |
180 |
181 | def _int64_feature(value):
182 | return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
183 |
184 |
185 | def _bytes_feature(value):
186 | return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
187 |
188 |
189 | def _save_as_tfrecord(data, filename):
190 | assert len(data['images']) == len(data['labels'])
191 | filename = os.path.join(libml_data.DATA_DIR, filename + '.tfrecord')
192 | print('Saving dataset:', filename)
193 | with tf.python_io.TFRecordWriter(filename) as writer:
194 | for x in trange(len(data['images']), desc='Building records'):
195 | feat = dict(image=_bytes_feature(data['images'][x]),
196 | label=_int64_feature(data['labels'][x]))
197 | record = tf.train.Example(features=tf.train.Features(feature=feat))
198 | writer.write(record.SerializeToString())
199 | print('Saved:', filename)
200 |
201 |
202 | def _is_installed(name, checksums):
203 | for subset, checksum in checksums.items():
204 | filename = os.path.join(libml_data.DATA_DIR, '%s-%s.tfrecord' % (name, subset))
205 | if not tf.gfile.Exists(filename):
206 | return False
207 | return True
208 |
209 |
210 | def _save_files(files, *args, **kwargs):
211 | del args, kwargs
212 | for folder in frozenset(os.path.dirname(x) for x in files):
213 | tf.gfile.MakeDirs(os.path.join(libml_data.DATA_DIR, folder))
214 | for filename, contents in files.items():
215 | with tf.gfile.Open(os.path.join(libml_data.DATA_DIR, filename), 'w') as f:
216 | f.write(contents)
217 |
218 |
219 | def _is_installed_folder(name, folder):
220 | return tf.gfile.Exists(os.path.join(libml_data.DATA_DIR, name, folder))
221 |
222 |
223 | CONFIGS = dict(
224 | cifar10=dict(loader=_load_cifar10, checksums=dict(train=None, test=None)),
225 | cifar100=dict(loader=_load_cifar100, checksums=dict(train=None, test=None)),
226 | svhn=dict(loader=_load_svhn, checksums=dict(train=None, test=None, extra=None)),
227 | stl10=dict(loader=_load_stl10, checksums=dict(train=None, test=None)),
228 | # fashion_mnist=dict(loader=_load_fashionmnist, checksums=dict(train=None, test=None)),
229 | )
230 |
231 |
232 | def main(argv):
233 | if len(argv[1:]):
234 | subset = set(argv[1:])
235 | else:
236 | subset = set(CONFIGS.keys())
237 | tf.gfile.MakeDirs(libml_data.DATA_DIR)
238 | for name, config in CONFIGS.items():
239 | if name not in subset:
240 | continue
241 | if 'is_installed' in config:
242 | if config['is_installed']():
243 | print('Skipping already installed:', name)
244 | continue
245 | elif _is_installed(name, config['checksums']):
246 | print('Skipping already installed:', name)
247 | continue
248 | print('Preparing', name)
249 | datas = config['loader']()
250 | saver = config.get('saver', _save_as_tfrecord)
251 | for sub_name, data in datas.items():
252 | if sub_name == 'readme':
253 | filename = os.path.join(libml_data.DATA_DIR, '%s-%s.txt' % (name, sub_name))
254 | with tf.gfile.Open(filename, 'w') as f:
255 | f.write(data)
256 | elif sub_name == 'files':
257 | for file_and_data in data:
258 | path = os.path.join(libml_data.DATA_DIR, file_and_data.filename)
259 | with tf.gfile.Open(path, "wb") as f:
260 | f.write(file_and_data.data)
261 | else:
262 | saver(data, '%s-%s' % (name, sub_name))
263 |
264 |
265 | if __name__ == '__main__':
266 | app.run(main)
267 |
--------------------------------------------------------------------------------
/scripts/create_split.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 |
3 | # Copyright 2019 Google LLC
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # https://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | """Script to create SSL splits from a dataset.
18 | """
19 |
20 | import json
21 | import os
22 | from collections import defaultdict
23 |
24 | import numpy as np
25 | import tensorflow as tf
26 | from absl import app
27 | from absl import flags
28 | from tqdm import trange, tqdm
29 |
30 | from libml import data as libml_data
31 | from libml import utils
32 |
33 | flags.DEFINE_integer('seed', 0, 'Random seed to use, 0 for no shuffling.')
34 | flags.DEFINE_integer('size', 0, 'Size of labelled set.')
35 |
36 | FLAGS = flags.FLAGS
37 |
38 |
39 | def get_class(serialized_example):
40 | return tf.parse_single_example(serialized_example, features={'label': tf.FixedLenFeature([], tf.int64)})['label']
41 |
42 |
43 | def main(argv):
44 | assert FLAGS.size
45 | argv.pop(0)
46 | if any(not tf.gfile.Exists(f) for f in argv[1:]):
47 | raise FileNotFoundError(argv[1:])
48 | target = '%s.%d@%d' % (argv[0], FLAGS.seed, FLAGS.size)
49 | if tf.gfile.Exists(target):
50 | raise FileExistsError('For safety overwriting is not allowed', target)
51 | input_files = argv[1:]
52 | count = 0
53 | id_class = []
54 | class_id = defaultdict(list)
55 | print('Computing class distribution')
56 | dataset = tf.data.TFRecordDataset(input_files).map(get_class, 4).batch(1 << 10)
57 | it = dataset.make_one_shot_iterator().get_next()
58 | try:
59 | with tf.Session() as session, tqdm(leave=False) as t:
60 | while 1:
61 | old_count = count
62 | for i in session.run(it):
63 | id_class.append(i)
64 | class_id[i].append(count)
65 | count += 1
66 | t.update(count - old_count)
67 | except tf.errors.OutOfRangeError:
68 | pass
69 | print('%d records found' % count)
70 | nclass = len(class_id)
71 | assert min(class_id.keys()) == 0 and max(class_id.keys()) == (nclass - 1)
72 | train_stats = np.array([len(class_id[i]) for i in range(nclass)], np.float64)
73 | train_stats /= train_stats.max()
74 | if 'stl10' in argv[1]:
75 | # All of the unlabeled data is given label 0, but we know that
76 | # STL has equally distributed data among the 10 classes.
77 | train_stats[:] = 1
78 |
79 | print(' Stats', ' '.join(['%.2f' % (100 * x) for x in train_stats]))
80 | assert min(class_id.keys()) == 0 and max(class_id.keys()) == (nclass - 1)
81 | class_id = [np.array(class_id[i], dtype=np.int64) for i in range(nclass)]
82 | if FLAGS.seed:
83 | np.random.seed(FLAGS.seed)
84 | for i in range(nclass):
85 | np.random.shuffle(class_id[i])
86 |
87 | # Distribute labels to match the input distribution.
88 | npos = np.zeros(nclass, np.int64)
89 | label = []
90 | for i in range(FLAGS.size):
91 | c = np.argmax(train_stats - npos / max(npos.max(), 1))
92 | label.append(class_id[c][npos[c]])
93 | npos[c] += 1
94 |
95 | del npos, class_id
96 | label = frozenset([int(x) for x in label])
97 | if 'stl10' in argv[1] and FLAGS.size == 1000:
98 | data = tf.gfile.Open(os.path.join(libml_data.DATA_DIR, 'stl10_fold_indices.txt'), 'r').read()
99 | label = frozenset(list(map(int, data.split('\n')[FLAGS.seed].split())))
100 |
101 | print('Creating split in %s' % target)
102 | tf.gfile.MakeDirs(os.path.dirname(target))
103 | with tf.python_io.TFRecordWriter(target + '-label.tfrecord') as writer_label:
104 | pos, loop = 0, trange(count, desc='Writing records')
105 | for input_file in input_files:
106 | for record in tf.python_io.tf_record_iterator(input_file):
107 | if pos in label:
108 | writer_label.write(record)
109 | pos += 1
110 | loop.update()
111 | loop.close()
112 | with tf.gfile.Open(target + '-label.json', 'w') as writer:
113 | writer.write(json.dumps(dict(distribution=train_stats.tolist(), label=sorted(label)), indent=2, sort_keys=True))
114 |
115 |
116 | if __name__ == '__main__':
117 | utils.setup_tf()
118 | app.run(main)
119 |
--------------------------------------------------------------------------------
/scripts/create_unlabeled.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 |
3 | # Copyright 2019 Google LLC
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # https://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | """Script to create SSL splits from a dataset.
18 | """
19 |
20 | import json
21 | import os
22 | from collections import defaultdict
23 |
24 | import numpy as np
25 | import tensorflow as tf
26 | from absl import app
27 | from tqdm import trange, tqdm
28 |
29 | from libml import utils
30 |
31 |
32 | def get_class(serialized_example):
33 | return tf.parse_single_example(serialized_example, features={'label': tf.FixedLenFeature([], tf.int64)})['label']
34 |
35 |
36 | def main(argv):
37 | argv.pop(0)
38 | if any(not tf.gfile.Exists(f) for f in argv[1:]):
39 | raise FileNotFoundError(argv[1:])
40 | target = argv[0]
41 | input_files = argv[1:]
42 | count = 0
43 | id_class = []
44 | class_id = defaultdict(list)
45 | print('Computing class distribution')
46 | dataset = tf.data.TFRecordDataset(input_files).map(get_class, 4).batch(1 << 10)
47 | it = dataset.make_one_shot_iterator().get_next()
48 | try:
49 | with tf.Session() as session, tqdm(leave=False) as t:
50 | while 1:
51 | old_count = count
52 | for i in session.run(it):
53 | id_class.append(i)
54 | class_id[i].append(count)
55 | count += 1
56 | t.update(count - old_count)
57 | except tf.errors.OutOfRangeError:
58 | pass
59 | print('%d records found' % count)
60 | nclass = len(class_id)
61 | assert min(class_id.keys()) == 0 and max(class_id.keys()) == (nclass - 1)
62 | train_stats = np.array([len(class_id[i]) for i in range(nclass)], np.float64)
63 | train_stats /= train_stats.max()
64 | if 'stl10' in argv[1]:
65 | # All of the unlabeled data is given label 0, but we know that
66 | # STL has equally distributed data among the 10 classes.
67 | train_stats[:] = 1
68 |
69 | print(' Stats', ' '.join(['%.2f' % (100 * x) for x in train_stats]))
70 | del class_id
71 |
72 | print('Creating unlabeled dataset for in %s' % target)
73 | npos = np.zeros(nclass, np.int64)
74 | class_data = [[] for _ in range(nclass)]
75 | unlabel = []
76 | tf.gfile.MakeDirs(os.path.dirname(target))
77 | with tf.python_io.TFRecordWriter(target + '-unlabel.tfrecord') as writer_unlabel:
78 | pos, loop = 0, trange(count, desc='Writing records')
79 | for input_file in input_files:
80 | for record in tf.python_io.tf_record_iterator(input_file):
81 | class_data[id_class[pos]].append((pos, record))
82 | while True:
83 | c = np.argmax(train_stats - npos / max(npos.max(), 1))
84 | if class_data[c]:
85 | p, v = class_data[c].pop(0)
86 | unlabel.append(p)
87 | writer_unlabel.write(v)
88 | npos[c] += 1
89 | else:
90 | break
91 | pos += 1
92 | loop.update()
93 | for remain in class_data:
94 | for p, v in remain:
95 | unlabel.append(p)
96 | writer_unlabel.write(v)
97 | loop.close()
98 | with tf.gfile.Open(target + '-unlabel.json', 'w') as writer:
99 | writer.write(json.dumps(dict(distribution=train_stats.tolist(), indexes=unlabel), indent=2, sort_keys=True))
100 |
101 |
102 | if __name__ == '__main__':
103 | utils.setup_tf()
104 | app.run(main)
105 |
--------------------------------------------------------------------------------
/scripts/extract_accuracy.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 |
3 | # Copyright 2019 Google LLC
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # https://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | """Extract and save accuracy to 'stats/accuracy.json'.
18 |
19 | The accuracy is extracted from the most recent eventfile.
20 | """
21 |
22 | import json
23 | import os.path
24 |
25 | import numpy as np
26 | import tensorflow as tf
27 | from absl import app
28 | from absl import flags
29 |
30 | FLAGS = flags.FLAGS
31 | TAG = 'accuracy'
32 |
33 |
34 | def summary_dict(accuracies):
35 | return {
36 | 'last%02d' % x: np.median(accuracies[-x:]) for x in [1, 10, 20, 50]
37 | }
38 |
39 |
40 | def main(argv):
41 | if len(argv) > 2:
42 | raise app.UsageError('Too many command-line arguments.')
43 | folder = argv[1]
44 | matches = sorted(tf.gfile.Glob(os.path.join(folder, 'tf/events.out.tfevents.*')))
45 | assert matches, 'No events files found'
46 | tags = set()
47 | accuracies = []
48 | for event_file in matches:
49 | try:
50 | for e in tf.train.summary_iterator(event_file):
51 | for v in e.summary.value:
52 | if v.tag == TAG:
53 | accuracies.append(v.simple_value)
54 | break
55 | elif not accuracies:
56 | tags.add(v.tag)
57 | except tf.errors.DataLossError:
58 | continue
59 |
60 | assert accuracies, 'No "accuracy" tag found. Found tags = %s' % tags
61 | target_dir = os.path.join(folder, 'stats')
62 | target_file = os.path.join(target_dir, 'accuracy.json')
63 | tf.gfile.MakeDirs(target_dir)
64 |
65 | with tf.gfile.Open(target_file, 'w') as f:
66 | json.dump(summary_dict(accuracies), f, sort_keys=True, indent=4)
67 | print('Saved: %s' % target_file)
68 |
69 |
70 | if __name__ == '__main__':
71 | app.run(main)
72 |
--------------------------------------------------------------------------------
/scripts/inspect_dataset.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 |
3 | # Copyright 2019 Google LLC
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # https://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | """Script to inspect a dataset, in particular label distribution.
18 | """
19 |
20 | import numpy as np
21 | import tensorflow as tf
22 | from absl import app
23 | from absl import flags
24 | from tqdm import trange
25 |
26 | from libml import data, utils
27 |
28 | flags.DEFINE_integer('batch', 64, 'Batch size.')
29 | flags.DEFINE_integer('samples', 1 << 16, 'Number of samples to load.')
30 |
31 | FLAGS = flags.FLAGS
32 |
33 |
34 | def main(argv):
35 | utils.setup_main()
36 | del argv
37 | utils.setup_tf()
38 | nbatch = FLAGS.samples // FLAGS.batch
39 | dataset = data.DATASETS()[FLAGS.dataset]()
40 | groups = [('labeled', dataset.train_labeled),
41 | ('unlabeled', dataset.train_unlabeled),
42 | ('test', dataset.test.repeat())]
43 | groups = [(name, ds.batch(FLAGS.batch).prefetch(16).make_one_shot_iterator().get_next())
44 | for name, ds in groups]
45 | with tf.train.MonitoredSession() as sess:
46 | for group, train_data in groups:
47 | stats = np.zeros(dataset.nclass, np.int32)
48 | minmax = [], []
49 | for _ in trange(nbatch, leave=False, unit='img', unit_scale=FLAGS.batch, desc=group):
50 | v = sess.run(train_data)['label']
51 | for u in v:
52 | stats[u] += 1
53 | minmax[0].append(v.min())
54 | minmax[1].append(v.max())
55 | print(group)
56 | print(' Label range', min(minmax[0]), max(minmax[1]))
57 | print(' Stats', ' '.join(['%.2f' % (100 * x) for x in (stats / stats.max())]))
58 |
59 |
60 | if __name__ == '__main__':
61 | app.run(main)
62 |
--------------------------------------------------------------------------------
/third_party/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright 2015 Google Inc.
4 | This code has been modified from the original version at https://github.com/takerum/vat_tf
5 | Original license reproduced below.
6 |
7 | Copyright (c) 2017 Takeru Miyato
8 |
9 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
10 |
11 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
12 |
13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
14 |
--------------------------------------------------------------------------------
/third_party/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google-research/remixmatch/f7061ebf055227cbeb5c6fced1ce054e0ceecfcd/third_party/__init__.py
--------------------------------------------------------------------------------
/third_party/auto_augment/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google-research/remixmatch/f7061ebf055227cbeb5c6fced1ce054e0ceecfcd/third_party/auto_augment/__init__.py
--------------------------------------------------------------------------------
/third_party/auto_augment/custom_ops.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2019 The Google UDA Team Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """Contains convenience wrappers for typical Neural Network TensorFlow layers.
16 |
17 | Ops that have different behavior during training or eval have an is_training
18 | parameter.
19 |
20 | Copied from AutoAugment: https://github.com/tensorflow/models/blob/master/research/autoaugment/
21 | """
22 | from __future__ import absolute_import
23 | from __future__ import division
24 | from __future__ import print_function
25 |
26 | import numpy as np
27 | import tensorflow as tf
28 |
29 | arg_scope = tf.contrib.framework.arg_scope
30 | FLAGS = tf.flags.FLAGS
31 |
32 |
33 | def variable(name, shape, dtype, initializer, trainable):
34 | """Returns a TF variable with the passed in specifications."""
35 | var = tf.get_variable(
36 | name,
37 | shape=shape,
38 | dtype=dtype,
39 | initializer=initializer,
40 | trainable=trainable)
41 | return var
42 |
43 |
44 | def global_avg_pool(x, scope=None):
45 | """Average pools away spatial height and width dimension of 4D tensor."""
46 | assert x.get_shape().ndims == 4
47 | with tf.name_scope(scope, 'global_avg_pool', [x]):
48 | kernel_size = (1, int(x.shape[1]), int(x.shape[2]), 1)
49 | squeeze_dims = (1, 2)
50 | result = tf.nn.avg_pool(
51 | x,
52 | ksize=kernel_size,
53 | strides=(1, 1, 1, 1),
54 | padding='VALID',
55 | data_format='NHWC')
56 | return tf.squeeze(result, squeeze_dims)
57 |
58 |
59 | def zero_pad(inputs, in_filter, out_filter):
60 | """Zero pads `input` tensor to have `out_filter` number of filters."""
61 | outputs = tf.pad(inputs, [[0, 0], [0, 0], [0, 0],
62 | [(out_filter - in_filter) // 2,
63 | (out_filter - in_filter) // 2]])
64 | return outputs
65 |
66 |
67 | @tf.contrib.framework.add_arg_scope
68 | def batch_norm(inputs,
69 | update_stats=True,
70 | decay=0.999,
71 | center=True,
72 | scale=False,
73 | epsilon=0.001,
74 | is_training=True,
75 | reuse=None,
76 | scope=None,
77 | ):
78 | """Small wrapper around tf.contrib.layers.batch_norm."""
79 | batch_norm_op = tf.layers.batch_normalization(
80 | inputs,
81 | axis=-1,
82 | momentum=decay,
83 | epsilon=epsilon,
84 | center=center,
85 | scale=scale,
86 | training=is_training,
87 | fused=True,
88 | trainable=True,
89 | )
90 | return batch_norm_op
91 |
92 |
93 | def stride_arr(stride_h, stride_w):
94 | return [1, stride_h, stride_w, 1]
95 |
96 |
97 | @tf.contrib.framework.add_arg_scope
98 | def conv2d(inputs,
99 | num_filters_out,
100 | kernel_size,
101 | stride=1,
102 | scope=None,
103 | reuse=None):
104 | """Adds a 2D convolution.
105 |
106 | conv2d creates a variable called 'weights', representing the convolutional
107 | kernel, that is convolved with the input.
108 |
109 | Args:
110 | inputs: a 4D tensor in NHWC format.
111 | num_filters_out: the number of output filters.
112 | kernel_size: an int specifying the kernel height and width size.
113 | stride: an int specifying the height and width stride.
114 | scope: Optional scope for variable_scope.
115 | reuse: whether or not the layer and its variables should be reused.
116 | Returns:
117 | a tensor that is the result of a convolution being applied to `inputs`.
118 | """
119 | with tf.variable_scope(scope, 'Conv', [inputs], reuse=reuse):
120 | num_filters_in = int(inputs.shape[3])
121 | weights_shape = [kernel_size, kernel_size, num_filters_in, num_filters_out]
122 |
123 | # Initialization
124 | n = int(weights_shape[0] * weights_shape[1] * weights_shape[3])
125 | weights_initializer = tf.random_normal_initializer(
126 | stddev=np.sqrt(2.0 / n))
127 |
128 | weights = variable(
129 | name='weights',
130 | shape=weights_shape,
131 | dtype=tf.float32,
132 | initializer=weights_initializer,
133 | trainable=True)
134 | strides = stride_arr(stride, stride)
135 | outputs = tf.nn.conv2d(
136 | inputs, weights, strides, padding='SAME', data_format='NHWC')
137 | return outputs
138 |
139 |
140 | @tf.contrib.framework.add_arg_scope
141 | def fc(inputs,
142 | num_units_out,
143 | scope=None,
144 | reuse=None):
145 | """Creates a fully connected layer applied to `inputs`.
146 |
147 | Args:
148 | inputs: a tensor that the fully connected layer will be applied to. It
149 | will be reshaped if it is not 2D.
150 | num_units_out: the number of output units in the layer.
151 | scope: Optional scope for variable_scope.
152 | reuse: whether or not the layer and its variables should be reused.
153 |
154 | Returns:
155 | a tensor that is the result of applying a linear matrix to `inputs`.
156 | """
157 | if len(inputs.shape) > 2:
158 | inputs = tf.reshape(inputs, [int(inputs.shape[0]), -1])
159 |
160 | with tf.variable_scope(scope, 'FC', [inputs], reuse=reuse):
161 | num_units_in = inputs.shape[1]
162 | weights_shape = [num_units_in, num_units_out]
163 | unif_init_range = 1.0 / (num_units_out) ** (0.5)
164 | weights_initializer = tf.random_uniform_initializer(
165 | -unif_init_range, unif_init_range)
166 | weights = variable(
167 | name='weights',
168 | shape=weights_shape,
169 | dtype=tf.float32,
170 | initializer=weights_initializer,
171 | trainable=True)
172 | bias_initializer = tf.constant_initializer(0.0)
173 | biases = variable(
174 | name='biases',
175 | shape=[num_units_out, ],
176 | dtype=tf.float32,
177 | initializer=bias_initializer,
178 | trainable=True)
179 | outputs = tf.nn.xw_plus_b(inputs, weights, biases)
180 | return outputs
181 |
182 |
183 | @tf.contrib.framework.add_arg_scope
184 | def avg_pool(inputs, kernel_size, stride=2, padding='VALID', scope=None):
185 | """Wrapper around tf.nn.avg_pool."""
186 | with tf.name_scope(scope, 'AvgPool', [inputs]):
187 | kernel = stride_arr(kernel_size, kernel_size)
188 | strides = stride_arr(stride, stride)
189 | return tf.nn.avg_pool(
190 | inputs,
191 | ksize=kernel,
192 | strides=strides,
193 | padding=padding,
194 | data_format='NHWC')
195 |
--------------------------------------------------------------------------------
/third_party/auto_augment/policies.py:
--------------------------------------------------------------------------------
1 | # Copyright 2018 The TensorFlow Authors All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 |
17 | def cifar10_policies():
18 | """AutoAugment policies found on CIFAR-10."""
19 | exp0_0 = [[('Invert', 0.1, 7), ('Contrast', 0.2, 6)],
20 | [('Rotate', 0.7, 2), ('TranslateX', 0.3, 9)],
21 | [('Sharpness', 0.8, 1), ('Sharpness', 0.9, 3)],
22 | [('ShearY', 0.5, 8), ('TranslateY', 0.7, 9)],
23 | [('AutoContrast', 0.5, 8), ('Equalize', 0.9, 2)]]
24 | exp0_1 = [[('Solarize', 0.4, 5), ('AutoContrast', 0.9, 3)],
25 | [('TranslateY', 0.9, 9), ('TranslateY', 0.7, 9)],
26 | [('AutoContrast', 0.9, 2), ('Solarize', 0.8, 3)],
27 | [('Equalize', 0.8, 8), ('Invert', 0.1, 3)],
28 | [('TranslateY', 0.7, 9), ('AutoContrast', 0.9, 1)]]
29 | exp0_2 = [[('Solarize', 0.4, 5), ('AutoContrast', 0.0, 2)],
30 | [('TranslateY', 0.7, 9), ('TranslateY', 0.7, 9)],
31 | [('AutoContrast', 0.9, 0), ('Solarize', 0.4, 3)],
32 | [('Equalize', 0.7, 5), ('Invert', 0.1, 3)],
33 | [('TranslateY', 0.7, 9), ('TranslateY', 0.7, 9)]]
34 | exp0_3 = [[('Solarize', 0.4, 5), ('AutoContrast', 0.9, 1)],
35 | [('TranslateY', 0.8, 9), ('TranslateY', 0.9, 9)],
36 | [('AutoContrast', 0.8, 0), ('TranslateY', 0.7, 9)],
37 | [('TranslateY', 0.2, 7), ('Color', 0.9, 6)],
38 | [('Equalize', 0.7, 6), ('Color', 0.4, 9)]]
39 | exp1_0 = [[('ShearY', 0.2, 7), ('Posterize', 0.3, 7)],
40 | [('Color', 0.4, 3), ('Brightness', 0.6, 7)],
41 | [('Sharpness', 0.3, 9), ('Brightness', 0.7, 9)],
42 | [('Equalize', 0.6, 5), ('Equalize', 0.5, 1)],
43 | [('Contrast', 0.6, 7), ('Sharpness', 0.6, 5)]]
44 | exp1_1 = [[('Brightness', 0.3, 7), ('AutoContrast', 0.5, 8)],
45 | [('AutoContrast', 0.9, 4), ('AutoContrast', 0.5, 6)],
46 | [('Solarize', 0.3, 5), ('Equalize', 0.6, 5)],
47 | [('TranslateY', 0.2, 4), ('Sharpness', 0.3, 3)],
48 | [('Brightness', 0.0, 8), ('Color', 0.8, 8)]]
49 | exp1_2 = [[('Solarize', 0.2, 6), ('Color', 0.8, 6)],
50 | [('Solarize', 0.2, 6), ('AutoContrast', 0.8, 1)],
51 | [('Solarize', 0.4, 1), ('Equalize', 0.6, 5)],
52 | [('Brightness', 0.0, 0), ('Solarize', 0.5, 2)],
53 | [('AutoContrast', 0.9, 5), ('Brightness', 0.5, 3)]]
54 | exp1_3 = [[('Contrast', 0.7, 5), ('Brightness', 0.0, 2)],
55 | [('Solarize', 0.2, 8), ('Solarize', 0.1, 5)],
56 | [('Contrast', 0.5, 1), ('TranslateY', 0.2, 9)],
57 | [('AutoContrast', 0.6, 5), ('TranslateY', 0.0, 9)],
58 | [('AutoContrast', 0.9, 4), ('Equalize', 0.8, 4)]]
59 | exp1_4 = [[('Brightness', 0.0, 7), ('Equalize', 0.4, 7)],
60 | [('Solarize', 0.2, 5), ('Equalize', 0.7, 5)],
61 | [('Equalize', 0.6, 8), ('Color', 0.6, 2)],
62 | [('Color', 0.3, 7), ('Color', 0.2, 4)],
63 | [('AutoContrast', 0.5, 2), ('Solarize', 0.7, 2)]]
64 | exp1_5 = [[('AutoContrast', 0.2, 0), ('Equalize', 0.1, 0)],
65 | [('ShearY', 0.6, 5), ('Equalize', 0.6, 5)],
66 | [('Brightness', 0.9, 3), ('AutoContrast', 0.4, 1)],
67 | [('Equalize', 0.8, 8), ('Equalize', 0.7, 7)],
68 | [('Equalize', 0.7, 7), ('Solarize', 0.5, 0)]]
69 | exp1_6 = [[('Equalize', 0.8, 4), ('TranslateY', 0.8, 9)],
70 | [('TranslateY', 0.8, 9), ('TranslateY', 0.6, 9)],
71 | [('TranslateY', 0.9, 0), ('TranslateY', 0.5, 9)],
72 | [('AutoContrast', 0.5, 3), ('Solarize', 0.3, 4)],
73 | [('Solarize', 0.5, 3), ('Equalize', 0.4, 4)]]
74 | exp2_0 = [[('Color', 0.7, 7), ('TranslateX', 0.5, 8)],
75 | [('Equalize', 0.3, 7), ('AutoContrast', 0.4, 8)],
76 | [('TranslateY', 0.4, 3), ('Sharpness', 0.2, 6)],
77 | [('Brightness', 0.9, 6), ('Color', 0.2, 8)],
78 | [('Solarize', 0.5, 2), ('Invert', 0.0, 3)]]
79 | exp2_1 = [[('AutoContrast', 0.1, 5), ('Brightness', 0.0, 0)],
80 | [('Cutout', 0.2, 4), ('Equalize', 0.1, 1)],
81 | [('Equalize', 0.7, 7), ('AutoContrast', 0.6, 4)],
82 | [('Color', 0.1, 8), ('ShearY', 0.2, 3)],
83 | [('ShearY', 0.4, 2), ('Rotate', 0.7, 0)]]
84 | exp2_2 = [[('ShearY', 0.1, 3), ('AutoContrast', 0.9, 5)],
85 | [('TranslateY', 0.3, 6), ('Cutout', 0.3, 3)],
86 | [('Equalize', 0.5, 0), ('Solarize', 0.6, 6)],
87 | [('AutoContrast', 0.3, 5), ('Rotate', 0.2, 7)],
88 | [('Equalize', 0.8, 2), ('Invert', 0.4, 0)]]
89 | exp2_3 = [[('Equalize', 0.9, 5), ('Color', 0.7, 0)],
90 | [('Equalize', 0.1, 1), ('ShearY', 0.1, 3)],
91 | [('AutoContrast', 0.7, 3), ('Equalize', 0.7, 0)],
92 | [('Brightness', 0.5, 1), ('Contrast', 0.1, 7)],
93 | [('Contrast', 0.1, 4), ('Solarize', 0.6, 5)]]
94 | exp2_4 = [[('Solarize', 0.2, 3), ('ShearX', 0.0, 0)],
95 | [('TranslateX', 0.3, 0), ('TranslateX', 0.6, 0)],
96 | [('Equalize', 0.5, 9), ('TranslateY', 0.6, 7)],
97 | [('ShearX', 0.1, 0), ('Sharpness', 0.5, 1)],
98 | [('Equalize', 0.8, 6), ('Invert', 0.3, 6)]]
99 | exp2_5 = [[('AutoContrast', 0.3, 9), ('Cutout', 0.5, 3)],
100 | [('ShearX', 0.4, 4), ('AutoContrast', 0.9, 2)],
101 | [('ShearX', 0.0, 3), ('Posterize', 0.0, 3)],
102 | [('Solarize', 0.4, 3), ('Color', 0.2, 4)],
103 | [('Equalize', 0.1, 4), ('Equalize', 0.7, 6)]]
104 | exp2_6 = [[('Equalize', 0.3, 8), ('AutoContrast', 0.4, 3)],
105 | [('Solarize', 0.6, 4), ('AutoContrast', 0.7, 6)],
106 | [('AutoContrast', 0.2, 9), ('Brightness', 0.4, 8)],
107 | [('Equalize', 0.1, 0), ('Equalize', 0.0, 6)],
108 | [('Equalize', 0.8, 4), ('Equalize', 0.0, 4)]]
109 | exp2_7 = [[('Equalize', 0.5, 5), ('AutoContrast', 0.1, 2)],
110 | [('Solarize', 0.5, 5), ('AutoContrast', 0.9, 5)],
111 | [('AutoContrast', 0.6, 1), ('AutoContrast', 0.7, 8)],
112 | [('Equalize', 0.2, 0), ('AutoContrast', 0.1, 2)],
113 | [('Equalize', 0.6, 9), ('Equalize', 0.4, 4)]]
114 | exp0s = exp0_0 + exp0_1 + exp0_2 + exp0_3
115 | exp1s = exp1_0 + exp1_1 + exp1_2 + exp1_3 + exp1_4 + exp1_5 + exp1_6
116 | exp2s = exp2_0 + exp2_1 + exp2_2 + exp2_3 + exp2_4 + exp2_5 + exp2_6 + exp2_7
117 | return exp0s + exp1s + exp2s
118 |
119 |
120 | def svhn_policies():
121 | """AutoAugment policies found on SVHN."""
122 | policies = [
123 | [('ShearX', 0.9, 4), ('Invert', 0.2, 3)],
124 | [('ShearY', 0.9, 8), ('Invert', 0.7, 5)],
125 | [('Equalize', 0.6, 5), ('Solarize', 0.6, 6)],
126 | [('Invert', 0.9, 3), ('Equalize', 0.6, 3)],
127 | [('Equalize', 0.6, 1), ('Rotate', 0.9, 3)],
128 | [('ShearX', 0.9, 4), ('AutoContrast', 0.8, 3)],
129 | [('ShearY', 0.9, 8), ('Invert', 0.4, 5)],
130 | [('ShearY', 0.9, 5), ('Solarize', 0.2, 6)],
131 | [('Invert', 0.9, 6), ('AutoContrast', 0.8, 1)],
132 | [('Equalize', 0.6, 3), ('Rotate', 0.9, 3)],
133 | [('ShearX', 0.9, 4), ('Solarize', 0.3, 3)],
134 | [('ShearY', 0.8, 8), ('Invert', 0.7, 4)],
135 | [('Equalize', 0.9, 5), ('TranslateY', 0.6, 6)],
136 | [('Invert', 0.9, 4), ('Equalize', 0.6, 7)],
137 | [('Contrast', 0.3, 3), ('Rotate', 0.8, 4)],
138 | [('ShearX', 0.9, 3), ('Invert', 0.5, 3)],
139 | [('ShearY', 0.9, 8), ('Invert', 0.4, 5)],
140 | [('Equalize', 0.6, 3), ('Solarize', 0.2, 3)],
141 | [('Invert', 0.9, 4), ('Equalize', 0.5, 6)],
142 | [('Equalize', 0.6, 1), ('Rotate', 0.9, 3)],
143 | [('Invert', 0.8, 5), ('TranslateY', 0.0, 2)],
144 | [('ShearY', 0.7, 6), ('Solarize', 0.4, 8)],
145 | [('Invert', 0.6, 4), ('Rotate', 0.8, 4)],
146 | [('ShearY', 0.3, 7), ('TranslateX', 0.9, 3)],
147 | [('ShearX', 0.1, 6), ('Invert', 0.6, 5)],
148 | [('Solarize', 0.7, 2), ('TranslateY', 0.6, 7)],
149 | [('ShearY', 0.8, 4), ('Invert', 0.8, 8)],
150 | [('ShearX', 0.7, 9), ('TranslateY', 0.8, 3)],
151 | [('ShearY', 0.8, 5), ('AutoContrast', 0.7, 3)],
152 | [('ShearX', 0.7, 2), ('Invert', 0.1, 5)],
153 | [('ShearY', 0.8, 9), ('ShearX', 0.7, 7)],
154 | [('ShearY', 0.7, 4), ('Solarize', 0.9, 7)],
155 | [('ShearY', 0.9, 5), ('Invert', 0.0, 4)],
156 | [('TranslateX', 0.8, 3), ('ShearY', 0.7, 7)],
157 | [('Invert', 0.1, 7), ('Solarize', 0.3, 9)],
158 | [('Invert', 0.6, 2), ('Invert', 0.9, 4)],
159 | [('Equalize', 0.5, 2), ('Solarize', 0.9, 7)],
160 | [('ShearY', 0.6, 7), ('Solarize', 0.8, 3)],
161 | [('ShearY', 0.6, 3), ('Invert', 0.6, 1)],
162 | [('ShearX', 0.4, 2), ('Rotate', 0.7, 5)]]
163 |
164 | return policies
165 |
166 |
167 | def imagenet_policies():
168 | """AutoAugment policies found on ImageNet.
169 | This policy also transfers to five FGVC datasets with image size similar to
170 | ImageNet including Oxford 102 Flowers, Caltech-101, Oxford-IIIT Pets,
171 | FGVC Aircraft and Stanford Cars.
172 | """
173 | policies = [
174 | [('Posterize', 0.4, 8), ('Rotate', 0.6, 9)],
175 | [('Solarize', 0.6, 5), ('AutoContrast', 0.6, 5)],
176 | [('Equalize', 0.8, 8), ('Equalize', 0.6, 3)],
177 | [('Posterize', 0.6, 7), ('Posterize', 0.6, 6)],
178 | [('Equalize', 0.4, 7), ('Solarize', 0.2, 4)],
179 | [('Equalize', 0.4, 4), ('Rotate', 0.8, 8)],
180 | [('Solarize', 0.6, 3), ('Equalize', 0.6, 7)],
181 | [('Posterize', 0.8, 5), ('Equalize', 1.0, 2)],
182 | [('Rotate', 0.2, 3), ('Solarize', 0.6, 8)],
183 | [('Equalize', 0.6, 8), ('Posterize', 0.4, 6)],
184 | [('Rotate', 0.8, 8), ('Color', 0.4, 0)],
185 | [('Rotate', 0.4, 9), ('Equalize', 0.6, 2)],
186 | [('Equalize', 0.0, 7), ('Equalize', 0.8, 8)],
187 | [('Invert', 0.6, 4), ('Equalize', 1.0, 8)],
188 | [('Color', 0.6, 4), ('Contrast', 1.0, 8)],
189 | [('Rotate', 0.8, 8), ('Color', 1.0, 2)],
190 | [('Color', 0.8, 8), ('Solarize', 0.8, 7)],
191 | [('Sharpness', 0.4, 7), ('Invert', 0.6, 8)],
192 | [('ShearX', 0.6, 5), ('Equalize', 1.0, 9)],
193 | [('Color', 0.4, 0), ('Equalize', 0.6, 3)]
194 | ]
195 | return policies
196 |
--------------------------------------------------------------------------------
/third_party/auto_augment/shake_drop.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2019 The Google UDA Team Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """Builds the Shake-Drop Model.
16 | Copied from AutoAugment: https://github.com/tensorflow/models/blob/master/research/autoaugment/
17 | """
18 |
19 | from __future__ import absolute_import
20 | from __future__ import division
21 | from __future__ import print_function
22 |
23 | import math
24 |
25 | import tensorflow as tf
26 |
27 | import third_party.auto_augment.custom_ops as ops
28 |
29 |
30 | def round_int(x):
31 | """Rounds `x` and then converts to an int."""
32 | return int(math.floor(x + 0.5))
33 |
34 |
35 | def shortcut(x, output_filters, stride):
36 | """Applies strided avg pool or zero padding to make output_filters match x."""
37 | num_filters = int(x.shape[3])
38 | if stride == 2:
39 | x = ops.avg_pool(x, 2, stride=stride, padding='SAME')
40 | if num_filters != output_filters:
41 | diff = output_filters - num_filters
42 | assert diff > 0
43 | # Zero padd diff zeros
44 | padding = [[0, 0], [0, 0], [0, 0], [0, diff]]
45 | x = tf.pad(x, padding)
46 | return x
47 |
48 |
49 | def calc_prob(curr_layer, total_layers, p_l):
50 | """Calculates drop prob depending on the current layer."""
51 | return 1 - (float(curr_layer) / total_layers) * p_l
52 |
53 |
54 | def bottleneck_layer(x, n, stride, prob, is_training, alpha, beta):
55 | """Bottleneck layer for shake drop model."""
56 | assert alpha[1] > alpha[0]
57 | assert beta[1] > beta[0]
58 | with tf.variable_scope('bottleneck_{}'.format(prob)):
59 | input_layer = x
60 | x = ops.batch_norm(x, scope='bn_1_pre')
61 | x = ops.conv2d(x, n, 1, scope='1x1_conv_contract')
62 | x = ops.batch_norm(x, scope='bn_1_post')
63 | x = tf.nn.relu(x)
64 | x = ops.conv2d(x, n, 3, stride=stride, scope='3x3')
65 | x = ops.batch_norm(x, scope='bn_2')
66 | x = tf.nn.relu(x)
67 | x = ops.conv2d(x, n * 4, 1, scope='1x1_conv_expand')
68 | x = ops.batch_norm(x, scope='bn_3')
69 |
70 | # Apply regularization here
71 | # Sample bernoulli with prob
72 | if is_training:
73 | batch_size = tf.shape(x)[0]
74 | bern_shape = [batch_size, 1, 1, 1]
75 | random_tensor = prob
76 | random_tensor += tf.random_uniform(bern_shape, dtype=tf.float32)
77 | binary_tensor = tf.floor(random_tensor)
78 |
79 | alpha_values = tf.random_uniform(
80 | [batch_size, 1, 1, 1], minval=alpha[0], maxval=alpha[1],
81 | dtype=tf.float32)
82 | beta_values = tf.random_uniform(
83 | [batch_size, 1, 1, 1], minval=beta[0], maxval=beta[1],
84 | dtype=tf.float32)
85 | rand_forward = (
86 | binary_tensor + alpha_values - binary_tensor * alpha_values)
87 | rand_backward = (
88 | binary_tensor + beta_values - binary_tensor * beta_values)
89 | x = x * rand_backward + tf.stop_gradient(x * rand_forward -
90 | x * rand_backward)
91 | else:
92 | expected_alpha = (alpha[1] + alpha[0]) / 2
93 | # prob is the expectation of the bernoulli variable
94 | x = (prob + expected_alpha - prob * expected_alpha) * x
95 |
96 | res = shortcut(input_layer, n * 4, stride)
97 | return x + res
98 |
99 |
100 | def build_shake_drop_model(images, num_classes, is_training):
101 | """Builds the PyramidNet Shake-Drop model.
102 |
103 | Build the PyramidNet Shake-Drop model from https://arxiv.org/abs/1802.02375.
104 |
105 | Args:
106 | images: Tensor of images that will be fed into the Wide ResNet Model.
107 | num_classes: Number of classed that the model needs to predict.
108 | is_training: Is the model training or not.
109 |
110 | Returns:
111 | The logits of the PyramidNet Shake-Drop model.
112 | """
113 |
114 | is_training = is_training
115 | # ShakeDrop Hparams
116 | p_l = 0.5
117 | alpha_shake = [-1, 1]
118 | beta_shake = [0, 1]
119 |
120 | # PyramidNet Hparams
121 | alpha = 200
122 | depth = 272
123 | # This is for the bottleneck architecture specifically
124 | n = int((depth - 2) / 9)
125 | start_channel = 16
126 | add_channel = alpha / (3 * n)
127 |
128 | # Building the models
129 | x = images
130 | x = ops.conv2d(x, 16, 3, scope='init_conv')
131 | x = ops.batch_norm(x, scope='init_bn')
132 |
133 | layer_num = 1
134 | total_layers = n * 3
135 | start_channel += add_channel
136 | prob = calc_prob(layer_num, total_layers, p_l)
137 | x = bottleneck_layer(
138 | x, round_int(start_channel), 1, prob, is_training, alpha_shake,
139 | beta_shake)
140 | layer_num += 1
141 | for _ in range(1, n):
142 | start_channel += add_channel
143 | prob = calc_prob(layer_num, total_layers, p_l)
144 | x = bottleneck_layer(
145 | x, round_int(start_channel), 1, prob, is_training, alpha_shake,
146 | beta_shake)
147 | layer_num += 1
148 |
149 | start_channel += add_channel
150 | prob = calc_prob(layer_num, total_layers, p_l)
151 | x = bottleneck_layer(
152 | x, round_int(start_channel), 2, prob, is_training, alpha_shake,
153 | beta_shake)
154 | layer_num += 1
155 | for _ in range(1, n):
156 | start_channel += add_channel
157 | prob = calc_prob(layer_num, total_layers, p_l)
158 | x = bottleneck_layer(
159 | x, round_int(start_channel), 1, prob, is_training, alpha_shake,
160 | beta_shake)
161 | layer_num += 1
162 |
163 | start_channel += add_channel
164 | prob = calc_prob(layer_num, total_layers, p_l)
165 | x = bottleneck_layer(
166 | x, round_int(start_channel), 2, prob, is_training, alpha_shake,
167 | beta_shake)
168 | layer_num += 1
169 | for _ in range(1, n):
170 | start_channel += add_channel
171 | prob = calc_prob(layer_num, total_layers, p_l)
172 | x = bottleneck_layer(
173 | x, round_int(start_channel), 1, prob, is_training, alpha_shake,
174 | beta_shake)
175 | layer_num += 1
176 |
177 | assert layer_num - 1 == total_layers
178 | x = ops.batch_norm(x, scope='final_bn')
179 | x = tf.nn.relu(x)
180 | x = ops.global_avg_pool(x)
181 | # Fully connected
182 | logits = ops.fc(x, num_classes)
183 | return logits
184 |
--------------------------------------------------------------------------------
/third_party/auto_augment/shake_shake.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2019 The Google UDA Team Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """Builds the Shake-Shake Model.
16 | Copied from AutoAugment: https://github.com/tensorflow/models/blob/master/research/autoaugment/
17 | """
18 |
19 | from __future__ import absolute_import
20 | from __future__ import division
21 | from __future__ import print_function
22 |
23 | import tensorflow as tf
24 |
25 | import third_party.auto_augment.custom_ops as ops
26 |
27 |
28 | def _shake_shake_skip_connection(x, output_filters, stride):
29 | """Adds a residual connection to the filter x for the shake-shake model."""
30 | curr_filters = int(x.shape[3])
31 | if curr_filters == output_filters:
32 | return x
33 | stride_spec = ops.stride_arr(stride, stride)
34 | # Skip path 1
35 | path1 = tf.nn.avg_pool(
36 | x, [1, 1, 1, 1], stride_spec, 'VALID', data_format='NHWC')
37 | path1 = ops.conv2d(path1, int(output_filters / 2), 1, scope='path1_conv')
38 |
39 | # Skip path 2
40 | # First pad with 0's then crop
41 | pad_arr = [[0, 0], [0, 1], [0, 1], [0, 0]]
42 | path2 = tf.pad(x, pad_arr)[:, 1:, 1:, :]
43 | concat_axis = 3
44 |
45 | path2 = tf.nn.avg_pool(
46 | path2, [1, 1, 1, 1], stride_spec, 'VALID', data_format='NHWC')
47 | path2 = ops.conv2d(path2, int(output_filters / 2), 1, scope='path2_conv')
48 |
49 | # Concat and apply BN
50 | final_path = tf.concat(values=[path1, path2], axis=concat_axis)
51 | final_path = ops.batch_norm(final_path, scope='final_path_bn')
52 | return final_path
53 |
54 |
55 | def _shake_shake_branch(x, output_filters, stride, rand_forward, rand_backward,
56 | is_training):
57 | """Building a 2 branching convnet."""
58 | x = tf.nn.relu(x)
59 | x = ops.conv2d(x, output_filters, 3, stride=stride, scope='conv1')
60 | x = ops.batch_norm(x, scope='bn1')
61 | x = tf.nn.relu(x)
62 | x = ops.conv2d(x, output_filters, 3, scope='conv2')
63 | x = ops.batch_norm(x, scope='bn2')
64 | if is_training:
65 | x = x * rand_backward + tf.stop_gradient(x * rand_forward -
66 | x * rand_backward)
67 | else:
68 | x *= 1.0 / 2
69 | return x
70 |
71 |
72 | def _shake_shake_block(x, output_filters, stride, is_training):
73 | """Builds a full shake-shake sub layer."""
74 | batch_size = tf.shape(x)[0]
75 |
76 | # Generate random numbers for scaling the branches
77 | rand_forward = [
78 | tf.random_uniform(
79 | [batch_size, 1, 1, 1], minval=0, maxval=1, dtype=tf.float32)
80 | for _ in range(2)
81 | ]
82 | rand_backward = [
83 | tf.random_uniform(
84 | [batch_size, 1, 1, 1], minval=0, maxval=1, dtype=tf.float32)
85 | for _ in range(2)
86 | ]
87 | # Normalize so that all sum to 1
88 | total_forward = tf.add_n(rand_forward)
89 | total_backward = tf.add_n(rand_backward)
90 | rand_forward = [samp / total_forward for samp in rand_forward]
91 | rand_backward = [samp / total_backward for samp in rand_backward]
92 | zipped_rand = zip(rand_forward, rand_backward)
93 |
94 | branches = []
95 | for branch, (r_forward, r_backward) in enumerate(zipped_rand):
96 | with tf.variable_scope('branch_{}'.format(branch)):
97 | b = _shake_shake_branch(x, output_filters, stride, r_forward, r_backward,
98 | is_training)
99 | branches.append(b)
100 | res = _shake_shake_skip_connection(x, output_filters, stride)
101 | return res + tf.add_n(branches)
102 |
103 |
104 | def _shake_shake_layer(x, output_filters, num_blocks, stride,
105 | is_training):
106 | """Builds many sub layers into one full layer."""
107 | for block_num in range(num_blocks):
108 | curr_stride = stride if (block_num == 0) else 1
109 | with tf.variable_scope('layer_{}'.format(block_num)):
110 | x = _shake_shake_block(x, output_filters, curr_stride,
111 | is_training)
112 | return x
113 |
114 |
115 | def build_shake_shake_model(images, num_classes, hparams, is_training):
116 | """Builds the Shake-Shake model.
117 |
118 | Build the Shake-Shake model from https://arxiv.org/abs/1705.07485.
119 |
120 | Args:
121 | images: Tensor of images that will be fed into the Wide ResNet Model.
122 | num_classes: Number of classed that the model needs to predict.
123 | hparams: tf.HParams object that contains additional hparams needed to
124 | construct the model. In this case it is the `shake_shake_widen_factor`
125 | that is used to determine how many filters the model has.
126 | is_training: Is the model training or not.
127 |
128 | Returns:
129 | The logits of the Shake-Shake model.
130 | """
131 | depth = 26
132 | k = hparams.shake_shake_widen_factor # The widen factor
133 | n = int((depth - 2) / 6)
134 | x = images
135 |
136 | x = ops.conv2d(x, 16, 3, scope='init_conv')
137 | x = ops.batch_norm(x, scope='init_bn')
138 | with tf.variable_scope('L1'):
139 | x = _shake_shake_layer(x, 16 * k, n, 1, is_training)
140 | with tf.variable_scope('L2'):
141 | x = _shake_shake_layer(x, 32 * k, n, 2, is_training)
142 | with tf.variable_scope('L3'):
143 | x = _shake_shake_layer(x, 64 * k, n, 2, is_training)
144 | x = tf.nn.relu(x)
145 | x = ops.global_avg_pool(x)
146 |
147 | # Fully connected
148 | logits = ops.fc(x, num_classes)
149 | return logits
150 |
--------------------------------------------------------------------------------
/third_party/auto_augment/wrn.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2019 The Google UDA Team Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """Builds the WideResNet Model."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import numpy as np
22 | import tensorflow as tf
23 |
24 | import third_party.auto_augment.custom_ops as ops
25 |
26 |
27 | def residual_block(
28 | x, in_filter, out_filter, stride, update_bn=True,
29 | activate_before_residual=False):
30 | """Adds residual connection to `x` in addition to applying BN->ReLU->3x3 Conv.
31 |
32 | Args:
33 | x: Tensor that is the output of the previous layer in the model.
34 | in_filter: Number of filters `x` has.
35 | out_filter: Number of filters that the output of this layer will have.
36 | stride: Integer that specified what stride should be applied to `x`.
37 | activate_before_residual: Boolean on whether a BN->ReLU should be applied
38 | to x before the convolution is applied.
39 |
40 | Returns:
41 | A Tensor that is the result of applying two sequences of BN->ReLU->3x3 Conv
42 | and then adding that Tensor to `x`.
43 | """
44 |
45 | if activate_before_residual: # Pass up RELU and BN activation for resnet
46 | with tf.variable_scope('shared_activation'):
47 | x = ops.batch_norm(x, update_stats=update_bn, scope='init_bn')
48 | x = tf.nn.relu(x)
49 | orig_x = x
50 | else:
51 | orig_x = x
52 |
53 | block_x = x
54 | if not activate_before_residual:
55 | with tf.variable_scope('residual_only_activation'):
56 | block_x = ops.batch_norm(block_x, update_stats=update_bn,
57 | scope='init_bn')
58 | block_x = tf.nn.relu(block_x)
59 |
60 | with tf.variable_scope('sub1'):
61 | block_x = ops.conv2d(
62 | block_x, out_filter, 3, stride=stride, scope='conv1')
63 |
64 | with tf.variable_scope('sub2'):
65 | block_x = ops.batch_norm(block_x, update_stats=update_bn, scope='bn2')
66 | block_x = tf.nn.relu(block_x)
67 | block_x = ops.conv2d(
68 | block_x, out_filter, 3, stride=1, scope='conv2')
69 |
70 | with tf.variable_scope(
71 | 'sub_add'): # If number of filters do not agree then zero pad them
72 | if in_filter != out_filter:
73 | orig_x = ops.avg_pool(orig_x, stride, stride)
74 | orig_x = ops.zero_pad(orig_x, in_filter, out_filter)
75 | x = orig_x + block_x
76 | return x
77 |
78 |
79 | def _res_add(in_filter, out_filter, stride, x, orig_x):
80 | """Adds `x` with `orig_x`, both of which are layers in the model.
81 |
82 | Args:
83 | in_filter: Number of filters in `orig_x`.
84 | out_filter: Number of filters in `x`.
85 | stride: Integer specifying the stide that should be applied `orig_x`.
86 | x: Tensor that is the output of the previous layer.
87 | orig_x: Tensor that is the output of an earlier layer in the network.
88 |
89 | Returns:
90 | A Tensor that is the result of `x` and `orig_x` being added after
91 | zero padding and striding are applied to `orig_x` to get the shapes
92 | to match.
93 | """
94 | if in_filter != out_filter:
95 | orig_x = ops.avg_pool(orig_x, stride, stride)
96 | orig_x = ops.zero_pad(orig_x, in_filter, out_filter)
97 | x = x + orig_x
98 | orig_x = x
99 | return x, orig_x
100 |
101 |
102 | def build_wrn_model(images, num_classes, wrn_size, update_bn=True):
103 | """Builds the WRN model.
104 |
105 | Build the Wide ResNet model from https://arxiv.org/abs/1605.07146.
106 |
107 | Args:
108 | images: Tensor of images that will be fed into the Wide ResNet Model.
109 | num_classes: Number of classed that the model needs to predict.
110 | wrn_size: Parameter that scales the number of filters in the Wide ResNet
111 | model.
112 |
113 | Returns:
114 | The logits of the Wide ResNet model.
115 | """
116 | # wrn_size = 16 * widening factor k
117 | kernel_size = wrn_size
118 | filter_size = 3
119 | # depth = num_blocks_per_resnet * 6 + 4 = 28
120 | num_blocks_per_resnet = 4
121 | filters = [
122 | min(kernel_size, 16), kernel_size, kernel_size * 2, kernel_size * 4
123 | ]
124 | strides = [1, 2, 2] # stride for each resblock
125 |
126 | # Run the first conv
127 | with tf.variable_scope('init'):
128 | x = images
129 | output_filters = filters[0]
130 | x = ops.conv2d(x, output_filters, filter_size, scope='init_conv')
131 |
132 | first_x = x # Res from the beginning
133 | orig_x = x # Res from previous block
134 |
135 | for block_num in range(1, 4):
136 | with tf.variable_scope('unit_{}_0'.format(block_num)):
137 | activate_before_residual = True if block_num == 1 else False
138 | x = residual_block(
139 | x,
140 | filters[block_num - 1],
141 | filters[block_num],
142 | strides[block_num - 1],
143 | update_bn=update_bn,
144 | activate_before_residual=activate_before_residual)
145 | for i in range(1, num_blocks_per_resnet):
146 | with tf.variable_scope('unit_{}_{}'.format(block_num, i)):
147 | x = residual_block(
148 | x,
149 | filters[block_num],
150 | filters[block_num],
151 | 1,
152 | update_bn=update_bn,
153 | activate_before_residual=False)
154 | x, orig_x = _res_add(filters[block_num - 1], filters[block_num],
155 | strides[block_num - 1], x, orig_x)
156 | final_stride_val = np.prod(strides)
157 | x, _ = _res_add(filters[0], filters[3], final_stride_val, x, first_x)
158 | with tf.variable_scope('unit_last'):
159 | x = ops.batch_norm(x, scope='final_bn')
160 | x = tf.nn.relu(x)
161 | x = ops.global_avg_pool(x)
162 | logits = ops.fc(x, num_classes)
163 | return logits
164 |
--------------------------------------------------------------------------------
/third_party/vat_utils.py:
--------------------------------------------------------------------------------
1 | """Utilities derived from the VAT code."""
2 |
3 | import tensorflow as tf
4 |
5 |
6 | def generate_perturbation(x, logit, forward, epsilon, xi=1e-6):
7 | """Generate an adversarial perturbation.
8 |
9 | Args:
10 | x: Model inputs.
11 | logit: Original model output without perturbation.
12 | forward: Callable which computs logits given input.
13 | epsilon: Gradient multiplier.
14 | xi: Small constant.
15 |
16 | Returns:
17 | Aversarial perturbation to be applied to x.
18 | """
19 | d = tf.random_normal(shape=tf.shape(x))
20 |
21 | for _ in range(1):
22 | d = xi * get_normalized_vector(d)
23 | logit_p = logit
24 | logit_m = forward(x + d)
25 | dist = kl_divergence_with_logit(logit_p, logit_m)
26 | grad = tf.gradients(tf.reduce_mean(dist), [d], aggregation_method=2)[0]
27 | d = tf.stop_gradient(grad)
28 |
29 | return epsilon * get_normalized_vector(d)
30 |
31 |
32 | def kl_divergence_with_logit(q_logit, p_logit):
33 | """Compute the per-element KL-divergence of a batch."""
34 | q = tf.nn.softmax(q_logit)
35 | qlogq = tf.reduce_sum(q * logsoftmax(q_logit), 1)
36 | qlogp = tf.reduce_sum(q * logsoftmax(p_logit), 1)
37 | return qlogq - qlogp
38 |
39 |
40 | def get_normalized_vector(d):
41 | """Normalize d by infinity and L2 norms."""
42 | d /= 1e-12 + tf.reduce_max(
43 | tf.abs(d), list(range(1, len(d.get_shape()))), keepdims=True
44 | )
45 | d /= tf.sqrt(
46 | 1e-6
47 | + tf.reduce_sum(
48 | tf.pow(d, 2.0), list(range(1, len(d.get_shape()))), keepdims=True
49 | )
50 | )
51 | return d
52 |
53 |
54 | def logsoftmax(x):
55 | """Compute log-domain softmax of logits."""
56 | xdev = x - tf.reduce_max(x, 1, keepdims=True)
57 | lsm = xdev - tf.log(tf.reduce_sum(tf.exp(xdev), 1, keepdims=True))
58 | return lsm
59 |
--------------------------------------------------------------------------------
/vat.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Virtual adversarial training:a regularization method for supervised and semi-supervised learning.
15 |
16 | Application to SSL of https://arxiv.org/abs/1704.03976
17 | """
18 |
19 | import functools
20 | import os
21 |
22 | import tensorflow as tf
23 | from absl import app
24 | from absl import flags
25 |
26 | from libml import utils, data, layers, models
27 | from libml.utils import EasyDict
28 | from third_party import vat_utils
29 |
30 | FLAGS = flags.FLAGS
31 |
32 |
33 | class VAT(models.MultiModel):
34 |
35 | def model(self, batch, lr, wd, ema, warmup_pos, vat, vat_eps, entmin_weight, **kwargs):
36 | hwc = [self.dataset.height, self.dataset.width, self.dataset.colors]
37 | xt_in = tf.placeholder(tf.float32, [batch] + hwc, 'xt') # For training
38 | x_in = tf.placeholder(tf.float32, [None] + hwc, 'x')
39 | y_in = tf.placeholder(tf.float32, [batch] + hwc, 'y')
40 | l_in = tf.placeholder(tf.int32, [batch], 'labels')
41 | wd *= lr
42 | warmup = tf.clip_by_value(tf.to_float(self.step) / (warmup_pos * (FLAGS.train_kimg << 10)), 0, 1)
43 |
44 | classifier = lambda x, **kw: self.classifier(x, **kw, **kwargs).logits
45 | l = tf.one_hot(l_in, self.nclass)
46 | logits_x = classifier(xt_in, training=True)
47 | post_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) # Take only first call to update batch norm.
48 | logits_y = classifier(y_in, training=True)
49 | delta_y = vat_utils.generate_perturbation(y_in, logits_y, lambda x: classifier(x, training=True), vat_eps)
50 | logits_student = classifier(y_in + delta_y, training=True)
51 | logits_teacher = tf.stop_gradient(logits_y)
52 | loss_vat = layers.kl_divergence_from_logits(logits_student, logits_teacher)
53 | loss_vat = tf.reduce_mean(loss_vat)
54 | loss_entmin = tf.reduce_mean(tf.distributions.Categorical(logits=logits_y).entropy())
55 |
56 | loss = tf.nn.softmax_cross_entropy_with_logits_v2(labels=l, logits=logits_x)
57 | loss = tf.reduce_mean(loss)
58 | tf.summary.scalar('losses/xe', loss)
59 | tf.summary.scalar('losses/vat', loss_vat)
60 | tf.summary.scalar('losses/entmin', loss_entmin)
61 |
62 | ema = tf.train.ExponentialMovingAverage(decay=ema)
63 | ema_op = ema.apply(utils.model_vars())
64 | ema_getter = functools.partial(utils.getter_ema, ema)
65 | post_ops.append(ema_op)
66 | post_ops.extend([tf.assign(v, v * (1 - wd)) for v in utils.model_vars('classify') if 'kernel' in v.name])
67 |
68 | train_op = tf.train.AdamOptimizer(lr).minimize(loss + loss_vat * warmup * vat + entmin_weight * loss_entmin,
69 | colocate_gradients_with_ops=True)
70 | with tf.control_dependencies([train_op]):
71 | train_op = tf.group(*post_ops)
72 |
73 | return EasyDict(
74 | xt=xt_in, x=x_in, y=y_in, label=l_in, train_op=train_op,
75 | classify_raw=tf.nn.softmax(classifier(x_in, training=False)), # No EMA, for debugging.
76 | classify_op=tf.nn.softmax(classifier(x_in, getter=ema_getter, training=False)))
77 |
78 |
79 | def main(argv):
80 | utils.setup_main()
81 | del argv # Unused.
82 | dataset = data.DATASETS()[FLAGS.dataset]()
83 | log_width = utils.ilog2(dataset.width)
84 | model = VAT(
85 | os.path.join(FLAGS.train_dir, dataset.name),
86 | dataset,
87 | lr=FLAGS.lr,
88 | wd=FLAGS.wd,
89 | arch=FLAGS.arch,
90 | warmup_pos=FLAGS.warmup_pos,
91 | batch=FLAGS.batch,
92 | nclass=dataset.nclass,
93 | ema=FLAGS.ema,
94 | smoothing=FLAGS.smoothing,
95 | vat=FLAGS.vat,
96 | vat_eps=FLAGS.vat_eps,
97 | entmin_weight=FLAGS.entmin_weight,
98 |
99 | scales=FLAGS.scales or (log_width - 2),
100 | filters=FLAGS.filters,
101 | repeat=FLAGS.repeat)
102 | model.train(FLAGS.train_kimg << 10, FLAGS.report_kimg << 10)
103 |
104 |
105 | if __name__ == '__main__':
106 | utils.setup_tf()
107 | flags.DEFINE_float('wd', 0.02, 'Weight decay.')
108 | flags.DEFINE_float('vat', 0.3, 'VAT weight.')
109 | flags.DEFINE_float('vat_eps', 6, 'VAT perturbation size.')
110 | flags.DEFINE_float('entmin_weight', 0.06, 'Entropy minimization weight.')
111 | flags.DEFINE_float('warmup_pos', 0.4, 'Relative position at which constraint loss warmup ends.')
112 | flags.DEFINE_float('ema', 0.999, 'Exponential moving average of params.')
113 | flags.DEFINE_float('smoothing', 0.1, 'Label smoothing.')
114 | flags.DEFINE_integer('scales', 0, 'Number of 2x2 downscalings in the classifier.')
115 | flags.DEFINE_integer('filters', 32, 'Filter size of convolutions.')
116 | flags.DEFINE_integer('repeat', 4, 'Number of residual layers per stage.')
117 | FLAGS.set_default('dataset', 'cifar10.3@250-5000')
118 | FLAGS.set_default('batch', 64)
119 | FLAGS.set_default('lr', 0.002)
120 | FLAGS.set_default('train_kimg', 1 << 16)
121 | app.run(main)
122 |
--------------------------------------------------------------------------------