├── LICENSE ├── README.md ├── moon_data_exp.py ├── parser.py ├── scripts ├── fixmatch-setup │ ├── cifar10 │ │ ├── base.sh │ │ ├── fixmatch.sh │ │ └── uda.sh │ ├── cifar100 │ │ ├── base.sh │ │ ├── fixmatch.sh │ │ └── uda.sh │ ├── stl10 │ │ ├── base.sh │ │ ├── fixmatch.sh │ │ └── uda.sh │ └── svhn │ │ ├── base.sh │ │ ├── fixmatch.sh │ │ └── uda.sh └── realistic-evaluation-setup │ ├── cifar10 │ ├── base.sh │ ├── fixmatch.sh │ ├── ict.sh │ ├── mean_teacher.sh │ ├── pi_model.sh │ ├── pseudo_label.sh │ ├── supervised.sh │ ├── uda.sh │ └── vat.sh │ └── svhn │ ├── base.sh │ ├── fixmatch.sh │ ├── ict.sh │ ├── mean_teacher.sh │ ├── pi_model.sh │ ├── pseudo_label.sh │ ├── supervised.sh │ ├── uda.sh │ └── vat.sh ├── ssl_lib ├── __init__.py ├── algs │ ├── __init__.py │ ├── builder.py │ ├── consistency.py │ ├── ict.py │ ├── pseudo_label.py │ ├── utils.py │ └── vat.py ├── augmentation │ ├── __init__.py │ ├── augmentation_class.py │ ├── augmentation_pool.py │ ├── builder.py │ ├── rand_augment.py │ └── utils.py ├── consistency │ ├── __init__.py │ ├── builder.py │ ├── cross_entropy.py │ └── mean_squared.py ├── datasets │ ├── __init__.py │ ├── builder.py │ ├── dataset_class.py │ └── utils.py ├── misc │ ├── __init__.py │ └── meter.py ├── models │ ├── __init__.py │ ├── builder.py │ ├── cnn13.py │ ├── resnet.py │ ├── shakenet.py │ └── utils.py └── param_scheduler │ ├── __init__.py │ └── scheduler.py ├── train_test.py └── train_val_test.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 teppei suzuki 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # [WIP]Consistency Regularization for Semi-supervised Learning with PyTorch 2 | This repositrory includes consistency regularization algorithms for semi-supervised learning: 3 | - Pi-Model 4 | - Pseudo-label 5 | - Mean Teacher 6 | - Virtual Adversarial Training 7 | - Interpolation Consistency Training 8 | - Unsupervised Data Augmentation 9 | - FixMatch (with RandAugment) 10 | 11 | Training and evaluation setting follow Oliver+ 2018 and FixMatch. 12 | 13 | # Requirements 14 | - Python >= 3.7 15 | - PyTorch >= 1.0 16 | - torchvision >= 0.4 17 | - NumPy 18 | - sklearn (optional) 19 | 20 | sklean is used for moon_data_exp.py (two moons dataset experiment) 21 | 22 | # Usage 23 | One can use ```sh ./scripts/DATASET_NAME/ALGORITHM.sh /PATH/TO/OUTPUT_DIR NUM_LABELS```, 24 | for example, to reproduce fixmatch in CIFAR-10 with 250 labels results, run 25 | 26 | ``` 27 | sh ./scripts/fixmatch-setup/cifar10/fixmatch.sh ./results/cifar10-fixmatch-250labeles 250 28 | ``` 29 | 30 | The scripts in ```scripts/fixmatch-setup``` are for training and evaluating a model with the FixMatch setting, 31 | and the scripts in ```scripst/realistic-evaluation-setup``` are for training and evaluating a model with the Oliver+ 2018 setting. 32 | 33 | If yor would like to train a model with own setting, please see ```parser.py```. 34 | 35 | NOTE: ```train_test.py``` evaluates a model performance as median of last [1, 10, 20, 50] checkpoint accuracies (FixMatch setting), 36 | and ```train_val_test.py``` evaluates the test accuracy of the best model on validation data (Oliver+ 2018 setting). 37 | 38 | # Performance 39 | WIP 40 | ||Oliver+ 2018||this repo| | 41 | |--|--|--|--|--| 42 | ||CIFAR-10 4000 labels|SVHN 1000 labels|CIFAR-10 4000 labels|SVHN 1000 labels| 43 | |Supervised|20.26 ±0.38|12.83 ±0.47|19.85|11.03 44 | |Pi-Model|16.37 ±0.63|7.19 ±0.27|14.84|7.87 45 | |Mean Teacher|15.87 ±0.28|5.65 ±0.47|14.28|5.83 46 | |VAT|13.13 ±0.39|5.35 ±0.19|12.15|6.38 47 | 48 | NOTE: Our implementation is different from Oliver+ 2018 as follows: 49 | 1. we use not only purely unlabeled data, but also labeled data as unlabeled data. (following Sohn+ 2020) 50 | 2. our VAT implementation follows Miyato+, but Oliver+ use KLD with different directions as the loss function. 51 | see [issue](https://github.com/brain-research/realistic-ssl-evaluation/issues/27). 52 | 3. parameter initialization of WRN-28. (following Sohn+ 2020) 53 | 54 | If you would like to evaluate the model with the same conditions as Oliver+ 2018, please see [this repo](https://github.com/perrying/realistic-ssl-evaluation-pytorch). 55 | 56 | ||Sohn+ 2020||this repo| | 57 | |--|--|--|--|--| 58 | ||CIFAR-10 250 labels|CIFAR-10 4000 labels|CIFAR-10 250 labels|CIFAR-10 4000 labels| 59 | |UDA|8.82±1.08|4.88±0.18 | 10.08 | 6.32 60 | |FixMatch|5.07±0.65|4.26±0.05| 9.88 | 6.84 61 | 62 | reported error rates are the median of last 20 checkpoints 63 | 64 | # Citation 65 | ``` 66 | @misc{suzuki2020consistency, 67 | author = {Teppei Suzuki}, 68 | title = {Consistency Regularization for Semi-supervised Learning with PyTorch}, 69 | year = {2020}, 70 | publisher = {GitHub}, 71 | journal = {GitHub repository}, 72 | howpublished = {\url{https://github.com/perrying/pytorch-consistency-regularization}}, 73 | } 74 | ``` 75 | 76 | # References 77 | - Miyato, Takeru, et al. "Distributional smoothing with virtual adversarial training." arXiv preprint arXiv:1507.00677 (2015). 78 | - Laine, Samuli, and Timo Aila. "Temporal ensembling for semi-supervised learning." arXiv preprint arXiv:1610.02242 (2016). 79 | - Sajjadi, Mehdi, Mehran Javanmardi, and Tolga Tasdizen. "Regularization with stochastic transformations and perturbations for deep semi-supervised learning." Advances in neural information processing systems. 2016. 80 | - Tarvainen, Antti, and Harri Valpola. "Mean teachers are better role models: Weight-averaged consistency targets improve semi-supervised deep learning results." Advances in neural information processing systems. 2017. 81 | - Miyato, Takeru, et al. "Virtual adversarial training: a regularization method for supervised and semi-supervised learning." IEEE transactions on pattern analysis and machine intelligence 41.8 (2018): 1979-1993. 82 | - Oliver, Avital, et al. "Realistic evaluation of deep semi-supervised learning algorithms." Advances in Neural Information Processing Systems. 2018. 83 | - Verma, Vikas, et al. "Interpolation consistency training for semi-supervised learning." arXiv preprint arXiv:1903.03825 (2019). 84 | - Sohn, Kihyuk, et al. "Fixmatch: Simplifying semi-supervised learning with consistency and confidence." arXiv preprint arXiv:2001.07685 (2020). 85 | -------------------------------------------------------------------------------- /moon_data_exp.py: -------------------------------------------------------------------------------- 1 | """ 2 | Two moons experiment for visualization 3 | """ 4 | import os 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | import torch.optim as optim 10 | from torch.utils.data import DataLoader 11 | import matplotlib.pyplot as plt 12 | from sklearn.datasets import make_moons 13 | from tqdm import tqdm 14 | 15 | from ssl_lib.algs.builder import gen_ssl_alg 16 | from ssl_lib.models.utils import ema_update 17 | from ssl_lib.consistency.builder import gen_consistency 18 | 19 | 20 | def gen_model(): 21 | return nn.Sequential( 22 | nn.Linear(2, 128), 23 | nn.ReLU(), 24 | nn.Linear(128, 256), 25 | nn.ReLU(), 26 | nn.Linear(256, 2) 27 | ) 28 | 29 | 30 | def gen_ssl_moon_dataset(seed, num_samples, labeled_sample, noise_factor=0.1): 31 | assert num_samples > labeled_sample 32 | data, label = make_moons(num_samples, False, noise_factor, random_state=seed) 33 | data = (data - data.mean(0, keepdims=True)) / data.std(0, keepdims=True) 34 | 35 | l0_idx = (label == 0) 36 | l1_idx = (label == 1) 37 | 38 | l0_data = data[l0_idx] 39 | l1_data = data[l1_idx] 40 | 41 | np.random.seed(seed) 42 | 43 | l0_data = np.random.permutation(l0_data) 44 | l1_data = np.random.permutation(l1_data) 45 | 46 | labeled_l0 = l0_data[:labeled_sample//2] 47 | labeled_l1 = l1_data[:labeled_sample//2] 48 | 49 | unlabeled = np.concatenate([ 50 | l0_data[labeled_sample//2:], l1_data[labeled_sample//2:] 51 | ]) 52 | 53 | l0_label = np.zeros(labeled_l0.shape[0]) 54 | l1_label = np.ones(labeled_l1.shape[0]) 55 | label = np.concatenate([l0_label, l1_label]) 56 | 57 | return labeled_l0, labeled_l1, unlabeled, label 58 | 59 | 60 | def scatter_plot_with_confidence(l0_data, l1_data, all_data, model, device, out_dir=None, show=False): 61 | xx, yy = np.meshgrid( 62 | np.linspace(all_data[:,0].min()-0.1, all_data[:,0].max()+0.1, 1000), 63 | np.linspace(all_data[:,1].min()-0.1, all_data[:,1].max()+0.1, 1000)) 64 | np_points = np.stack([xx.ravel(),yy.ravel()],1).reshape(-1, 2) 65 | points = torch.from_numpy(np_points).to(device).float() 66 | outputs = model(points).softmax(1)[:,1].detach().to("cpu").numpy().reshape(xx.shape) 67 | plt.contourf(xx, yy, outputs, alpha=0.5, cmap=plt.cm.jet) 68 | plt.scatter(all_data[:,0], all_data[:,1], c="gray") 69 | plt.scatter(l0_data[:,0], l0_data[:,1], c="blue") 70 | plt.scatter(l1_data[:,0], l1_data[:,1], c="red") 71 | plt.xlim(-2, 2) 72 | plt.ylim(-2, 2) 73 | # plt.grid() 74 | plt.tight_layout() 75 | if out_dir is not None: 76 | plt.savefig(os.path.join(out_dir, "confidence_with_labeled.png")) 77 | if show: 78 | plt.show() 79 | plt.contourf(xx, yy, outputs, alpha=0.5, cmap=plt.cm.jet) 80 | plt.scatter(l0_data[:,0], l0_data[:,1], c="blue") 81 | plt.scatter(l1_data[:,0], l1_data[:,1], c="red") 82 | plt.xlim(-2, 2) 83 | plt.ylim(-2, 2) 84 | # plt.grid() 85 | plt.tight_layout() 86 | if out_dir is not None: 87 | plt.savefig(os.path.join(out_dir, "confidence.png")) 88 | if show: 89 | plt.show() 90 | 91 | 92 | def scatter_plot(l0_data, l1_data, unlabeled_data, out_dir=None, show=False): 93 | plt.scatter(unlabeled_data[:,0], unlabeled_data[:,1], c="gray") 94 | plt.scatter(l0_data[:,0], l0_data[:,1], c="blue") 95 | plt.scatter(l1_data[:,0], l1_data[:,1], c="red") 96 | plt.xlim(-2, 2) 97 | plt.ylim(-2, 2) 98 | # plt.grid() 99 | plt.tight_layout() 100 | if out_dir is not None: 101 | plt.savefig(os.path.join(out_dir, "labeled_raw_data.png")) 102 | if show: 103 | plt.show() 104 | plt.scatter(l0_data[:,0], l0_data[:,1], c="blue") 105 | plt.scatter(l1_data[:,0], l1_data[:,1], c="red") 106 | plt.xlim(-2, 2) 107 | plt.ylim(-2, 2) 108 | # plt.grid() 109 | plt.tight_layout() 110 | if out_dir is not None: 111 | plt.savefig(os.path.join(out_dir, "raw_data.png")) 112 | if show: 113 | plt.show() 114 | 115 | def fit(cfg): 116 | torch.manual_seed(cfg.seed) 117 | if torch.cuda.is_available(): 118 | device = "cuda" 119 | torch.backends.cudnn.benchmark = True 120 | else: 121 | device = "cpu" 122 | 123 | model = gen_model().to(device) 124 | model.train() 125 | 126 | optimizer = optim.Adam(model.parameters(), cfg.lr) 127 | 128 | weak_augmentation = lambda x: x + torch.randn_like(x) * cfg.gauss_std 129 | 130 | # set consistency type 131 | consistency = gen_consistency(cfg.consistency, cfg) 132 | # set ssl algorithm 133 | ssl_alg = gen_ssl_alg( 134 | cfg.alg, 135 | cfg 136 | ) 137 | 138 | l0_data, l1_data, u_data, label = gen_ssl_moon_dataset( 139 | cfg.seed, cfg.n_sample, cfg.n_labeled, cfg.noise_factor 140 | ) 141 | 142 | labeled_data = np.concatenate([l0_data, l1_data]) 143 | 144 | scatter_plot(l0_data, l1_data, u_data, cfg.out_dir, cfg.vis_data) 145 | 146 | tch_labeled_data = torch.from_numpy(labeled_data).float().to(device) 147 | tch_unlabeled_data = torch.from_numpy(u_data).float().to(device) 148 | label = torch.from_numpy(label).long().to(device) 149 | 150 | for i in range(cfg.iterations): 151 | unlabeled_weak1 = weak_augmentation(tch_unlabeled_data) 152 | unlabeled_weak2 = weak_augmentation(tch_unlabeled_data) 153 | all_data = torch.cat([ 154 | tch_labeled_data, 155 | unlabeled_weak1, 156 | unlabeled_weak2], 0) 157 | 158 | outputs = model(all_data) 159 | labeled_logits = outputs[:tch_labeled_data.shape[0]] 160 | loss = F.cross_entropy(labeled_logits, label) 161 | if cfg.coef > 0: 162 | unlabeled_logits, unlabeled_logits_target = torch.chunk(outputs[tch_labeled_data.shape[0]:], 2, dim=2) 163 | 164 | y, targets, mask = ssl_alg( 165 | stu_preds = unlabeled_logits, 166 | tea_logits = unlabeled_logits_target.detach(), 167 | w_data = unlabeled_weak1, 168 | s_data = unlabeled_weak2, 169 | stu_forward = model, 170 | tea_forward = model 171 | ) 172 | 173 | L_consistency = consistency(y, targets, mask) 174 | loss += cfg.coef * L_consistency 175 | else: 176 | L_consistency = torch.zeros_like(loss) 177 | 178 | if cfg.entropy_minimize > 0: 179 | loss -= cfg.entropy_minimize * (unlabeled_logits.softmax(1) * F.log_softmax(unlabeled_logits, 1)).sum(1).mean() 180 | 181 | print("[{}/{}] loss {} | ssl loss {}".format( 182 | i+1, cfg.iterations, loss.item(), L_consistency.item())) 183 | 184 | optimizer.zero_grad() 185 | loss.backward() 186 | optimizer.step() 187 | 188 | scatter_plot_with_confidence(l0_data, l1_data, all_data, model, device, cfg.out_dir, cfg.vis_data) 189 | 190 | 191 | if __name__ == "__main__": 192 | import argparse 193 | parser = argparse.ArgumentParser() 194 | # dataset config 195 | parser.add_argument("--n_sample", default=1000, type=int, help="number of samples") 196 | parser.add_argument("--n_labeled", default=10, type=int, help="number of labeled samples") 197 | parser.add_argument("--noise_factor", default=0.1, type=float, help="std of gaussian noise") 198 | # optimization config 199 | parser.add_argument("--iterations", default=1000, type=int, help="number of training iteration") 200 | parser.add_argument("--lr", default=0.01, type=float, help="learning rate") 201 | # SSL common config 202 | parser.add_argument("--alg", default="cr", type=str, help="ssl algorithm, ['ict', 'cr', 'pl', 'vat']") 203 | parser.add_argument("--coef", default=1, type=float, help="coefficient for consistency loss") 204 | parser.add_argument("--ema_teacher", action="store_true", help="consistency with mean teacher") 205 | parser.add_argument("--ema_factor", default=0.999, type=float, help="exponential mean avarage factor") 206 | parser.add_argument("--entropy_minimize", "-em", default=0, type=float, help="coefficient of entropy minimization") 207 | parser.add_argument("--threshold", default=None, type=float, help="pseudo label threshold") 208 | parser.add_argument("--sharpen", default=None, type=float, help="tempereture parameter for sharpening") 209 | parser.add_argument("--temp_softmax", default=None, type=float, help="tempereture for softmax") 210 | parser.add_argument("--gauss_std", default=0.1, type=float, help="standard deviation for gaussian noise") 211 | ## SSL alg parameter 212 | ### ICT config 213 | parser.add_argument("--alpha", default=0.1, type=float, help="parameter for beta distribution in ICT") 214 | ### VAT config 215 | parser.add_argument("--eps", default=6, type=float, help="norm of virtual adversarial noise") 216 | parser.add_argument("--xi", default=1e-6, type=float, help="perturbation for finite difference method") 217 | parser.add_argument("--vat_iter", default=1, type=int, help="number of iteration for power iteration") 218 | ## consistency config 219 | parser.add_argument("--consistency", "-consis", default="ce", type=str, help="consistency type, ['ce', 'ms']") 220 | parser.add_argument("--sinkhorn_tau", default=10, type=float, help="tempereture parameter for sinkhorn distance") 221 | parser.add_argument("--sinkhorn_iter", default=10, type=int, help="number of iterations for sinkhorn normalization") 222 | # evaluation config 223 | parser.add_argument("--weight_average", action="store_true", help="evaluation with weight-averaged model") 224 | # misc 225 | parser.add_argument("--out_dir", default="log", type=str, help="output directory") 226 | parser.add_argument("--seed", default=96, type=int, help="random seed") 227 | parser.add_argument("--vis_data", action="store_true", help="visualize input data") 228 | 229 | args = parser.parse_args() 230 | 231 | fit(args) 232 | -------------------------------------------------------------------------------- /parser.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | 4 | def get_args(): 5 | parser = argparse.ArgumentParser() 6 | # dataset config 7 | parser.add_argument("--root", "-r", default="./data", type=str, help="/path/to/dataset") 8 | parser.add_argument("--dataset", "-d", default="cifar10", choices=['stl10', 'svhn', 'cifar10', 'cifar100'], type=str, help="dataset name") 9 | parser.add_argument("--num_labels", default=4000, type=int, help="number of labeled data") 10 | parser.add_argument("--val_ratio", default=0.1, type=float, help="the ratio of evaluation data to training data.") 11 | parser.add_argument("--random_split", action="store_true", help="random sampleing from training data for validation") 12 | parser.add_argument("--num_workers", default=8, type=int, help="number of thread for CPU parallel") 13 | parser.add_argument("--whiten", action="store_true", help="use whitening as preprocessing") 14 | parser.add_argument("--zca", action="store_true", help="use zca whitening as preprocessing") 15 | # augmentation config 16 | parser.add_argument("--labeled_aug", default="WA", choices=['WA', 'RA'], type=str, help="type of augmentation for labeled data") 17 | parser.add_argument("--unlabeled_aug", default="WA", choices=['WA', 'RA'], type=str, help="type of augmentation for unlabeled data") 18 | parser.add_argument("--wa", default="t.t.f", type=str, help="transformations (flip, crop, noise) for weak augmentation. t and f indicate true and false.") 19 | parser.add_argument("--strong_aug", action="store_true", help="use strong augmentation (RandAugment) for unlabeled data") 20 | # optimization config 21 | parser.add_argument("--model", default="wrn", choices=['wrn', 'shake', 'cnn13'], type=str, help="model architecture") 22 | parser.add_argument("--ul_batch_size", "-ul_bs", default=50, type=int, help="mini-batch size of unlabeled data") 23 | parser.add_argument("--l_batch_size", "-l_bs", default=50, type=int, help="mini-batch size of labeled data") 24 | parser.add_argument("--optimizer", "-opt", default="sgd", choices=['sgd', 'adam'], type=str, help="optimizer") 25 | parser.add_argument("--lr", default=3e-2, type=float, help="learning rate") 26 | parser.add_argument("--weight_decay", "-wd", default=0.0005, type=float, help="weight decay") 27 | parser.add_argument("--momentum", default=0.9, type=float, help="momentum for sgd or beta_1 for adam") 28 | parser.add_argument("--iteration", default=500000, type=int, help="number of training iteration") 29 | parser.add_argument("--lr_decay", default="cos", choices=['cos', 'step'], type=str, help="way to decay learning rate") 30 | parser.add_argument("--lr_decay_rate", default=0.2, type=float, help="decay rate for step lr decay") 31 | parser.add_argument("--only_validation", action="store_true", help="only training and validation for hyperparameter tuning") 32 | parser.add_argument("--warmup_iter", default=0, type=int, help="wnumber of armup iteration for SSL loss coefficient") 33 | parser.add_argument("--tsa", action="store_true", help="use training signal annealing proposed by UDA") 34 | parser.add_argument("--tsa_schedule", default="linear", choices=['linear', 'exp', 'log'], type=str, help="tsa schedule") 35 | # SSL common config 36 | parser.add_argument("--alg", default="cr", choices=['ict', 'cr', 'pl', 'vat'], type=str, help="ssl algorithm") 37 | parser.add_argument("--coef", default=1, type=float, help="coefficient for consistency loss") 38 | parser.add_argument("--ema_teacher", action="store_true", help="use mean teacher") 39 | parser.add_argument("--ema_teacher_warmup", action="store_true", help="warmup for mean teacher") 40 | parser.add_argument("--ema_teacher_factor", default=0.999, type=float, help="exponential mean avarage factor for mean teacher") 41 | parser.add_argument("--ema_apply_wd", action="store_true", help="apply weight decay to ema model") 42 | parser.add_argument("--entropy_minimization", "-em", default=0, type=float, help="coefficient of entropy minimization") 43 | parser.add_argument("--threshold", default=None, type=float, help="pseudo label threshold") 44 | parser.add_argument("--sharpen", default=None, type=float, help="tempereture parameter for sharpening") 45 | parser.add_argument("--temp_softmax", default=None, type=float, help="tempereture for softmax") 46 | parser.add_argument("--consistency", "-consis", default="ce", choices=['ce', 'ms'], type=str, help="consistency type") 47 | ## SSL alg parameter 48 | ### ICT config 49 | parser.add_argument("--alpha", default=0.1, type=float, help="parameter for beta distribution in ICT") 50 | ### VAT config 51 | parser.add_argument("--eps", default=6, type=float, help="norm of virtual adversarial noise") 52 | parser.add_argument("--xi", default=1e-6, type=float, help="perturbation for finite difference method") 53 | parser.add_argument("--vat_iter", default=1, type=int, help="number of iteration for power iteration") 54 | # evaluation config 55 | parser.add_argument("--weight_average", action="store_true", help="evaluation with weight-averaged model") 56 | parser.add_argument("--wa_ema_factor", default=0.999, type=float, help="exponential mean avarage factor for weight-averaged model") 57 | parser.add_argument("--wa_apply_wd", action="store_true", help="apply weight decay to weight-averaged model") 58 | parser.add_argument("--checkpoint", default=10000, type=int, help="checkpoint every N samples") 59 | # training from checkpoint 60 | parser.add_argument("--checkpoint_model", default=None, type=str, help="path to checkpoint model") 61 | parser.add_argument("--checkpoint_optimizer", default=None, type=str, help="path to checkpoint optimizer") 62 | parser.add_argument("--start_iter", default=None, type=int, help="start iteration") 63 | # misc 64 | parser.add_argument("--out_dir", default="log", type=str, help="output directory") 65 | parser.add_argument("--seed", default=96, type=int, help="random seed") 66 | parser.add_argument("--disp", default=256, type=int, help="display loss every N") 67 | return parser.parse_args() 68 | -------------------------------------------------------------------------------- /scripts/fixmatch-setup/cifar10/base.sh: -------------------------------------------------------------------------------- 1 | python3 train_test.py \ 2 | --lr 3e-2 \ 3 | -wd 5e-4 \ 4 | --dataset cifar10 \ 5 | -ul_bs 448 \ 6 | -l_bs 64 \ 7 | --weight_average \ 8 | --iteration 1048576 \ 9 | --checkpoint 1024 \ 10 | --wa_apply_wd \ 11 | $* -------------------------------------------------------------------------------- /scripts/fixmatch-setup/cifar10/fixmatch.sh: -------------------------------------------------------------------------------- 1 | python3 train_test.py \ 2 | --lr 3e-2 \ 3 | -wd 5e-4 \ 4 | --dataset cifar10 \ 5 | -ul_bs 448 \ 6 | -l_bs 64 \ 7 | --weight_average \ 8 | --iteration 1048576 \ 9 | --checkpoint 1024 \ 10 | --wa_apply_wd \ 11 | --alg pl \ 12 | --strong_aug \ 13 | --threshold 0.95 \ 14 | --coef 1 \ 15 | --out_dir $1 \ 16 | --num_labels $2 -------------------------------------------------------------------------------- /scripts/fixmatch-setup/cifar10/uda.sh: -------------------------------------------------------------------------------- 1 | python3 train_test.py \ 2 | --lr 3e-2 \ 3 | -wd 5e-4 \ 4 | --dataset cifar10 \ 5 | -ul_bs 448 \ 6 | -l_bs 64 \ 7 | --weight_average \ 8 | --iteration 1048576 \ 9 | --checkpoint 1024 \ 10 | --wa_apply_wd \ 11 | --strong_aug \ 12 | --threshold 0.8 \ 13 | --temp_softmax 0.4 \ 14 | --tsa \ 15 | --coef 1 \ 16 | --out_dir $1 \ 17 | --num_labels $2 -------------------------------------------------------------------------------- /scripts/fixmatch-setup/cifar100/base.sh: -------------------------------------------------------------------------------- 1 | python3 train_test.py \ 2 | --lr 3e-2 \ 3 | -wd 1e-3 \ 4 | --dataset cifar100 \ 5 | -ul_bs 448 \ 6 | -l_bs 64 \ 7 | --weight_average \ 8 | --iteration 1048576 \ 9 | --checkpoint 1024 \ 10 | --wa_apply_wd \ 11 | $* -------------------------------------------------------------------------------- /scripts/fixmatch-setup/cifar100/fixmatch.sh: -------------------------------------------------------------------------------- 1 | python3 train_test.py \ 2 | --lr 3e-2 \ 3 | -wd 1e-3 \ 4 | --dataset cifar100 \ 5 | -ul_bs 448 \ 6 | -l_bs 64 \ 7 | --weight_average \ 8 | --iteration 1048576 \ 9 | --checkpoint 1024 \ 10 | --wa_apply_wd \ 11 | --alg pl \ 12 | --strong_aug \ 13 | --threshold 0.95 \ 14 | --coef 1 \ 15 | --out_dir $1 \ 16 | --num_labels $2 -------------------------------------------------------------------------------- /scripts/fixmatch-setup/cifar100/uda.sh: -------------------------------------------------------------------------------- 1 | python3 train_test.py \ 2 | --lr 3e-2 \ 3 | -wd 1e-3 \ 4 | --dataset cifar100 \ 5 | -ul_bs 448 \ 6 | -l_bs 64 \ 7 | --weight_average \ 8 | --iteration 1048576 \ 9 | --checkpoint 1024 \ 10 | --wa_apply_wd \ 11 | --strong_aug \ 12 | --threshold 0.8 \ 13 | --temp_softmax 0.4 \ 14 | --tsa \ 15 | --coef 1 \ 16 | --out_dir $1 \ 17 | --num_labels $2 -------------------------------------------------------------------------------- /scripts/fixmatch-setup/stl10/base.sh: -------------------------------------------------------------------------------- 1 | python3 train_test.py \ 2 | --lr 3e-2 \ 3 | -wd 5e-4 \ 4 | --dataset stl10 \ 5 | -ul_bs 448 \ 6 | -l_bs 64 \ 7 | --weight_average \ 8 | --iteration 1048576 \ 9 | --checkpoint 1024 \ 10 | --wa_apply_wd \ 11 | $* -------------------------------------------------------------------------------- /scripts/fixmatch-setup/stl10/fixmatch.sh: -------------------------------------------------------------------------------- 1 | python3 train_test.py \ 2 | --lr 3e-2 \ 3 | -wd 5e-4 \ 4 | --dataset stl10 \ 5 | -ul_bs 448 \ 6 | -l_bs 64 \ 7 | --weight_average \ 8 | --iteration 1048576 \ 9 | --checkpoint 1024 \ 10 | --wa_apply_wd \ 11 | --alg pl \ 12 | --strong_aug \ 13 | --threshold 0.95 \ 14 | --coef 1 \ 15 | --out_dir $1 \ 16 | --num_labels $2 -------------------------------------------------------------------------------- /scripts/fixmatch-setup/stl10/uda.sh: -------------------------------------------------------------------------------- 1 | python3 train_test.py \ 2 | --lr 3e-2 \ 3 | -wd 5e-4 \ 4 | --dataset stl10 \ 5 | -ul_bs 448 \ 6 | -l_bs 64 \ 7 | --weight_average \ 8 | --iteration 1048576 \ 9 | --checkpoint 1024 \ 10 | --wa_apply_wd \ 11 | --strong_aug \ 12 | --threshold 0.8 \ 13 | --temp_softmax 0.4 \ 14 | --tsa \ 15 | --coef 1 \ 16 | --out_dir $1 \ 17 | --num_labels $2 -------------------------------------------------------------------------------- /scripts/fixmatch-setup/svhn/base.sh: -------------------------------------------------------------------------------- 1 | python3 train_test.py \ 2 | --lr 3e-2 \ 3 | -wd 5e-4 \ 4 | --dataset svhn \ 5 | -ul_bs 448 \ 6 | -l_bs 64 \ 7 | --weight_average \ 8 | --iteration 1048576 \ 9 | --checkpoint 1024 \ 10 | --wa f.t.f \ 11 | --wa_apply_wd \ 12 | $* -------------------------------------------------------------------------------- /scripts/fixmatch-setup/svhn/fixmatch.sh: -------------------------------------------------------------------------------- 1 | python3 train_test.py \ 2 | --lr 3e-2 \ 3 | -wd 5e-4 \ 4 | --dataset svhn \ 5 | -ul_bs 448 \ 6 | -l_bs 64 \ 7 | --weight_average \ 8 | --iteration 1048576 \ 9 | --checkpoint 1024 \ 10 | --wa f.t.f \ 11 | --wa_apply_wd \ 12 | --alg pl \ 13 | --strong_aug \ 14 | --threshold 0.95 \ 15 | --coef 1 \ 16 | --out_dir $1 \ 17 | --num_labels $2 -------------------------------------------------------------------------------- /scripts/fixmatch-setup/svhn/uda.sh: -------------------------------------------------------------------------------- 1 | python3 train_test.py \ 2 | --lr 3e-2 \ 3 | -wd 5e-4 \ 4 | --dataset svhn \ 5 | -ul_bs 448 \ 6 | -l_bs 64 \ 7 | --weight_average \ 8 | --iteration 1048576 \ 9 | --checkpoint 1024 \ 10 | --wa f.t.f \ 11 | --wa_apply_wd \ 12 | --strong_aug \ 13 | --threshold 0.8 \ 14 | --temp_softmax 0.4 \ 15 | --tsa \ 16 | --coef 1 \ 17 | --out_dir $1 \ 18 | --num_labels $2 -------------------------------------------------------------------------------- /scripts/realistic-evaluation-setup/cifar10/base.sh: -------------------------------------------------------------------------------- 1 | python3 train_val_test.py \ 2 | --optimizer adam \ 3 | --lr_decay step \ 4 | --weight_decay 0 \ 5 | --dataset cifar10 \ 6 | --zca \ 7 | --wa t.t.t \ 8 | $* -------------------------------------------------------------------------------- /scripts/realistic-evaluation-setup/cifar10/fixmatch.sh: -------------------------------------------------------------------------------- 1 | python3 train_val_test.py \ 2 | --optimizer adam \ 3 | --lr_decay step \ 4 | --weight_decay 0 \ 5 | --dataset cifar10 \ 6 | --zca \ 7 | --wa t.t.t \ 8 | --lr 3e-2 \ 9 | --coef 1 \ 10 | --alg pl \ 11 | --strong_aug \ 12 | --warmup_iter 0 \ 13 | --threshold 0.95 \ 14 | --out_dir $1 \ 15 | --num_labels $2 -------------------------------------------------------------------------------- /scripts/realistic-evaluation-setup/cifar10/ict.sh: -------------------------------------------------------------------------------- 1 | python3 train_val_test.py \ 2 | --optimizer adam \ 3 | --lr_decay step \ 4 | --weight_decay 0 \ 5 | --dataset cifar10 \ 6 | --zca \ 7 | --wa t.t.t \ 8 | --warmup_iter 200000 \ 9 | --lr 4e-4 \ 10 | --coef 100 \ 11 | --alg ict \ 12 | --alpha 0.1 \ 13 | -consis ms \ 14 | --ema_teacher \ 15 | --ema_teacher_warmup \ 16 | --out_dir $1 \ 17 | --num_labels $2 -------------------------------------------------------------------------------- /scripts/realistic-evaluation-setup/cifar10/mean_teacher.sh: -------------------------------------------------------------------------------- 1 | python3 train_val_test.py \ 2 | --optimizer adam \ 3 | --lr_decay step \ 4 | --weight_decay 0 \ 5 | --dataset cifar10 \ 6 | --zca \ 7 | --wa t.t.t \ 8 | --warmup_iter 200000 \ 9 | --lr 4e-4 \ 10 | --coef 8 \ 11 | -consis ms \ 12 | --ema_teacher \ 13 | --ema_teacher_factor 0.95 \ 14 | --ema_teacher_warmup \ 15 | --out_dir $1 \ 16 | --num_labels $2 -------------------------------------------------------------------------------- /scripts/realistic-evaluation-setup/cifar10/pi_model.sh: -------------------------------------------------------------------------------- 1 | python3 train_val_test.py \ 2 | --optimizer adam \ 3 | --lr_decay step \ 4 | --weight_decay 0 \ 5 | --dataset cifar10 \ 6 | --zca \ 7 | --wa t.t.t \ 8 | --warmup_iter 200000 \ 9 | --lr 3e-4 \ 10 | --coef 20 \ 11 | -consis ms \ 12 | --out_dir $1 \ 13 | --num_labels $2 -------------------------------------------------------------------------------- /scripts/realistic-evaluation-setup/cifar10/pseudo_label.sh: -------------------------------------------------------------------------------- 1 | python3 train_val_test.py \ 2 | --optimizer adam \ 3 | --lr_decay step \ 4 | --weight_decay 0 \ 5 | --dataset cifar10 \ 6 | --zca \ 7 | --wa t.t.t \ 8 | --warmup_iter 200000 \ 9 | --alg pl \ 10 | --lr 3e-3 \ 11 | --coef 1 \ 12 | -consis ms \ 13 | --threshold 0.95 \ 14 | --out_dir $1 \ 15 | --num_labels $2 -------------------------------------------------------------------------------- /scripts/realistic-evaluation-setup/cifar10/supervised.sh: -------------------------------------------------------------------------------- 1 | python3 train_val_test.py \ 2 | --optimizer adam \ 3 | --lr_decay step \ 4 | --weight_decay 0 \ 5 | --dataset cifar10 \ 6 | --zca \ 7 | --wa t.t.t \ 8 | --lr 3e-3 \ 9 | --coef 0 \ 10 | --out_dir $1 \ 11 | -------------------------------------------------------------------------------- /scripts/realistic-evaluation-setup/cifar10/uda.sh: -------------------------------------------------------------------------------- 1 | python3 train_val_test.py \ 2 | --optimizer adam \ 3 | --lr_decay step \ 4 | --weight_decay 0 \ 5 | --dataset cifar10 \ 6 | --zca \ 7 | --wa t.t.t \ 8 | --lr 3e-2 \ 9 | --coef 1 \ 10 | --strong_aug \ 11 | --threshold 0.8 \ 12 | --temp_softmax 0.4 \ 13 | --warmup 0 \ 14 | --weight_average \ 15 | --tsa \ 16 | --out_dir $1 \ 17 | --num_labels $2 -------------------------------------------------------------------------------- /scripts/realistic-evaluation-setup/cifar10/vat.sh: -------------------------------------------------------------------------------- 1 | python3 train_val_test.py \ 2 | --optimizer adam \ 3 | --lr_decay step \ 4 | --weight_decay 0 \ 5 | --dataset cifar10 \ 6 | --zca \ 7 | --wa t.t.t \ 8 | --warmup_iter 200000 \ 9 | --lr 3e-3 \ 10 | --coef 0.3 \ 11 | --alg vat \ 12 | -em 0.06 \ 13 | --eps 6 \ 14 | --out_dir $1 \ 15 | --num_labels $2 -------------------------------------------------------------------------------- /scripts/realistic-evaluation-setup/svhn/base.sh: -------------------------------------------------------------------------------- 1 | python3 train_val_test.py \ 2 | --optimizer adam \ 3 | --lr_decay step \ 4 | --weight_decay 0 \ 5 | --dataset svhn \ 6 | --wa f.t.f \ 7 | $* -------------------------------------------------------------------------------- /scripts/realistic-evaluation-setup/svhn/fixmatch.sh: -------------------------------------------------------------------------------- 1 | python3 train_val_test.py \ 2 | --optimizer adam \ 3 | --lr_decay step \ 4 | --weight_decay 0 \ 5 | --dataset cifar10 \ 6 | --wa f.t.f \ 7 | --lr 3e-2 \ 8 | --coef 1 \ 9 | --alg pl \ 10 | --strong_aug \ 11 | --warmup_iter 0 \ 12 | --threshold 0.95 \ 13 | --out_dir $1 \ 14 | --num_labels $2 -------------------------------------------------------------------------------- /scripts/realistic-evaluation-setup/svhn/ict.sh: -------------------------------------------------------------------------------- 1 | python3 train_val_test.py \ 2 | --optimizer adam \ 3 | --lr_decay step \ 4 | --weight_decay 0 \ 5 | --dataset svhn \ 6 | --wa f.t.f \ 7 | --warmup_iter 200000 \ 8 | --lr 4e-4 \ 9 | --coef 100 \ 10 | --alg ict \ 11 | --alpha 0.1 \ 12 | -consis ms \ 13 | --ema_teacher \ 14 | --ema_teacher_warmup \ 15 | --out_dir $1 \ 16 | --num_labels $2 -------------------------------------------------------------------------------- /scripts/realistic-evaluation-setup/svhn/mean_teacher.sh: -------------------------------------------------------------------------------- 1 | python3 train_val_test.py \ 2 | --optimizer adam \ 3 | --lr_decay step \ 4 | --weight_decay 0 \ 5 | --dataset svhn \ 6 | --wa f.t.f \ 7 | --warmup_iter 200000 \ 8 | --lr 4e-4 \ 9 | --coef 8 \ 10 | -consis ms \ 11 | --ema_teacher \ 12 | --ema_teacher_factor 0.95 \ 13 | --ema_teacher_warmup \ 14 | --out_dir $1 \ 15 | --num_labels $2 -------------------------------------------------------------------------------- /scripts/realistic-evaluation-setup/svhn/pi_model.sh: -------------------------------------------------------------------------------- 1 | python3 train_val_test.py \ 2 | --optimizer adam \ 3 | --lr_decay step \ 4 | --weight_decay 0 \ 5 | --dataset svhn \ 6 | --wa f.t.f \ 7 | --warmup_iter 200000 \ 8 | --lr 3e-4 \ 9 | --coef 20 \ 10 | -consis ms \ 11 | --out_dir $1 \ 12 | --num_labels $2 -------------------------------------------------------------------------------- /scripts/realistic-evaluation-setup/svhn/pseudo_label.sh: -------------------------------------------------------------------------------- 1 | python3 train_val_test.py \ 2 | --optimizer adam \ 3 | --lr_decay step \ 4 | --weight_decay 0 \ 5 | --dataset svhn \ 6 | --wa f.t.f \ 7 | --warmup_iter 200000 \ 8 | --alg pl \ 9 | --lr 3e-3 \ 10 | --coef 1 \ 11 | -consis ms \ 12 | --threshold 0.95 \ 13 | --out_dir $1 \ 14 | --num_labels $2 -------------------------------------------------------------------------------- /scripts/realistic-evaluation-setup/svhn/supervised.sh: -------------------------------------------------------------------------------- 1 | python3 train_val_test.py \ 2 | --optimizer adam \ 3 | --lr_decay step \ 4 | --weight_decay 0 \ 5 | --dataset svhn \ 6 | --wa f.t.f \ 7 | --lr 3e-3 \ 8 | --coef 0 \ 9 | --out_dir $1 \ 10 | --num_labels $2 11 | -------------------------------------------------------------------------------- /scripts/realistic-evaluation-setup/svhn/uda.sh: -------------------------------------------------------------------------------- 1 | python3 train_val_test.py \ 2 | --optimizer adam \ 3 | --lr_decay step \ 4 | --weight_decay 0 \ 5 | --dataset svhn \ 6 | --wa f.t.f \ 7 | --lr 3e-2 \ 8 | --coef 1 \ 9 | --strong_aug \ 10 | --threshold 0.8 \ 11 | --temp_softmax 0.4 \ 12 | --warmup 0 \ 13 | --weight_average \ 14 | --tsa \ 15 | --out_dir $1 \ 16 | --num_labels $2 -------------------------------------------------------------------------------- /scripts/realistic-evaluation-setup/svhn/vat.sh: -------------------------------------------------------------------------------- 1 | python3 train_val_test.py \ 2 | --optimizer adam \ 3 | --lr_decay step \ 4 | --weight_decay 0 \ 5 | --dataset svhn \ 6 | --whiten \ 7 | --warmup_iter 200000 \ 8 | --lr 3e-3 \ 9 | --coef 0.3 \ 10 | --alg vat \ 11 | -em 0.06 \ 12 | --eps 1 \ 13 | --out_dir $1 \ 14 | --num_labels $2 -------------------------------------------------------------------------------- /ssl_lib/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/perrying/pytorch-consistency-regularization/6624c4e0bb1813b5952445ce34f9d4e52484ce38/ssl_lib/__init__.py -------------------------------------------------------------------------------- /ssl_lib/algs/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/perrying/pytorch-consistency-regularization/6624c4e0bb1813b5952445ce34f9d4e52484ce38/ssl_lib/algs/__init__.py -------------------------------------------------------------------------------- /ssl_lib/algs/builder.py: -------------------------------------------------------------------------------- 1 | from .ict import ICT 2 | from .consistency import ConsistencyRegularization 3 | from .pseudo_label import PseudoLabel 4 | from .vat import VAT 5 | 6 | 7 | def gen_ssl_alg(name, cfg): 8 | if name == "ict": # mixed target <-> mixed input 9 | return ICT( 10 | cfg.consistency, 11 | cfg.threshold, 12 | cfg.sharpen, 13 | cfg.temp_softmax, 14 | cfg.alpha 15 | ) 16 | elif name == "cr": # base augment <-> another augment 17 | return ConsistencyRegularization( 18 | cfg.consistency, 19 | cfg.threshold, 20 | cfg.sharpen, 21 | cfg.temp_softmax 22 | ) 23 | elif name == "pl": # hard label <-> strong augment 24 | return PseudoLabel( 25 | cfg.consistency, 26 | cfg.threshold, 27 | cfg.sharpen, 28 | cfg.temp_softmax 29 | ) 30 | elif name == "vat": # base augment <-> adversarial 31 | from ..consistency import builder 32 | return VAT( 33 | cfg.consistency, 34 | cfg.threshold, 35 | cfg.sharpen, 36 | cfg.temp_softmax, 37 | builder.gen_consistency(cfg.consistency, cfg), 38 | cfg.eps, 39 | cfg.xi, 40 | cfg.vat_iter 41 | ) 42 | else: 43 | raise NotImplementedError 44 | -------------------------------------------------------------------------------- /ssl_lib/algs/consistency.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .utils import sharpening, tempereture_softmax 3 | 4 | class ConsistencyRegularization: 5 | """ 6 | Basis Consistency Regularization 7 | 8 | Parameters 9 | -------- 10 | consistency: str 11 | consistency objective name 12 | threshold: float 13 | threshold to make mask 14 | sharpen: float 15 | sharpening temperature for target value 16 | temp_softmax: float 17 | temperature for temperature softmax 18 | """ 19 | def __init__( 20 | self, 21 | consistency, 22 | threshold: float = None, 23 | sharpen: float = None, 24 | temp_softmax: float = None 25 | ): 26 | self.consistency = consistency 27 | self.threshold = threshold 28 | self.sharpen = sharpen 29 | self.tau = temp_softmax 30 | 31 | def __call__( 32 | self, 33 | stu_preds, 34 | tea_logits, 35 | *args, 36 | **kwargs 37 | ): 38 | mask = self.gen_mask(tea_logits) 39 | targets = self.adjust_target(tea_logits) 40 | return stu_preds, targets, mask 41 | 42 | def adjust_target(self, targets): 43 | if self.sharpen is not None: 44 | targets = targets.softmax(1) 45 | targets = sharpening(targets, self.sharpen) 46 | elif self.tau is not None: 47 | targets = tempereture_softmax(targets, self.tau) 48 | else: 49 | targets = targets.softmax(1) 50 | return targets 51 | 52 | def gen_mask(self, targets): 53 | targets = targets.softmax(1) 54 | if self.threshold is None or self.threshold == 0: 55 | return torch.ones_like(targets.max(1)[0]) 56 | return (targets.max(1)[0] >= self.threshold).float() 57 | 58 | def __repr__(self): 59 | return f"Consistency(threshold={self.threshold}, sharpen={self.sharpen}, tau={self.tau})" 60 | -------------------------------------------------------------------------------- /ssl_lib/algs/ict.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .consistency import ConsistencyRegularization 3 | from .utils import mixup 4 | 5 | class ICT(ConsistencyRegularization): 6 | """ 7 | Interpolation Consistency Training https://arxiv.org/abs/1903.03825 8 | 9 | Parameters 10 | -------- 11 | consistency: str 12 | consistency objective name 13 | threshold: float 14 | threshold to make mask 15 | sharpen: float 16 | sharpening temperature for target value 17 | temp_softmax: float 18 | temperature for temperature softmax 19 | alpha: float 20 | beta distribution parameter 21 | """ 22 | def __init__( 23 | self, 24 | consistency, 25 | threshold: float = 1., 26 | sharpen: float = None, 27 | temp_softmax: float = None, 28 | alpha: float = 0.1 29 | ): 30 | super().__init__( 31 | consistency, 32 | threshold, 33 | sharpen, 34 | temp_softmax 35 | ) 36 | self.alpha = alpha 37 | 38 | def __call__( 39 | self, 40 | tea_logits, 41 | w_data, 42 | stu_forward, 43 | *args, 44 | **kwargs 45 | ): 46 | mask = self.gen_mask(tea_logits) 47 | targets = self.adjust_target(tea_logits) 48 | mixed_x, mixed_targets = mixup(w_data, targets, self.alpha) 49 | y = stu_forward(mixed_x) 50 | return y, mixed_targets, mask 51 | 52 | def __repr__(self): 53 | return f"ICT(threshold={self.threshold}, sharpen={self.sharpen}, tau={self.tau}, alpha={self.alpha})" 54 | 55 | -------------------------------------------------------------------------------- /ssl_lib/algs/pseudo_label.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from .consistency import ConsistencyRegularization 4 | from ..consistency.cross_entropy import CrossEntropy 5 | from .utils import make_pseudo_label, sharpening 6 | 7 | class PseudoLabel(ConsistencyRegularization): 8 | """ 9 | PseudoLabel 10 | 11 | Parameters 12 | -------- 13 | consistency: str 14 | consistency objective name 15 | threshold: float 16 | threshold to make mask 17 | sharpen: float 18 | sharpening temperature for target value 19 | temp_softmax: float 20 | temperature for temperature softmax 21 | """ 22 | def __init__( 23 | self, 24 | consistency, 25 | threshold = 0.95, 26 | sharpen: float = None, 27 | temp_softmax: float = None 28 | ): 29 | super().__init__( 30 | consistency, 31 | threshold, 32 | sharpen, 33 | temp_softmax 34 | ) 35 | 36 | def __call__(self, stu_preds, tea_logits, *args, **kwargs): 37 | hard_label, mask = make_pseudo_label(tea_logits, self.threshold) 38 | return stu_preds, hard_label, mask 39 | 40 | def __repr__(self): 41 | return f"PseudoLabel(threshold={self.threshold}, sharpen={self.sharpen}, tau={self.tau})" 42 | 43 | -------------------------------------------------------------------------------- /ssl_lib/algs/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | def make_pseudo_label(logits, threshold): 6 | max_value, hard_label = logits.softmax(1).max(1) 7 | mask = (max_value >= threshold) 8 | return hard_label, mask 9 | 10 | 11 | def sharpening(soft_labels, temp): 12 | soft_labels = soft_labels.pow(temp) 13 | return soft_labels / soft_labels.abs().sum(1, keepdim=True) 14 | 15 | 16 | def tempereture_softmax(logits, tau): 17 | return (logits/tau).softmax(1) 18 | 19 | 20 | def mixup(x, y, alpha): 21 | device = x.device 22 | b = x.shape[0] 23 | permute = torch.randperm(b) 24 | perm_x = x[permute] 25 | perm_y = y[permute] 26 | factor = torch.distributions.beta.Beta(alpha, alpha).sample((b,1)).to(device) 27 | if x.ndim == 4: 28 | x_factor = factor[...,None,None] 29 | else: 30 | x_factor = factor 31 | mixed_x = x_factor * x + (1-x_factor) * perm_x 32 | mixed_y = factor * y + (1-factor) * perm_y 33 | return mixed_x, mixed_y 34 | 35 | 36 | def anneal_loss(logits, labels, loss, global_step, max_iter, num_classes, schedule): 37 | tsa_start = 1 / num_classes 38 | threshold = get_tsa_threshold( 39 | schedule, global_step, max_iter, 40 | tsa_start, end=1 41 | ) 42 | with torch.no_grad(): 43 | probs = logits.softmax(1) 44 | correct_label_probs = probs.gather(1, labels[:,None]).squeeze() 45 | mask = correct_label_probs < threshold 46 | return (loss * mask).mean() 47 | 48 | 49 | def get_tsa_threshold(schedule, global_step, max_iter, start, end): 50 | step_ratio = global_step / max_iter 51 | if schedule == "linear": 52 | coef = step_ratio 53 | elif schedule == "exp": 54 | scale = 5 55 | coef = ((step_ratio - 1) * scale).exp() 56 | elif schedule == "log": 57 | scale = 5 58 | coef = 1 - (-step_ratio * scale).exp() 59 | else: 60 | raise NotImplementedError 61 | return coef * (end - start) + start -------------------------------------------------------------------------------- /ssl_lib/algs/vat.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .consistency import ConsistencyRegularization 3 | 4 | class VAT(ConsistencyRegularization): 5 | """ 6 | Virtual Adversarial Training https://arxiv.org/abs/1704.03976 7 | 8 | Parameters 9 | -------- 10 | consistency: str 11 | consistency objective name 12 | threshold: float 13 | threshold to make mask 14 | sharpen: float 15 | sharpening temperature for target value 16 | temp_softmax: float 17 | temperature for temperature softmax 18 | objective: function 19 | objective function 20 | eps: float 21 | virtual adversarial noise norm 22 | xi: float 23 | perturbation for finite differential method 24 | n_iter: int 25 | number of iterations for power method 26 | """ 27 | def __init__( 28 | self, 29 | consistency, 30 | threshold: float = 1., 31 | sharpen: float = None, 32 | temp_softmax: float = None, 33 | objective = None, 34 | eps = 1.0, 35 | xi = 1e-6, 36 | n_iter = 1 37 | ): 38 | super().__init__( 39 | consistency, 40 | threshold, 41 | sharpen, 42 | temp_softmax 43 | ) 44 | self.eps = eps 45 | self.xi = xi 46 | self.n_iter = n_iter 47 | self.obj_func = objective 48 | 49 | def __call__( 50 | self, 51 | tea_logits, 52 | w_data, 53 | stu_forward, 54 | *args, 55 | **kwargs 56 | ): 57 | mask = self.gen_mask(tea_logits) 58 | targets = self.adjust_target(tea_logits) 59 | d = torch.randn_like(w_data) 60 | d = self.__normalize(d) 61 | for _ in range(self.n_iter): 62 | d.requires_grad = True 63 | x_hat = w_data + self.xi * d 64 | y = stu_forward(x_hat) 65 | loss = self.obj_func(y, targets) 66 | d = torch.autograd.grad(loss, d)[0] 67 | d = self.__normalize(d).detach() 68 | x_hat = w_data + self.eps * d 69 | y = stu_forward(x_hat) 70 | return y, targets, mask 71 | 72 | def __normalize(self, v): 73 | v = v / (1e-12 + self.__reduce_max(v.abs(), range(1, len(v.shape)))) # to avoid overflow by v.pow(2) 74 | v = v / (1e-6 + v.pow(2).sum(list(range(1, len(v.shape))), keepdim=True)).sqrt() 75 | return v 76 | 77 | def __reduce_max(self, v, idx_list): 78 | for i in idx_list: 79 | v = v.max(i, keepdim=True)[0] 80 | return v 81 | 82 | def __repr__(self): 83 | return f"VAT(threshold={self.threshold}, \ 84 | sharpen={self.sharpen}, \ 85 | tau={self.tau}, \ 86 | eps={self.eps}), \ 87 | xi={self.xi}" 88 | -------------------------------------------------------------------------------- /ssl_lib/augmentation/__init__.py: -------------------------------------------------------------------------------- 1 | from . import augmentation_pool -------------------------------------------------------------------------------- /ssl_lib/augmentation/augmentation_class.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision.transforms as tt 3 | 4 | from . import augmentation_pool as aug_pool 5 | from .rand_augment import RandAugment 6 | 7 | 8 | class ReduceChannelwithNormalize: 9 | """ Reduce alpha channel of RGBA """ 10 | def __init__(self, mean, scale, zca): 11 | self.mean = mean 12 | self.scale = scale 13 | self.zca = zca 14 | 15 | def __call__(self, tch_img): 16 | rgb = tch_img[:3] 17 | i1, i2 = torch.where(tch_img[3] == 0) 18 | if self.zca: 19 | rgb = aug_pool.GCN()(tch_img) 20 | rgb = aug_pool.ZCA(self.mean, self.scale)(rgb) 21 | else: 22 | rgb = tt.functional.normalize(rgb, self.mean, self.scale, True) 23 | rgb[:, i1, i2] = 0 24 | return rgb 25 | 26 | def __repr__(self): 27 | return f"ReduceChannelwithNormalize(mean={self.mean}, scale={self.scale})" 28 | 29 | 30 | class RGB2RGBA: 31 | def __call__(self, x): 32 | return x.convert("RGBA") 33 | 34 | def __repr__(self): 35 | return "RGB2RGBA()" 36 | 37 | 38 | class StrongAugmentation: 39 | """ 40 | Strong augmentation class 41 | including RandAugment and Cutout 42 | """ 43 | def __init__( 44 | self, 45 | img_size: int, 46 | mean: list, 47 | scale: list, 48 | flip: bool, 49 | crop: bool, 50 | alg: str = "fixmatch", 51 | zca: bool = False, 52 | cutout: bool = True, 53 | ): 54 | augmentations = [tt.ToPILImage()] 55 | 56 | if flip: 57 | augmentations += [tt.RandomHorizontalFlip(p=0.5)] 58 | if crop: 59 | augmentations += [tt.RandomCrop(img_size, int(img_size*0.125), padding_mode="reflect")] 60 | 61 | augmentations += [ 62 | RGB2RGBA(), 63 | RandAugment(alg=alg), 64 | tt.ToTensor(), 65 | ReduceChannelwithNormalize(mean, scale, zca) 66 | ] 67 | if cutout: 68 | augmentations += [aug_pool.TorchCutout(16)] 69 | 70 | self.augmentations = tt.Compose(augmentations) 71 | 72 | def __call__(self, img): 73 | return self.augmentations(img) 74 | 75 | def __repr__(self): 76 | return repr(self.augmentations) 77 | 78 | 79 | class WeakAugmentation: 80 | """ 81 | Weak augmentation class 82 | including horizontal flip, random crop, and gaussian noise 83 | """ 84 | def __init__( 85 | self, 86 | img_size: int, 87 | mean: list, 88 | scale: list, 89 | flip=True, 90 | crop=True, 91 | noise=True, 92 | zca=False 93 | ): 94 | augmentations = [tt.ToPILImage()] 95 | if flip: 96 | augmentations.append(tt.RandomHorizontalFlip()) 97 | if crop: 98 | augmentations.append(tt.RandomCrop(img_size, int(img_size*0.125), padding_mode="reflect")) 99 | augmentations += [tt.ToTensor()] 100 | if zca: 101 | augmentations += [aug_pool.GCN(), aug_pool.ZCA(mean, scale)] 102 | else: 103 | augmentations += [tt.Normalize(mean, scale, True)] 104 | if noise: 105 | augmentations.append(aug_pool.GaussianNoise()) 106 | self.augmentations = tt.Compose(augmentations) 107 | 108 | def __call__(self, img): 109 | return self.augmentations(img) 110 | 111 | def __repr__(self): 112 | return repr(self.augmentations) 113 | -------------------------------------------------------------------------------- /ssl_lib/augmentation/augmentation_pool.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | from PIL import ImageOps, ImageEnhance, ImageFilter, Image 7 | 8 | 9 | """ 10 | For PIL.Image 11 | """ 12 | 13 | def autocontrast(x, *args, **kwargs): 14 | return ImageOps.autocontrast(x.convert("RGB")).convert("RGBA") 15 | 16 | 17 | def brightness(x, level, magnitude=10, max_level=1.8, *args, **kwargs): 18 | level = (level / magnitude) * max_level + 0.1 19 | return ImageEnhance.Brightness(x).enhance(level) 20 | 21 | 22 | def color(x, level, magnitude=10, max_level=1.8, *args, **kwargs): 23 | level = (level / magnitude) * max_level + 0.1 24 | return ImageEnhance.Color(x).enhance(level) 25 | 26 | 27 | def contrast(x, level, magnitude=10, max_level=1.8, *args, **kwargs): 28 | level = (level / magnitude) * max_level + 0.1 29 | return ImageEnhance.Contrast(x).enhance(level) 30 | 31 | 32 | def equalize(x, *args, **kwargs): 33 | return ImageOps.equalize(x.convert("RGB")).convert("RGBA") 34 | 35 | 36 | def identity(x, *args, **kwargs): 37 | return x 38 | 39 | 40 | def invert(x, *args, **kwargs): 41 | return ImageOps.invert(x.convert("RGB")).convert("RGBA") 42 | 43 | 44 | def posterize(x, level, magnitude=10, max_level=4, *args, **kwargs): 45 | level = int((level / magnitude) * max_level) 46 | return ImageOps.posterize(x.convert("RGB"), 4 - level).convert("RGBA") 47 | 48 | 49 | def rotate(x, level, magnitude=10, max_level=30, *args, **kwargs): 50 | degree = int((level / magnitude) * max_level) 51 | if random.random() > 0.5: 52 | degree = -degree 53 | return x.rotate(degree) 54 | 55 | 56 | def sharpness(x, level, magnitude=10, max_level=1.8, *args, **kwargs): 57 | level = (level / magnitude) * max_level + 0.1 58 | return ImageEnhance.Sharpness(x).enhance(level) 59 | 60 | 61 | def shear_x(x, level, magnitude=10, max_level=0.3, *args, **kwargs): 62 | level = (level / magnitude) * max_level 63 | if random.random() > 0.5: 64 | level = -level 65 | return x.transform(x.size, Image.AFFINE, (1, level, 0, 0, 1, 0)) 66 | 67 | 68 | def shear_y(x, level, magnitude=10, max_level=0.3, *args, **kwargs): 69 | level = (level / magnitude) * max_level 70 | if random.random() > 0.5: 71 | level = -level 72 | return x.transform(x.size, Image.AFFINE, (1, 0, 0, level, 1, 0)) 73 | 74 | 75 | def solarize(x, level, magnitude=10, max_level=256, *args, **kwargs): 76 | level = int((level / magnitude) * max_level) 77 | return ImageOps.solarize(x.convert("RGB"), 256 - level).convert("RGBA") 78 | 79 | 80 | def translate_x(x, level, magnitude=10, max_level=10, *args, **kwargs): 81 | level = int((level / magnitude) * max_level) 82 | if random.random() > 0.5: 83 | level = -level 84 | return x.transform(x.size, Image.AFFINE, (1, 0, level, 0, 1, 0)) 85 | 86 | 87 | def translate_y(x, level, magnitude=10, max_level=10, *args, **kwargs): 88 | level = int((level / magnitude) * max_level) 89 | if random.random() > 0.5: 90 | level = -level 91 | return x.transform(x.size, Image.AFFINE, (1, 0, 0, 0, 1, level)) 92 | 93 | 94 | def cutout(x, level, magnitude=10, max_level=20, *args, **kwargs): 95 | size = int((level / magnitude) * max_level) 96 | if size <= 0: 97 | return x 98 | w, h = x.size 99 | upper_coord, lower_coord = _gen_cutout_coord(h, w, size) 100 | 101 | pixels = x.load() 102 | for i in range(upper_coord[0], lower_coord[0]): 103 | for j in range(upper_coord[1], lower_coord[1]): 104 | pixels[i, j] = (127, 127, 127, 0) 105 | return x 106 | 107 | 108 | def _gen_cutout_coord(height, width, size): 109 | height_loc = random.randint(0, height - 1) 110 | width_loc = random.randint(0, width - 1) 111 | 112 | upper_coord = (max(0, height_loc - size // 2), 113 | max(0, width_loc - size // 2)) 114 | lower_coord = (min(height, height_loc + size // 2), 115 | min(width, width_loc + size // 2)) 116 | 117 | return upper_coord, lower_coord 118 | 119 | """ 120 | For torch.Tensor 121 | """ 122 | 123 | class TorchCutout: 124 | def __init__(self, size=16): 125 | self.size = size 126 | 127 | def __call__(self, img): 128 | h, w = img.shape[-2:] 129 | upper_coord, lower_coord = _gen_cutout_coord(h, w, self.size) 130 | 131 | mask_height = lower_coord[0] - upper_coord[0] 132 | mask_width = lower_coord[1] - upper_coord[1] 133 | assert mask_height > 0 134 | assert mask_width > 0 135 | 136 | mask = torch.ones_like(img) 137 | zeros = torch.zeros((img.shape[0], mask_height, mask_width)) 138 | mask[:, upper_coord[0]:lower_coord[0], upper_coord[1]:lower_coord[1]] = zeros 139 | return img * mask 140 | 141 | def __repr__(self): 142 | return f"TorchCutout(size={self.size})" 143 | 144 | 145 | class GaussianNoise: 146 | def __init__(self, std=0.15): 147 | self.std = std 148 | 149 | def __call__(self, x): 150 | with torch.no_grad(): 151 | return x + torch.randn_like(x) * self.std 152 | 153 | def __repr__(self): 154 | return f"GaussianNoise(std={self.std})" 155 | 156 | 157 | class BatchRandomFlip: 158 | def __init__(self, flip_prob=0.5): 159 | self.p = flip_prob 160 | 161 | def __call__(self, x): 162 | with torch.no_grad(): 163 | return torch.stack([ 164 | torch.flip(img, (-1,)) 165 | if random.random() > self.p 166 | else img 167 | for img in x 168 | ], 0) 169 | 170 | def __repr__(self): 171 | return f"BatchRandomFlip(flip_prob={self.p})" 172 | 173 | 174 | class RandomFlip: 175 | def __init__(self, flip_prob=0.5): 176 | self.p = flip_prob 177 | 178 | def __call__(self, x): 179 | if random.random() > self.p: 180 | return torch.flip(x, (-1,)) 181 | return x 182 | 183 | def __repr__(self): 184 | return f"RandomFlip(flip_prob={self.p})" 185 | 186 | 187 | class BatchRandomCrop: 188 | def __init__(self, padding=4): 189 | self.pad = padding 190 | 191 | def __call__(self, x): 192 | with torch.no_grad(): 193 | b, _, h, w = x.shape 194 | x = F.pad(x, [self.pad for _ in range(4)], mode="reflect") 195 | left, top = torch.randint(0, 1+self.pad*2, (b,)), torch.randint(0, 1+self.pad*2, (b,)) 196 | return torch.stack([ 197 | img[..., t:t+h, l:l+w] 198 | for img, t, l in zip(x, left, top) 199 | ], 0) 200 | 201 | def __repr__(self): 202 | return f"BatchRandomCrop(padding={self.pad})" 203 | 204 | 205 | class RandomCrop: 206 | def __init__(self, padding=4): 207 | self.pad = padding 208 | 209 | def __call__(self, x): 210 | with torch.no_grad(): 211 | _, h, w = x.shape 212 | x = F.pad(x[None], [self.pad for _ in range(4)], mode="reflect") 213 | left, top = random.randint(0, self.pad*2), random.randint(0, self.pad*2) 214 | return x[0, :, top:top+h, left:left+w] 215 | 216 | def __repr__(self): 217 | return f"RandomCrop(padding={self.pad})" 218 | 219 | 220 | class ZCA: 221 | def __init__(self, mean, scale): 222 | self.mean = torch.from_numpy(mean).float() 223 | self.scale = torch.from_numpy(scale).float() 224 | 225 | def __call__(self, x): 226 | c, h, w = x.shape 227 | x = x.reshape(-1) 228 | x = (x - self.mean) @ self.scale 229 | return x.reshape(c, h, w) 230 | 231 | def __repr__(self): 232 | return f"ZCA()" 233 | 234 | 235 | class GCN: 236 | """global contrast normalization""" 237 | def __init__(self, multiplier=55, eps=1e-10): 238 | self.multiplier = multiplier 239 | self.eps = eps 240 | 241 | def __call__(self, x): 242 | x -= x.mean() 243 | norm = x.norm(2) 244 | norm[norm < self.eps] = 1 245 | return self.multiplier * x / norm 246 | 247 | def __repr__(self): 248 | return f"GCN(multiplier={self.multiplier}, eps={self.eps})" 249 | 250 | 251 | """ 252 | For numpy.array 253 | """ 254 | def numpy_batch_gcn(images, multiplier=55, eps=1e-10): 255 | # global contrast normalization 256 | images = images.astype(np.float) 257 | images -= images.mean(axis=(1,2,3), keepdims=True) 258 | per_image_norm = np.sqrt(np.square(images).sum((1,2,3), keepdims=True)) 259 | per_image_norm[per_image_norm < eps] = 1 260 | return multiplier * images / per_image_norm 261 | -------------------------------------------------------------------------------- /ssl_lib/augmentation/builder.py: -------------------------------------------------------------------------------- 1 | from .augmentation_class import WeakAugmentation, StrongAugmentation 2 | 3 | 4 | def gen_strong_augmentation(img_size, mean, std, flip=True, crop=True, alg="fixmatch", zca=False): 5 | return StrongAugmentation(img_size, mean, std, flip, crop, alg, zca) 6 | 7 | 8 | def gen_weak_augmentation(img_size, mean, std, flip=True, crop=True, noise=True, zca=False): 9 | return WeakAugmentation(img_size, mean, std, flip, crop, noise, zca) 10 | -------------------------------------------------------------------------------- /ssl_lib/augmentation/rand_augment.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from . import augmentation_pool 3 | from . import utils 4 | 5 | 6 | class RandAugment: 7 | """ 8 | RandAugment class 9 | 10 | Parameters 11 | -------- 12 | nops: int 13 | number of operations per image 14 | magnitude: int 15 | maximmum magnitude 16 | alg: str 17 | algorithm name 18 | """ 19 | def __init__(self, nops=2, magnitude=10, prob=0.5, alg="fixmatch"): 20 | self.nops = nops 21 | self.magnitude = magnitude 22 | self.prob = prob 23 | if alg == "fixmatch": 24 | self.ops_list = utils.FIXMATCH_RANDAUGMENT_OPS_LIST 25 | elif alg == "uda": 26 | self.ops_list = utils.UDA_RANDAUGMENT_OPS_LIST 27 | else: 28 | raise NotImplementedError 29 | 30 | self.ops_max_level = utils.RANDAUGMENT_MAX_LEVELS 31 | 32 | def __call__(self, img): 33 | """ 34 | Apply augmentations to PIL image 35 | """ 36 | ops = np.random.choice(self.ops_list, self.nops) 37 | for name in ops: 38 | if np.random.rand() <= self.prob: 39 | level = np.random.randint(1, self.magnitude) 40 | max_level = self.ops_max_level[name] 41 | transform = getattr(augmentation_pool, name) 42 | img = transform(img, level, magnitude=self.magnitude, max_level=max_level) 43 | return img 44 | 45 | def __repr__(self): 46 | return f"RandAugment(nops={self.nops}, magnitude={self.magnitude})" 47 | -------------------------------------------------------------------------------- /ssl_lib/augmentation/utils.py: -------------------------------------------------------------------------------- 1 | FIXMATCH_RANDAUGMENT_OPS_LIST = [ 2 | 'identity', 3 | 'autocontrast', 4 | 'brightness', 5 | 'color', 6 | 'contrast', 7 | 'equalize', 8 | 'posterize', 9 | 'rotate', 10 | 'sharpness', 11 | 'shear_x', 12 | 'shear_y', 13 | 'solarize', 14 | 'translate_x', 15 | 'translate_y' 16 | ] 17 | 18 | 19 | UDA_RANDAUGMENT_OPS_LIST = [ 20 | 'invert', 21 | 'autocontrast', 22 | 'brightness', 23 | 'color', 24 | 'contrast', 25 | 'cutout', 26 | 'equalize', 27 | 'posterize', 28 | 'rotate', 29 | 'sharpness', 30 | 'shear_x', 31 | 'shear_y', 32 | 'solarize', 33 | 'translate_x', 34 | 'translate_y' 35 | ] 36 | 37 | 38 | RANDAUGMENT_MAX_LEVELS = { 39 | 'autocontrast': None, 40 | 'brightness': 1.8, 41 | 'color': 1.8, 42 | 'contrast': 1.8, 43 | 'cutout': 20, 44 | 'equalize': None, 45 | 'identity': None, 46 | 'invert': None, 47 | 'posterize': 4, 48 | 'rotate': 30, 49 | 'sharpness': 1.8, 50 | 'shear_x': 0.3, 51 | 'shear_y':0.3, 52 | 'solarize': 256, 53 | 'translate_x': 10, 54 | 'translate_y': 10 55 | } 56 | -------------------------------------------------------------------------------- /ssl_lib/consistency/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/perrying/pytorch-consistency-regularization/6624c4e0bb1813b5952445ce34f9d4e52484ce38/ssl_lib/consistency/__init__.py -------------------------------------------------------------------------------- /ssl_lib/consistency/builder.py: -------------------------------------------------------------------------------- 1 | from .cross_entropy import CrossEntropy 2 | from .mean_squared import MeanSquared 3 | 4 | 5 | def gen_consistency(type, cfg): 6 | if type == "ce": 7 | return CrossEntropy() 8 | elif type == "ms": 9 | return MeanSquared() 10 | else: 11 | return None -------------------------------------------------------------------------------- /ssl_lib/consistency/cross_entropy.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | def cross_entropy(y, target, mask=None): 5 | if target.ndim == 1: # for hard label 6 | loss = F.cross_entropy(y, target, reduction="none") 7 | else: 8 | loss = -(target * F.log_softmax(y, 1)).sum(1) 9 | if mask is not None: 10 | loss = mask * loss 11 | return loss.mean() 12 | 13 | class CrossEntropy(nn.Module): 14 | def forward(self, y, target, mask=None, *args, **kwargs): 15 | return cross_entropy(y, target.detach(), mask) 16 | -------------------------------------------------------------------------------- /ssl_lib/consistency/mean_squared.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | def mean_squared(y, target, mask=None): 5 | y = y.softmax(1) 6 | loss = F.mse_loss(y, target, reduction="none").mean(1) 7 | if mask is not None: 8 | loss = mask * loss 9 | return loss.mean() 10 | 11 | class MeanSquared(nn.Module): 12 | def forward(self, y, target, mask=None, *args, **kwargs): 13 | return mean_squared(y, target.detach(), mask) -------------------------------------------------------------------------------- /ssl_lib/datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/perrying/pytorch-consistency-regularization/6624c4e0bb1813b5952445ce34f9d4e52484ce38/ssl_lib/datasets/__init__.py -------------------------------------------------------------------------------- /ssl_lib/datasets/builder.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from torch.utils.data import DataLoader 4 | from torchvision import transforms 5 | 6 | from . import utils 7 | from . import dataset_class 8 | from ..augmentation.builder import gen_strong_augmentation, gen_weak_augmentation 9 | from ..augmentation.augmentation_pool import numpy_batch_gcn, ZCA, GCN 10 | 11 | 12 | def __val_labeled_unlabeled_split(cfg, train_data, test_data, num_classes, ul_data=None): 13 | num_validation = int(np.round(len(train_data["images"]) * cfg.val_ratio)) 14 | 15 | np.random.seed(cfg.seed) 16 | 17 | permutation = np.random.permutation(len(train_data["images"])) 18 | train_data["images"] = train_data["images"][permutation] 19 | train_data["labels"] = train_data["labels"][permutation] 20 | 21 | val_data, train_data = utils.dataset_split(train_data, num_validation, num_classes, cfg.random_split) 22 | l_train_data, ul_train_data = utils.dataset_split(train_data, cfg.num_labels, num_classes) 23 | 24 | if ul_data is not None: 25 | ul_train_data["images"] = np.concatenate([ul_train_data["images"], ul_data["images"]], 0) 26 | ul_train_data["labels"] = np.concatenate([ul_train_data["labels"], ul_data["labels"]], 0) 27 | 28 | return val_data, l_train_data, ul_train_data 29 | 30 | 31 | def __labeled_unlabeled_split(cfg, train_data, test_data, num_classes, ul_data=None): 32 | np.random.seed(cfg.seed) 33 | 34 | permutation = np.random.permutation(len(train_data["images"])) 35 | train_data["images"] = train_data["images"][permutation] 36 | train_data["labels"] = train_data["labels"][permutation] 37 | 38 | l_train_data, ul_train_data = utils.dataset_split(train_data, cfg.num_labels, num_classes) 39 | 40 | if ul_data is not None: 41 | ul_train_data["images"] = np.concatenate([ul_train_data["images"], ul_data["images"]], 0) 42 | ul_train_data["labels"] = np.concatenate([ul_train_data["labels"], ul_data["labels"]], 0) 43 | 44 | return l_train_data, ul_train_data 45 | 46 | 47 | def gen_dataloader(root, dataset, validation_split, cfg, logger=None): 48 | """ 49 | generate train, val, and test dataloaders 50 | 51 | Parameters 52 | -------- 53 | root: str 54 | root directory 55 | dataset: str 56 | dataset name, ['cifar10', 'cifar100', 'svhn', 'stl10'] 57 | validation_split: bool 58 | if True, return validation loader. 59 | validation data is made from training data 60 | cfg: argparse.Namespace or something 61 | logger: logging.Logger 62 | """ 63 | ul_train_data = None 64 | if dataset == "svhn": 65 | train_data, test_data = utils.get_svhn(root) 66 | num_classes = 10 67 | img_size = 32 68 | elif dataset == "stl10": 69 | train_data, ul_train_data, test_data = utils.get_stl10(root) 70 | num_classes = 10 71 | img_size = 96 72 | elif dataset == "cifar10": 73 | train_data, test_data = utils.get_cifar10(root) 74 | num_classes = 10 75 | img_size = 32 76 | elif dataset == "cifar100": 77 | train_data, test_data = utils.get_cifar100(root) 78 | num_classes = 100 79 | img_size = 32 80 | else: 81 | raise NotImplementedError 82 | 83 | if validation_split: 84 | val_data, l_train_data, ul_train_data = __val_labeled_unlabeled_split( 85 | cfg, train_data, test_data, num_classes, ul_train_data) 86 | else: 87 | l_train_data, ul_train_data = __labeled_unlabeled_split( 88 | cfg, train_data, test_data, num_classes, ul_train_data) 89 | val_data = None 90 | 91 | ul_train_data["images"] = np.concatenate([ul_train_data["images"], l_train_data["images"]], 0) 92 | ul_train_data["labels"] = np.concatenate([ul_train_data["labels"], l_train_data["labels"]], 0) 93 | 94 | if logger is not None: 95 | logger.info("number of :\n \ 96 | training data: %d\n \ 97 | labeled data: %d\n \ 98 | unlabeled data: %d\n \ 99 | validation data: %d\n \ 100 | test data: %d", 101 | len(train_data["images"]), 102 | len(l_train_data["images"]), 103 | len(ul_train_data["images"]), 104 | 0 if val_data is None else len(val_data["images"]), 105 | len(test_data["images"])) 106 | 107 | labeled_train_data = dataset_class.LabeledDataset(l_train_data) 108 | unlabeled_train_data = dataset_class.UnlabeledDataset(ul_train_data) 109 | 110 | train_data = np.concatenate([ 111 | labeled_train_data.dataset["images"], 112 | unlabeled_train_data.dataset["images"] 113 | ], 0) 114 | 115 | if cfg.whiten: 116 | mean = train_data.mean((0, 1, 2)) / 255. 117 | scale = train_data.std((0, 1, 2)) / 255. 118 | elif cfg.zca: 119 | mean, scale = utils.get_zca_normalization_param(numpy_batch_gcn(train_data)) 120 | else: 121 | # from [0, 1] to [-1, 1] 122 | mean = [0.5, 0.5, 0.5] 123 | scale = [0.5, 0.5, 0.5] 124 | 125 | # set augmentation 126 | # RA: RandAugment, WA: Weak Augmentation 127 | randauglist = "fixmatch" if cfg.alg == "pl" else "uda" 128 | 129 | flags = [True if b == "t" else False for b in cfg.wa.split(".")] 130 | 131 | if cfg.labeled_aug == "RA": 132 | labeled_augmentation = gen_strong_augmentation( 133 | img_size, mean, scale, flags[0], flags[1], randauglist, cfg.zca) 134 | elif cfg.labeled_aug == "WA": 135 | labeled_augmentation = gen_weak_augmentation(img_size, mean, scale, *flags, cfg.zca) 136 | else: 137 | raise NotImplementedError 138 | 139 | labeled_train_data.transform = labeled_augmentation 140 | 141 | if cfg.unlabeled_aug == "RA": 142 | unlabeled_augmentation = gen_strong_augmentation( 143 | img_size, mean, scale, flags[0], flags[1], randauglist, cfg.zca) 144 | elif cfg.unlabeled_aug == "WA": 145 | unlabeled_augmentation = gen_weak_augmentation(img_size, mean, scale, *flags, cfg.zca) 146 | else: 147 | raise NotImplementedError 148 | 149 | if logger is not None: 150 | logger.info("labeled augmentation") 151 | logger.info(labeled_augmentation) 152 | logger.info("unlabeled augmentation") 153 | logger.info(unlabeled_augmentation) 154 | 155 | unlabeled_train_data.weak_augmentation = unlabeled_augmentation 156 | 157 | if cfg.strong_aug: 158 | strong_augmentation = gen_strong_augmentation( 159 | img_size, mean, scale, flags[0], flags[1], randauglist, cfg.zca) 160 | unlabeled_train_data.strong_augmentation = strong_augmentation 161 | if logger is not None: 162 | logger.info(strong_augmentation) 163 | 164 | if cfg.zca: 165 | test_transform = transforms.Compose([GCN(), ZCA(mean, scale)]) 166 | else: 167 | test_transform = transforms.Compose([transforms.Normalize(mean, scale, True)]) 168 | 169 | test_data = dataset_class.LabeledDataset(test_data, test_transform) 170 | 171 | l_train_loader = DataLoader( 172 | labeled_train_data, 173 | cfg.l_batch_size, 174 | sampler=utils.InfiniteSampler(len(labeled_train_data), cfg.iteration * cfg.l_batch_size), 175 | num_workers=cfg.num_workers 176 | ) 177 | ul_train_loader = DataLoader( 178 | unlabeled_train_data, 179 | cfg.ul_batch_size, 180 | sampler=utils.InfiniteSampler(len(unlabeled_train_data), cfg.iteration * cfg.ul_batch_size), 181 | num_workers=cfg.num_workers 182 | ) 183 | test_loader = DataLoader( 184 | test_data, 185 | 1, 186 | shuffle=False, 187 | drop_last=False, 188 | num_workers=cfg.num_workers 189 | ) 190 | 191 | if validation_split: 192 | validation_data = dataset_class.LabeledDataset(val_data, test_transform) 193 | val_loader = DataLoader( 194 | validation_data, 195 | 1, 196 | shuffle=False, 197 | drop_last=False, 198 | num_workers=cfg.num_workers 199 | ) 200 | 201 | return ( 202 | l_train_loader, 203 | ul_train_loader, 204 | val_loader, 205 | test_loader, 206 | num_classes, 207 | img_size 208 | ) 209 | 210 | else: 211 | return ( 212 | l_train_loader, 213 | ul_train_loader, 214 | test_loader, 215 | num_classes, 216 | img_size 217 | ) 218 | -------------------------------------------------------------------------------- /ssl_lib/datasets/dataset_class.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class LabeledDataset: 5 | """ 6 | For labeled dataset 7 | """ 8 | def __init__(self, dataset, transform=None): 9 | self.dataset = dataset 10 | self.transform = transform 11 | 12 | def __getitem__(self, idx): 13 | image = torch.from_numpy(self.dataset["images"][idx]).float() 14 | image = image.permute(2, 0, 1).contiguous() / 255. 15 | label = int(self.dataset["labels"][idx]) 16 | if self.transform is not None: 17 | image = self.transform(image) 18 | return image, label 19 | 20 | def __len__(self): 21 | return len(self.dataset["images"]) 22 | 23 | 24 | class UnlabeledDataset: 25 | """ 26 | For unlabeled dataset 27 | """ 28 | def __init__(self, dataset, weak_augmentation=None, strong_augmentation=None): 29 | self.dataset = dataset 30 | self.weak_augmentation = weak_augmentation 31 | self.strong_augmentation = strong_augmentation 32 | 33 | def __getitem__(self, idx): 34 | image = torch.from_numpy(self.dataset["images"][idx]).float() 35 | image = image.permute(2, 0, 1).contiguous() / 255. 36 | label = int(self.dataset["labels"][idx]) 37 | w_aug_image = self.weak_augmentation(image) 38 | if self.strong_augmentation is not None: 39 | s_aug_image = self.strong_augmentation(image) 40 | else: 41 | s_aug_image = self.weak_augmentation(image) 42 | return w_aug_image, s_aug_image, label 43 | 44 | def __len__(self): 45 | return len(self.dataset["images"]) 46 | 47 | -------------------------------------------------------------------------------- /ssl_lib/datasets/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | from torch.utils.data import Sampler 5 | from torchvision.datasets import SVHN, CIFAR10, CIFAR100, STL10 6 | 7 | 8 | class InfiniteSampler(Sampler): 9 | """ sampling without replacement """ 10 | def __init__(self, num_data, num_sample): 11 | epochs = num_sample // num_data + 1 12 | self.indices = torch.cat([torch.randperm(num_data) for _ in range(epochs)]).tolist()[:num_sample] 13 | 14 | def __iter__(self): 15 | return iter(self.indices) 16 | 17 | def __len__(self): 18 | return len(self.indices) 19 | 20 | 21 | def get_svhn(root): 22 | train_data = SVHN(root, "train", download=True) 23 | test_data = SVHN(root, "test", download=True) 24 | train_data = {"images": np.transpose(train_data.data.astype(np.float32), (0, 2, 3, 1)), 25 | "labels": train_data.labels.astype(np.int32)} 26 | test_data = {"images": np.transpose(test_data.data.astype(np.float32), (0, 2, 3, 1)), 27 | "labels": test_data.labels.astype(np.int32)} 28 | return train_data, test_data 29 | 30 | 31 | def get_cifar10(root): 32 | train_data = CIFAR10(root, download=True) 33 | test_data = CIFAR10(root, False) 34 | train_data = {"images": train_data.data.astype(np.float32), 35 | "labels": np.asarray(train_data.targets).astype(np.int32)} 36 | test_data = {"images": test_data.data.astype(np.float32), 37 | "labels": np.asarray(test_data.targets).astype(np.int32)} 38 | return train_data, test_data 39 | 40 | 41 | def get_cifar100(root): 42 | train_data = CIFAR100(root, download=True) 43 | test_data = CIFAR100(root, False) 44 | train_data = {"images": train_data.data.astype(np.float32), 45 | "labels": np.asarray(train_data.targets).astype(np.int32)} 46 | test_data = {"images": test_data.data.astype(np.float32), 47 | "labels": np.asarray(test_data.targets).astype(np.int32)} 48 | return train_data, test_data 49 | 50 | 51 | def get_stl10(root): 52 | train_data = STL10(root, split="train", download=True) 53 | ul_train_data = STL10(root, split="unlabeled") 54 | test_data = STL10(root, split="test") 55 | train_data = {"images": np.transpose(train_data.data.astype(np.float32), (0, 2, 3, 1)), 56 | "labels": train_data.labels} 57 | ul_train_data = {"images": np.transpose(ul_train_data.data.astype(np.float32), (0, 2, 3, 1)), 58 | "labels": ul_train_data.labels} 59 | test_data = {"images": np.transpose(test_data.data.astype(np.float32), (0, 2, 3, 1)), 60 | "labels": test_data.labels} 61 | return train_data, ul_train_data, test_data 62 | 63 | 64 | def dataset_split(data, num_data, num_classes, random=False): 65 | """split dataset into two datasets 66 | 67 | Parameters 68 | ----- 69 | data: dict with keys ["images", "labels"] 70 | each value is numpy.array 71 | num_data: int 72 | number of dataset1 73 | num_classes: int 74 | number of classes 75 | random: bool 76 | if True, dataset1 is randomly sampled from data. 77 | if False, dataset1 is uniformly sampled from data, 78 | which means that the dataset1 contains the same number of samples per class. 79 | 80 | Returns 81 | ----- 82 | dataset1, dataset2: the same dict as data. 83 | number of data in dataset1 is num_data. 84 | number of data in dataset1 is len(data) - num_data. 85 | """ 86 | dataset1 = {"images": [], "labels": []} 87 | dataset2 = {"images": [], "labels": []} 88 | images = data["images"] 89 | labels = data["labels"] 90 | 91 | # random sampling 92 | if random: 93 | dataset1["images"] = images[:num_data] 94 | dataset1["labels"] = labels[:num_data] 95 | dataset2["images"] = images[num_data:] 96 | dataset2["labels"] = labels[num_data:] 97 | 98 | else: 99 | data_per_class = num_data // num_classes 100 | for c in range(num_classes): 101 | c_idx = (labels == c) 102 | c_imgs = images[c_idx] 103 | c_lbls = labels[c_idx] 104 | dataset1["images"].append(c_imgs[:data_per_class]) 105 | dataset1["labels"].append(c_lbls[:data_per_class]) 106 | dataset2["images"].append(c_imgs[data_per_class:]) 107 | dataset2["labels"].append(c_lbls[data_per_class:]) 108 | for k in ("images", "labels"): 109 | dataset1[k] = np.concatenate(dataset1[k]) 110 | dataset2[k] = np.concatenate(dataset2[k]) 111 | 112 | return dataset1, dataset2 113 | 114 | 115 | def get_zca_normalization_param(images, scale=0.1, eps=1e-10): 116 | n_data, height, width, channels = images.shape 117 | images = images.transpose(0, 3, 1, 2) 118 | images = images.reshape(n_data, channels * height * width) 119 | image_cov = np.cov(images, rowvar=False) 120 | U, S, _ = np.linalg.svd(image_cov + scale * np.eye(image_cov.shape[0])) 121 | zca_decomp = np.dot(U, np.dot(np.diag(1/np.sqrt(S + eps)), U.T)) 122 | mean = images.mean(axis=0) 123 | return mean, zca_decomp 124 | -------------------------------------------------------------------------------- /ssl_lib/misc/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/perrying/pytorch-consistency-regularization/6624c4e0bb1813b5952445ce34f9d4e52484ce38/ssl_lib/misc/__init__.py -------------------------------------------------------------------------------- /ssl_lib/misc/meter.py: -------------------------------------------------------------------------------- 1 | class Meter: 2 | def __init__(self, ema_coef=0.9): 3 | self.ema_coef = ema_coef 4 | self.params = {} 5 | 6 | def add(self, params:dict, ignores:list = []): 7 | for k, v in params.items(): 8 | if k in ignores: 9 | continue 10 | if not k in self.params.keys(): 11 | self.params[k] = v 12 | else: 13 | self.params[k] -= (1 - self.ema_coef) * (self.params[k] - v) 14 | 15 | def state(self, header="", footer=""): 16 | state = header 17 | for k, v in self.params.items(): 18 | state += f" {k} {v:.6g} |" 19 | return state + " " + footer 20 | 21 | def reset(self): 22 | self.params = {} 23 | -------------------------------------------------------------------------------- /ssl_lib/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/perrying/pytorch-consistency-regularization/6624c4e0bb1813b5952445ce34f9d4e52484ce38/ssl_lib/models/__init__.py -------------------------------------------------------------------------------- /ssl_lib/models/builder.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from .resnet import ResNet 4 | from .shakenet import ShakeNet 5 | from .cnn13 import CNN13 6 | 7 | 8 | def gen_model(name, num_classes, img_size): 9 | scale = int(np.ceil(np.log2(img_size))) 10 | if name == "wrn": 11 | return ResNet(num_classes, 32, scale, 4) 12 | elif name == "shake": 13 | return ShakeNet(num_classes, 32, scale, 4) 14 | elif name == "cnn13": 15 | return CNN13(num_classes, 32) 16 | else: 17 | raise NotImplementedError -------------------------------------------------------------------------------- /ssl_lib/models/cnn13.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from .utils import leaky_relu, conv3x3, BatchNorm2d, BaseModel 3 | 4 | 5 | class CNN13(BaseModel): 6 | """ 7 | 13-layer CNN 8 | 9 | Parameters 10 | -------- 11 | num_classes: int 12 | number of classes 13 | filters: int 14 | number of filters 15 | """ 16 | def __init__(self, num_classes, filters, *args, **kwargs): 17 | super().__init__() 18 | self.feature_extractor = nn.Sequential( 19 | conv3x3(3, filters, bias=True), 20 | leaky_relu(), 21 | BatchNorm2d(filters), 22 | conv3x3(filters, filters, bias=True), 23 | leaky_relu(), 24 | BatchNorm2d(filters), 25 | conv3x3(filters, filters, bias=True), 26 | leaky_relu(), 27 | BatchNorm2d(filters), 28 | nn.MaxPool2d(2, 2), 29 | conv3x3(filters, 2*filters, bias=True), 30 | leaky_relu(), 31 | BatchNorm2d(2*filters), 32 | conv3x3(2*filters, 2*filters, bias=True), 33 | leaky_relu(), 34 | BatchNorm2d(2*filters), 35 | conv3x3(2*filters, 2*filters, bias=True), 36 | leaky_relu(), 37 | BatchNorm2d(2*filters), 38 | nn.MaxPool2d(2, 2), 39 | nn.Conv2d(2*filters, 4*filters, 3), 40 | leaky_relu(), 41 | BatchNorm2d(4*filters), 42 | nn.Conv2d(4*filters, 2*filters, 1, bias=False), 43 | leaky_relu(), 44 | BatchNorm2d(2*filters), 45 | nn.Conv2d(2*filters, filters, 1, bias=False), 46 | leaky_relu(), 47 | BatchNorm2d(filters) 48 | ) 49 | 50 | self.classifier = nn.Linear(filters, num_classes) 51 | 52 | for m in self.modules(): 53 | if isinstance(m, (nn.Conv2d, nn.Linear)): 54 | nn.init.xavier_normal_(m.weight) 55 | if m.bias is not None: 56 | nn.init.constant_(m.bias, 0) 57 | -------------------------------------------------------------------------------- /ssl_lib/models/resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from .utils import leaky_relu, conv3x3, BatchNorm2d, param_init, BaseModel 5 | 6 | 7 | class _Residual(nn.Module): 8 | def __init__(self, input_channels, output_channels, stride=1, activate_before_residual=False): 9 | super().__init__() 10 | layer = [] 11 | if activate_before_residual: 12 | self.pre_act = nn.Sequential( 13 | BatchNorm2d(input_channels), 14 | leaky_relu() 15 | ) 16 | else: 17 | self.pre_act = nn.Identity() 18 | layer.append(BatchNorm2d(input_channels)) 19 | layer.append(leaky_relu()) 20 | layer.append(conv3x3(input_channels, output_channels, stride)) 21 | layer.append(BatchNorm2d(output_channels)) 22 | layer.append(leaky_relu()) 23 | layer.append(conv3x3(output_channels, output_channels)) 24 | 25 | if stride >= 2 or input_channels != output_channels: 26 | self.identity = nn.Conv2d(input_channels, output_channels, 1, stride, bias=False) 27 | else: 28 | self.identity = nn.Identity() 29 | 30 | self.layer = nn.Sequential(*layer) 31 | 32 | def forward(self, x): 33 | x = self.pre_act(x) 34 | return self.identity(x) + self.layer(x) 35 | 36 | 37 | class ResNet(BaseModel): 38 | """ 39 | ResNet 40 | 41 | Parameters 42 | -------- 43 | num_classes: int 44 | number of classes 45 | filters: int 46 | number of filters 47 | scales: int 48 | number of scales 49 | repeat: int 50 | number of residual blocks per scale 51 | dropout: float 52 | dropout ratio (None indicates dropout is unused) 53 | """ 54 | def __init__(self, num_classes, filters, scales, repeat, dropout=None, *args, **kwargs): 55 | super().__init__() 56 | feature_extractor = [conv3x3(3, 16)] 57 | channels = 16 58 | for scale in range(scales): 59 | feature_extractor.append( 60 | _Residual(channels, filters<> 1, 1, bias=False)) 60 | branch2 = nn.Sequential(nn.ReLU(), nn.Conv2d(channels//2, filters >> 1, 1, bias=False)) 61 | bn = BatchNorm2d(filters) 62 | self.skip = _SkipBranch(branch1, branch2, bn) 63 | elif channels != filters: 64 | self.skip = nn.Sequential( 65 | nn.Conv2d(channels, filters, 1, bias=False), 66 | BatchNorm2d(filters) 67 | ) 68 | 69 | def forward(self, x): 70 | return self.branch(x) + self.skip(x) 71 | 72 | 73 | class ShakeNet(BaseModel): 74 | """ 75 | Shake-Shake model 76 | 77 | Parameters 78 | -------- 79 | num_classes: int 80 | number of classes 81 | filters: int 82 | number of filters 83 | scales: int 84 | number of scales 85 | repeat: int 86 | number of residual blocks per scale 87 | dropout: float 88 | dropout ratio (None indicates dropout is unused) 89 | """ 90 | def __init__(self, num_classes, filters, scales, repeat, dropout=None, *args, **kwargs): 91 | super().__init__() 92 | 93 | feature_extractor = [conv3x3(3, 16)] 94 | channels = 16 95 | 96 | for scale, i in itertools.product(range(scales), range(repeat)): 97 | if i == 0: 98 | feature_extractor.append(_Residual(channels, filters << scale, stride = 2 if scale else 1)) 99 | else: 100 | feature_extractor.append(_Residual(channels, filters << scale)) 101 | 102 | channels = filters << scale 103 | 104 | self.feature_extractor = nn.Sequential(*feature_extractor) 105 | 106 | classifier = [] 107 | if dropout is not None: 108 | classifier.append(nn.Dropout(dropout)) 109 | classifier.append(nn.Linear(channels, num_classes)) 110 | 111 | param_init(self.modules()) 112 | -------------------------------------------------------------------------------- /ssl_lib/models/utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class BaseModel(nn.Module): 7 | def forward(self, x): 8 | f = self.feature_extractor(x) 9 | f = f.mean((2, 3)) 10 | return self.classifier(f) 11 | 12 | def logits_with_feature(self, x): 13 | f = self.feature_extractor(x) 14 | c = self.classifier(f.mean((2, 3))) 15 | return c, f 16 | 17 | def update_batch_stats(self, flag): 18 | for m in self.modules(): 19 | if isinstance(m, nn.BatchNorm2d): 20 | m.update_batch_stats = flag 21 | 22 | 23 | def conv3x3(i_c, o_c, stride=1, bias=False): 24 | return nn.Conv2d(i_c, o_c, 3, stride, 1, bias=bias) 25 | 26 | 27 | class BatchNorm2d(nn.BatchNorm2d): 28 | def __init__(self, channels, momentum=1e-3, eps=1e-3): 29 | super().__init__(channels) 30 | self.update_batch_stats = True 31 | 32 | def forward(self, x): 33 | if self.update_batch_stats or not self.training: 34 | return super().forward(x) 35 | else: 36 | return nn.functional.batch_norm( 37 | x, None, None, self.weight, self.bias, True, self.momentum, self.eps 38 | ) 39 | 40 | 41 | def leaky_relu(): 42 | return nn.LeakyReLU(0.1) 43 | 44 | 45 | """ 46 | For exponential moving average 47 | """ 48 | 49 | def apply_weight_decay(modules, decay_rate): 50 | """apply weight decay to weight parameters in nn.Conv2d and nn.Linear""" 51 | for m in modules: 52 | if isinstance(m, (nn.Conv2d, nn.Linear)): 53 | m.weight.data -= decay_rate * m.weight.data 54 | 55 | 56 | def param_init(modules): 57 | for m in modules: 58 | if isinstance(m, nn.Conv2d): 59 | f, _, k, _ = m.weight.shape 60 | nn.init.normal_(m.weight, 0, 1./math.sqrt(0.5 * k * k * f)) 61 | elif isinstance(m, nn.Linear): 62 | nn.init.xavier_normal_(m.weight) 63 | nn.init.constant_(m.bias, 0) 64 | 65 | 66 | def __ema(p1, p2, factor): 67 | return factor * p1 + (1 - factor) * p2 68 | 69 | 70 | def __param_update(ema_model, raw_model, factor): 71 | """ema for trainable parameters""" 72 | for ema_p, raw_p in zip(ema_model.parameters(), raw_model.parameters()): 73 | ema_p.data = __ema(ema_p.data, raw_p.data, factor) 74 | 75 | 76 | def __buffer_update(ema_model, raw_model, factor): 77 | """ema for buffer parameters (e.g., running_mean and running_var in nn.BatchNorm2d)""" 78 | for ema_p, raw_p in zip(ema_model.buffers(), raw_model.buffers()): 79 | ema_p.data = __ema(ema_p.data, raw_p.data, factor) 80 | # """copy buffer parameters (e.g., running_mean and running_var in nn.BatchNorm2d)""" 81 | # for ema_p, raw_p in zip(ema_model.buffers(), raw_model.buffers()): 82 | # ema_p.copy_(raw_p) 83 | 84 | 85 | def ema_update(ema_model, raw_model, ema_factor, weight_decay_factor=None, global_step=None): 86 | if global_step is not None: 87 | ema_factor = min(1 - 1 / (global_step+1), ema_factor) 88 | __param_update(ema_model, raw_model, ema_factor) 89 | __buffer_update(ema_model, raw_model, ema_factor) 90 | if weight_decay_factor is not None: 91 | apply_weight_decay(ema_model.modules(), weight_decay_factor) 92 | 93 | -------------------------------------------------------------------------------- /ssl_lib/param_scheduler/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/perrying/pytorch-consistency-regularization/6624c4e0bb1813b5952445ce34f9d4e52484ce38/ssl_lib/param_scheduler/__init__.py -------------------------------------------------------------------------------- /ssl_lib/param_scheduler/scheduler.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import warnings 3 | import math 4 | import torch.optim as optim 5 | 6 | 7 | def exp_warmup(base_value, max_warmup_iter, cur_step): 8 | """exponential warmup proposed in mean teacher 9 | 10 | calcurate 11 | base_value * exp(-5(1 - t)^2), t = cur_step / max_warmup_iter 12 | 13 | Parameters 14 | ----- 15 | base_value: float 16 | maximum value 17 | max_warmup_iter: int 18 | maximum warmup iteration 19 | cur_step: int 20 | current iteration 21 | """ 22 | if max_warmup_iter <= cur_step: 23 | return base_value 24 | return base_value * math.exp(-5 * (1 - cur_step/max_warmup_iter)**2) 25 | 26 | 27 | def linear_warmup(base_value, max_warmup_iter, cur_step): 28 | """linear warmup 29 | 30 | calcurate 31 | base_value * (cur_step / max_warmup_iter) 32 | 33 | Parameters 34 | ----- 35 | base_value: float 36 | maximum value 37 | max_warmup_iter: int 38 | maximum warmup iteration 39 | cur_step: int 40 | current iteration 41 | """ 42 | if max_warmup_iter <= cur_step: 43 | return base_value 44 | return base_value * cur_step / max_warmup_iter 45 | 46 | 47 | def cosine_decay(base_lr, max_iteration, cur_step): 48 | """cosine learning rate decay 49 | 50 | cosine learning rate decay with parameters proposed FixMatch 51 | base_lr * cos( (7\pi cur_step) / (16 max_warmup_iter) ) 52 | 53 | Parameters 54 | ----- 55 | base_lr: float 56 | maximum learning rate 57 | max_warmup_iter: int 58 | maximum warmup iteration 59 | cur_step: int 60 | current iteration 61 | """ 62 | return base_lr * (math.cos( (7*math.pi*cur_step) / (16*max_iteration) )) 63 | 64 | 65 | def CosineAnnealingLR(optimizer, max_iteration): 66 | """ 67 | generate cosine annealing learning rate scheduler as LambdaLR 68 | """ 69 | return optim.lr_scheduler.LambdaLR(optimizer, lr_lambda = lambda cur_step : math.cos((7*math.pi*cur_step) / (16*max_iteration))) 70 | -------------------------------------------------------------------------------- /train_test.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import numpy, random, time, json 3 | import torch 4 | import torch.nn.functional as F 5 | import torch.optim as optim 6 | 7 | from ssl_lib.algs.builder import gen_ssl_alg 8 | from ssl_lib.algs import utils as alg_utils 9 | from ssl_lib.models import utils as model_utils 10 | from ssl_lib.consistency.builder import gen_consistency 11 | from ssl_lib.models.builder import gen_model 12 | from ssl_lib.datasets.builder import gen_dataloader 13 | from ssl_lib.param_scheduler import scheduler 14 | from ssl_lib.misc.meter import Meter 15 | 16 | 17 | def evaluation(raw_model, eval_model, loader, device): 18 | raw_model.eval() 19 | eval_model.eval() 20 | sum_raw_acc = sum_acc = sum_loss = 0 21 | with torch.no_grad(): 22 | for (data, labels) in loader: 23 | data, labels = data.to(device), labels.to(device) 24 | preds = eval_model(data) 25 | raw_preds = raw_model(data) 26 | loss = F.cross_entropy(preds, labels) 27 | sum_loss += loss.item() 28 | acc = (preds.max(1)[1] == labels).float().mean() 29 | raw_acc = (raw_preds.max(1)[1] == labels).float().mean() 30 | sum_acc += acc.item() 31 | sum_raw_acc += raw_acc.item() 32 | mean_raw_acc = sum_raw_acc / len(loader) 33 | mean_acc = sum_acc / len(loader) 34 | mean_loss = sum_loss / len(loader) 35 | raw_model.train() 36 | eval_model.train() 37 | return mean_raw_acc, mean_acc, mean_loss 38 | 39 | 40 | def param_update( 41 | cfg, 42 | cur_iteration, 43 | model, 44 | teacher_model, 45 | optimizer, 46 | ssl_alg, 47 | consistency, 48 | labeled_data, 49 | ul_weak_data, 50 | ul_strong_data, 51 | labels, 52 | average_model 53 | ): 54 | start_time = time.time() 55 | 56 | all_data = torch.cat([labeled_data, ul_weak_data, ul_strong_data], 0) 57 | forward_func = model.forward 58 | stu_logits = forward_func(all_data) 59 | labeled_preds = stu_logits[:labeled_data.shape[0]] 60 | 61 | stu_unlabeled_weak_logits, stu_unlabeled_strong_logits = torch.chunk(stu_logits[labels.shape[0]:], 2, dim=0) 62 | 63 | if cfg.tsa: 64 | none_reduced_loss = F.cross_entropy(labeled_preds, labels, reduction="none") 65 | L_supervised = alg_utils.anneal_loss( 66 | labeled_preds, labels, none_reduced_loss, cur_iteration+1, 67 | cfg.iteration, labeled_preds.shape[1], cfg.tsa_schedule) 68 | else: 69 | L_supervised = F.cross_entropy(labeled_preds, labels) 70 | 71 | if cfg.coef > 0: 72 | # get target values 73 | if teacher_model is not None: # get target values from teacher model 74 | t_forward_func = teacher_model.forward 75 | tea_logits = t_forward_func(all_data) 76 | tea_unlabeled_weak_logits, _ = torch.chunk(tea_logits[labels.shape[0]:], 2, dim=0) 77 | else: 78 | t_forward_func = forward_func 79 | tea_unlabeled_weak_logits = stu_unlabeled_weak_logits 80 | 81 | # calc consistency loss 82 | model.update_batch_stats(False) 83 | y, targets, mask = ssl_alg( 84 | stu_preds = stu_unlabeled_strong_logits, 85 | tea_logits = tea_unlabeled_weak_logits.detach(), 86 | data = ul_strong_data, 87 | stu_forward = forward_func, 88 | tea_forward = t_forward_func 89 | ) 90 | model.update_batch_stats(True) 91 | L_consistency = consistency(y, targets, mask, weak_prediction=tea_unlabeled_weak_logits.softmax(1)) 92 | 93 | else: 94 | L_consistency = torch.zeros_like(L_supervised) 95 | mask = None 96 | 97 | # calc total loss 98 | coef = scheduler.exp_warmup(cfg.coef, cfg.warmup_iter, cur_iteration+1) 99 | loss = L_supervised + coef * L_consistency 100 | if cfg.entropy_minimization > 0: 101 | loss -= cfg.entropy_minimization * \ 102 | (stu_unlabeled_weak_logits.softmax(1) * F.log_softmax(stu_unlabeled_weak_logits, 1)).sum(1).mean() 103 | 104 | # update parameters 105 | cur_lr = optimizer.param_groups[0]["lr"] 106 | optimizer.zero_grad() 107 | loss.backward() 108 | if cfg.weight_decay > 0: 109 | decay_coeff = cfg.weight_decay * cur_lr 110 | model_utils.apply_weight_decay(model.modules(), decay_coeff) 111 | optimizer.step() 112 | 113 | # update teacher parameters by exponential moving average 114 | if cfg.ema_teacher: 115 | model_utils.ema_update( 116 | teacher_model, model, cfg.ema_teacher_factor, 117 | cfg.weight_decay * cur_lr if cfg.ema_apply_wd else None, 118 | cur_iteration if cfg.ema_teacher_warmup else None) 119 | # update evaluation model's parameters by exponential moving average 120 | if cfg.weight_average: 121 | model_utils.ema_update( 122 | average_model, model, cfg.wa_ema_factor, 123 | cfg.weight_decay * cur_lr if cfg.wa_apply_wd else None) 124 | 125 | # calculate accuracy for labeled data 126 | acc = (labeled_preds.max(1)[1] == labels).float().mean() 127 | 128 | return { 129 | "acc": acc, 130 | "loss": loss.item(), 131 | "sup loss": L_supervised.item(), 132 | "ssl loss": L_consistency.item(), 133 | "mask": mask.float().mean().item() if mask is not None else 1, 134 | "coef": coef, 135 | "sec/iter": (time.time() - start_time) 136 | } 137 | 138 | 139 | def main(cfg, logger): 140 | # set seed 141 | torch.manual_seed(cfg.seed) 142 | numpy.random.seed(cfg.seed) 143 | random.seed(cfg.seed) 144 | # select device 145 | if torch.cuda.is_available(): 146 | device = "cuda" 147 | torch.backends.cudnn.benchmark = True 148 | else: 149 | logger.info("CUDA is NOT available") 150 | device = "cpu" 151 | # build data loader 152 | logger.info("load dataset") 153 | lt_loader, ult_loader, test_loader, num_classes, img_size = gen_dataloader(cfg.root, cfg.dataset, False, cfg, logger) 154 | 155 | # set consistency type 156 | consistency = gen_consistency(cfg.consistency, cfg) 157 | # set ssl algorithm 158 | ssl_alg = gen_ssl_alg(cfg.alg, cfg) 159 | # build student model 160 | model = gen_model(cfg.model, num_classes, img_size).to(device) 161 | # build teacher model 162 | if cfg.ema_teacher: 163 | teacher_model = gen_model(cfg.model, num_classes, img_size).to(device) 164 | teacher_model.load_state_dict(model.state_dict()) 165 | else: 166 | teacher_model = None 167 | # for evaluation 168 | if cfg.weight_average: 169 | average_model = gen_model(cfg.model, num_classes, img_size).to(device) 170 | average_model.load_state_dict(model.state_dict()) 171 | else: 172 | average_model = None 173 | 174 | model.train() 175 | 176 | logger.info(model) 177 | 178 | # build optimizer 179 | if cfg.optimizer == "sgd": 180 | optimizer = optim.SGD( 181 | model.parameters(), cfg.lr, cfg.momentum, weight_decay=0, nesterov=True 182 | ) 183 | elif cfg.optimizer == "adam": 184 | optimizer = optim.Adam( 185 | model.parameters(), cfg.lr, (cfg.momentum, 0.999), weight_decay=0 186 | ) 187 | else: 188 | raise NotImplementedError 189 | # set lr scheduler 190 | if cfg.lr_decay == "cos": 191 | lr_scheduler = scheduler.CosineAnnealingLR(optimizer, cfg.iteration) 192 | elif cfg.lr_decay == "step": 193 | # TODO: fixed milstones 194 | lr_scheduler = optim.lr_scheduler.MultiStepLR(optimizer, [400000, ], cfg.lr_decay_rate) 195 | else: 196 | raise NotImplementedError 197 | 198 | # init meter 199 | metric_meter = Meter() 200 | test_acc_list = [] 201 | raw_acc_list = [] 202 | 203 | logger.info("training") 204 | for i, (l_data, ul_data) in enumerate(zip(lt_loader, ult_loader)): 205 | l_aug, labels = l_data 206 | ul_w_aug, ul_s_aug, _ = ul_data 207 | 208 | params = param_update( 209 | cfg, i, model, teacher_model, optimizer, ssl_alg, 210 | consistency, l_aug.to(device), ul_w_aug.to(device), 211 | ul_s_aug.to(device), labels.to(device), 212 | average_model 213 | ) 214 | 215 | # moving average for reporting losses and accuracy 216 | metric_meter.add(params, ignores=["coef"]) 217 | 218 | # display losses every cfg.disp iterations 219 | if ((i+1) % cfg.disp) == 0: 220 | state = metric_meter.state( 221 | header = f'[{i+1}/{cfg.iteration}]', 222 | footer = f'ssl coef {params["coef"]:.4g} | lr {optimizer.param_groups[0]["lr"]:.4g}' 223 | ) 224 | logger.info(state) 225 | 226 | lr_scheduler.step() 227 | if ((i + 1) % cfg.checkpoint) == 0 or (i+1) == cfg.iteration: 228 | with torch.no_grad(): 229 | if cfg.weight_average: 230 | eval_model = average_model 231 | else: 232 | eval_model = model 233 | logger.info("test") 234 | mean_raw_acc, mean_test_acc, mean_test_loss = evaluation(model, eval_model, test_loader, device) 235 | logger.info("test loss %f | test acc. %f | raw acc. %f", mean_test_loss, mean_test_acc, mean_raw_acc) 236 | test_acc_list.append(mean_test_acc) 237 | raw_acc_list.append(mean_raw_acc) 238 | 239 | torch.save(model.state_dict(), os.path.join(cfg.out_dir, "model_checkpoint.pth")) 240 | torch.save(optimizer.state_dict(), os.path.join(cfg.out_dir, "optimizer_checkpoint.pth")) 241 | 242 | numpy.save(os.path.join(cfg.out_dir, "results"), test_acc_list) 243 | numpy.save(os.path.join(cfg.out_dir, "raw_results"), raw_acc_list) 244 | accuracies = {} 245 | for i in [1, 10, 20, 50]: 246 | logger.info("mean test acc. over last %d checkpoints: %f", i, numpy.median(test_acc_list[-i:])) 247 | logger.info("mean test acc. for raw model over last %d checkpoints: %f", i, numpy.median(raw_acc_list[-i:])) 248 | accuracies[f"last{i}"] = numpy.median(test_acc_list[-i:]) 249 | 250 | with open(os.path.join(cfg.out_dir, "results.json"), "w") as f: 251 | json.dump(accuracies, f, sort_keys=True) 252 | 253 | 254 | if __name__ == "__main__": 255 | import os, sys 256 | from parser import get_args 257 | args = get_args() 258 | os.makedirs(args.out_dir, exist_ok=True) 259 | 260 | # setup logger 261 | plain_formatter = logging.Formatter( 262 | "[%(asctime)s] %(name)s %(levelname)s: %(message)s", datefmt="%m/%d %H:%M:%S" 263 | ) 264 | logger = logging.getLogger(__name__) 265 | logger.setLevel(logging.DEBUG) 266 | s_handler = logging.StreamHandler(stream=sys.stdout) 267 | s_handler.setFormatter(plain_formatter) 268 | s_handler.setLevel(logging.DEBUG) 269 | logger.addHandler(s_handler) 270 | f_handler = logging.FileHandler(os.path.join(args.out_dir, "console.log")) 271 | f_handler.setFormatter(plain_formatter) 272 | f_handler.setLevel(logging.DEBUG) 273 | logger.addHandler(f_handler) 274 | logger.propagate = False 275 | 276 | logger.info(args) 277 | 278 | main(args, logger) -------------------------------------------------------------------------------- /train_val_test.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import numpy, random, time 3 | import torch 4 | import torch.nn.functional as F 5 | import torch.optim as optim 6 | 7 | from ssl_lib.algs.builder import gen_ssl_alg 8 | from ssl_lib.algs import utils as alg_utils 9 | from ssl_lib.models import utils as model_utils 10 | from ssl_lib.consistency.builder import gen_consistency 11 | from ssl_lib.models.builder import gen_model 12 | from ssl_lib.datasets.builder import gen_dataloader 13 | from ssl_lib.param_scheduler import scheduler 14 | from ssl_lib.misc.meter import Meter 15 | 16 | 17 | def evaluation(raw_model, eval_model, loader, device): 18 | raw_model.eval() 19 | eval_model.eval() 20 | sum_raw_acc = sum_acc = sum_loss = 0 21 | with torch.no_grad(): 22 | for (data, labels) in loader: 23 | data, labels = data.to(device), labels.to(device) 24 | preds = eval_model(data) 25 | raw_preds = raw_model(data) 26 | loss = F.cross_entropy(preds, labels) 27 | sum_loss += loss.item() 28 | acc = (preds.max(1)[1] == labels).float().mean() 29 | raw_acc = (raw_preds.max(1)[1] == labels).float().mean() 30 | sum_acc += acc.item() 31 | sum_raw_acc += raw_acc.item() 32 | mean_raw_acc = sum_raw_acc / len(loader) 33 | mean_acc = sum_acc / len(loader) 34 | mean_loss = sum_loss / len(loader) 35 | raw_model.train() 36 | eval_model.train() 37 | return mean_raw_acc, mean_acc, mean_loss 38 | 39 | 40 | def param_update( 41 | cfg, 42 | cur_iteration, 43 | model, 44 | teacher_model, 45 | optimizer, 46 | ssl_alg, 47 | consistency, 48 | labeled_data, 49 | ul_weak_data, 50 | ul_strong_data, 51 | labels, 52 | average_model 53 | ): 54 | start_time = time.time() 55 | 56 | all_data = torch.cat([labeled_data, ul_weak_data, ul_strong_data], 0) 57 | forward_func = model.forward 58 | stu_logits = forward_func(all_data) 59 | labeled_preds = stu_logits[:labeled_data.shape[0]] 60 | 61 | stu_unlabeled_weak_logits, stu_unlabeled_strong_logits = torch.chunk(stu_logits[labels.shape[0]:], 2, dim=0) 62 | 63 | if cfg.tsa: 64 | none_reduced_loss = F.cross_entropy(labeled_preds, labels, reduction="none") 65 | L_supervised = alg_utils.anneal_loss( 66 | labeled_preds, labels, none_reduced_loss, cur_iteration+1, 67 | cfg.iteration, labeled_preds.shape[1], cfg.tsa_schedule) 68 | else: 69 | L_supervised = F.cross_entropy(labeled_preds, labels) 70 | 71 | if cfg.coef > 0: 72 | # get target values 73 | if teacher_model is not None: # get target values from teacher model 74 | t_forward_func = teacher_model.forward 75 | tea_logits = t_forward_func(all_data) 76 | tea_unlabeled_weak_logits, _ = torch.chunk(tea_logits[labels.shape[0]:], 2, dim=0) 77 | else: 78 | t_forward_func = forward_func 79 | tea_unlabeled_weak_logits = stu_unlabeled_weak_logits 80 | 81 | # calc consistency loss 82 | model.update_batch_stats(False) 83 | y, targets, mask = ssl_alg( 84 | stu_preds = stu_unlabeled_strong_logits, 85 | tea_logits = tea_unlabeled_weak_logits.detach(), 86 | w_data = ul_weak_data, 87 | s_data = ul_strong_data, 88 | stu_forward = forward_func, 89 | tea_forward = t_forward_func 90 | ) 91 | model.update_batch_stats(True) 92 | L_consistency = consistency(y, targets, mask, weak_prediction=tea_unlabeled_weak_logits.softmax(1)) 93 | 94 | else: 95 | L_consistency = torch.zeros_like(L_supervised) 96 | mask = None 97 | 98 | # calc total loss 99 | coef = scheduler.exp_warmup(cfg.coef, cfg.warmup_iter, cur_iteration+1) 100 | loss = L_supervised + coef * L_consistency 101 | if cfg.entropy_minimization > 0: 102 | loss -= cfg.entropy_minimization * \ 103 | (stu_unlabeled_weak_logits.softmax(1) * F.log_softmax(stu_unlabeled_weak_logits, 1)).sum(1).mean() 104 | 105 | # update parameters 106 | cur_lr = optimizer.param_groups[0]["lr"] 107 | optimizer.zero_grad() 108 | loss.backward() 109 | if cfg.weight_decay > 0: 110 | decay_coeff = cfg.weight_decay * cur_lr 111 | model_utils.apply_weight_decay(model.modules(), decay_coeff) 112 | optimizer.step() 113 | 114 | # update teacher parameters by exponential moving average 115 | if cfg.ema_teacher: 116 | model_utils.ema_update( 117 | teacher_model, model, cfg.ema_teacher_factor, 118 | cfg.weight_decay * cur_lr if cfg.ema_apply_wd else None, 119 | cur_iteration if cfg.ema_teacher_warmup else None) 120 | # update evaluation model's parameters by exponential moving average 121 | if cfg.weight_average: 122 | model_utils.ema_update( 123 | average_model, model, cfg.wa_ema_factor, 124 | cfg.weight_decay * cur_lr if cfg.wa_apply_wd else None) 125 | 126 | # calculate accuracy for labeled data 127 | acc = (labeled_preds.max(1)[1] == labels).float().mean() 128 | 129 | return { 130 | "acc": acc, 131 | "loss": loss.item(), 132 | "sup loss": L_supervised.item(), 133 | "ssl loss": L_consistency.item(), 134 | "mask": mask.float().mean().item() if mask is not None else 1, 135 | "coef": coef, 136 | "sec/iter": (time.time() - start_time) 137 | } 138 | 139 | 140 | def main(cfg, logger): 141 | # set seed 142 | torch.manual_seed(cfg.seed) 143 | numpy.random.seed(cfg.seed) 144 | random.seed(cfg.seed) 145 | # select device 146 | if torch.cuda.is_available(): 147 | device = "cuda" 148 | torch.backends.cudnn.benchmark = True 149 | else: 150 | logger.info("CUDA is NOT available") 151 | device = "cpu" 152 | # build data loader 153 | logger.info("load dataset") 154 | lt_loader, ult_loader, val_loader, test_loader, num_classes, img_size = gen_dataloader(cfg.root, cfg.dataset, True, cfg, logger) 155 | 156 | # set consistency type 157 | consistency = gen_consistency(cfg.consistency, cfg) 158 | # set ssl algorithm 159 | ssl_alg = gen_ssl_alg(cfg.alg, cfg) 160 | # build student model 161 | model = gen_model(cfg.model, num_classes, img_size).to(device) 162 | # build teacher model 163 | if cfg.ema_teacher: 164 | teacher_model = gen_model(cfg.model, num_classes, img_size).to(device) 165 | teacher_model.load_state_dict(model.state_dict()) 166 | else: 167 | teacher_model = None 168 | # for evaluation 169 | if cfg.weight_average: 170 | average_model = gen_model(cfg.model, num_classes, img_size).to(device) 171 | average_model.load_state_dict(model.state_dict()) 172 | else: 173 | average_model = None 174 | 175 | model.train() 176 | 177 | logger.info(model) 178 | 179 | # build optimizer 180 | if cfg.optimizer == "sgd": 181 | optimizer = optim.SGD( 182 | model.parameters(), cfg.lr, cfg.momentum, weight_decay=0, nesterov=True 183 | ) 184 | elif cfg.optimizer == "adam": 185 | optimizer = optim.AdamW( 186 | model.parameters(), cfg.lr, (cfg.momentum, 0.999), weight_decay=0 187 | ) 188 | else: 189 | raise NotImplementedError 190 | # set lr scheduler 191 | if cfg.lr_decay == "cos": 192 | lr_scheduler = scheduler.CosineAnnealingLR(optimizer, cfg.iteration) 193 | elif cfg.lr_decay == "step": 194 | lr_scheduler = optim.lr_scheduler.MultiStepLR(optimizer, [400000, ], cfg.lr_decay_rate) 195 | else: 196 | raise NotImplementedError 197 | 198 | # init meter 199 | metric_meter = Meter() 200 | maximum_val_acc = 0 201 | 202 | logger.info("training") 203 | for i, (l_data, ul_data) in enumerate(zip(lt_loader, ult_loader)): 204 | l_aug, labels = l_data 205 | ul_w_aug, ul_s_aug, _ = ul_data 206 | 207 | params = param_update( 208 | cfg, i, model, teacher_model, optimizer, ssl_alg, 209 | consistency, l_aug.to(device), ul_w_aug.to(device), 210 | ul_s_aug.to(device), labels.to(device), 211 | average_model 212 | ) 213 | 214 | # moving average for reporting losses and accuracy 215 | metric_meter.add(params, ignores=["coef"]) 216 | 217 | # display losses every cfg.disp iterations 218 | if ((i+1) % cfg.disp) == 0: 219 | state = metric_meter.state( 220 | header = f'[{i+1}/{cfg.iteration}]', 221 | footer = f'ssl coef {params["coef"]:.4g} | lr {optimizer.param_groups[0]["lr"]:.4g}' 222 | ) 223 | logger.info(state) 224 | 225 | lr_scheduler.step() 226 | # validation 227 | if ((i + 1) % cfg.checkpoint) == 0 or (i+1) == cfg.iteration: 228 | with torch.no_grad(): 229 | if cfg.weight_average: 230 | eval_model = average_model 231 | else: 232 | eval_model = model 233 | logger.info("validation") 234 | mean_raw_acc, mean_val_acc, mean_val_loss = evaluation(model, eval_model, val_loader, device) 235 | logger.info("validation loss %f | validation acc. %f | raw acc. %f", mean_val_loss, mean_val_acc, mean_raw_acc) 236 | 237 | # test 238 | if not cfg.only_validation and mean_val_acc > maximum_val_acc: 239 | torch.save(eval_model.state_dict(), os.path.join(cfg.out_dir, "best_model.pth")) 240 | maximum_val_acc = mean_val_acc 241 | logger.info("test") 242 | mean_raw_acc, mean_test_acc, mean_test_loss = evaluation(model, eval_model, test_loader, device) 243 | logger.info("test loss %f | test acc. %f | raw acc. %f", mean_test_loss, mean_test_acc, mean_raw_acc) 244 | 245 | torch.save(model.state_dict(), os.path.join(cfg.out_dir, "model_checkpoint.pth")) 246 | torch.save(optimizer.state_dict(), os.path.join(cfg.out_dir, "optimizer_checkpoint.pth")) 247 | 248 | logger.info("test accuracy %f", mean_test_acc) 249 | 250 | 251 | if __name__ == "__main__": 252 | import os, sys 253 | from parser import get_args 254 | args = get_args() 255 | os.makedirs(args.out_dir, exist_ok=True) 256 | 257 | # setup logger 258 | plain_formatter = logging.Formatter( 259 | "[%(asctime)s] %(name)s %(levelname)s: %(message)s", datefmt="%m/%d %H:%M:%S" 260 | ) 261 | logger = logging.getLogger(__name__) 262 | logger.setLevel(logging.DEBUG) 263 | s_handler = logging.StreamHandler(stream=sys.stdout) 264 | s_handler.setFormatter(plain_formatter) 265 | s_handler.setLevel(logging.DEBUG) 266 | logger.addHandler(s_handler) 267 | f_handler = logging.FileHandler(os.path.join(args.out_dir, "console.log")) 268 | f_handler.setFormatter(plain_formatter) 269 | f_handler.setLevel(logging.DEBUG) 270 | logger.addHandler(f_handler) 271 | logger.propagate = False 272 | 273 | logger.info(args) 274 | 275 | main(args, logger) 276 | --------------------------------------------------------------------------------