├── LICENSE ├── README.md ├── check_result.py ├── command.sh ├── loader_cifar.py ├── loader_cifar_zca.py ├── loader_svhn.py ├── methods.py ├── preresnet_sd_cifar.py ├── train.py └── wideresnet.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 siit-vtt 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ssl (semi-supervised learning) 2 | This repository contains code to reproduce "[Realistic Evaluation of Deep Semi-Supervised Learning Algorithms](https://arxiv.org/abs/1804.09170)" in pytorch. Currently, only supervised baseline, PI-model[2] and Mean-Teacher[3] are implemented. We attempted to follow the description in the paper, but there are several differences made intentionally. There may be other differences made accidentally from experiments in the paper. 3 | 4 | * The training code is under modification. 5 | 6 | # Prerequisites 7 | Tested on 8 | * python 2.7 9 | * pytorch 0.4.0 10 | 11 | Download ZCA preprocessed CIFAR-10 dataset 12 | * As described in the paper, global contrast normalize (GCN) and ZCA are important steps for the performance. We preprocess CIFAR-10 dataset using the code implemented in [Mean-Teacher repository](https://github.com/CuriousAI/mean-teacher). The code is in tensorflow/dataset folder. 13 | Place the preprocessed file (e.g. cifar10_gcn_zca_v2.npz) into a subfolder (e.g. cifar10_zca). 14 | 15 | # Experiment detail 16 | 17 | 18 | # To Run 19 | For basline 20 | 21 | python train.py -a=wideresnet -m=baseline -o=adam -b=225 --dataset=cifar10_zca --gpu=0,1 --lr=0.003 --boundary=0 22 | 23 | For Pi model 24 | 25 | python train.py -a=wideresnet -m=pi -o=adam -b=225 --dataset=cifar10_zca --gpu=0,1 --lr=0.0003 --boundary=0 26 | For Mean Teacher 27 | 28 | python train.py -a=wideresnet -m=mt -o=adam -b=225 --dataset=cifar10_zca --gpu=0,1 --lr=0.0004 --boundary=0 29 | 30 | * boundary option is for different label/unlabel division [0, 9]. 31 | 32 | You can check the average error rates for `n` runs using `check_result.py`. For example, you trained baseline model on 10 different boundary, 33 | 34 | python check_result.py --fdir ckpt_cifar10_zca_wideresnet_baseline_adam_e1200/ --fname wideresnet --nckpt 10 35 | 36 | # Result (CIFAR-10) 37 | |Method |WideResnet28x2 [1] |WideResnet28x3 w/ dropout (ours) | 38 | |-------------|----------------------|-----------------------------------| 39 | |Supervised |20.26 (0.38) | | 40 | |PI Model |16.37 (0.63) | | 41 | |Mean Teacher |15.87 (0.28) | | 42 | |VAT |13.86 (0.27) |- | 43 | |VAT + EM |13.13 (0.39) |- | 44 | 45 | 46 | # References 47 | [1] Oliver, Avital, et al. "Realistic Evaluation of Deep Semi-Supervised Learning Algorithms." arXiv preprint arXiv:1804.09170 (2018). 48 | 49 | [2] Laine, Samuli, and Timo Aila. "Temporal ensembling for semi-supervised learning." arXiv preprint arXiv:1610.02242 (2016). 50 | 51 | [3] Tarvainen, Antti, and Harri Valpola. "Mean teachers are better role models: Weight-averaged consistency targets improve semi-supervised deep learning results." Advances in neural information processing systems. 2017. 52 | 53 | [4] https://github.com/CuriousAI/mean-teacher 54 | 55 | [5] https://github.com/facebookresearch/odin 56 | -------------------------------------------------------------------------------- /check_result.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import os 4 | import argparse 5 | parser = argparse.ArgumentParser(description='PyTorch Places365 Training') 6 | parser.add_argument('--fdir', default='ckpt', type=str, metavar='PATH', 7 | help='path to load checkpoint (default: ckpt)') 8 | parser.add_argument('--fname', default='wideresnet', type=str, metavar='PATH', 9 | help='checkpoint filename (default: wideresnet)') 10 | parser.add_argument('--nckpt',default=1, type=int, help='num of checkpoints') 11 | parser.add_argument('--plot',default=False, type=bool, help='num of checkpoints') 12 | args = parser.parse_args() 13 | 14 | fdir = args.fdir 15 | fname = args.fname 16 | nckpt = args.nckpt 17 | 18 | best_prec1s = [] 19 | for i in range(nckpt): 20 | path = os.path.join(fdir,fname+str(i)+'_latest.pth.tar') 21 | checkpoint = torch.load(path) 22 | print(path) 23 | if 'best_test_prec1_t' in checkpoint: 24 | print("Teacher precision") 25 | best_prec1 = 100.0 - checkpoint['best_test_prec1_t'] 26 | else: 27 | best_prec1 = 100.0 - checkpoint['best_test_prec1'] 28 | best_prec1_val = 100.0 - checkpoint['best_prec1'] 29 | print('Test Error: ',best_prec1) 30 | print('Val. Error: ',best_prec1_val) 31 | best_prec1s.append(best_prec1) 32 | 33 | fname_acc = os.path.join(fdir,'accuracy%d.png'%i) 34 | fname_lr = os.path.join(fdir,'lr%d.png'%i) 35 | fname_loss = os.path.join(fdir,'losses%d.png'%i) 36 | acc1_tr = checkpoint['acc1_tr'] 37 | acc1_val = checkpoint['acc1_val'] 38 | acc1_te = checkpoint['acc1_test'] 39 | losses_tr = checkpoint['losses_tr'] 40 | losses_val = checkpoint['losses_val'] 41 | losses_te = checkpoint['losses_test'] 42 | weights_cl = checkpoint['weights_cl'] 43 | learning_rate = checkpoint['learning_rate'] 44 | losses_cl_tr = [] 45 | if 'losses_cl_tr' in checkpoint: 46 | losses_cl_tr = checkpoint['losses_cl_tr'] 47 | 48 | if(args.plot): 49 | import matplotlib.pyplot as plt 50 | fig = plt.figure() 51 | ax = plt.subplot(1,1,1) 52 | ax.plot(acc1_tr, label='train_acc1') 53 | ax.plot(acc1_val, label='val_acc1') 54 | ax.plot(acc1_te, label='test_acc1') 55 | ax.legend() 56 | ax.grid(linestyle='--') 57 | plt.savefig(fname_acc) 58 | #plt.show() 59 | plt.clf() 60 | 61 | fig = plt.figure() 62 | ax = plt.subplot(2,1,1) 63 | ax.plot(learning_rate, label='lr') 64 | ax.legend() 65 | ax.grid(linestyle='--') 66 | ax = plt.subplot(2,1,2) 67 | ax.plot(weights_cl, label='w_cl') 68 | ax.legend() 69 | ax.grid(linestyle='--') 70 | plt.savefig(fname_lr) 71 | #plt.show() 72 | plt.clf() 73 | 74 | 75 | fig = plt.figure() 76 | ax = plt.subplot(2,1,1) 77 | ax.plot(losses_tr, label='train_loss') 78 | ax.plot(losses_val, label='val_loss') 79 | ax.plot(losses_te, label='test_loss') 80 | ax.legend() 81 | ax.grid(linestyle='--') 82 | ax = plt.subplot(2,1,2) 83 | ax.plot(losses_cl_tr, label='train_loss_cl') 84 | ax.legend() 85 | ax.grid(linestyle='--') 86 | plt.savefig(fname_loss) 87 | #plt.show() 88 | plt.clf() 89 | 90 | 91 | #plt.show() 92 | best_prec1s = np.array(best_prec1s) 93 | bmean = np.around(np.mean(best_prec1s), decimals=2) 94 | bstd = np.around(np.std(best_prec1s), decimals=2) 95 | print('Best error rate: %.2f(%.2f)'%(bmean,bstd)) 96 | #print('Best precision: ',bmean,'(',bstd,')') 97 | 98 | #for key, val in checkpoint.iteritems(): 99 | # print(key) 100 | 101 | 102 | -------------------------------------------------------------------------------- /command.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | 4 | # 5 | #python train.py -a=wideresnet -m=baseline -o=adam -b=225 --dataset=cifar10 --ckpt=ckpt --gpu=0,1 --lr=0.003 --boundary=0 --epochs=1200 && 6 | #python train.py -a=wideresnet -m=pi -o=adam -b=225 --dataset=cifar10 --ckpt=ckpt --gpu=0,1 --lr=0.0003 --boundary=0 --epochs=1200 && 7 | #python train.py -a=wideresnet -m=mt -o=adam -b=225 --dataset=cifar10 --ckpt=ckpt --gpu=0,1 --lr=0.0004 --boundary=0 --epochs=1200 && 8 | #python train.py -a=wideresnet -m=baseline -o=adam -b=225 --dataset=cifar10_zca --ckpt=ckpt --gpu=0,1 --lr=0.003 --boundary=0 --epochs=1200 && 9 | #python train.py -a=wideresnet -m=pi -o=adam -b=225 --dataset=cifar10_zca --ckpt=ckpt --gpu=0,1 --lr=0.0003 --boundary=0 --epochs=1200 && 10 | #python train.py -a=wideresnet -m=mt -o=adam -b=225 --dataset=cifar10_zca --ckpt=ckpt --gpu=0,1 --lr=0.0004 --boundary=0 --epochs=1200 && 11 | #python train.py -a=wideresnet -m=baseline -o=adam -b=198 --dataset=svhn --ckpt=ckpt --gpu=0,1 --lr=0.003 --boundary=0 --epochs=1200 && 12 | #python train.py -a=wideresnet -m=pi -o=adam -b=198 --dataset=svhn --ckpt=ckpt --gpu=0,1 --lr=0.0003 --boundary=0 --epochs=1200 && 13 | #python train.py -a=wideresnet -m=mt -o=adam -b=198 --dataset=svhn --ckpt=ckpt --gpu=0,1 --lr=0.0004 --boundary=0 --epochs=1200 && 14 | 15 | ls 16 | -------------------------------------------------------------------------------- /loader_cifar.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from PIL import Image 3 | import os 4 | import os.path 5 | import errno 6 | import numpy as np 7 | import sys 8 | if sys.version_info[0] == 2: 9 | import cPickle as pickle 10 | else: 11 | import pickle 12 | 13 | import torch.utils.data as data 14 | from torchvision.datasets.utils import download_url, check_integrity 15 | 16 | 17 | class CIFAR10(data.Dataset): 18 | """`CIFAR10 `_ Dataset. 19 | Args: 20 | root (string): Root directory of dataset where directory 21 | ``cifar-10-batches-py`` exists. 22 | train (bool, optional): If True, creates dataset from training set, otherwise 23 | creates from test set. 24 | transform (callable, optional): A function/transform that takes in an PIL image 25 | and returns a transformed version. E.g, ``transforms.RandomCrop`` 26 | target_transform (callable, optional): A function/transform that takes in the 27 | target and transforms it. 28 | download (bool, optional): If true, downloads the dataset from the internet and 29 | puts it in root directory. If dataset is already downloaded, it is not 30 | downloaded again. 31 | """ 32 | base_folder = 'cifar-10-batches-py' 33 | url = "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz" 34 | filename = "cifar-10-python.tar.gz" 35 | tgz_md5 = 'c58f30108f718f92721af3b95e74349a' 36 | train_list = [ 37 | ['data_batch_1', 'c99cafc152244af753f735de768cd75f'], 38 | ['data_batch_2', 'd4bba439e000b95fd0a9bffe97cbabec'], 39 | ['data_batch_3', '54ebc095f3ab1f0389bbae665268c751'], 40 | ['data_batch_4', '634d18415352ddfa80567beed471001a'], 41 | ['data_batch_5', '482c414d41f54cd18b22e5b47cb7c3cb'], 42 | ] 43 | 44 | test_list = [ 45 | ['test_batch', '40351d587109b95175f43aff81a1287e'], 46 | ] 47 | nclass = 10 48 | split_list = ['label', 'unlabel', 'valid', 'test'] 49 | 50 | def __init__(self, root, split='train', 51 | transform=None, target_transform=None, 52 | download=False, boundary=0): 53 | self.root = os.path.expanduser(root) 54 | self.transform = transform 55 | self.target_transform = target_transform 56 | self.split = split 57 | assert(boundary<10) 58 | print("Boundary: ", boundary) 59 | if self.split not in self.split_list: 60 | raise ValueError('Wrong split entered! Please use split="train" ' 61 | 'or split="extra" or split="test"') 62 | 63 | if download: 64 | self.download() 65 | 66 | if not self._check_integrity(): 67 | raise RuntimeError('Dataset not found or corrupted.' + 68 | ' You can use download=True to download it') 69 | 70 | # now load the picked numpy arrays 71 | if self.split is 'label' or self.split is 'unlabel' or self.split is 'valid': 72 | self.train_data = [] 73 | self.train_labels = [] 74 | for fentry in self.train_list: 75 | f = fentry[0] 76 | file = os.path.join(self.root, self.base_folder, f) 77 | fo = open(file, 'rb') 78 | if sys.version_info[0] == 2: 79 | entry = pickle.load(fo) 80 | else: 81 | entry = pickle.load(fo, encoding='latin1') 82 | self.train_data.append(entry['data']) 83 | if 'labels' in entry: 84 | self.train_labels += entry['labels'] 85 | else: 86 | self.train_labels += entry['fine_labels'] 87 | fo.close() 88 | 89 | self.train_data = np.concatenate(self.train_data) 90 | if boundary is not 0: 91 | bidx = 5000 * boundary 92 | self.train_data = [self.train_data[bidx:],self.train_data[:bidx]] 93 | self.train_data = np.concatenate(self.train_data) 94 | self.train_labels = [self.train_labels[bidx:],self.train_labels[:bidx]] 95 | self.train_labels = np.concatenate(self.train_labels) 96 | 97 | train_datau = [] 98 | train_labelsu = [] 99 | train_data1 = [] 100 | train_labels1 = [] 101 | valid_data1 = [] 102 | valid_labels1 = [] 103 | num_labels_valid = [0 for _ in range(self.nclass)] 104 | num_labels_train = [0 for _ in range(self.nclass)] 105 | for i in range(self.train_data.shape[0]): 106 | tmp_label = self.train_labels[i] 107 | if num_labels_valid[tmp_label] < 500: 108 | valid_data1.append(self.train_data[i]) 109 | valid_labels1.append(self.train_labels[i]) 110 | num_labels_valid[tmp_label] += 1 111 | elif num_labels_train[tmp_label] < 400: 112 | train_data1.append(self.train_data[i]) 113 | train_labels1.append(self.train_labels[i]) 114 | num_labels_train[tmp_label] += 1 115 | 116 | #train_datau.append(self.train_data[i]) 117 | #train_labelsu.append(self.train_labels[i]) 118 | else: 119 | train_datau.append(self.train_data[i]) 120 | train_labelsu.append(self.train_labels[i]) 121 | 122 | if self.split is 'label': 123 | self.train_data = train_data1 124 | self.train_labels = train_labels1 125 | 126 | self.train_data = np.concatenate(self.train_data) 127 | self.train_data = self.train_data.reshape((len(train_data1), 3, 32, 32)) 128 | self.train_data = self.train_data.transpose((0, 2, 3, 1)) # convert to HWC 129 | 130 | num_tr = self.train_data.shape[0] 131 | #print(self.train_data1[:1,:1,:5,:5]) 132 | #print(self.train_labels1[:10]) 133 | #print(self.train_data[:1,:1,:5,:5]) 134 | #print(self.train_labels[:10]) 135 | print('Label: ',num_tr) #label 136 | 137 | #self.midx=0 138 | #self.idx_offset = num_tr_ul - (num_tr_ul//num_tr) * num_tr 139 | #print('Offset: :',self.idx_offset) 140 | 141 | elif self.split is 'unlabel': 142 | self.train_data_ul = train_datau 143 | self.train_labels_ul = train_labelsu 144 | 145 | self.train_data_ul = np.concatenate(self.train_data_ul) 146 | self.train_data_ul = self.train_data_ul.reshape((len(train_datau), 3, 32, 32)) 147 | self.train_data_ul = self.train_data_ul.transpose((0, 2, 3, 1)) # convert to HWC 148 | 149 | num_tr_ul = self.train_data_ul.shape[0] 150 | print('Unlabel: ',num_tr_ul) #unlabel 151 | 152 | elif self.split is 'valid': 153 | self.valid_data = valid_data1 154 | self.valid_labels = valid_labels1 155 | 156 | self.valid_data = np.concatenate(self.valid_data) 157 | self.valid_data = self.valid_data.reshape((len(valid_data1), 3, 32, 32)) 158 | self.valid_data = self.valid_data.transpose((0, 2, 3, 1)) # convert to HWC 159 | 160 | num_val = self.valid_data.shape[0] 161 | print('Valid: ',num_val) #valid 162 | #print(self.valid_data[:1,:1,:5,:5]) 163 | #print(self.valid_labels[:10]) 164 | 165 | elif self.split is 'test': 166 | f = self.test_list[0][0] 167 | file = os.path.join(self.root, self.base_folder, f) 168 | fo = open(file, 'rb') 169 | if sys.version_info[0] == 2: 170 | entry = pickle.load(fo) 171 | else: 172 | entry = pickle.load(fo, encoding='latin1') 173 | self.test_data = entry['data'] 174 | if 'labels' in entry: 175 | self.test_labels = entry['labels'] 176 | else: 177 | self.test_labels = entry['fine_labels'] 178 | fo.close() 179 | self.test_data = self.test_data.reshape((10000, 3, 32, 32)) 180 | self.test_data = self.test_data.transpose((0, 2, 3, 1)) # convert to HWC 181 | 182 | def __getitem__(self, index): 183 | """ 184 | Args: 185 | index (int): Index 186 | Returns: 187 | tuple: (image, target) where target is index of the target class. 188 | """ 189 | if self.split is 'label': 190 | img, target = self.train_data[index], self.train_labels[index] 191 | elif self.split is 'unlabel': 192 | img, target = self.train_data_ul[index], self.train_labels_ul[index] 193 | elif self.split is 'valid': 194 | img, target = self.valid_data[index], self.valid_labels[index] 195 | elif self.split is 'test': 196 | img, target = self.test_data[index], self.test_labels[index] 197 | 198 | # doing this so that it is consistent with all other datasets 199 | # to return a PIL Image 200 | img1 = np.copy(img) 201 | img = Image.fromarray(img) 202 | img1 = Image.fromarray(img1) 203 | 204 | if self.transform is not None: 205 | img = self.transform(img) 206 | img1 = self.transform(img1) 207 | 208 | if self.target_transform is not None: 209 | target = self.target_transform(target) 210 | 211 | return img, target, img1 212 | 213 | def __len__(self): 214 | if self.split is 'label': 215 | return len(self.train_data) 216 | elif self.split is 'unlabel': 217 | return len(self.train_data_ul) 218 | elif self.split is 'valid': 219 | return len(self.valid_data) 220 | elif self.split is 'test': 221 | return len(self.test_data) 222 | 223 | def _check_integrity(self): 224 | root = self.root 225 | for fentry in (self.train_list + self.test_list): 226 | filename, md5 = fentry[0], fentry[1] 227 | fpath = os.path.join(root, self.base_folder, filename) 228 | if not check_integrity(fpath, md5): 229 | return False 230 | return True 231 | 232 | def download(self): 233 | import tarfile 234 | 235 | if self._check_integrity(): 236 | print('Files already downloaded and verified') 237 | return 238 | 239 | root = self.root 240 | download_url(self.url, root, self.filename, self.tgz_md5) 241 | 242 | # extract file 243 | cwd = os.getcwd() 244 | tar = tarfile.open(os.path.join(root, self.filename), "r:gz") 245 | os.chdir(root) 246 | tar.extractall() 247 | tar.close() 248 | os.chdir(cwd) 249 | 250 | def __repr__(self): 251 | fmt_str = 'Dataset ' + self.__class__.__name__ + '\n' 252 | fmt_str += ' Number of datapoints: {}\n'.format(self.__len__()) 253 | tmp = 'train' if self.train is True else 'test' 254 | fmt_str += ' Split: {}\n'.format(tmp) 255 | fmt_str += ' Root Location: {}\n'.format(self.root) 256 | tmp = ' Transforms (if any): ' 257 | fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) 258 | tmp = ' Target Transforms (if any): ' 259 | fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) 260 | return fmt_str 261 | 262 | 263 | class CIFAR100(CIFAR10): 264 | """`CIFAR100 `_ Dataset. 265 | This is a subclass of the `CIFAR10` Dataset. 266 | """ 267 | base_folder = 'cifar-100-python' 268 | url = "https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz" 269 | filename = "cifar-100-python.tar.gz" 270 | tgz_md5 = 'eb9058c3a382ffc7106e4002c42a8d85' 271 | train_list = [ 272 | ['train', '16019d7e3df5f24257cddd939b257f8d'], 273 | ] 274 | 275 | test_list = [ 276 | ['test', 'f0ef6b0ae62326f3e7ffdfab6717acfc'], 277 | ] 278 | nclass = 100 279 | 280 | 281 | if __name__ == '__main__': 282 | 283 | ''' 284 | for i in range(10): 285 | print("Boundary %d///////////////////////////////////////"%i) 286 | data_train = CIFAR10('/tmp', split='label', download=True, transform=None, boundary=i) 287 | data_train_ul = CIFAR10('/tmp', split='unlabel', download=True, transform=None, boundary=i) 288 | data_valid = CIFAR10('/tmp', split='valid', download=True, transform=None, boundary=i) 289 | data_test = CIFAR10('/tmp', split='test', download=True, transform=None, boundary=i) 290 | 291 | print("Number of data") 292 | print(len(data_train)) 293 | print(len(data_train_ul)) 294 | print(len(data_valid)) 295 | print(len(data_test)) 296 | ''' 297 | 298 | import torch.utils.data as data 299 | from math import ceil 300 | 301 | batch_size = 230 302 | 303 | labelset = CIFAR10('/tmp', split='label', download=True, transform=None, boundary=0) 304 | unlabelset = CIFAR10('/tmp', split='unlabel', download=True, transform=None, boundary=0) 305 | 306 | for i in range(100,256): 307 | batch_size = i 308 | label_size = len(labelset) 309 | unlabel_size = len(unlabelset) 310 | iter_per_epoch = int(ceil(float(label_size + unlabel_size)/batch_size)) 311 | batch_size_label = int(ceil(float(label_size) / iter_per_epoch)) 312 | batch_size_unlabel = int(ceil(float(unlabel_size) / iter_per_epoch)) 313 | iter_label = int(ceil(float(label_size)/batch_size_label)) 314 | iter_unlabel = int(ceil(float(unlabel_size)/batch_size_unlabel)) 315 | if iter_label == iter_unlabel: 316 | print('Batch size: ', batch_size) 317 | print('Iter/epoch: ', iter_per_epoch) 318 | print('Batch size (label): ', batch_size_label) 319 | print('Batch size (unlabel): ', batch_size_unlabel) 320 | print('Iter/epoch (label): ', iter_label) 321 | print('Iter/epoch (unlabel): ', iter_unlabel) 322 | 323 | 324 | label_loader = data.DataLoader(labelset, batch_size=batch_size_label, shuffle=True) 325 | label_iter = iter(label_loader) 326 | 327 | unlabel_loader = data.DataLoader(unlabelset, batch_size=batch_size_unlabel, shuffle=True) 328 | unlabel_iter = iter(unlabel_loader) 329 | 330 | print(len(label_iter)) 331 | print(len(unlabel_iter)) 332 | 333 | 334 | 335 | 336 | -------------------------------------------------------------------------------- /loader_cifar_zca.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from PIL import Image 3 | import os 4 | import os.path 5 | import errno 6 | import numpy as np 7 | import sys 8 | if sys.version_info[0] == 2: 9 | import cPickle as pickle 10 | else: 11 | import pickle 12 | 13 | import torch.utils.data as data 14 | from torchvision.datasets.utils import download_url, check_integrity 15 | import torch 16 | import random 17 | 18 | class CIFAR10(data.Dataset): 19 | """`CIFAR10 `_ Dataset. 20 | Args: 21 | root (string): Root directory of dataset where directory 22 | ``cifar-10-batches-py`` exists. 23 | train (bool, optional): If True, creates dataset from training set, otherwise 24 | creates from test set. 25 | transform (callable, optional): A function/transform that takes in an PIL image 26 | and returns a transformed version. E.g, ``transforms.RandomCrop`` 27 | target_transform (callable, optional): A function/transform that takes in the 28 | target and transforms it. 29 | download (bool, optional): If true, downloads the dataset from the internet and 30 | puts it in root directory. If dataset is already downloaded, it is not 31 | downloaded again. 32 | """ 33 | 34 | #data_file = 'cifar10_zca/cifar10_gcn_zca_v2.npz' 35 | nclass = 10 36 | split_list = ['label', 'unlabel', 'valid', 'test'] 37 | 38 | def __init__(self, root, split='train', 39 | transform=None, target_transform=None, 40 | download=False, boundary=0): 41 | self.root = os.path.expanduser(root) 42 | self.transform = transform 43 | self.target_transform = target_transform 44 | self.split = split 45 | assert(boundary<10) 46 | print("Boundary: ", boundary) 47 | if self.split not in self.split_list: 48 | raise ValueError('Wrong split entered! Please use split="train" ' 49 | 'or split="extra" or split="test"') 50 | 51 | # load data 52 | self.data = np.load(root) 53 | #self.data = np.load(self.data_file) 54 | #self.train_data_zca = self.data['train_x'].transpose(0,3,1,2) 55 | #self.train_labels_zca = self.data['train_y'] 56 | #self.test_data_zca = self.data['test_x'].transpose(0,3,1,2) 57 | #self.test_labels_zca = self.data['test_y'] 58 | 59 | # now load the picked numpy arrays 60 | if self.split is 'label' or self.split is 'unlabel' or self.split is 'valid': 61 | 62 | self.train_data = self.data['train_x'].astype(np.float32).transpose(0,3,1,2) 63 | #self.train_data = np.concatenate(self.train_data) 64 | self.train_labels = self.data['train_y'].astype(int) 65 | print(self.train_data.shape) 66 | print(self.train_labels.shape) 67 | if boundary is not 0: 68 | bidx = 5000 * boundary 69 | self.train_data = [self.train_data[bidx:],self.train_data[:bidx]] 70 | self.train_data = np.concatenate(self.train_data) 71 | self.train_labels = [self.train_labels[bidx:],self.train_labels[:bidx]] 72 | self.train_labels = np.concatenate(self.train_labels) 73 | 74 | train_datau = [] 75 | train_labelsu = [] 76 | train_data1 = [] 77 | train_labels1 = [] 78 | valid_data1 = [] 79 | valid_labels1 = [] 80 | num_labels_valid = [0 for _ in range(self.nclass)] 81 | num_labels_train = [0 for _ in range(self.nclass)] 82 | for i in range(self.train_data.shape[0]): 83 | tmp_label = self.train_labels[i] 84 | if num_labels_valid[tmp_label] < 500: 85 | valid_data1.append(self.train_data[i]) 86 | valid_labels1.append(self.train_labels[i]) 87 | num_labels_valid[tmp_label] += 1 88 | elif num_labels_train[tmp_label] < 400: 89 | train_data1.append(self.train_data[i]) 90 | train_labels1.append(self.train_labels[i]) 91 | num_labels_train[tmp_label] += 1 92 | 93 | #train_datau.append(self.train_data[i]) 94 | #train_labelsu.append(self.train_labels[i]) 95 | else: 96 | train_datau.append(self.train_data[i]) 97 | train_labelsu.append(self.train_labels[i]) 98 | 99 | if self.split is 'label': 100 | self.train_data = train_data1 101 | self.train_labels = train_labels1 102 | 103 | self.train_data = np.concatenate(self.train_data) 104 | self.train_data = self.train_data.reshape((len(train_data1), 3, 32, 32)) 105 | self.train_data = self.train_data.transpose((0, 2, 3, 1)) # convert to HWC 106 | 107 | num_tr = self.train_data.shape[0] 108 | #print(self.train_data1[:1,:1,:5,:5]) 109 | #print(self.train_labels1[:10]) 110 | #print(self.train_data[:1,:1,:5,:5]) 111 | #print(self.train_labels[:10]) 112 | print('Label: ',num_tr) #label 113 | 114 | #self.midx=0 115 | #self.idx_offset = num_tr_ul - (num_tr_ul//num_tr) * num_tr 116 | #print('Offset: :',self.idx_offset) 117 | 118 | elif self.split is 'unlabel': 119 | self.train_data_ul = train_datau 120 | self.train_labels_ul = train_labelsu 121 | 122 | self.train_data_ul = np.concatenate(self.train_data_ul) 123 | self.train_data_ul = self.train_data_ul.reshape((len(train_datau), 3, 32, 32)) 124 | self.train_data_ul = self.train_data_ul.transpose((0, 2, 3, 1)) # convert to HWC 125 | 126 | num_tr_ul = self.train_data_ul.shape[0] 127 | print('Unlabel: ',num_tr_ul) #unlabel 128 | 129 | elif self.split is 'valid': 130 | self.valid_data = valid_data1 131 | self.valid_labels = valid_labels1 132 | 133 | self.valid_data = np.concatenate(self.valid_data) 134 | self.valid_data = self.valid_data.reshape((len(valid_data1), 3, 32, 32)) 135 | self.valid_data = self.valid_data.transpose((0, 2, 3, 1)) # convert to HWC 136 | 137 | num_val = self.valid_data.shape[0] 138 | print('Valid: ',num_val) #valid 139 | #print(self.valid_data[:1,:1,:5,:5]) 140 | #print(self.valid_labels[:10]) 141 | 142 | elif self.split is 'test': 143 | #self.test_data = self.test_data.reshape((10000, 3, 32, 32)) 144 | #self.test_data = self.test_data.transpose((0, 2, 3, 1)) # convert to HWC 145 | self.test_data = self.data['test_x'].astype(np.float32) 146 | self.test_labels = self.data['test_y'].astype(int) 147 | 148 | def __getitem__(self, index): 149 | """ 150 | Args: 151 | index (int): Index 152 | Returns: 153 | tuple: (image, target) where target is index of the target class. 154 | """ 155 | if self.split is 'label': 156 | img, target = self.train_data[index], self.train_labels[index] 157 | elif self.split is 'unlabel': 158 | img, target = self.train_data_ul[index], self.train_labels_ul[index] 159 | elif self.split is 'valid': 160 | img, target = self.valid_data[index], self.valid_labels[index] 161 | elif self.split is 'test': 162 | img, target = self.test_data[index], self.test_labels[index] 163 | 164 | # doing this so that it is consistent with all other datasets 165 | # to return a PIL Image 166 | #img = Image.fromarray(img) 167 | img1 = np.copy(img) 168 | #img1 = Image.fromarray(img1) 169 | if self.split is 'label' or self.split is 'unlabel': 170 | img = random_crop(img, 32, padding=2) 171 | img = horizontal_flip(img, 0.5) 172 | img = img.copy() 173 | img = torch.from_numpy(img) 174 | img = img + torch.randn_like(img) * 0.15 175 | img = img.permute(2,0,1) 176 | #img = self.transform(img) 177 | 178 | img1 = random_crop(img1, 32, padding=2) 179 | img1 = horizontal_flip(img1, 0.5) 180 | img1 = img1.copy() 181 | img1 = torch.from_numpy(img1) 182 | img1 = img1 + torch.randn_like(img1) * 0.15 183 | img1 = img1.permute(2,0,1) 184 | #img1 = self.transform(img1) 185 | else: 186 | img = torch.from_numpy(img) 187 | img = img.permute(2,0,1) 188 | 189 | img1 = torch.from_numpy(img1) 190 | img1 = img1.permute(2,0,1) 191 | 192 | if self.target_transform is not None: 193 | target = self.target_transform(target) 194 | 195 | return img, target, img1 196 | 197 | def __len__(self): 198 | if self.split is 'label': 199 | return len(self.train_data) 200 | elif self.split is 'unlabel': 201 | return len(self.train_data_ul) 202 | elif self.split is 'valid': 203 | return len(self.valid_data) 204 | elif self.split is 'test': 205 | return len(self.test_data) 206 | 207 | def _check_integrity(self): 208 | root = self.root 209 | for fentry in (self.train_list + self.test_list): 210 | filename, md5 = fentry[0], fentry[1] 211 | fpath = os.path.join(root, self.base_folder, filename) 212 | if not check_integrity(fpath, md5): 213 | return False 214 | return True 215 | 216 | def download(self): 217 | import tarfile 218 | 219 | if self._check_integrity(): 220 | print('Files already downloaded and verified') 221 | return 222 | 223 | root = self.root 224 | download_url(self.url, root, self.filename, self.tgz_md5) 225 | 226 | # extract file 227 | cwd = os.getcwd() 228 | tar = tarfile.open(os.path.join(root, self.filename), "r:gz") 229 | os.chdir(root) 230 | tar.extractall() 231 | tar.close() 232 | os.chdir(cwd) 233 | 234 | def __repr__(self): 235 | fmt_str = 'Dataset ' + self.__class__.__name__ + '\n' 236 | fmt_str += ' Number of datapoints: {}\n'.format(self.__len__()) 237 | tmp = 'train' if self.train is True else 'test' 238 | fmt_str += ' Split: {}\n'.format(tmp) 239 | fmt_str += ' Root Location: {}\n'.format(self.root) 240 | tmp = ' Transforms (if any): ' 241 | fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) 242 | tmp = ' Target Transforms (if any): ' 243 | fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) 244 | return fmt_str 245 | 246 | 247 | def horizontal_flip(image, rate=0.5): 248 | if random.random() < rate: 249 | #image = np.flip(image,1).copy() 250 | image = image[:, ::-1, :] 251 | return image 252 | 253 | def random_crop(image, crop_size, padding=4): 254 | crop_size = check_size(crop_size) 255 | image = np.pad(image,((padding,padding),(padding,padding),(0,0)),'constant',constant_values=0) 256 | h, w, _ = image.shape 257 | top = random.randrange(0, h - crop_size[0]) 258 | left = random.randrange(0, w - crop_size[1]) 259 | bottom = top + crop_size[0] 260 | right = left + crop_size[1] 261 | image = image[top:bottom, left:right, :] 262 | return image 263 | 264 | def check_size(size): 265 | if type(size) == int: 266 | size = (size, size) 267 | if type(size) != tuple: 268 | raise TypeError('size is int or tuple') 269 | return size 270 | 271 | if __name__ == '__main__': 272 | 273 | ''' 274 | for i in range(10): 275 | print("Boundary %d///////////////////////////////////////"%i) 276 | data_train = CIFAR10('/tmp', split='label', download=True, transform=None, boundary=i) 277 | data_train_ul = CIFAR10('/tmp', split='unlabel', download=True, transform=None, boundary=i) 278 | data_valid = CIFAR10('/tmp', split='valid', download=True, transform=None, boundary=i) 279 | data_test = CIFAR10('/tmp', split='test', download=True, transform=None, boundary=i) 280 | 281 | print("Number of data") 282 | print(len(data_train)) 283 | print(len(data_train_ul)) 284 | print(len(data_valid)) 285 | print(len(data_test)) 286 | ''' 287 | 288 | import torch.utils.data as data 289 | from math import ceil 290 | 291 | batch_size = 230 292 | 293 | labelset = CIFAR10('/tmp', split='label', download=True, transform=None, boundary=0) 294 | unlabelset = CIFAR10('/tmp', split='unlabel', download=True, transform=None, boundary=0) 295 | 296 | for i in range(90,256): 297 | batch_size = i 298 | label_size = len(labelset) 299 | unlabel_size = len(unlabelset) 300 | iter_per_epoch = int(ceil(float(label_size + unlabel_size)/batch_size)) 301 | batch_size_label = int(ceil(float(label_size) / iter_per_epoch)) 302 | batch_size_unlabel = int(ceil(float(unlabel_size) / iter_per_epoch)) 303 | iter_label = int(ceil(float(label_size)/batch_size_label)) 304 | iter_unlabel = int(ceil(float(unlabel_size)/batch_size_unlabel)) 305 | if iter_label == iter_unlabel: 306 | print('Batch size: ', batch_size) 307 | print('Iter/epoch: ', iter_per_epoch) 308 | print('Batch size (label): ', batch_size_label) 309 | print('Batch size (unlabel): ', batch_size_unlabel) 310 | print('Iter/epoch (label): ', iter_label) 311 | print('Iter/epoch (unlabel): ', iter_unlabel) 312 | 313 | 314 | label_loader = data.DataLoader(labelset, batch_size=batch_size_label, shuffle=True) 315 | label_iter = iter(label_loader) 316 | 317 | unlabel_loader = data.DataLoader(unlabelset, batch_size=batch_size_unlabel, shuffle=True) 318 | unlabel_iter = iter(unlabel_loader) 319 | 320 | print(len(label_iter)) 321 | print(len(unlabel_iter)) 322 | 323 | 324 | 325 | 326 | -------------------------------------------------------------------------------- /loader_svhn.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from PIL import Image 3 | import os 4 | import os.path 5 | import errno 6 | import numpy as np 7 | import sys 8 | if sys.version_info[0] == 2: 9 | import cPickle as pickle 10 | else: 11 | import pickle 12 | 13 | import torch.utils.data as data 14 | from torchvision.datasets.utils import download_url, check_integrity 15 | 16 | 17 | class SVHN(data.Dataset): 18 | """`SVHN `_ Dataset. 19 | Note: The SVHN dataset assigns the label `10` to the digit `0`. However, in this Dataset, 20 | we assign the label `0` to the digit `0` to be compatible with PyTorch loss functions which 21 | expect the class labels to be in the range `[0, C-1]` 22 | 23 | Args: 24 | root (string): Root directory of dataset where directory 25 | ``SVHN`` exists. 26 | split (string): One of {'train', 'test', 'extra'}. 27 | Accordingly dataset is selected. 'extra' is Extra training set. 28 | transform (callable, optional): A function/transform that takes in an PIL image 29 | and returns a transformed version. E.g, ``transforms.RandomCrop`` 30 | target_transform (callable, optional): A function/transform that takes in the 31 | target and transforms it. 32 | download (bool, optional): If true, downloads the dataset from the internet and 33 | puts it in root directory. If dataset is already downloaded, it is not 34 | downloaded again. 35 | 36 | """ 37 | url = "" 38 | filename = "" 39 | file_md5 = "" 40 | 41 | split_list = { 42 | 'label': ["http://ufldl.stanford.edu/housenumbers/train_32x32.mat", 43 | "train_32x32.mat", "e26dedcc434d2e4c54c9b2d4a06d8373"], 44 | 'unlabel': ["http://ufldl.stanford.edu/housenumbers/train_32x32.mat", 45 | "train_32x32.mat", "e26dedcc434d2e4c54c9b2d4a06d8373"], 46 | 'valid': ["http://ufldl.stanford.edu/housenumbers/train_32x32.mat", 47 | "train_32x32.mat", "e26dedcc434d2e4c54c9b2d4a06d8373"], 48 | 'test': ["http://ufldl.stanford.edu/housenumbers/test_32x32.mat", 49 | "test_32x32.mat", "eb5a983be6a315427106f1b164d9cef3"], 50 | 'extra': ["http://ufldl.stanford.edu/housenumbers/extra_32x32.mat", 51 | "extra_32x32.mat", "a93ce644f1a588dc4d68dda5feec44a7"]} 52 | 53 | def __init__(self, root, split='label', 54 | transform=None, target_transform=None, download=False, boundary=0): 55 | self.root = os.path.expanduser(root) 56 | self.transform = transform 57 | self.target_transform = target_transform 58 | self.split = split # training set or test set or extra set 59 | 60 | if self.split not in self.split_list: 61 | raise ValueError('Wrong split entered! Please use split="train" ' 62 | 'or split="extra" or split="test"') 63 | 64 | self.url = self.split_list[split][0] 65 | self.filename = self.split_list[split][1] 66 | self.file_md5 = self.split_list[split][2] 67 | assert(boundary<10) 68 | print('Boundary: ', boundary) 69 | 70 | if download: 71 | self.download() 72 | 73 | if not self._check_integrity(): 74 | raise RuntimeError('Dataset not found or corrupted.' + 75 | ' You can use download=True to download it') 76 | 77 | # import here rather than at top of file because this is 78 | # an optional dependency for torchvision 79 | import scipy.io as sio 80 | 81 | # reading(loading) mat file as array 82 | loaded_mat = sio.loadmat(os.path.join(self.root, self.filename)) 83 | 84 | self.train_data = loaded_mat['X'] 85 | # loading from the .mat file gives an np array of type np.uint8 86 | # converting to np.int64, so that we have a LongTensor after 87 | # the conversion from the numpy array 88 | # the squeeze is needed to obtain a 1D tensor 89 | self.train_labels = loaded_mat['y'].astype(np.int64).squeeze() 90 | 91 | # the svhn dataset assigns the class label "10" to the digit 0 92 | # this makes it inconsistent with several loss functions 93 | # which expect the class labels to be in the range [0, C-1] 94 | np.place(self.train_labels, self.train_labels == 10, 0) 95 | self.train_data = np.transpose(self.train_data, (3, 2, 0, 1)) 96 | 97 | if self.split is 'label' or self.split is 'unlabel' or self.split is 'valid': 98 | if boundary is not 0: 99 | bidx = 7000 * boundary 100 | self.train_data = [self.train_data[bidx:], self.train_data[:bidx]] 101 | self.train_data = np.concatenate(self.train_data) 102 | self.train_labels = [self.train_labels[bidx:], self.train_labels[:bidx]] 103 | self.train_labels = np.concatenate(self.train_labels) 104 | 105 | print(self.split) 106 | train_datau = [] 107 | train_labelsu = [] 108 | train_data1 = [] 109 | train_labels1 = [] 110 | valid_data1 = [] 111 | valid_labels1 = [] 112 | num_labels_train = [0 for _ in range(10)] 113 | num_labels_valid = [0 for _ in range(10)] 114 | 115 | for i in range(self.train_data.shape[0]): 116 | tmp_label = self.train_labels[i] 117 | if num_labels_valid[tmp_label] < 732: 118 | valid_data1.append(self.train_data[i]) 119 | valid_labels1.append(self.train_labels[i]) 120 | num_labels_valid[tmp_label] += 1 121 | elif num_labels_train[tmp_label] < 100: 122 | train_data1.append(self.train_data[i]) 123 | train_labels1.append(self.train_labels[i]) 124 | num_labels_train[tmp_label] += 1 125 | 126 | #train_datau.append(self.train_data[i]) 127 | #train_labelsu.append(self.train_labels[i]) 128 | else: 129 | train_datau.append(self.train_data[i]) 130 | train_labelsu.append(self.train_labels[i]) 131 | 132 | if self.split is 'label': 133 | self.train_data = train_data1 134 | self.train_labels = train_labels1 135 | 136 | self.train_data = np.concatenate(self.train_data) 137 | self.train_data = self.train_data.reshape((len(train_data1), 3, 32, 32)) 138 | #self.train_data = self.train_data.transpose((0, 2, 3, 1)) # convert to HWC 139 | 140 | num_tr = self.train_data.shape[0] 141 | print('Label: ',num_tr) #label 142 | 143 | elif self.split is 'unlabel': 144 | self.train_data_ul = train_datau 145 | self.train_labels_ul = train_labelsu 146 | 147 | self.train_data_ul = np.concatenate(self.train_data_ul) 148 | self.train_data_ul = self.train_data_ul.reshape((len(train_datau), 3, 32, 32)) 149 | #self.train_data_ul = self.train_data_ul.transpose((0, 2, 3, 1)) # convert to HWC 150 | 151 | num_tr_ul = self.train_data_ul.shape[0] 152 | print('Unlabel: ',num_tr_ul) #unlabel 153 | 154 | elif self.split is 'valid': 155 | self.valid_data = valid_data1 156 | self.valid_labels = valid_labels1 157 | 158 | self.valid_data = np.concatenate(self.valid_data) 159 | self.valid_data = self.valid_data.reshape((len(valid_data1), 3, 32, 32)) 160 | #self.valid_data = self.valid_data.transpose((0, 2, 3, 1)) # convert to HWC 161 | 162 | num_val = self.valid_data.shape[0] 163 | print('Valid: ',num_val) #valid 164 | #print(self.valid_data[:1,:1,:5,:5]) 165 | #print(self.valid_labels[:10]) 166 | 167 | else: 168 | print(self.split) 169 | self.test_data = self.train_data.reshape((len(self.train_data), 3, 32, 32)) 170 | #self.train_data = self.train_data.transpose((0, 2, 3, 1)) # convert to HWC 171 | self.test_labels = self.train_labels 172 | 173 | 174 | 175 | def __getitem__(self, index): 176 | """ 177 | Args: 178 | index (int): Index 179 | 180 | Returns: 181 | tuple: (image, target) where target is index of the target class. 182 | """ 183 | if self.split is 'label': 184 | img, target = self.train_data[index], int(self.train_labels[index]) 185 | elif self.split is 'unlabel': 186 | img, target = self.train_data_ul[index], int(self.train_labels_ul[index]) 187 | elif self.split is 'valid': 188 | img, target = self.valid_data[index], int(self.valid_labels[index]) 189 | elif self.split is 'test': 190 | img, target = self.test_data[index], int(self.test_labels[index]) 191 | 192 | 193 | # doing this so that it is consistent with all other datasets 194 | # to return a PIL Image 195 | img1 = np.copy(img) 196 | img = Image.fromarray(np.transpose(img, (1, 2, 0))) 197 | img1 = Image.fromarray(np.transpose(img1, (1, 2, 0))) 198 | 199 | if self.transform is not None: 200 | img = self.transform(img) 201 | img1 = self.transform(img1) 202 | 203 | if self.target_transform is not None: 204 | target = self.target_transform(target) 205 | 206 | return img, target, img1 207 | 208 | def __len__(self): 209 | if self.split is 'label': 210 | return len(self.train_data) 211 | elif self.split is 'unlabel': 212 | return len(self.train_data_ul) 213 | elif self.split is 'valid': 214 | return len(self.valid_data) 215 | elif self.split is 'test': 216 | return len(self.test_data) 217 | else: 218 | assert(False) 219 | 220 | def _check_integrity(self): 221 | root = self.root 222 | md5 = self.split_list[self.split][2] 223 | fpath = os.path.join(root, self.filename) 224 | return check_integrity(fpath, md5) 225 | 226 | def download(self): 227 | md5 = self.split_list[self.split][2] 228 | download_url(self.url, self.root, self.filename, md5) 229 | 230 | def __repr__(self): 231 | fmt_str = 'Dataset ' + self.__class__.__name__ + '\n' 232 | fmt_str += ' Number of datapoints: {}\n'.format(self.__len__()) 233 | fmt_str += ' Split: {}\n'.format(self.split) 234 | fmt_str += ' Root Location: {}\n'.format(self.root) 235 | tmp = ' Transforms (if any): ' 236 | fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) 237 | tmp = ' Target Transforms (if any): ' 238 | fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) 239 | return fmt_str 240 | 241 | 242 | if __name__ == '__main__': 243 | 244 | ''' 245 | for i in range(10): 246 | print("Boundary %d///////////////////////////////////////"%i) 247 | data_train = SVHN('/tmp', split='label', download=True, transform=None, boundary=i) 248 | data_train_ul = SVHN('/tmp', split='unlabel', download=True, transform=None, boundary=i) 249 | data_valid = SVHN('/tmp', split='valid', download=True, transform=None, boundary=i) 250 | data_test = SVHN('/tmp', split='test', download=True, transform=None, boundary=i) 251 | 252 | print("Number of data") 253 | print(len(data_train)) 254 | print(len(data_train_ul)) 255 | print(len(data_valid)) 256 | print(len(data_test)) 257 | 258 | ''' 259 | import torch.utils.data as data 260 | from math import ceil 261 | 262 | batch_size = 230 263 | 264 | labelset = SVHN('/tmp', split='label', download=True, transform=None, boundary=0) 265 | unlabelset = SVHN('/tmp', split='unlabel', download=True, transform=None, boundary=0) 266 | 267 | for i in range(100,256): 268 | batch_size = i 269 | label_size = len(labelset) 270 | unlabel_size = len(unlabelset) 271 | iter_per_epoch = int(ceil(float(label_size + unlabel_size)/batch_size)) 272 | batch_size_label = int(ceil(float(label_size) / iter_per_epoch)) 273 | batch_size_unlabel = int(ceil(float(unlabel_size) / iter_per_epoch)) 274 | iter_label = int(ceil(float(label_size)/batch_size_label)) 275 | iter_unlabel = int(ceil(float(unlabel_size)/batch_size_unlabel)) 276 | if iter_label == iter_unlabel: 277 | print('Batch size: ', batch_size) 278 | print('Iter/epoch: ', iter_per_epoch) 279 | print('Batch size (label): ', batch_size_label) 280 | print('Batch size (unlabel): ', batch_size_unlabel) 281 | print('Iter/epoch (label): ', iter_label) 282 | print('Iter/epoch (unlabel): ', iter_unlabel) 283 | 284 | 285 | label_loader = data.DataLoader(labelset, batch_size=batch_size_label, shuffle=True) 286 | label_iter = iter(label_loader) 287 | 288 | unlabel_loader = data.DataLoader(unlabelset, batch_size=batch_size_unlabel, shuffle=True) 289 | unlabel_iter = iter(unlabel_loader) 290 | 291 | print(len(label_iter)) 292 | print(len(unlabel_iter)) 293 | 294 | 295 | -------------------------------------------------------------------------------- /methods.py: -------------------------------------------------------------------------------- 1 | import time 2 | import math 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | def train_sup(label_loader, model, criterions, optimizer, epoch, args): 7 | batch_time = AverageMeter() 8 | data_time = AverageMeter() 9 | losses = AverageMeter() 10 | top1 = AverageMeter() 11 | top5 = AverageMeter() 12 | 13 | # switch to train mode 14 | model.train() 15 | 16 | criterion, _, _, criterion_l1 = criterions 17 | 18 | end = time.time() 19 | 20 | label_iter = iter(label_loader) 21 | for i in range(len(label_iter)): 22 | input, target, _ = next(label_iter) 23 | # measure data loading time 24 | data_time.update(time.time() - end) 25 | sl = input.shape 26 | batch_size = sl[0] 27 | target = target.cuda(async=True) 28 | input_var = torch.autograd.Variable(input) 29 | target_var = torch.autograd.Variable(target) 30 | # compute output 31 | output = model(input_var) 32 | 33 | loss_ce = criterion(output, target_var) / float(batch_size) 34 | 35 | reg_l1 = cal_reg_l1(model, criterion_l1) 36 | 37 | loss = loss_ce + args.weight_l1 * reg_l1 38 | 39 | # measure accuracy and record loss 40 | prec1, prec5 = accuracy(output.data, target, topk=(1, 5)) 41 | losses.update(loss_ce.item(), input.size(0)) 42 | top1.update(prec1.item(), input.size(0)) 43 | top5.update(prec5.item(), input.size(0)) 44 | 45 | # compute gradient and do SGD step 46 | optimizer.zero_grad() 47 | loss.backward() 48 | optimizer.step() 49 | 50 | # measure elapsed time 51 | batch_time.update(time.time() - end) 52 | end = time.time() 53 | 54 | if i % args.print_freq == 0: 55 | print('Epoch: [{0}][{1}/{2}]\t' 56 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 57 | 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 58 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 59 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' 60 | 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format( 61 | epoch, i, len(label_iter), batch_time=batch_time, 62 | data_time=data_time, loss=losses, top1=top1, top5=top5)) 63 | 64 | return top1.avg , losses.avg 65 | 66 | def train_pi(label_loader, unlabel_loader, model, criterions, optimizer, epoch, args, weight_pi=20.0): 67 | batch_time = AverageMeter() 68 | data_time = AverageMeter() 69 | losses = AverageMeter() 70 | losses_pi = AverageMeter() 71 | top1 = AverageMeter() 72 | top5 = AverageMeter() 73 | weights_cl = AverageMeter() 74 | 75 | # switch to train mode 76 | model.train() 77 | 78 | criterion, criterion_mse, _, criterion_l1 = criterions 79 | 80 | end = time.time() 81 | 82 | label_iter = iter(label_loader) 83 | unlabel_iter = iter(unlabel_loader) 84 | len_iter = len(unlabel_iter) 85 | for i in range(len_iter): 86 | # set weights for the consistency loss 87 | weight_cl = cal_consistency_weight(epoch*len_iter+i, end_ep=(args.epochs//2)*len_iter, end_w=1.0) 88 | 89 | try: 90 | input, target, input1 = next(label_iter) 91 | except StopIteration: 92 | label_iter = iter(label_loader) 93 | input, target, input1 = next(label_iter) 94 | input_ul, _, input1_ul = next(unlabel_iter) 95 | sl = input.shape 96 | su = input_ul.shape 97 | batch_size = sl[0] + su[0] 98 | # measure data loading time 99 | data_time.update(time.time() - end) 100 | target = target.cuda(async=True) 101 | input_var = torch.autograd.Variable(input) 102 | input1_var = torch.autograd.Variable(input1) 103 | input_ul_var = torch.autograd.Variable(input_ul) 104 | input1_ul_var = torch.autograd.Variable(input1_ul) 105 | input_concat_var = torch.cat([input_var, input_ul_var]) 106 | input1_concat_var = torch.cat([input1_var, input1_ul_var]) 107 | 108 | target_var = torch.autograd.Variable(target) 109 | 110 | # compute output 111 | output = model(input_concat_var) 112 | with torch.no_grad(): 113 | output1 = model(input1_concat_var) 114 | 115 | output_label = output[:sl[0]] 116 | #pred = F.softmax(output, 1) # consistency loss on logit is better 117 | #pred1 = F.softmax(output1, 1) 118 | loss_ce = criterion(output_label, target_var) / float(sl[0]) 119 | loss_pi = criterion_mse(output, output1) / float(args.num_classes * batch_size) 120 | 121 | reg_l1 = cal_reg_l1(model, criterion_l1) 122 | 123 | loss = loss_ce + args.weight_l1 * reg_l1 + weight_cl * weight_pi * loss_pi 124 | 125 | # measure accuracy and record loss 126 | prec1, prec5 = accuracy(output_label.data, target, topk=(1, 5)) 127 | losses.update(loss_ce.item(), input.size(0)) 128 | losses_pi.update(loss_pi.item(), input.size(0)) 129 | top1.update(prec1.item(), input.size(0)) 130 | top5.update(prec5.item(), input.size(0)) 131 | weights_cl.update(weight_cl, input.size(0)) 132 | 133 | # compute gradient and do SGD step 134 | optimizer.zero_grad() 135 | loss.backward() 136 | optimizer.step() 137 | 138 | # measure elapsed time 139 | batch_time.update(time.time() - end) 140 | end = time.time() 141 | 142 | if i % args.print_freq == 0: 143 | print('Epoch: [{0}][{1}/{2}]\t' 144 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 145 | 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 146 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 147 | 'LossPi {loss_pi.val:.4f} ({loss_pi.avg:.4f})\t' 148 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' 149 | 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format( 150 | epoch, i, len_iter, batch_time=batch_time, 151 | data_time=data_time, loss=losses, loss_pi=losses_pi, 152 | top1=top1, top5=top5)) 153 | 154 | return top1.avg , losses.avg, losses_pi.avg, weights_cl.avg 155 | 156 | def train_mt(label_loader, unlabel_loader, model, model_teacher, criterions, optimizer, epoch, args, ema_const=0.95, weight_mt=8.0): 157 | batch_time = AverageMeter() 158 | data_time = AverageMeter() 159 | losses = AverageMeter() 160 | losses_cl = AverageMeter() 161 | top1 = AverageMeter() 162 | top5 = AverageMeter() 163 | top1_t = AverageMeter() 164 | top5_t = AverageMeter() 165 | weights_cl = AverageMeter() 166 | 167 | # switch to train mode 168 | model.train() 169 | model_teacher.train() 170 | 171 | criterion, criterion_mse, _, criterion_l1 = criterions 172 | 173 | end = time.time() 174 | 175 | label_iter = iter(label_loader) 176 | unlabel_iter = iter(unlabel_loader) 177 | len_iter = len(unlabel_iter) 178 | for i in range(len_iter): 179 | # set weights for the consistency loss 180 | global_step = epoch * len_iter + i 181 | weight_cl = cal_consistency_weight(global_step, end_ep=(args.epochs//2)*len_iter, end_w=1.0) 182 | 183 | try: 184 | input, target, input1 = next(label_iter) 185 | except StopIteration: 186 | label_iter = iter(label_loader) 187 | input, target, input1 = next(label_iter) 188 | input_ul, _, input1_ul = next(unlabel_iter) 189 | sl = input.shape 190 | su = input_ul.shape 191 | batch_size = sl[0] + su[0] 192 | # measure data loading time 193 | data_time.update(time.time() - end) 194 | target = target.cuda(async=True) 195 | input_var = torch.autograd.Variable(input) 196 | input1_var = torch.autograd.Variable(input1) 197 | input_ul_var = torch.autograd.Variable(input_ul) 198 | input1_ul_var = torch.autograd.Variable(input1_ul) 199 | input_concat_var = torch.cat([input_var, input_ul_var]) 200 | input1_concat_var = torch.cat([input1_var, input1_ul_var]) 201 | 202 | target_var = torch.autograd.Variable(target) 203 | 204 | # compute output 205 | output = model(input_concat_var) 206 | with torch.no_grad(): 207 | output1 = model_teacher(input1_concat_var) 208 | 209 | output_label = output[:sl[0]] 210 | output1_label = output1[:sl[0]] 211 | #pred = F.softmax(output, 1) 212 | #pred1 = F.softmax(output1, 1) 213 | loss_ce = criterion(output_label, target_var) /float(sl[0]) 214 | loss_cl = criterion_mse(output, output1) /float(args.num_classes * batch_size) 215 | 216 | reg_l1 = cal_reg_l1(model, criterion_l1) 217 | 218 | loss = loss_ce + args.weight_l1 * reg_l1 + weight_cl * weight_mt * loss_cl 219 | 220 | # measure accuracy and record loss 221 | prec1, prec5 = accuracy(output_label.data, target, topk=(1, 5)) 222 | prec1_t, prec5_t = accuracy(output1_label.data, target, topk=(1, 5)) 223 | losses.update(loss_ce.item(), input.size(0)) 224 | losses_cl.update(loss_cl.item(), input.size(0)) 225 | top1.update(prec1.item(), input.size(0)) 226 | top5.update(prec5.item(), input.size(0)) 227 | top1_t.update(prec1_t.item(), input.size(0)) 228 | top5_t.update(prec5_t.item(), input.size(0)) 229 | weights_cl.update(weight_cl, input.size(0)) 230 | 231 | # compute gradient and do SGD step 232 | optimizer.zero_grad() 233 | loss.backward() 234 | optimizer.step() 235 | update_ema_variables(model, model_teacher, ema_const, global_step) 236 | 237 | # measure elapsed time 238 | batch_time.update(time.time() - end) 239 | end = time.time() 240 | 241 | if i % args.print_freq == 0: 242 | print('Epoch: [{0}][{1}/{2}]\t' 243 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 244 | 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 245 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 246 | 'LossCL {loss_cl.val:.4f} ({loss_cl.avg:.4f})\t' 247 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' 248 | 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})\t' 249 | 'PrecT@1 {top1_t.val:.3f} ({top1_t.avg:.3f})\t' 250 | 'PrecT@5 {top5_t.val:.3f} ({top5_t.avg:.3f})'.format( 251 | epoch, i, len_iter, batch_time=batch_time, 252 | data_time=data_time, loss=losses, loss_cl=losses_cl, 253 | top1=top1, top5=top5, top1_t=top1_t, top5_t=top5_t)) 254 | 255 | return top1.avg , losses.avg, losses_cl.avg, top1_t.avg, weights_cl.avg 256 | 257 | 258 | def validate(val_loader, model, criterions, args, mode = 'valid'): 259 | batch_time = AverageMeter() 260 | losses = AverageMeter() 261 | top1 = AverageMeter() 262 | top5 = AverageMeter() 263 | 264 | # switch to evaluate mode 265 | model.eval() 266 | 267 | criterion, criterion_mse, _, _ = criterions 268 | 269 | end = time.time() 270 | with torch.no_grad(): 271 | for i, (input, target, _) in enumerate(val_loader): 272 | sl = input.shape 273 | batch_size = sl[0] 274 | target = target.cuda(async=True) 275 | input_var = torch.autograd.Variable(input) 276 | target_var = torch.autograd.Variable(target) 277 | 278 | # compute output 279 | output = model(input_var) 280 | softmax = torch.nn.LogSoftmax(dim=1)(output) 281 | loss = criterion(output, target_var) / float(batch_size) 282 | 283 | # measure accuracy and record loss 284 | prec1, prec5 = accuracy(output.data, target, topk=(1, 5)) 285 | losses.update(loss.item(), input.size(0)) 286 | top1.update(prec1.item(), input.size(0)) 287 | top5.update(prec5.item(), input.size(0)) 288 | 289 | # measure elapsed time 290 | batch_time.update(time.time() - end) 291 | end = time.time() 292 | 293 | if i % args.print_freq == 0: 294 | if mode == 'test': 295 | print('Test: [{0}/{1}]\t' 296 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 297 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 298 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' 299 | 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format( 300 | i, len(val_loader), batch_time=batch_time, loss=losses, 301 | top1=top1, top5=top5)) 302 | else: 303 | print('Valid: [{0}/{1}]\t' 304 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 305 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 306 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' 307 | 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format( 308 | i, len(val_loader), batch_time=batch_time, loss=losses, 309 | top1=top1, top5=top5)) 310 | 311 | print(' ****** Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f} Loss {loss.avg:.3f} ' 312 | .format(top1=top1, top5=top5, loss=losses)) 313 | 314 | return top1.avg, losses.avg 315 | 316 | 317 | class AverageMeter(object): 318 | """Computes and stores the average and current value""" 319 | def __init__(self): 320 | self.reset() 321 | 322 | def reset(self): 323 | self.val = 0 324 | self.avg = 0 325 | self.sum = 0 326 | self.count = 0 327 | 328 | def update(self, val, n=1): 329 | self.val = val 330 | self.sum += val * n 331 | self.count += n 332 | self.avg = self.sum / self.count 333 | 334 | 335 | def accuracy(output, target, topk=(1,)): 336 | """Computes the precision@k for the specified values of k""" 337 | maxk = max(topk) 338 | batch_size = target.size(0) 339 | 340 | _, pred = output.topk(maxk, 1, True, True) 341 | pred = pred.t() 342 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 343 | 344 | res = [] 345 | for k in topk: 346 | correct_k = correct[:k].view(-1).float().sum(0) 347 | res.append(correct_k.mul_(100.0 / batch_size)) 348 | return res 349 | 350 | def cal_consistency_weight(epoch, init_ep=0, end_ep=150, init_w=0.0, end_w=20.0): 351 | """Sets the weights for the consistency loss""" 352 | if epoch > end_ep: 353 | weight_cl = end_w 354 | elif epoch < init_ep: 355 | weight_cl = init_w 356 | else: 357 | T = float(epoch - init_ep)/float(end_ep - init_ep) 358 | #weight_mse = T * (end_w - init_w) + init_w #linear 359 | weight_cl = (math.exp(-5.0 * (1.0 - T) * (1.0 - T))) * (end_w - init_w) + init_w #exp 360 | #print('Consistency weight: %f'%weight_cl) 361 | return weight_cl 362 | 363 | def cal_reg_l1(model, criterion_l1): 364 | reg_loss = 0 365 | np = 0 366 | for param in model.parameters(): 367 | reg_loss += criterion_l1(param, torch.zeros_like(param)) 368 | np += param.nelement() 369 | reg_loss = reg_loss / np 370 | return reg_loss 371 | 372 | def update_ema_variables(model, model_teacher, alpha, global_step): 373 | # Use the true average until the exponential average is more correct 374 | alpha = min(1.0 - 1.0 / float(global_step + 1), alpha) 375 | for param_t, param in zip(model_teacher.parameters(), model.parameters()): 376 | param_t.data.mul_(alpha).add_(1 - alpha, param.data) 377 | 378 | 379 | -------------------------------------------------------------------------------- /preresnet_sd_cifar.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import math 4 | import torch.utils.model_zoo as model_zoo 5 | 6 | 7 | __all__ = ['resnet'] 8 | 9 | 10 | def conv3x3(in_planes, out_planes, stride=1): 11 | "3x3 convolution with padding" 12 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 13 | padding=1, bias=False) 14 | 15 | 16 | class BasicBlock(nn.Module): 17 | expansion = 1 18 | 19 | def __init__(self, inplanes, planes, stride=1, downsample=None, death_rate=0.): 20 | super(BasicBlock, self).__init__() 21 | self.bn1 = nn.BatchNorm2d(inplanes) 22 | self.relu = nn.ReLU(inplace=True) 23 | self.conv1 = conv3x3(inplanes, planes, stride) 24 | self.bn2 = nn.BatchNorm2d(planes) 25 | self.conv2 = conv3x3(planes, planes) 26 | self.downsample = downsample 27 | self.stride = stride 28 | self.death_rate =death_rate 29 | 30 | def forward(self, x): 31 | residual = x 32 | if self.downsample is not None: 33 | residual = self.downsample(x) 34 | 35 | if not self.training or torch.rand(1)[0] >= self.death_rate: 36 | out = self.bn1(x) 37 | out = self.relu(out) 38 | out = self.conv1(out) 39 | 40 | out = self.bn2(out) 41 | out = self.relu(out) 42 | out = self.conv2(out) 43 | 44 | if self.training: 45 | out /= (1. - self.death_rate) 46 | 47 | out += residual 48 | else: 49 | out = residual 50 | 51 | return out 52 | 53 | 54 | class Bottleneck(nn.Module): 55 | expansion = 4 56 | 57 | def __init__(self, inplanes, planes, stride=1, downsample=None, death_rate=0.): 58 | super(Bottleneck, self).__init__() 59 | self.bn1 = nn.BatchNorm2d(inplanes) 60 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 61 | self.bn2 = nn.BatchNorm2d(planes) 62 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 63 | padding=1, bias=False) 64 | self.bn3 = nn.BatchNorm2d(planes) 65 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 66 | self.relu = nn.ReLU(inplace=True) 67 | self.downsample = downsample 68 | self.stride = stride 69 | self.death_rate =death_rate 70 | 71 | def forward(self, x): 72 | residual = x 73 | if self.downsample is not None: 74 | residual = self.downsample(x) 75 | 76 | if not self.training or torch.rand(1)[0] >= self.death_rate: 77 | out = self.bn1(x) 78 | out = self.relu(out) 79 | out = self.conv1(out) 80 | 81 | out = self.bn2(out) 82 | out = self.relu(out) 83 | out = self.conv2(out) 84 | 85 | out = self.bn3(out) 86 | out = self.relu(out) 87 | out = self.conv3(out) 88 | 89 | if self.training: 90 | out /= (1. - self.death_rate) 91 | 92 | out += residual 93 | else: 94 | out = residual 95 | 96 | return out 97 | 98 | class ResNet(nn.Module): 99 | 100 | def __init__(self, depth, num_classes=1000, death_mode='linear', death_rate=0.5): 101 | super(ResNet, self).__init__() 102 | # Model type specifies number of layers for CIFAR-10 model 103 | assert (depth - 2) % 6 == 0, 'depth should be 6n+2' 104 | n = (depth - 2) // 6 105 | 106 | block = Bottleneck if depth >=44 else BasicBlock 107 | 108 | nblocks = (depth - 2) // 2 109 | if death_mode == 'uniform': 110 | death_rates = [death_rate] * nblocks 111 | print("Stochastic Depth: uniform mode") 112 | elif death_mode == 'linear': 113 | death_rates = [float(i + 1) * death_rate / float(nblocks) 114 | for i in range(nblocks)] 115 | print("Stochastic Depth: linear mode") 116 | else: 117 | death_rates = [0.] * (3 * n) 118 | print("Stochastic Depth: none mode") 119 | 120 | self.inplanes = 16 121 | self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1, 122 | bias=False) 123 | self.layer1 = self._make_layer(block, 16, n, death_rates[:n]) 124 | self.layer2 = self._make_layer(block, 32, n, death_rates[n:2*n], stride=2) 125 | self.layer3 = self._make_layer(block, 64, n, death_rates[2*n:], stride=2) 126 | self.bn1 = nn.BatchNorm2d(64 * block.expansion) 127 | self.relu = nn.ReLU(inplace=True) 128 | self.avgpool = nn.AvgPool2d(8) 129 | self.fc1 = nn.Linear(64 * block.expansion, 64 * block.expansion) 130 | self.fc = nn.Linear(64 * block.expansion, num_classes) 131 | 132 | for m in self.modules(): 133 | if isinstance(m, nn.Conv2d): 134 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 135 | m.weight.data.normal_(0, math.sqrt(2. / n)) 136 | elif isinstance(m, nn.BatchNorm2d): 137 | m.weight.data.fill_(1) 138 | m.bias.data.zero_() 139 | 140 | def _make_layer(self, block, planes, blocks, death_rates, stride=1): 141 | downsample = None 142 | if stride != 1 or self.inplanes != planes * block.expansion: 143 | downsample = nn.Sequential( 144 | nn.Conv2d(self.inplanes, planes * block.expansion, 145 | kernel_size=1, stride=stride, bias=False), 146 | #nn.BatchNorm2d(planes * block.expansion), 147 | #nn.AvgPool2d((2,2), stride = (2, 2)) 148 | ) 149 | 150 | layers = [] 151 | layers.append(block(self.inplanes, planes, stride, downsample, death_rate=death_rates[0])) 152 | self.inplanes = planes * block.expansion 153 | for i in range(1, blocks): 154 | layers.append(block(self.inplanes, planes, death_rate=death_rates[i])) 155 | 156 | return nn.Sequential(*layers) 157 | 158 | def split2(self, x): 159 | x1, x2 = torch.split(x,x.shape[1]/2,1) 160 | return x1, x2 161 | 162 | def reparameterize(self, mu, logvar): 163 | if self.training: 164 | std = logvar.mul(0.5).exp_() 165 | eps = torch.autograd.Variable(std.data.new(std.size()).normal_()) 166 | return eps.mul(std).add_(mu) 167 | else: 168 | return mu 169 | 170 | def forward(self, x): 171 | x = self.conv1(x) 172 | 173 | x = self.layer1(x) # 32x32 174 | x = self.layer2(x) # 16x16 175 | x = self.layer3(x) # 8x8 176 | 177 | x = self.bn1(x) 178 | x = self.relu(x) 179 | 180 | x = self.avgpool(x) 181 | x = x.view(x.size(0), -1) 182 | 183 | x = self.fc1(x) 184 | 185 | output = self.fc(x) 186 | 187 | 188 | return output 189 | 190 | 191 | def resnet(**kwargs): 192 | """ 193 | Constructs a ResNet model. 194 | """ 195 | return ResNet(**kwargs) 196 | 197 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # this code is modified from the pytorch code: https://github.com/CSAILVision/places365 2 | # JH Kim 3 | # 4 | 5 | import argparse 6 | import os 7 | import shutil 8 | import time 9 | 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.parallel 13 | import torch.backends.cudnn as cudnn 14 | import torch.optim 15 | import torch.utils.data as data 16 | import torchvision.transforms as transforms 17 | import torchvision.datasets as datasets 18 | 19 | import preresnet_sd_cifar as preresnet_cifar 20 | import wideresnet 21 | import pdb 22 | import bisect 23 | 24 | import loader_cifar as cifar 25 | import loader_cifar_zca as cifar_zca 26 | import loader_svhn as svhn 27 | import math 28 | from math import ceil 29 | import torch.nn.functional as F 30 | from methods import train_sup, train_pi, train_mt, validate 31 | 32 | 33 | parser = argparse.ArgumentParser(description='PyTorch Semi-supervised learning Training') 34 | parser.add_argument('--arch', '-a', metavar='ARCH', default='wideresnet', 35 | help='model architecture: '+ ' (default: wideresnet)') 36 | parser.add_argument('--model', '-m', metavar='MODEL', default='baseline', 37 | help='model: '+' (default: baseline)', choices=['baseline', 'pi', 'mt']) 38 | parser.add_argument('--optim', '-o', metavar='OPTIM', default='adam', 39 | help='optimizer: '+' (default: adam)', choices=['adam', 'sgd']) 40 | parser.add_argument('--dataset', '-d', metavar='DATASET', default='cifar10_zca', 41 | help='dataset: '+' (default: cifar10)', choices=['cifar10', 'cifar10_zca', 'svhn']) 42 | parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', 43 | help='number of data loading workers (default: 4)') 44 | parser.add_argument('--epochs', default=1200, type=int, metavar='N', 45 | help='number of total epochs to run') 46 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', 47 | help='manual epoch number (useful on restarts)') 48 | parser.add_argument('-b', '--batch-size', default=256, type=int, 49 | metavar='N', help='mini-batch size (default: 225)') 50 | parser.add_argument('--lr', '--learning-rate', default=0.003, type=float, 51 | metavar='LR', help='initial learning rate') 52 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 53 | help='momentum') 54 | parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float, 55 | metavar='W', help='weight decay (default: 1e-4)') 56 | parser.add_argument('--weight_l1', '--l1', default=1e-3, type=float, 57 | metavar='W1', help='l1 regularization (default: 1e-3)') 58 | parser.add_argument('--print-freq', '-p', default=100, type=int, 59 | metavar='N', help='print frequency (default: 10)') 60 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 61 | help='path to latest checkpoint (default: none)') 62 | parser.add_argument('--num_classes',default=10, type=int, help='number of classes in the model') 63 | parser.add_argument('--ckpt', default='ckpt', type=str, metavar='PATH', 64 | help='path to save checkpoint (default: ckpt)') 65 | parser.add_argument('--boundary',default=0, type=int, help='different label/unlabel division [0,9]') 66 | parser.add_argument('--gpu',default=0, type=str, help='cuda_visible_devices') 67 | args = parser.parse_args() 68 | 69 | os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID" 70 | os.environ["CUDA_VISIBLE_DEVICES"]=str(args.gpu) 71 | 72 | best_prec1 = 0 73 | best_test_prec1 = 0 74 | acc1_tr, losses_tr = [], [] 75 | losses_cl_tr = [] 76 | acc1_val, losses_val, losses_et_val = [], [], [] 77 | acc1_test, losses_test, losses_et_test = [], [], [] 78 | acc1_t_tr, acc1_t_val, acc1_t_test = [], [], [] 79 | learning_rate, weights_cl = [], [] 80 | 81 | def main(): 82 | global args, best_prec1, best_test_prec1 83 | global acc1_tr, losses_tr 84 | global losses_cl_tr 85 | global acc1_val, losses_val, losses_et_val 86 | global acc1_test, losses_test, losses_et_test 87 | global weights_cl 88 | args = parser.parse_args() 89 | print args 90 | if args.dataset == 'svhn': 91 | drop_rate=0.3 92 | widen_factor=3 93 | else: 94 | drop_rate=0.3 95 | widen_factor=3 96 | 97 | # create model 98 | if args.arch == 'preresnet': 99 | print("Model: %s"%args.arch) 100 | model = preresnet_cifar.resnet(depth=32, num_classes=args.num_classes) 101 | elif args.arch == 'wideresnet': 102 | print("Model: %s"%args.arch) 103 | model = wideresnet.WideResNet(28, args.num_classes, widen_factor=widen_factor, dropRate=drop_rate, leakyRate=0.1) 104 | else: 105 | assert(False) 106 | 107 | if args.model == 'mt': 108 | import copy 109 | model_teacher = copy.deepcopy(model) 110 | model_teacher = torch.nn.DataParallel(model_teacher).cuda() 111 | 112 | model = torch.nn.DataParallel(model).cuda() 113 | print model 114 | 115 | # optionally resume from a checkpoint 116 | if args.resume: 117 | if os.path.isfile(args.resume): 118 | print("=> loading checkpoint '{}'".format(args.resume)) 119 | checkpoint = torch.load(args.resume) 120 | args.start_epoch = checkpoint['epoch'] 121 | best_prec1 = checkpoint['best_prec1'] 122 | 123 | model.load_state_dict(checkpoint['state_dict']) 124 | if args.model=='mt': model_teacher.load_state_dict(checkpoint['state_dict']) 125 | print("=> loaded checkpoint '{}' (epoch {})" 126 | .format(args.resume, checkpoint['epoch'])) 127 | else: 128 | print("=> no checkpoint found at '{}'".format(args.resume)) 129 | 130 | if args.optim == 'sgd' or args.optim == 'adam': 131 | pass 132 | else: 133 | print('Not Implemented Optimizer') 134 | assert(False) 135 | 136 | 137 | ckpt_dir = args.ckpt+'_'+args.dataset+'_'+args.arch+'_'+args.model+'_'+args.optim 138 | ckpt_dir = ckpt_dir + '_e%d'%(args.epochs) 139 | if not os.path.exists(ckpt_dir): 140 | os.makedirs(ckpt_dir) 141 | print(ckpt_dir) 142 | cudnn.benchmark = True 143 | 144 | # Data loading code 145 | if args.dataset == 'cifar10': 146 | dataloader = cifar.CIFAR10 147 | num_classes = 10 148 | data_dir = '/tmp/' 149 | 150 | normalize = transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], 151 | std=[0.2023, 0.1994, 0.2010]) 152 | transform_train = transforms.Compose([ 153 | transforms.RandomCrop(32, padding=2), 154 | transforms.RandomHorizontalFlip(), 155 | transforms.ToTensor(), 156 | normalize, 157 | ]) 158 | 159 | transform_test = transforms.Compose([ 160 | transforms.ToTensor(), 161 | normalize, 162 | ]) 163 | 164 | elif args.dataset == 'cifar10_zca': 165 | dataloader = cifar_zca.CIFAR10 166 | num_classes = 10 167 | data_dir = 'cifar10_zca/cifar10_gcn_zca_v2.npz' 168 | 169 | # transform is implemented inside zca dataloader 170 | transform_train = transforms.Compose([ 171 | transforms.ToTensor(), 172 | ]) 173 | 174 | transform_test = transforms.Compose([ 175 | transforms.ToTensor(), 176 | ]) 177 | 178 | 179 | elif args.dataset == 'svhn': 180 | dataloader = svhn.SVHN 181 | num_classes = 10 182 | data_dir = '/tmp/' 183 | 184 | normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], 185 | std=[0.5, 0.5, 0.5]) 186 | transform_train = transforms.Compose([ 187 | transforms.RandomCrop(32, padding=2), 188 | transforms.ToTensor(), 189 | normalize, 190 | ]) 191 | 192 | transform_test = transforms.Compose([ 193 | transforms.ToTensor(), 194 | normalize, 195 | ]) 196 | 197 | 198 | labelset = dataloader(root=data_dir, split='label', download=True, transform=transform_train, boundary=args.boundary) 199 | unlabelset = dataloader(root=data_dir, split='unlabel', download=True, transform=transform_train, boundary=args.boundary) 200 | batch_size_label = args.batch_size//2 201 | batch_size_unlabel = args.batch_size//2 202 | if args.model == 'baseline': batch_size_label=args.batch_size 203 | 204 | label_loader = data.DataLoader(labelset, 205 | batch_size=batch_size_label, 206 | shuffle=True, 207 | num_workers=args.workers, 208 | pin_memory=True) 209 | label_iter = iter(label_loader) 210 | 211 | unlabel_loader = data.DataLoader(unlabelset, 212 | batch_size=batch_size_unlabel, 213 | shuffle=True, 214 | num_workers=args.workers, 215 | pin_memory=True) 216 | unlabel_iter = iter(unlabel_loader) 217 | 218 | print("Batch size (label): ", batch_size_label) 219 | print("Batch size (unlabel): ", batch_size_unlabel) 220 | 221 | 222 | validset = dataloader(root=data_dir, split='valid', download=True, transform=transform_test, boundary=args.boundary) 223 | val_loader = data.DataLoader(validset, 224 | batch_size=args.batch_size, 225 | shuffle=False, 226 | num_workers=args.workers, 227 | pin_memory=True) 228 | 229 | testset = dataloader(root=data_dir, split='test', download=True, transform=transform_test) 230 | test_loader = data.DataLoader(testset, 231 | batch_size=args.batch_size, 232 | shuffle=False, 233 | num_workers=args.workers, 234 | pin_memory=True) 235 | 236 | # deifine loss function (criterion) and optimizer 237 | criterion = nn.CrossEntropyLoss(size_average=False).cuda() 238 | criterion_mse = nn.MSELoss(size_average=False).cuda() 239 | criterion_kl = nn.KLDivLoss(size_average=False).cuda() 240 | criterion_l1 = nn.L1Loss(size_average=False).cuda() 241 | 242 | criterions = (criterion, criterion_mse, criterion_kl, criterion_l1) 243 | 244 | if args.optim == 'adam': 245 | print('Using Adam optimizer') 246 | optimizer = torch.optim.Adam(model.parameters(), args.lr, 247 | betas=(0.9,0.999), 248 | weight_decay=args.weight_decay) 249 | elif args.optim == 'sgd': 250 | print('Using SGD optimizer') 251 | optimizer = torch.optim.SGD(model.parameters(), args.lr, 252 | momentum=args.momentum, 253 | weight_decay=args.weight_decay) 254 | 255 | for epoch in range(args.start_epoch, args.epochs): 256 | if args.optim == 'adam': 257 | print('Learning rate schedule for Adam') 258 | lr = adjust_learning_rate_adam(optimizer, epoch) 259 | elif args.optim == 'sgd': 260 | print('Learning rate schedule for SGD') 261 | lr = adjust_learning_rate(optimizer, epoch) 262 | 263 | # train for one epoch 264 | if args.model == 'baseline': 265 | print('Supervised Training') 266 | for i in range(10): #baseline repeat 10 times since small number of training set 267 | prec1_tr, loss_tr = train_sup(label_loader, model, criterions, optimizer, epoch, args) 268 | weight_cl = 0.0 269 | elif args.model == 'pi': 270 | print('Pi model') 271 | prec1_tr, loss_tr, loss_cl_tr, weight_cl = train_pi(label_loader, unlabel_loader, model, criterions, optimizer, epoch, args) 272 | elif args.model == 'mt': 273 | print('Mean Teacher model') 274 | prec1_tr, loss_tr, loss_cl_tr, prec1_t_tr, weight_cl = train_mt(label_loader, unlabel_loader, model, model_teacher, criterions, optimizer, epoch, args) 275 | else: 276 | print("Not Implemented ", args.model) 277 | assert(False) 278 | 279 | # evaluate on validation set 280 | prec1_val, loss_val = validate(val_loader, model, criterions, args, 'valid') 281 | prec1_test, loss_test = validate(test_loader, model, criterions, args, 'test') 282 | if args.model=='mt': 283 | prec1_t_val, loss_t_val = validate(val_loader, model_teacher, criterions, args, 'valid') 284 | prec1_t_test, loss_t_test = validate(test_loader, model_teacher, criterions, args, 'test') 285 | 286 | # append values 287 | acc1_tr.append(prec1_tr) 288 | losses_tr.append(loss_tr) 289 | acc1_val.append(prec1_val) 290 | losses_val.append(loss_val) 291 | acc1_test.append(prec1_test) 292 | losses_test.append(loss_test) 293 | if args.model != 'baseline': 294 | losses_cl_tr.append(loss_cl_tr) 295 | if args.model=='mt': 296 | acc1_t_tr.append(prec1_t_tr) 297 | acc1_t_val.append(prec1_t_val) 298 | acc1_t_test.append(prec1_t_test) 299 | weights_cl.append(weight_cl) 300 | learning_rate.append(lr) 301 | 302 | # remember best prec@1 and save checkpoint 303 | if args.model == 'mt': 304 | is_best = prec1_t_val > best_prec1 305 | if is_best: 306 | best_test_prec1_t = prec1_t_test 307 | best_test_prec1 = prec1_test 308 | print("Best test precision: %.3f"%best_test_prec1_t) 309 | best_prec1 = max(prec1_t_val, best_prec1) 310 | dict_checkpoint = { 311 | 'epoch': epoch + 1, 312 | 'state_dict': model.state_dict(), 313 | 'best_prec1': best_prec1, 314 | 'best_test_prec1' : best_test_prec1, 315 | 'acc1_tr': acc1_tr, 316 | 'losses_tr': losses_tr, 317 | 'losses_cl_tr': losses_cl_tr, 318 | 'acc1_val': acc1_val, 319 | 'losses_val': losses_val, 320 | 'acc1_test' : acc1_test, 321 | 'losses_test' : losses_test, 322 | 'acc1_t_tr': acc1_t_tr, 323 | 'acc1_t_val': acc1_t_val, 324 | 'acc1_t_test': acc1_t_test, 325 | 'state_dict_teacher': model_teacher.state_dict(), 326 | 'best_test_prec1_t' : best_test_prec1_t, 327 | 'weights_cl' : weights_cl, 328 | 'learning_rate' : learning_rate, 329 | } 330 | 331 | else: 332 | is_best = prec1_val > best_prec1 333 | if is_best: 334 | best_test_prec1 = prec1_test 335 | print("Best test precision: %.3f"%best_test_prec1) 336 | best_prec1 = max(prec1_val, best_prec1) 337 | dict_checkpoint = { 338 | 'epoch': epoch + 1, 339 | 'state_dict': model.state_dict(), 340 | 'best_prec1': best_prec1, 341 | 'best_test_prec1' : best_test_prec1, 342 | 'acc1_tr': acc1_tr, 343 | 'losses_tr': losses_tr, 344 | 'losses_cl_tr': losses_cl_tr, 345 | 'acc1_val': acc1_val, 346 | 'losses_val': losses_val, 347 | 'acc1_test' : acc1_test, 348 | 'losses_test' : losses_test, 349 | 'weights_cl' : weights_cl, 350 | 'learning_rate' : learning_rate, 351 | } 352 | 353 | save_checkpoint(dict_checkpoint, is_best, args.arch.lower()+str(args.boundary), dirname=ckpt_dir) 354 | 355 | 356 | def save_checkpoint(state, is_best, filename='checkpoint.pth.tar', dirname='.'): 357 | fpath = os.path.join(dirname, filename + '_latest.pth.tar') 358 | torch.save(state, fpath) 359 | if is_best: 360 | bpath = os.path.join(dirname, filename + '_best.pth.tar') 361 | shutil.copyfile(fpath, bpath) 362 | 363 | def adjust_learning_rate(optimizer, epoch): 364 | """Sets the learning rate to the initial LR decayed by 10 at [150, 225, 300] epochs""" 365 | 366 | boundary = [args.epochs//2,args.epochs//4*3,args.epochs] 367 | lr = args.lr * 0.1 ** int(bisect.bisect_left(boundary, epoch)) 368 | print('Learning rate: %f'%lr) 369 | #print(epoch, lr, bisect.bisect_left(boundary, epoch)) 370 | # lr = args.lr * (0.1 ** (epoch // 30)) 371 | for param_group in optimizer.param_groups: 372 | param_group['lr'] = lr 373 | 374 | return lr 375 | 376 | def adjust_learning_rate_adam(optimizer, epoch): 377 | """Sets the learning rate to the initial LR decayed by 5 at [240] epochs""" 378 | 379 | boundary = [args.epochs//5*4] 380 | lr = args.lr * 0.2 ** int(bisect.bisect_left(boundary, epoch)) 381 | print('Learning rate: %f'%lr) 382 | #print(epoch, lr) 383 | for param_group in optimizer.param_groups: 384 | param_group['lr'] = lr 385 | 386 | return lr 387 | 388 | 389 | 390 | if __name__ == '__main__': 391 | main() 392 | -------------------------------------------------------------------------------- /wideresnet.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | # Modified 9 | # ReLU --> LeakyReLU 10 | 11 | import math 12 | import torch 13 | import torch.nn as nn 14 | import torch.nn.functional as F 15 | 16 | 17 | class BasicBlock(nn.Module): 18 | def __init__(self, in_planes, out_planes, stride, dropRate=0.0, leakyRate=0.01, actBeforeRes=True): 19 | super(BasicBlock, self).__init__() 20 | self.bn1 = nn.BatchNorm2d(in_planes) 21 | self.relu1 = nn.LeakyReLU(leakyRate, inplace=True) 22 | self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 23 | padding=1, bias=False) 24 | self.bn2 = nn.BatchNorm2d(out_planes) 25 | self.relu2 = nn.LeakyReLU(leakyRate, inplace=True) 26 | self.conv2 = nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1, 27 | padding=1, bias=False) 28 | self.droprate = dropRate 29 | self.equalInOut = (in_planes == out_planes) 30 | self.convShortcut = (not self.equalInOut) and nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, 31 | padding=0, bias=False) or None 32 | self.activateBeforeResidual= actBeforeRes 33 | def forward(self, x): 34 | if not self.equalInOut and self.activateBeforeResidual: 35 | x = self.relu1(self.bn1(x)) 36 | out = self.conv1(x) 37 | else: 38 | out = self.conv1(self.relu1(self.bn1(x))) 39 | #out = self.conv1(out if self.equalInOut else x) 40 | #out = self.conv1(self.equalInOut and out or x) 41 | if self.droprate > 0: 42 | out = F.dropout(out, p=self.droprate, training=self.training) 43 | out = self.conv2(self.relu2(self.bn2(out))) 44 | res = self.convShortcut(x) if not self.equalInOut else x 45 | return torch.add(res, out) 46 | 47 | class NetworkBlock(nn.Module): 48 | def __init__(self, nb_layers, in_planes, out_planes, block, stride, dropRate=0.0, leakyRate=0.01, actBeforeRes=True): 49 | super(NetworkBlock, self).__init__() 50 | self.layer = self._make_layer(block, in_planes, out_planes, nb_layers, stride, dropRate, leakyRate, actBeforeRes) 51 | def _make_layer(self, block, in_planes, out_planes, nb_layers, stride, dropRate, leakyRate, actBeforeRes): 52 | layers = [] 53 | for i in range(nb_layers): 54 | layers.append(block(i == 0 and in_planes or out_planes, out_planes, i == 0 and stride or 1, dropRate, leakyRate, actBeforeRes)) 55 | return nn.Sequential(*layers) 56 | def forward(self, x): 57 | return self.layer(x) 58 | 59 | class WideResNet(nn.Module): 60 | def __init__(self, depth, num_classes, widen_factor=1, dropRate=0.0, leakyRate=0.01): 61 | super(WideResNet, self).__init__() 62 | nChannels = [16, 16*widen_factor, 32*widen_factor, 64*widen_factor] 63 | assert((depth - 4) % 6 == 0) 64 | n = (depth - 4) / 6 65 | block = BasicBlock 66 | # 1st conv before any network block 67 | self.conv1 = nn.Conv2d(3, nChannels[0], kernel_size=3, stride=1, 68 | padding=1, bias=False) 69 | # 1st block 70 | self.block1 = NetworkBlock(n, nChannels[0], nChannels[1], block, 1, dropRate, leakyRate, False) 71 | # 2nd block 72 | self.block2 = NetworkBlock(n, nChannels[1], nChannels[2], block, 2, dropRate, leakyRate, True) 73 | # 3rd block 74 | self.block3 = NetworkBlock(n, nChannels[2], nChannels[3], block, 2, dropRate, leakyRate, True) 75 | # global average pooling and classifier 76 | self.bn1 = nn.BatchNorm2d(nChannels[3]) 77 | self.relu = nn.LeakyReLU(leakyRate, inplace=True) 78 | self.fc = nn.Linear(nChannels[3], num_classes) 79 | self.nChannels = nChannels[3] 80 | 81 | for m in self.modules(): 82 | if isinstance(m, nn.Conv2d): 83 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 84 | m.weight.data.normal_(0, math.sqrt(2. / n)) 85 | elif isinstance(m, nn.BatchNorm2d): 86 | m.weight.data.fill_(1) 87 | m.bias.data.zero_() 88 | elif isinstance(m, nn.Linear): 89 | m.bias.data.zero_() 90 | def forward(self, x): 91 | out = self.conv1(x) 92 | out = self.block1(out) 93 | out = self.block2(out) 94 | out = self.block3(out) 95 | out = self.relu(self.bn1(out)) 96 | out = F.avg_pool2d(out, 8) 97 | out = out.view(-1, self.nChannels) 98 | return self.fc(out) 99 | --------------------------------------------------------------------------------