├── LICENSE ├── README.md ├── build_dataset.py ├── config.py ├── lib ├── __init__.py ├── algs │ ├── __init__.py │ ├── ict.py │ ├── mean_teacher.py │ ├── mixmatch.py │ ├── pimodel.py │ ├── pseudo_label.py │ └── vat.py ├── datasets │ ├── __init__.py │ ├── cifar10.py │ └── svhn.py ├── transform.py └── wrn.py └── train.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 teppei suzuki 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # realistic-ssl-evaluation-pytorch 2 | This repository is reimplementation of [Realistic Evaluation of Deep Semi-Supervised Learning Algorithms](https://arxiv.org/abs/1804.09170), by Avital Oliver*, Augustus Odena*, Colin Raffel*, Ekin D. Cubuk, and Ian J. Goodfellow, arXiv preprint arXiv:1804.09170. 3 | Original repo is [here](https://github.com/brain-research/realistic-ssl-evaluation). 4 | 5 | NOTE: This repository has been deprecated. Please see [here](https://github.com/perrying/pytorch-consistency-regularization), which includes UDA and FixMatch. 6 | 7 | # Requirements 8 | - Python 3.6+ 9 | - PyTorch 1.1.0 10 | - torchvision 0.3.0 11 | - numpy 1.16.2 12 | 13 | # How to run 14 | Prepare dataset 15 | 16 | ```python 17 | python build_dataset.py 18 | ``` 19 | 20 | Default setting is SVHN 1000 labels. If you try other settings, please check the options first by ```python build_dataset.py -h```. 21 | 22 | Running experiments 23 | 24 | ```python 25 | python train.py 26 | ``` 27 | 28 | Default setting is VAT. Please check the options by ```python python train.py -h``` 29 | 30 | # Performance 31 | WIP 32 | 33 | |algorithm|paper||this repo| | 34 | |--|--|--|--|--| 35 | ||cifar10 4k labels|svhn 1k labels|cifar10 4k labels|svhn 1k labels| 36 | |Supervised|20.26 ±0.38|12.83 ±0.47|20.35±0.14|12.33±0.25 37 | |Pi-Model|16.37 ±0.63|7.19 ±0.27|16.24±0.38|7.81±0.39 38 | |Mean Teacher|15.87 ±0.28|5.65 ±0.47|15.77±0.22|6.48±0.44 39 | |VAT|13.86 ±0.27|5.63 ±0.20|13.83±0.49|5.84±0.20 40 | |VAT+EM|13.13 ±0.39|5.35 ±0.19|13.30±0.27|5.76±0.13 41 | |Pseudo-Label|17.78 ±0.57|7.62 ±0.29|N/A|N/A 42 | |[ICT](https://arxiv.org/abs/1903.03825)|( 7.66 ±0.17 )|( 3.53 ±0.07 )|N/A|N/A 43 | |[MixMatch](https://arxiv.org/abs/1905.02249)|( 6.50 )|( 3.27 ±0.31 )|N/A|N/A 44 | 45 | *NOTE: Experimental setting of ICT and MixMatch papers is different from this benchmark.* 46 | 47 | # Reference 48 | - [Realistic Evaluation of Deep Semi-Supervised Learning Algorithms](https://arxiv.org/abs/1804.09170), by Avital Oliver*, Augustus Odena*, Colin Raffel*, Ekin D. Cubuk, and Ian J. Goodfellow, arXiv preprint arXiv:1804.09170. 49 | - [Interpolation Consistency Training for Semi-Supervised Learning](https://arxiv.org/abs/1903.03825), by Vikas Verma, Alex Lamb, Juho Kannala, Yoshua Bengio, David Lopez-Paz 50 | - [MixMatch: A Holistic Approach to Semi-Supervised Learning](https://arxiv.org/abs/1905.02249), by David Berthelot, Nicholas Carlini, Ian Goodfellow, Nicolas Papernot, Avital Oliver, Colin Raffel 51 | -------------------------------------------------------------------------------- /build_dataset.py: -------------------------------------------------------------------------------- 1 | from torchvision import datasets 2 | import argparse, os 3 | import numpy as np 4 | 5 | parser = argparse.ArgumentParser() 6 | parser.add_argument("--seed", "-s", default=1, type=int, help="random seed") 7 | parser.add_argument("--dataset", "-d", default="svhn", type=str, help="dataset name : [svhn, cifar10]") 8 | parser.add_argument("--nlabels", "-n", default=1000, type=int, help="the number of labeled data") 9 | args = parser.parse_args() 10 | 11 | COUNTS = { 12 | "svhn": {"train": 73257, "test": 26032, "valid": 7326, "extra": 531131}, 13 | "cifar10": {"train": 50000, "test": 10000, "valid": 5000, "extra": 0}, 14 | "imagenet_32": { 15 | "train": 1281167, 16 | "test": 50000, 17 | "valid": 50050, 18 | "extra": 0, 19 | }, 20 | } 21 | 22 | _DATA_DIR = "./data" 23 | 24 | def split_l_u(train_set, n_labels): 25 | # NOTE: this function assume that train_set is shuffled. 26 | images = train_set["images"] 27 | labels = train_set["labels"] 28 | classes = np.unique(labels) 29 | n_labels_per_cls = n_labels // len(classes) 30 | l_images = [] 31 | l_labels = [] 32 | u_images = [] 33 | u_labels = [] 34 | for c in classes: 35 | cls_mask = (labels == c) 36 | c_images = images[cls_mask] 37 | c_labels = labels[cls_mask] 38 | l_images += [c_images[:n_labels_per_cls]] 39 | l_labels += [c_labels[:n_labels_per_cls]] 40 | u_images += [c_images[n_labels_per_cls:]] 41 | u_labels += [np.zeros_like(c_labels[n_labels_per_cls:]) - 1] # dammy label 42 | l_train_set = {"images": np.concatenate(l_images, 0), "labels": np.concatenate(l_labels, 0)} 43 | u_train_set = {"images": np.concatenate(u_images, 0), "labels": np.concatenate(u_labels, 0)} 44 | return l_train_set, u_train_set 45 | 46 | def _load_svhn(): 47 | splits = {} 48 | for split in ["train", "test", "extra"]: 49 | tv_data = datasets.SVHN(_DATA_DIR, split, download=True) 50 | data = {} 51 | data["images"] = tv_data.data 52 | data["labels"] = tv_data.labels 53 | splits[split] = data 54 | return splits.values() 55 | 56 | def _load_cifar10(): 57 | splits = {} 58 | for train in [True, False]: 59 | tv_data = datasets.CIFAR10(_DATA_DIR, train, download=True) 60 | data = {} 61 | data["images"] = tv_data.data 62 | data["labels"] = np.array(tv_data.targets) 63 | splits["train" if train else "test"] = data 64 | return splits.values() 65 | 66 | def gcn(images, multiplier=55, eps=1e-10): 67 | # global contrast normalization 68 | images = images.astype(np.float) 69 | images -= images.mean(axis=(1,2,3), keepdims=True) 70 | per_image_norm = np.sqrt(np.square(images).sum((1,2,3), keepdims=True)) 71 | per_image_norm[per_image_norm < eps] = 1 72 | return multiplier * images / per_image_norm 73 | 74 | def get_zca_normalization_param(images, scale=0.1, eps=1e-10): 75 | n_data, height, width, channels = images.shape 76 | images = images.reshape(n_data, height*width*channels) 77 | image_cov = np.cov(images, rowvar=False) 78 | U, S, _ = np.linalg.svd(image_cov + scale * np.eye(image_cov.shape[0])) 79 | zca_decomp = np.dot(U, np.dot(np.diag(1/np.sqrt(S + eps)), U.T)) 80 | mean = images.mean(axis=0) 81 | return mean, zca_decomp 82 | 83 | def zca_normalization(images, mean, decomp): 84 | n_data, height, width, channels = images.shape 85 | images = images.reshape(n_data, -1) 86 | images = np.dot((images - mean), decomp) 87 | return images.reshape(n_data, height, width, channels) 88 | 89 | rng = np.random.RandomState(args.seed) 90 | 91 | validation_count = COUNTS[args.dataset]["valid"] 92 | 93 | extra_set = None # In general, there won't be extra data. 94 | if args.dataset == "svhn": 95 | train_set, test_set, extra_set = _load_svhn() 96 | elif args.dataset == "cifar10": 97 | train_set, test_set = _load_cifar10() 98 | train_set["images"] = gcn(train_set["images"]) 99 | test_set["images"] = gcn(test_set["images"]) 100 | mean, zca_decomp = get_zca_normalization_param(train_set["images"]) 101 | train_set["images"] = zca_normalization(train_set["images"], mean, zca_decomp) 102 | test_set["images"] = zca_normalization(test_set["images"], mean, zca_decomp) 103 | # N x H x W x C -> N x C x H x W 104 | train_set["images"] = np.transpose(train_set["images"], (0,3,1,2)) 105 | test_set["images"] = np.transpose(test_set["images"], (0,3,1,2)) 106 | 107 | # permute index of training set 108 | indices = rng.permutation(len(train_set["images"])) 109 | train_set["images"] = train_set["images"][indices] 110 | train_set["labels"] = train_set["labels"][indices] 111 | 112 | if extra_set is not None: 113 | extra_indices = rng.permutation(len(extra_set["images"])) 114 | extra_set["images"] = extra_set["images"][extra_indices] 115 | extra_set["labels"] = extra_set["labels"][extra_indices] 116 | 117 | # split training set into training and validation 118 | train_images = train_set["images"][validation_count:] 119 | train_labels = train_set["labels"][validation_count:] 120 | validation_images = train_set["images"][:validation_count] 121 | validation_labels = train_set["labels"][:validation_count] 122 | validation_set = {"images": validation_images, "labels": validation_labels} 123 | train_set = {"images": train_images, "labels": train_labels} 124 | 125 | # split training set into labeled data and unlabeled data 126 | l_train_set, u_train_set = split_l_u(train_set, args.nlabels) 127 | 128 | if not os.path.exists(os.path.join(_DATA_DIR, args.dataset)): 129 | os.mkdir(os.path.join(_DATA_DIR, args.dataset)) 130 | 131 | np.save(os.path.join(_DATA_DIR, args.dataset, "l_train"), l_train_set) 132 | np.save(os.path.join(_DATA_DIR, args.dataset, "u_train"), u_train_set) 133 | np.save(os.path.join(_DATA_DIR, args.dataset, "val"), validation_set) 134 | np.save(os.path.join(_DATA_DIR, args.dataset, "test"), test_set) 135 | if extra_set is not None: 136 | np.save(os.path.join(_DATA_DIR, args.dataset, "extra"), extra_set) 137 | 138 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | from lib.datasets import svhn, cifar10 2 | import numpy as np 3 | 4 | shared_config = { 5 | "iteration" : 500000, 6 | "warmup" : 200000, 7 | "lr_decay_iter" : 400000, 8 | "lr_decay_factor" : 0.2, 9 | "batch_size" : 100, 10 | } 11 | ### dataset ### 12 | svhn_config = { 13 | "transform" : [False, True, False], # flip, rnd crop, gaussian noise 14 | "dataset" : svhn.SVHN, 15 | "num_classes" : 10, 16 | } 17 | cifar10_config = { 18 | "transform" : [True, True, True], 19 | "dataset" : cifar10.CIFAR10, 20 | "num_classes" : 10, 21 | } 22 | ### algorithm ### 23 | vat_config = { 24 | # virtual adversarial training 25 | "xi" : 1e-6, 26 | "eps" : {"cifar10":6, "svhn":1}, 27 | "consis_coef" : 0.3, 28 | "lr" : 3e-3 29 | } 30 | pl_config = { 31 | # pseudo label 32 | "threashold" : 0.95, 33 | "lr" : 3e-4, 34 | "consis_coef" : 1, 35 | } 36 | mt_config = { 37 | # mean teacher 38 | "ema_factor" : 0.95, 39 | "lr" : 4e-4, 40 | "consis_coef" : 8, 41 | } 42 | pi_config = { 43 | # Pi Model 44 | "lr" : 3e-4, 45 | "consis_coef" : 20.0, 46 | } 47 | ict_config = { 48 | # interpolation consistency training 49 | "ema_factor" : 0.999, 50 | "lr" : 4e-4, 51 | "consis_coef" : 100, 52 | "alpha" : 0.1, 53 | } 54 | mm_config = { 55 | # mixmatch 56 | "lr" : 3e-3, 57 | "consis_coef" : 100, 58 | "alpha" : 0.75, 59 | "T" : 0.5, 60 | "K" : 2, 61 | } 62 | supervised_config = { 63 | "lr" : 3e-3 64 | } 65 | ### master ### 66 | config = { 67 | "shared" : shared_config, 68 | "svhn" : svhn_config, 69 | "cifar10" : cifar10_config, 70 | "VAT" : vat_config, 71 | "PL" : pl_config, 72 | "MT" : mt_config, 73 | "PI" : pi_config, 74 | "ICT" : ict_config, 75 | "MM" : mm_config, 76 | "supervised" : supervised_config 77 | } 78 | -------------------------------------------------------------------------------- /lib/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/perrying/realistic-ssl-evaluation-pytorch/446ada86776b0284d06fc42123a3522bb72c7ac4/lib/__init__.py -------------------------------------------------------------------------------- /lib/algs/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/perrying/realistic-ssl-evaluation-pytorch/446ada86776b0284d06fc42123a3522bb72c7ac4/lib/algs/__init__.py -------------------------------------------------------------------------------- /lib/algs/ict.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | import numpy as np 6 | 7 | class ICT(nn.Module): 8 | def __init__(self, alpha, model, ema_factor): 9 | super().__init__() 10 | self.alpha = alpha 11 | self.mean_teacher = model 12 | self.mean_teacher.train() 13 | self.ema_factor = ema_factor 14 | self.global_step = 0 15 | 16 | def forward(self, x, y, model, mask): 17 | # NOTE: this implementaion uses mixup for only unlabeled data 18 | self.global_step += 1 # for moving average coef 19 | mask = mask.byte() 20 | model.update_batch_stats(False) 21 | mt_y = self.mean_teacher(x).detach() 22 | u_x, u_y = x[mask], mt_y[mask] 23 | l_x, l_y = x[mask==0], mt_y[mask==0] 24 | lam = np.random.beta(self.alpha, self.alpha) # sample mixup coef 25 | perm = torch.randperm(u_x.shape[0]) 26 | perm_u_x, perm_u_y = u_x[perm], u_y[perm] 27 | mixed_u_x = lam * u_x + (1 - lam) * perm_u_x 28 | mixed_u_y = (lam * u_y + (1 - lam) * perm_u_y).detach() 29 | y_hat = model(torch.cat([l_x, mixed_u_x], 0)) # "cat" indicates to compute batch stats from full batches 30 | loss = F.mse_loss(y_hat.softmax(1), torch.cat([l_y, mixed_u_y], 0).softmax(1), reduction="none").sum(1) 31 | # compute loss for only unlabeled data, but loss is normalized by full batchsize 32 | loss = loss[l_x.shape[0]:].sum() / x.shape[0] 33 | model.update_batch_stats(True) 34 | return loss 35 | 36 | def moving_average(self, parameters): 37 | ema_factor = min(1 - 1 / (self.global_step), self.ema_factor) 38 | for emp_p, p in zip(self.mean_teacher.parameters(), parameters): 39 | emp_p.data = ema_factor * emp_p.data + (1 - ema_factor) * p.data 40 | -------------------------------------------------------------------------------- /lib/algs/mean_teacher.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class MT(nn.Module): 6 | def __init__(self, model, ema_factor): 7 | super().__init__() 8 | self.model = model 9 | self.model.train() 10 | self.ema_factor = ema_factor 11 | self.global_step = 0 12 | 13 | def forward(self, x, y, model, mask): 14 | self.global_step += 1 15 | y_hat = self.model(x) 16 | model.update_batch_stats(False) 17 | y = model(x) # recompute y since y as input of forward function is detached 18 | model.update_batch_stats(True) 19 | return (F.mse_loss(y.softmax(1), y_hat.softmax(1).detach(), reduction="none").mean(1) * mask).mean() 20 | 21 | def moving_average(self, parameters): 22 | ema_factor = min(1 - 1 / (self.global_step+1), self.ema_factor) 23 | for emp_p, p in zip(self.model.parameters(), parameters): 24 | emp_p.data = ema_factor * emp_p.data + (1 - ema_factor) * p.data 25 | -------------------------------------------------------------------------------- /lib/algs/mixmatch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class MixMatch(nn.Module): 6 | def __init__(self, temperature, n_augment, alpha): 7 | super().__init__() 8 | self.T = temperature 9 | self.K = n_augment 10 | self.beta_distirb = torch.distributions.beta.Beta(alpha, alpha) 11 | 12 | def sharpen(self, y): 13 | y = y.pow(1/self.T) 14 | return y / y.sum(1,keepdim=True) 15 | 16 | def forward(self, x, y, model, mask): 17 | # NOTE: this implementaion uses mixup for only unlabeled data 18 | model.update_batch_stats(False) 19 | u_x = x[mask == 1] 20 | # K augmentation and make prediction labels 21 | u_x_hat = [u_x for _ in range(self.K)] 22 | y_hat = sum([model(u_x_hat[i]).softmax(1) for i in range(len(u_x_hat))]) / self.K 23 | y_hat = self.sharpen(y_hat) 24 | y_hat = y_hat.repeat(len(u_x_hat), 1) 25 | # mixup 26 | u_x_hat = torch.cat(u_x_hat, 0) 27 | index = torch.randperm(u_x_hat.shape[0]) 28 | shuffled_u_x_hat, shuffled_y_hat = u_x_hat[index], y_hat[index] 29 | lam = self.beta_distirb.sample().item() 30 | # lam = max(lam, 1-lam) 31 | mixed_x = lam * u_x_hat + (1-lam) * shuffled_u_x_hat 32 | mixed_y = lam * y_hat + (1-lam) * shuffled_y_hat.softmax(1) 33 | # mean squared error 34 | loss = F.mse_loss(model(mixed_x), mixed_y) 35 | model.update_batch_stats(True) 36 | return loss 37 | -------------------------------------------------------------------------------- /lib/algs/pimodel.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class PiModel(nn.Module): 6 | def __init__(self): 7 | super().__init__() 8 | 9 | def forward(self, x, y, model, mask): 10 | # NOTE: 11 | # stochastic transformation is embeded in forward function 12 | # so, pi-model is just to calculate consistency between two outputs 13 | model.update_batch_stats(False) 14 | y_hat = model(x) 15 | model.update_batch_stats(True) 16 | return (F.mse_loss(y_hat.softmax(1), y.softmax(1).detach(), reduction="none").mean(1) * mask).mean() 17 | -------------------------------------------------------------------------------- /lib/algs/pseudo_label.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class PL(nn.Module): 6 | def __init__(self, threshold): 7 | super().__init__() 8 | self.th = threshold 9 | 10 | def forward(self, x, y, model, mask): 11 | y_probs = y.softmax(1) 12 | onehot_label = self.__make_one_hot(y_probs.max(1)[1]).float() 13 | gt_mask = (y_probs > self.th).float() 14 | gt_mask = gt_mask.max(1)[0] # reduce_any 15 | lt_mask = 1 - gt_mask # logical not 16 | p_target = gt_mask[:,None] * 10 * onehot_label + lt_mask[:,None] * y_probs 17 | model.update_batch_stats(False) 18 | output = model(x) 19 | loss = (-(p_target.detach() * F.log_softmax(output, 1)).sum(1)*mask).mean() 20 | model.update_batch_stats(True) 21 | return loss 22 | 23 | def __make_one_hot(self, y, n_classes=10): 24 | return torch.eye(n_classes)[y].to(y.device) 25 | -------------------------------------------------------------------------------- /lib/algs/vat.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class VAT(nn.Module): 6 | def __init__(self, eps=1.0, xi=1e-6, n_iteration=1): 7 | super().__init__() 8 | self.eps = eps 9 | self.xi = xi 10 | self.n_iteration = n_iteration 11 | 12 | def kld(self, q_logit, p_logit): 13 | q = q_logit.softmax(1) 14 | qlogp = (q * self.__logsoftmax(p_logit)).sum(1) 15 | qlogq = (q * self.__logsoftmax(q_logit)).sum(1) 16 | return qlogq - qlogp 17 | 18 | def normalize(self, v): 19 | v = v / (1e-12 + self.__reduce_max(v.abs(), range(1, len(v.shape)))) 20 | v = v / (1e-6 + v.pow(2).sum((1,2,3),keepdim=True)).sqrt() 21 | return v 22 | 23 | def forward(self, x, y, model, mask): 24 | model.update_batch_stats(False) 25 | d = torch.randn_like(x) 26 | d = self.normalize(d) 27 | for _ in range(self.n_iteration): 28 | d.requires_grad = True 29 | x_hat = x + self.xi * d 30 | y_hat = model(x_hat) 31 | kld = self.kld(y.detach(), y_hat).mean() 32 | d = torch.autograd.grad(kld, d)[0] 33 | d = self.normalize(d).detach() 34 | x_hat = x + self.eps * d 35 | y_hat = model(x_hat) 36 | # NOTE: 37 | # Original implimentation of VAT defines KL(P(y|x)||P(x|x+r_adv)) as loss function 38 | # However, Avital Oliver's implimentation use KL(P(y|x+r_adv)||P(y|x)) as loss function of VAT 39 | # see issue https://github.com/brain-research/realistic-ssl-evaluation/issues/27 40 | loss = (self.kld(y_hat, y.detach()) * mask).mean() 41 | model.update_batch_stats(True) 42 | return loss 43 | 44 | def __reduce_max(self, v, idx_list): 45 | for i in idx_list: 46 | v = v.max(i, keepdim=True)[0] 47 | return v 48 | 49 | def __logsoftmax(self,x): 50 | xdev = x - x.max(1, keepdim=True)[0] 51 | lsm = xdev - xdev.exp().sum(1, keepdim=True).log() 52 | return lsm -------------------------------------------------------------------------------- /lib/datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/perrying/realistic-ssl-evaluation-pytorch/446ada86776b0284d06fc42123a3522bb72c7ac4/lib/datasets/__init__.py -------------------------------------------------------------------------------- /lib/datasets/cifar10.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | 4 | class CIFAR10: 5 | def __init__(self, root, split="l_train"): 6 | self.dataset = np.load(os.path.join(root, "cifar10", split+".npy"), allow_pickle=True).item() 7 | 8 | def __getitem__(self, idx): 9 | image = self.dataset["images"][idx] 10 | label = self.dataset["labels"][idx] 11 | return image, label 12 | 13 | def __len__(self): 14 | return len(self.dataset["images"]) -------------------------------------------------------------------------------- /lib/datasets/svhn.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | 4 | class SVHN: 5 | def __init__(self, root, split="l_train"): 6 | self.dataset = np.load(os.path.join(root, "svhn", split+".npy"), allow_pickle=True).item() 7 | 8 | def __getitem__(self, idx): 9 | image = self.dataset["images"][idx] 10 | label = self.dataset["labels"][idx] 11 | image = (image/255. - 0.5)/0.5 12 | return image, label 13 | 14 | def __len__(self): 15 | return len(self.dataset["images"]) -------------------------------------------------------------------------------- /lib/transform.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import random 4 | 5 | class transform: 6 | def __init__(self, flip=True, r_crop=True, g_noise=True): 7 | self.flip = flip 8 | self.r_crop = r_crop 9 | self.g_noise = g_noise 10 | print("holizontal flip : {}, random crop : {}, gaussian noise : {}".format( 11 | self.flip, self.r_crop, self.g_noise 12 | )) 13 | 14 | def __call__(self, x): 15 | if self.flip and random.random() > 0.5: 16 | x = x.flip(-1) 17 | if self.r_crop: 18 | h, w = x.shape[-2:] 19 | x = F.pad(x, [2,2,2,2], mode="reflect") 20 | l, t = random.randint(0, 4), random.randint(0,4) 21 | x = x[:,:,t:t+h,l:l+w] 22 | if self.g_noise: 23 | n = torch.randn_like(x) * 0.15 24 | x = n + x 25 | return x 26 | -------------------------------------------------------------------------------- /lib/wrn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | def conv3x3(i_c, o_c, stride=1): 5 | return nn.Conv2d(i_c, o_c, 3, stride, 1, bias=False) 6 | 7 | class BatchNorm2d(nn.BatchNorm2d): 8 | def __init__(self, channels, momentum=1e-3, eps=1e-3): 9 | super().__init__(channels) 10 | self.update_batch_stats = True 11 | 12 | def forward(self, x): 13 | if self.update_batch_stats: 14 | return super().forward(x) 15 | else: 16 | return nn.functional.batch_norm( 17 | x, None, None, self.weight, self.bias, True, self.momentum, self.eps 18 | ) 19 | 20 | def relu(): 21 | return nn.LeakyReLU(0.1) 22 | 23 | class residual(nn.Module): 24 | def __init__(self, input_channels, output_channels, stride=1, activate_before_residual=False): 25 | super().__init__() 26 | layer = [] 27 | if activate_before_residual: 28 | self.pre_act = nn.Sequential( 29 | BatchNorm2d(input_channels), 30 | relu() 31 | ) 32 | else: 33 | self.pre_act = nn.Identity() 34 | layer.append(BatchNorm2d(input_channels)) 35 | layer.append(relu()) 36 | layer.append(conv3x3(input_channels, output_channels, stride)) 37 | layer.append(BatchNorm2d(output_channels)) 38 | layer.append(relu()) 39 | layer.append(conv3x3(output_channels, output_channels)) 40 | 41 | if stride >= 2 or input_channels != output_channels: 42 | self.identity = nn.Conv2d(input_channels, output_channels, 1, stride, bias=False) 43 | else: 44 | self.identity = nn.Identity() 45 | 46 | self.layer = nn.Sequential(*layer) 47 | 48 | def forward(self, x): 49 | x = self.pre_act(x) 50 | return self.identity(x) + self.layer(x) 51 | 52 | class WRN(nn.Module): 53 | """ WRN28-width with leaky relu (negative slope is 0.1)""" 54 | def __init__(self, width, num_classes, transform_fn=None): 55 | super().__init__() 56 | 57 | self.init_conv = conv3x3(3, 16) 58 | 59 | filters = [16, 16*width, 32*width, 64*width] 60 | 61 | unit1 = [residual(filters[0], filters[1], activate_before_residual=True)] + \ 62 | [residual(filters[1], filters[1]) for _ in range(1, 4)] 63 | self.unit1 = nn.Sequential(*unit1) 64 | 65 | unit2 = [residual(filters[1], filters[2], 2)] + \ 66 | [residual(filters[2], filters[2]) for _ in range(1, 4)] 67 | self.unit2 = nn.Sequential(*unit2) 68 | 69 | unit3 = [residual(filters[2], filters[3], 2)] + \ 70 | [residual(filters[3], filters[3]) for _ in range(1, 4)] 71 | self.unit3 = nn.Sequential(*unit3) 72 | 73 | self.unit4 = nn.Sequential(*[BatchNorm2d(filters[3]), relu(), nn.AdaptiveAvgPool2d(1)]) 74 | 75 | self.output = nn.Linear(filters[3], num_classes) 76 | 77 | for m in self.modules(): 78 | if isinstance(m, nn.Conv2d): 79 | nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") 80 | elif isinstance(m, nn.BatchNorm2d): 81 | nn.init.constant_(m.weight, 1) 82 | nn.init.constant_(m.bias, 0) 83 | elif isinstance(m, nn.Linear): 84 | nn.init.xavier_normal_(m.weight) 85 | nn.init.constant_(m.bias, 0) 86 | 87 | self.transform_fn = transform_fn 88 | 89 | def forward(self, x, return_feature=False): 90 | if self.training and self.transform_fn is not None: 91 | x = self.transform_fn(x) 92 | x = self.init_conv(x) 93 | x = self.unit1(x) 94 | x = self.unit2(x) 95 | x = self.unit3(x) 96 | f = self.unit4(x) 97 | c = self.output(f.squeeze()) 98 | if return_feature: 99 | return [c, f] 100 | else: 101 | return c 102 | 103 | def update_batch_stats(self, flag): 104 | for m in self.modules(): 105 | if isinstance(m, nn.BatchNorm2d): 106 | m.update_batch_stats = flag 107 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | import torch 4 | import torch.nn as nn 5 | import torch.optim as optim 6 | import torch.nn.functional as F 7 | from torch.utils.data import DataLoader 8 | 9 | import argparse, math, time, json, os 10 | 11 | from lib import wrn, transform 12 | from config import config 13 | 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument("--alg", "-a", default="VAT", type=str, help="ssl algorithm : [supervised, PI, MT, VAT, PL, ICT]") 16 | parser.add_argument("--em", default=0, type=float, help="coefficient of entropy minimization. If you try VAT + EM, set 0.06") 17 | parser.add_argument("--validation", default=25000, type=int, help="validate at this interval (default 25000)") 18 | parser.add_argument("--dataset", "-d", default="svhn", type=str, help="dataset name : [svhn, cifar10]") 19 | parser.add_argument("--root", "-r", default="data", type=str, help="dataset dir") 20 | parser.add_argument("--output", "-o", default="./exp_res", type=str, help="output dir") 21 | args = parser.parse_args() 22 | 23 | if torch.cuda.is_available(): 24 | device = "cuda" 25 | torch.backends.cudnn.benchmark = True 26 | else: 27 | device = "cpu" 28 | 29 | condition = {} 30 | exp_name = "" 31 | 32 | print("dataset : {}".format(args.dataset)) 33 | condition["dataset"] = args.dataset 34 | exp_name += str(args.dataset) + "_" 35 | 36 | dataset_cfg = config[args.dataset] 37 | transform_fn = transform.transform(*dataset_cfg["transform"]) # transform function (flip, crop, noise) 38 | 39 | l_train_dataset = dataset_cfg["dataset"](args.root, "l_train") 40 | u_train_dataset = dataset_cfg["dataset"](args.root, "u_train") 41 | val_dataset = dataset_cfg["dataset"](args.root, "val") 42 | test_dataset = dataset_cfg["dataset"](args.root, "test") 43 | 44 | print("labeled data : {}, unlabeled data : {}, training data : {}".format( 45 | len(l_train_dataset), len(u_train_dataset), len(l_train_dataset)+len(u_train_dataset))) 46 | print("validation data : {}, test data : {}".format(len(val_dataset), len(test_dataset))) 47 | condition["number_of_data"] = { 48 | "labeled":len(l_train_dataset), "unlabeled":len(u_train_dataset), 49 | "validation":len(val_dataset), "test":len(test_dataset) 50 | } 51 | 52 | class RandomSampler(torch.utils.data.Sampler): 53 | """ sampling without replacement """ 54 | def __init__(self, num_data, num_sample): 55 | iterations = num_sample // num_data + 1 56 | self.indices = torch.cat([torch.randperm(num_data) for _ in range(iterations)]).tolist()[:num_sample] 57 | 58 | def __iter__(self): 59 | return iter(self.indices) 60 | 61 | def __len__(self): 62 | return len(self.indices) 63 | 64 | shared_cfg = config["shared"] 65 | if args.alg != "supervised": 66 | # batch size = 0.5 x batch size 67 | l_loader = DataLoader( 68 | l_train_dataset, shared_cfg["batch_size"]//2, drop_last=True, 69 | sampler=RandomSampler(len(l_train_dataset), shared_cfg["iteration"] * shared_cfg["batch_size"]//2) 70 | ) 71 | else: 72 | l_loader = DataLoader( 73 | l_train_dataset, shared_cfg["batch_size"], drop_last=True, 74 | sampler=RandomSampler(len(l_train_dataset), shared_cfg["iteration"] * shared_cfg["batch_size"]) 75 | ) 76 | print("algorithm : {}".format(args.alg)) 77 | condition["algorithm"] = args.alg 78 | exp_name += str(args.alg) + "_" 79 | 80 | u_loader = DataLoader( 81 | u_train_dataset, shared_cfg["batch_size"]//2, drop_last=True, 82 | sampler=RandomSampler(len(u_train_dataset), shared_cfg["iteration"] * shared_cfg["batch_size"]//2) 83 | ) 84 | 85 | val_loader = DataLoader(val_dataset, 128, shuffle=False, drop_last=False) 86 | test_loader = DataLoader(test_dataset, 128, shuffle=False, drop_last=False) 87 | 88 | print("maximum iteration : {}".format(min(len(l_loader), len(u_loader)))) 89 | 90 | alg_cfg = config[args.alg] 91 | print("parameters : ", alg_cfg) 92 | condition["h_parameters"] = alg_cfg 93 | 94 | if args.em > 0: 95 | print("entropy minimization : {}".format(args.em)) 96 | exp_name += "em_" 97 | condition["entropy_maximization"] = args.em 98 | 99 | model = wrn.WRN(2, dataset_cfg["num_classes"], transform_fn).to(device) 100 | optimizer = optim.Adam(model.parameters(), lr=alg_cfg["lr"]) 101 | 102 | trainable_paramters = sum([p.data.nelement() for p in model.parameters()]) 103 | print("trainable parameters : {}".format(trainable_paramters)) 104 | 105 | if args.alg == "VAT": # virtual adversarial training 106 | from lib.algs.vat import VAT 107 | ssl_obj = VAT(alg_cfg["eps"][args.dataset], alg_cfg["xi"], 1) 108 | elif args.alg == "PL": # pseudo label 109 | from lib.algs.pseudo_label import PL 110 | ssl_obj = PL(alg_cfg["threashold"]) 111 | elif args.alg == "MT": # mean teacher 112 | from lib.algs.mean_teacher import MT 113 | t_model = wrn.WRN(2, dataset_cfg["num_classes"], transform_fn).to(device) 114 | t_model.load_state_dict(model.state_dict()) 115 | ssl_obj = MT(t_model, alg_cfg["ema_factor"]) 116 | elif args.alg == "PI": # PI Model 117 | from lib.algs.pimodel import PiModel 118 | ssl_obj = PiModel() 119 | elif args.alg == "ICT": # interpolation consistency training 120 | from lib.algs.ict import ICT 121 | t_model = wrn.WRN(2, dataset_cfg["num_classes"], transform_fn).to(device) 122 | t_model.load_state_dict(model.state_dict()) 123 | ssl_obj = ICT(alg_cfg["alpha"], t_model, alg_cfg["ema_factor"]) 124 | elif args.alg == "MM": # MixMatch 125 | from lib.algs.mixmatch import MixMatch 126 | ssl_obj = MixMatch(alg_cfg["T"], alg_cfg["K"], alg_cfg["alpha"]) 127 | elif args.alg == "supervised": 128 | pass 129 | else: 130 | raise ValueError("{} is unknown algorithm".format(args.alg)) 131 | 132 | print() 133 | iteration = 0 134 | maximum_val_acc = 0 135 | s = time.time() 136 | for l_data, u_data in zip(l_loader, u_loader): 137 | iteration += 1 138 | l_input, target = l_data 139 | l_input, target = l_input.to(device).float(), target.to(device).long() 140 | 141 | if args.alg != "supervised": # for ssl algorithm 142 | u_input, dummy_target = u_data 143 | u_input, dummy_target = u_input.to(device).float(), dummy_target.to(device).long() 144 | 145 | target = torch.cat([target, dummy_target], 0) 146 | unlabeled_mask = (target == -1).float() 147 | 148 | inputs = torch.cat([l_input, u_input], 0) 149 | outputs = model(inputs) 150 | 151 | # ramp up exp(-5(1 - t)^2) 152 | coef = alg_cfg["consis_coef"] * math.exp(-5 * (1 - min(iteration/shared_cfg["warmup"], 1))**2) 153 | ssl_loss = ssl_obj(inputs, outputs.detach(), model, unlabeled_mask) * coef 154 | 155 | else: 156 | outputs = model(l_input) 157 | coef = 0 158 | ssl_loss = torch.zeros(1).to(device) 159 | 160 | # supervised loss 161 | cls_loss = F.cross_entropy(outputs, target, reduction="none", ignore_index=-1).mean() 162 | 163 | loss = cls_loss + ssl_loss 164 | 165 | if args.em > 0: 166 | loss -= args.em * ((outputs.softmax(1) * F.log_softmax(outputs, 1)).sum(1) * unlabeled_mask).mean() 167 | 168 | optimizer.zero_grad() 169 | loss.backward() 170 | optimizer.step() 171 | 172 | if args.alg == "MT" or args.alg == "ICT": 173 | # parameter update with exponential moving average 174 | ssl_obj.moving_average(model.parameters()) 175 | # display 176 | if iteration == 1 or (iteration % 100) == 0: 177 | wasted_time = time.time() - s 178 | rest = (shared_cfg["iteration"] - iteration)/100 * wasted_time / 60 179 | print("iteration [{}/{}] cls loss : {:.6e}, SSL loss : {:.6e}, coef : {:.5e}, time : {:.3f} iter/sec, rest : {:.3f} min, lr : {}".format( 180 | iteration, shared_cfg["iteration"], cls_loss.item(), ssl_loss.item(), coef, 100 / wasted_time, rest, optimizer.param_groups[0]["lr"]), 181 | "\r", end="") 182 | s = time.time() 183 | 184 | # validation 185 | if (iteration % args.validation) == 0 or iteration == shared_cfg["iteration"]: 186 | with torch.no_grad(): 187 | model.eval() 188 | print() 189 | print("### validation ###") 190 | sum_acc = 0. 191 | s = time.time() 192 | for j, data in enumerate(val_loader): 193 | input, target = data 194 | input, target = input.to(device).float(), target.to(device).long() 195 | 196 | output = model(input) 197 | 198 | pred_label = output.max(1)[1] 199 | sum_acc += (pred_label == target).float().sum() 200 | if ((j+1) % 10) == 0: 201 | d_p_s = 10/(time.time()-s) 202 | print("[{}/{}] time : {:.1f} data/sec, rest : {:.2f} sec".format( 203 | j+1, len(val_loader), d_p_s, (len(val_loader) - j-1)/d_p_s 204 | ), "\r", end="") 205 | s = time.time() 206 | acc = sum_acc/float(len(val_dataset)) 207 | print() 208 | print("varidation accuracy : {}".format(acc)) 209 | # test 210 | if maximum_val_acc < acc: 211 | print("### test ###") 212 | maximum_val_acc = acc 213 | sum_acc = 0. 214 | s = time.time() 215 | for j, data in enumerate(test_loader): 216 | input, target = data 217 | input, target = input.to(device).float(), target.to(device).long() 218 | output = model(input) 219 | pred_label = output.max(1)[1] 220 | sum_acc += (pred_label == target).float().sum() 221 | if ((j+1) % 10) == 0: 222 | d_p_s = 100/(time.time()-s) 223 | print("[{}/{}] time : {:.1f} data/sec, rest : {:.2f} sec".format( 224 | j+1, len(test_loader), d_p_s, (len(test_loader) - j-1)/d_p_s 225 | ), "\r", end="") 226 | s = time.time() 227 | print() 228 | test_acc = sum_acc / float(len(test_dataset)) 229 | print("test accuracy : {}".format(test_acc)) 230 | # torch.save(model.state_dict(), os.path.join(args.output, "best_model.pth")) 231 | model.train() 232 | s = time.time() 233 | # lr decay 234 | if iteration == shared_cfg["lr_decay_iter"]: 235 | optimizer.param_groups[0]["lr"] *= shared_cfg["lr_decay_factor"] 236 | 237 | print("test acc : {}".format(test_acc)) 238 | condition["test_acc"] = test_acc.item() 239 | 240 | exp_name += str(int(time.time())) # unique ID 241 | if not os.path.exists(args.output): 242 | os.mkdir(args.output) 243 | with open(os.path.join(args.output, exp_name + ".json"), "w") as f: 244 | json.dump(condition, f) 245 | --------------------------------------------------------------------------------