├── CIFAR-N ├── CIFAR-100_human.pt └── CIFAR-10_human.pt ├── README.md ├── dataloader.py ├── fashionmnist ├── test_images.npy ├── test_labels.npy ├── train_images.npy └── train_labels.npy ├── main.py ├── models ├── PreResNet.py ├── encoders.py └── vae.py ├── requirements.txt └── tools.py /CIFAR-N/CIFAR-100_human.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tmllab/2023_NeurIPS_CS-isolate/a5400d028019648ef1e01cb2bbbdf358af4a7dc8/CIFAR-N/CIFAR-100_human.pt -------------------------------------------------------------------------------- /CIFAR-N/CIFAR-10_human.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tmllab/2023_NeurIPS_CS-isolate/a5400d028019648ef1e01cb2bbbdf358af4a7dc8/CIFAR-N/CIFAR-10_human.pt -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CS-Isolate: Extracting Hard Confident Examples by Content and Style Isolation 2 | 3 | Official implementation of CS-Isolate: Extracting Hard Confident Examples by Content and Style Isolation (NeurIPS 2023). 4 | 5 | ## Abstract 6 | 7 | Label noise widely exists in large-scale image datasets. To mitigate the side effects of label noise, state-of-the-art methods focus on selecting confident examples by leveraging semi-supervised learning. Existing research shows that the ability to extract hard confident examples, which are close to the decision boundary, significantly influences the generalization ability of the learned classifier. In this paper, we find that a key reason for some hard examples being close to the decision boundary is due to the entanglement of style factors with content factors. The hard examples become more discriminative when we focus solely on content factors, such as semantic information, while ignoring style factors. Nonetheless, given only noisy data, content factors are not directly observed and have to be inferred. To tackle the problem of inferring content factors for classification when learning with noisy labels, our objective is to ensure that the content factors of all examples in the same underlying clean class remain unchanged as their style information changes. To achieve this, we utilize different data augmentation techniques to alter the styles while regularizing content factors based on some confident examples. By training existing methods with our inferred content factors, CS-Isolate proves their effectiveness in learning hard examples on benchmark datasets. 8 | 9 | ## Experiments 10 | 11 | To install the necessary Python packages: 12 | 13 | ``` 14 | pip install -r requirements.txt 15 | ``` 16 | 17 | Download the CIFAR-10 and CIFAR-100 datasets: 18 | 19 | ``` 20 | wget -c https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz 21 | tar -xzvf cifar-10-python.tar.gz 22 | mv cifar-10-batches-py cifar-10 23 | wget -c https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz 24 | tar -xzvf cifar-100-python.tar.gz 25 | mv cifar-100-batches-py cifar-100 26 | ``` 27 | 28 | For the FashionMNIST dataset, we use the dataset provided by [PTD](https://github.com/xiaoboxia/Part-dependent-label-noise). The images and labels have been processed to .npy format. You can download the [fashionmnist](https://drive.google.com/open?id=1Tz3W3JVYv2nu-mdM6x33KSnRIY1B7ygQ) here. 29 | 30 | To train the model on FashionMNIST: 31 | 32 | ``` 33 | python main.py --noise_mode instance --dataset fashionmnist --data_path ./fashionmnist --num_class 10 --r 0.4 34 | ``` 35 | 36 | To train the model on CIFAR-10: 37 | 38 | ``` 39 | python main.py --noise_mode instance --dataset cifar10 --data_path ./cifar-10 --num_class 10 --r 0.4 40 | ``` 41 | 42 | To train the model on CIFAR-100: 43 | 44 | ``` 45 | python main.py --noise_mode instance --dataset cifar100 --data_path ./cifar-100 --num_class 100 --r 0.4 --lambda_u 100 46 | ``` 47 | 48 | To train the model on CIFAR-10N: 49 | 50 | ``` 51 | python main.py --noise_mode worse_label --dataset cifar10 --data_path ./cifar-10 --num_class 10 52 | ``` 53 | 54 | ## Citation 55 | 56 | If you find our work insightful, please consider citing our paper: 57 | 58 | ``` 59 | @inproceedings{lin2023cs, 60 | title={CS-Isolate: Extracting Hard Confident Examples by Content and Style Isolation}, 61 | author={Lin, Yexiong and Yao, Yu and Shi, Xiaolong and Gong, Mingming and Shen, Xu and Xu, Dong and Liu, Tongliang}, 62 | booktitle={Thirty-seventh Conference on Neural Information Processing Systems}, 63 | year={2023} 64 | } 65 | ``` 66 | 67 | -------------------------------------------------------------------------------- /dataloader.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset, DataLoader 2 | import torchvision 3 | import torchvision.transforms as transforms 4 | import random 5 | import numpy as np 6 | import pandas as pd 7 | from PIL import Image 8 | import os 9 | import pickle 10 | import torch 11 | import tools 12 | from torchnet.meter import AUCMeter 13 | 14 | import albumentations as A 15 | from albumentations.pytorch import ToTensorV2 16 | import warnings 17 | warnings.filterwarnings("ignore") 18 | 19 | 20 | def unpickle(file): 21 | import pickle as cPickle 22 | with open(file, 'rb') as fo: 23 | dict = cPickle.load(fo, encoding='latin1') 24 | return dict 25 | 26 | class cifar_dataset(Dataset): 27 | def __init__(self, dataset, r, noise_mode, root_dir, transform_weak, transform_strong,mode, noise_file='', pred=[], probability=[], log=''): 28 | 29 | self.r = r 30 | self.transform = transform_weak 31 | self.transform_strong = transform_strong 32 | self.mode = mode 33 | self.transition = {0:0,2:0,4:7,7:7,1:1,9:1,3:5,5:3,6:6,8:8} # class transition for asymmetric noise 34 | self.replay_list = [] 35 | self.replay_num = 1000 36 | self.replay_file = '%s/%.2f_%s.npy'%(root_dir,r,noise_mode) 37 | self.id_file = '%s/%.2f_%s_id.npy'%(root_dir,r,noise_mode) 38 | if os.path.exists(self.replay_file): 39 | print('loading replay') 40 | self.replay_list = np.load(self.replay_file, allow_pickle=True) 41 | print('loading id file') 42 | self.u_c_list = np.load(self.id_file, allow_pickle=True) 43 | 44 | 45 | if self.mode=='test': 46 | if dataset=='cifar10': 47 | test_dic = unpickle('%s/test_batch'%root_dir) 48 | self.test_data = test_dic['data'] 49 | self.test_data = self.test_data.reshape((10000, 3, 32, 32)) 50 | self.test_data = self.test_data.transpose((0, 2, 3, 1)) 51 | self.test_label = test_dic['labels'] 52 | num_classes_ = 10 53 | elif dataset=='cifar100': 54 | test_dic = unpickle('%s/test'%root_dir) 55 | self.test_data = test_dic['data'] 56 | self.test_data = self.test_data.reshape((10000, 3, 32, 32)) 57 | self.test_data = self.test_data.transpose((0, 2, 3, 1)) 58 | self.test_label = test_dic['fine_labels'] 59 | num_classes_ = 100 60 | elif dataset=='fashionmnist': 61 | self.test_data = np.load('%s/test_images.npy'%root_dir) 62 | self.test_label = np.load('%s/test_labels.npy'%root_dir) 63 | num_classes_ = 10 64 | else: 65 | train_data=[] 66 | train_label=[] 67 | if dataset=='cifar10': 68 | for n in range(1,6): 69 | dpath = '%s/data_batch_%d'%(root_dir,n) 70 | data_dic = unpickle(dpath) 71 | train_data.append(data_dic['data']) 72 | train_label = train_label+data_dic['labels'] 73 | train_data = np.concatenate(train_data) 74 | num_classes_ = 10 75 | train_data = train_data.reshape((50000, 3, 32, 32)) 76 | train_data = train_data.transpose((0, 2, 3, 1)) 77 | feature_size = 32*32*3 78 | elif dataset=='cifar100': 79 | train_dic = unpickle('%s/train'%root_dir) 80 | train_data = train_dic['data'] 81 | train_label = train_dic['fine_labels'] 82 | num_classes_ = 100 83 | train_data = train_data.reshape((50000, 3, 32, 32)) 84 | train_data = train_data.transpose((0, 2, 3, 1)) 85 | feature_size = 32*32*3 86 | elif dataset=='fashionmnist': 87 | train_data = np.load('%s/train_images.npy'%root_dir) 88 | train_label = np.load('%s/train_labels.npy'%root_dir) 89 | num_classes_ = 10 90 | feature_size = 28*28 91 | 92 | if noise_mode in ['worse_label', 'aggre_label', 'random_label1', 'random_label2', 'random_label3']: 93 | noise_label = torch.load('./CIFAR-N/CIFAR-10_human.pt') 94 | worst_label = noise_label['worse_label'] 95 | aggre_label = noise_label['aggre_label'] 96 | random_label1 = noise_label['random_label1'] 97 | random_label2 = noise_label['random_label2'] 98 | random_label3 = noise_label['random_label3'] 99 | print('loading %s'%(noise_mode)) 100 | noise_label = noise_label[noise_mode] 101 | elif noise_mode == 'noisy_label': 102 | noise_label = torch.load('./CIFAR-N/CIFAR-100_human.pt') 103 | print('loading %s'%(noise_mode)) 104 | noise_label = noise_label[noise_mode] 105 | else: 106 | if os.path.exists(noise_file): 107 | print('loading %s'%noise_file) 108 | noise_label = torch.load(noise_file) 109 | else: 110 | data_ = torch.from_numpy(train_data).float().cuda() 111 | targets_ = torch.IntTensor(train_label).cuda() 112 | dataset = zip(data_, targets_) 113 | if noise_mode == 'instance': 114 | train_label = torch.FloatTensor(train_label).cuda() 115 | noise_label = tools.get_instance_noisy_label(self.r, dataset, train_label, num_classes = num_classes_, feature_size = feature_size, norm_std=0.1, seed=123) 116 | elif noise_mode == 'sym': 117 | noise_label = [] 118 | idx = list(range(train_data.shape[0])) 119 | random.shuffle(idx) 120 | num_noise = int(self.r*train_data.shape[0]) 121 | noise_idx = idx[:num_noise] 122 | for i in range(train_data.shape[0]): 123 | if i in noise_idx: 124 | noiselabel = random.randint(0,num_classes_-1) 125 | noise_label.append(noiselabel) 126 | else: 127 | noise_label.append(train_label[i]) 128 | noise_label = np.array(noise_label) 129 | elif noise_mode == 'pair': 130 | train_label = np.array(train_label) 131 | train_label = train_label.reshape((-1,1)) 132 | noise_label = tools.noisify_pairflip(train_label, self.r, 123, num_classes_) 133 | noise_label = noise_label[:, 0] 134 | print("save noisy labels to %s ..."%noise_file) 135 | torch.save(noise_label, noise_file) 136 | 137 | 138 | if self.mode == 'warmup': 139 | self.train_data = train_data 140 | self.noise_label = noise_label 141 | elif self.mode == 'all': 142 | self.train_data = train_data 143 | self.noise_label = noise_label 144 | self.pred_idx = pred.nonzero()[0] 145 | # updating id 146 | y_cluster_id_map={} 147 | self.cluster_id_map={} 148 | for i in range(train_data.shape[0]): 149 | if i in self.pred_idx: 150 | if noise_label[i] not in y_cluster_id_map.keys(): 151 | y_cluster_id_map[noise_label[i]]=self.u_c_list[i] 152 | self.cluster_id_map[self.u_c_list[i]]=noise_label[i] 153 | else: 154 | if self.mode == "labeled": 155 | pred_idx = pred.nonzero()[0] 156 | self.probability = [probability[i] for i in pred_idx] 157 | # updating id 158 | y_cluster_id_map={} 159 | self.cluster_id_map={} 160 | self.u_c_list 161 | for i in range(train_data.shape[0]): 162 | if i in pred_idx: 163 | if noise_label[i] not in y_cluster_id_map.keys(): 164 | y_cluster_id_map[noise_label[i]]=self.u_c_list[i] 165 | self.cluster_id_map[self.u_c_list[i]]=noise_label[i] 166 | 167 | clean = (np.array(noise_label)==np.array(train_label)) 168 | auc_meter = AUCMeter() 169 | auc_meter.reset() 170 | auc_meter.add(probability,clean) 171 | auc,_,_ = auc_meter.value() 172 | print('Numer of labeled samples:%d AUC:%.3f'%(pred.sum(),auc)) 173 | 174 | log.write('Numer of labeled samples:%d AUC:%.3f\n'%(pred.sum(),auc)) 175 | log.flush() 176 | 177 | elif self.mode == "unlabeled": 178 | pred_idx = (1-pred).nonzero()[0] 179 | self.probability = [probability[i] for i in pred_idx] 180 | 181 | self.train_data = train_data[pred_idx] 182 | self.u_c_list = self.u_c_list[pred_idx] 183 | self.noise_label = [noise_label[i] for i in pred_idx] 184 | print("%s data has a size of %d"%(self.mode,len(self.noise_label))) 185 | 186 | def __getitem__(self, index): 187 | if self.mode=='labeled': 188 | img, target, prob = self.train_data[index], self.noise_label[index], self.probability[index] 189 | u_s = random.randint(0, self.replay_num-1) 190 | img1 = self.transform(image=img) 191 | img2 = self.transform_strong(image=img) 192 | img1=img1['image'] 193 | img2=img2['image'] 194 | u_c = self.u_c_list[index] 195 | if u_c in self.cluster_id_map.keys(): 196 | u_c=self.cluster_id_map[u_c] 197 | return img1, img2, target, u_c, u_s, prob 198 | elif self.mode=='unlabeled': 199 | img, target, prob = self.train_data[index], self.noise_label[index], self.probability[index] 200 | u_s = random.randint(0, self.replay_num-1) 201 | img1 = self.transform(image=img) 202 | img2 = self.transform_strong(image=img) 203 | img1=img1['image'] 204 | img2=img2['image'] 205 | u_c = self.u_c_list[index] 206 | return img1, img2, target, u_c, u_s, prob 207 | elif self.mode=='warmup': 208 | img, target = self.train_data[index], self.noise_label[index] 209 | img = self.transform(image=img) 210 | img=img['image'] 211 | return img, target, index 212 | elif self.mode=='all': 213 | img, target = self.train_data[index], self.noise_label[index] 214 | if len(self.replay_list) < self.replay_num: 215 | img = self.transform_strong(image=img) 216 | self.replay_list.append(img['replay']) 217 | if len(self.replay_list) == self.replay_num: 218 | print('saving replay') 219 | np.save(self.replay_file, np.array(self.replay_list)) 220 | u_s=len(self.replay_list)-1 221 | else: 222 | u_s = random.randint(0, self.replay_num-1) 223 | img = A.ReplayCompose.replay(self.replay_list[u_s], image=img) 224 | u_c = self.u_c_list[index] 225 | if u_c in self.cluster_id_map.keys(): 226 | u_c=self.cluster_id_map[u_c] 227 | img=img['image'] 228 | 229 | if index in self.pred_idx: 230 | pred_clean = 1 231 | else: 232 | pred_clean=0 233 | return img, target, u_c, u_s, pred_clean, index 234 | elif self.mode=='test': 235 | img, target = self.test_data[index], self.test_label[index] 236 | img = self.transform(image=img) 237 | img=img['image'] 238 | return img, target 239 | 240 | def __len__(self): 241 | if self.mode!='test': 242 | return len(self.train_data) 243 | else: 244 | return len(self.test_data) 245 | 246 | 247 | class cifar_dataloader(): 248 | def __init__(self, dataset, r, noise_mode, batch_size, num_workers, root_dir, log, noise_file=''): 249 | self.dataset = dataset 250 | self.r = r 251 | self.noise_mode = noise_mode 252 | self.batch_size = batch_size 253 | self.num_workers = num_workers 254 | self.root_dir = root_dir 255 | self.log = log 256 | self.noise_file = noise_file 257 | if self.dataset in ['cifar10', 'cifar100']: 258 | self.transform_train = A.ReplayCompose( 259 | [ 260 | A.ShiftScaleRotate(p=0.5), 261 | A.CropAndPad(px=4, keep_size=False, always_apply=True), 262 | A.RandomCrop(height=32, width=32, always_apply=True), 263 | A.HorizontalFlip(), 264 | A.RandomBrightnessContrast(p=0.5), 265 | A.ColorJitter(0.8, 0.8, 0.8, 0.2,p=0.8), 266 | A.ToGray(p=0.2), 267 | A.Normalize(mean=(0.4914, 0.4822, 0.4465), std=(0.2023, 0.1994, 0.2010)), 268 | ToTensorV2(), 269 | ] 270 | ) 271 | self.transform_train_norm = A.Compose([ 272 | A.CropAndPad(px=4, keep_size=False, always_apply=True), 273 | A.RandomCrop(height=32, width=32, always_apply=True), 274 | A.HorizontalFlip(), 275 | A.Normalize(mean=(0.4914, 0.4822, 0.4465), std=(0.2023, 0.1994, 0.2010)), 276 | ToTensorV2(), 277 | ]) 278 | self.transform_test = A.Compose( 279 | [ 280 | A.Normalize(mean=(0.4914, 0.4822, 0.4465), std=(0.2023, 0.1994, 0.2010)), 281 | ToTensorV2(), 282 | ] 283 | ) 284 | elif self.dataset=='fashionmnist': 285 | self.transform_train = A.ReplayCompose( 286 | [ 287 | A.ShiftScaleRotate(p=0.5), 288 | A.CropAndPad(px=2, keep_size=False, always_apply=True), 289 | A.RandomCrop(height=28, width=28, always_apply=True), 290 | A.HorizontalFlip(), 291 | A.Normalize(mean=(0.1307,), std=(0.3081)), 292 | ToTensorV2(), 293 | ] 294 | ) 295 | self.transform_train_norm = A.Compose([ 296 | A.CropAndPad(px=2, keep_size=False, always_apply=True), 297 | A.RandomCrop(height=28, width=28, always_apply=True), 298 | A.HorizontalFlip(), 299 | A.Normalize(mean=(0.1307,), std=(0.3081)), 300 | ToTensorV2(), 301 | ]) 302 | self.transform_test = A.Compose( 303 | [ 304 | A.Normalize(mean=(0.1307,), std=(0.3081)), 305 | ToTensorV2(), 306 | ] 307 | ) 308 | def run(self,mode,pred=[],prob=[]): 309 | if mode=='warmup': 310 | all_dataset = cifar_dataset(dataset=self.dataset, noise_mode=self.noise_mode, r=self.r, root_dir=self.root_dir, transform_weak=self.transform_train_norm, transform_strong=self.transform_train, mode="warmup",noise_file=self.noise_file) 311 | trainloader = DataLoader( 312 | dataset=all_dataset, 313 | batch_size=self.batch_size, 314 | shuffle=True, 315 | num_workers=self.num_workers) 316 | return trainloader 317 | 318 | elif mode=='all': 319 | all_dataset = cifar_dataset(dataset=self.dataset, noise_mode=self.noise_mode, r=self.r, root_dir=self.root_dir, transform_weak=self.transform_train_norm, transform_strong=self.transform_train, mode="all",noise_file=self.noise_file, pred=pred, probability=prob) 320 | trainloader = DataLoader( 321 | dataset=all_dataset, 322 | batch_size=self.batch_size, 323 | shuffle=True, 324 | num_workers=self.num_workers) 325 | return trainloader 326 | 327 | elif mode=='train': 328 | labeled_dataset = cifar_dataset(dataset=self.dataset, noise_mode=self.noise_mode, r=self.r, root_dir=self.root_dir, transform_weak=self.transform_train_norm, transform_strong=self.transform_train, mode="labeled", noise_file=self.noise_file, pred=pred, probability=prob,log=self.log) 329 | labeled_trainloader = DataLoader( 330 | dataset=labeled_dataset, 331 | batch_size=self.batch_size, 332 | shuffle=True, 333 | num_workers=self.num_workers) 334 | 335 | unlabeled_dataset = cifar_dataset(dataset=self.dataset, noise_mode=self.noise_mode, r=self.r, root_dir=self.root_dir, transform_weak=self.transform_train_norm, transform_strong=self.transform_train, mode="unlabeled", noise_file=self.noise_file, pred=pred, probability=prob) 336 | unlabeled_trainloader = DataLoader( 337 | dataset=unlabeled_dataset, 338 | batch_size=self.batch_size, 339 | shuffle=True, 340 | num_workers=self.num_workers) 341 | return labeled_trainloader, unlabeled_trainloader 342 | 343 | elif mode=='test': 344 | test_dataset = cifar_dataset(dataset=self.dataset, noise_mode=self.noise_mode, r=self.r, root_dir=self.root_dir, transform_weak=self.transform_test, transform_strong=self.transform_test, mode='test') 345 | test_loader = DataLoader( 346 | dataset=test_dataset, 347 | batch_size=self.batch_size, 348 | shuffle=False, 349 | num_workers=self.num_workers) 350 | return test_loader 351 | 352 | elif mode=='eval_train': 353 | eval_dataset = cifar_dataset(dataset=self.dataset, noise_mode=self.noise_mode, r=self.r, root_dir=self.root_dir, transform_weak=self.transform_test, transform_strong=self.transform_test, mode='warmup', noise_file=self.noise_file) 354 | eval_loader = DataLoader( 355 | dataset=eval_dataset, 356 | batch_size=self.batch_size, 357 | shuffle=False, 358 | num_workers=self.num_workers) 359 | return eval_loader 360 | -------------------------------------------------------------------------------- /fashionmnist/test_images.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tmllab/2023_NeurIPS_CS-isolate/a5400d028019648ef1e01cb2bbbdf358af4a7dc8/fashionmnist/test_images.npy -------------------------------------------------------------------------------- /fashionmnist/test_labels.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tmllab/2023_NeurIPS_CS-isolate/a5400d028019648ef1e01cb2bbbdf358af4a7dc8/fashionmnist/test_labels.npy -------------------------------------------------------------------------------- /fashionmnist/train_images.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tmllab/2023_NeurIPS_CS-isolate/a5400d028019648ef1e01cb2bbbdf358af4a7dc8/fashionmnist/train_images.npy -------------------------------------------------------------------------------- /fashionmnist/train_labels.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tmllab/2023_NeurIPS_CS-isolate/a5400d028019648ef1e01cb2bbbdf358af4a7dc8/fashionmnist/train_labels.npy -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | ## Reference: 2 | ## 1. DivideMix: https://github.com/LiJunnan1992/DivideMix 3 | ## 2. CausalNL: https://github.com/a5507203/IDLN 4 | ## Our code is heavily based on the above-mentioned repositories. 5 | 6 | # Loading libraries 7 | import sys 8 | import torch 9 | import torch.nn as nn 10 | import torch.optim as optim 11 | import torch.nn.functional as F 12 | import torch.backends.cudnn as cudnn 13 | import time 14 | import random 15 | import argparse 16 | import numpy as np 17 | from models.PreResNet import * 18 | from models.vae import * 19 | from sklearn.mixture import GaussianMixture 20 | import dataloader 21 | import argparse 22 | import os 23 | import numpy as np 24 | from tqdm import tqdm 25 | import albumentations as A 26 | from albumentations.pytorch import ToTensorV2 27 | 28 | # Default values 29 | parser = argparse.ArgumentParser(description='PyTorch CIFAR Training') 30 | parser.add_argument('--batch_size', default=64, type=int, help='train batchsize') 31 | parser.add_argument('--lr', '--learning_rate', default=0.02, type=float, help='initial learning rate') 32 | parser.add_argument('--vae_lr', '--vae_learning_rate', default=0.001, type=float, help='initial vae learning rate') 33 | parser.add_argument('--noise_mode', default='instance') 34 | parser.add_argument('--alpha', default=4, type=float, help='parameter for Beta') 35 | parser.add_argument('--lambda_u', default=25, type=float, help='weight for unsupervised loss') 36 | parser.add_argument('--p_threshold', default=0.5, type=float, help='clean probability threshold') 37 | parser.add_argument('--T', default=0.5, type=float, help='sharpening temperature') 38 | parser.add_argument('--num_epochs', default=300, type=int) 39 | parser.add_argument('--r', default=0.5, type=float, help='noise ratio') 40 | parser.add_argument('--lambda_elbo', default=0.001, type=float, help='weight for elbo') 41 | parser.add_argument('--lambda_ref', default=0.001, type=float, help='weight for ref') 42 | parser.add_argument('--id', default='') 43 | parser.add_argument('--seed', default=123) 44 | parser.add_argument('--gpuid', default=0, type=int) 45 | parser.add_argument('--num_class', default=10, type=int) 46 | parser.add_argument('--data_path', default='./cifar-10', type=str, help='path to dataset') 47 | parser.add_argument('--dataset', default='cifar10', type=str) 48 | parser.add_argument('--z_dim', default=32, type=int) 49 | args,_ = parser.parse_known_args() 50 | print(args) 51 | 52 | torch.cuda.set_device(args.gpuid) 53 | random.seed(args.seed) 54 | torch.manual_seed(args.seed) 55 | torch.cuda.manual_seed_all(args.seed) 56 | 57 | class AverageMeter(object): 58 | """Computes and stores the average and current value""" 59 | def __init__(self, name, fmt=':f'): 60 | self.name = name 61 | self.fmt = fmt 62 | self.reset() 63 | 64 | def reset(self): 65 | self.val = 0 66 | self.avg = 0 67 | self.sum = 0 68 | self.count = 0 69 | 70 | def update(self, val, n=1): 71 | self.val = val 72 | self.sum += val * n 73 | self.count += n 74 | self.avg = self.sum / self.count 75 | 76 | def __str__(self): 77 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' 78 | return fmtstr.format(**self.__dict__) 79 | 80 | def unpickle(file): 81 | import pickle as cPickle 82 | with open(file, 'rb') as fo: 83 | dict = cPickle.load(fo, encoding='latin1') 84 | return dict 85 | 86 | def preprocess(dataset, root_dir, r, noise_mode, replay_num=1000): 87 | replay_file = '%s/%.2f_%s.npy'%(root_dir,r,noise_mode) 88 | if os.path.exists(replay_file): 89 | return 90 | train_data=[] 91 | train_label=[] 92 | if dataset=='cifar10': 93 | for n in range(1,6): 94 | dpath = '%s/data_batch_%d'%(root_dir,n) 95 | data_dic = unpickle(dpath) 96 | train_data.append(data_dic['data']) 97 | train_label = train_label+data_dic['labels'] 98 | train_data = np.concatenate(train_data) 99 | num_classes_ = 10 100 | train_data = train_data.reshape((50000, 3, 32, 32)) 101 | train_data = train_data.transpose((0, 2, 3, 1)) 102 | elif dataset=='cifar100': 103 | train_dic = unpickle('%s/train'%root_dir) 104 | train_data = train_dic['data'] 105 | train_label = train_dic['fine_labels'] 106 | num_classes_ = 100 107 | train_data = train_data.reshape((50000, 3, 32, 32)) 108 | train_data = train_data.transpose((0, 2, 3, 1)) 109 | elif dataset=='fashionmnist': 110 | train_data = np.load('%s/train_images.npy'%root_dir) 111 | train_label = np.load('%s/train_labels.npy'%root_dir) 112 | num_classes_=10 113 | 114 | ind = np.random.randint(0, train_data.shape[0]) 115 | data = train_data[ind] 116 | if dataset in ['cifar10', 'cifar100']: 117 | transform_train = A.ReplayCompose( 118 | [ 119 | A.ShiftScaleRotate(p=0.5), 120 | A.CropAndPad(px=4, keep_size=False, always_apply=True), 121 | A.RandomCrop(height=32, width=32, always_apply=True), 122 | A.HorizontalFlip(), 123 | A.RandomBrightnessContrast(p=0.5), 124 | A.ColorJitter(0.8, 0.8, 0.8, 0.2,p=0.8), 125 | A.ToGray(p=0.2), 126 | A.Normalize(mean=(0.4914, 0.4822, 0.4465), std=(0.2023, 0.1994, 0.2010)), 127 | ToTensorV2(), 128 | ] 129 | ) 130 | elif dataset=='fashionmnist': 131 | transform_train = A.ReplayCompose( 132 | [ 133 | A.ShiftScaleRotate(p=0.5), 134 | A.CropAndPad(px=2, keep_size=False, always_apply=True), 135 | A.RandomCrop(height=28, width=28, always_apply=True), 136 | A.HorizontalFlip(), 137 | A.Normalize(mean=(0.1307,), std=(0.3081)), 138 | ToTensorV2(), 139 | ] 140 | ) 141 | replay_list = [] 142 | while True: 143 | if len(replay_list) < replay_num: 144 | img = transform_train(image=data) 145 | replay_list.append(img['replay']) 146 | if len(replay_list) == replay_num: 147 | print('saving replay') 148 | np.save(replay_file, np.array(replay_list)) 149 | break 150 | id_file = '%s/%.2f_%s_id.npy'%(root_dir,r,noise_mode) 151 | img_id = np.arange(train_data.shape[0]) 152 | np.random.shuffle(img_id) 153 | img_id+=num_classes_ 154 | print('saving id file') 155 | np.save(id_file, img_id) 156 | 157 | def factor_func(step, end_step): 158 | if step= lr_decrease: 462 | lr /= 10 463 | for param_group in optimizer1.param_groups: 464 | param_group['lr'] = lr 465 | for param_group in optimizer2.param_groups: 466 | param_group['lr'] = lr 467 | lr=args.vae_lr 468 | if epoch >= lr_decrease: 469 | lr /= 10 470 | for param_group in optimizer_vae1.param_groups: 471 | param_group['lr'] = lr 472 | for param_group in optimizer_vae2.param_groups: 473 | param_group['lr'] = lr 474 | 475 | if epoch < warm_up: 476 | print('Warmup Net1') 477 | warmup(epoch,net1,optimizer1,warmup_trainloader) 478 | print('\nWarmup Net2') 479 | warmup(epoch,net2,optimizer2,warmup_trainloader) 480 | else: 481 | if epoch==warm_up: 482 | torch.save({ 483 | 'epoch': epoch, 484 | 'net1_state_dict': net1.state_dict(), 485 | 'net2_state_dict': net2.state_dict(), 486 | 'optimizer1_state_dict': optimizer1.state_dict(), 487 | 'optimizer2_state_dict': optimizer2.state_dict(), 488 | 'vae_model1_state_dict': vae_model1.state_dict(), 489 | 'vae_model2_state_dict': vae_model2.state_dict(), 490 | 'optimizer_vae1_state_dict': optimizer_vae1.state_dict(), 491 | 'optimizer_vae2_state_dict': optimizer_vae2.state_dict() 492 | }, './saved/%s/warmup_checkpoint_%s_%.2f'%(args.dataset, args.noise_mode, args.r)+'.tar') 493 | 494 | prob1,all_loss[0],eval_correct1=eval_train(net1,all_loss[0]) 495 | prob2,all_loss[1],eval_correct2=eval_train(net2,all_loss[1]) 496 | pred1 = (prob1 > args.p_threshold) 497 | pred2 = (prob2 > args.p_threshold) 498 | 499 | print('Train Net1') 500 | print('updating loader') 501 | labeled_trainloader, unlabeled_trainloader = loader.run('train',pred2,prob2) # co-divide 502 | train_all_loader = loader.run("all",pred2,prob2) 503 | loss_1 = train(epoch,net1,net2,optimizer1,vae_model1,vae_model2,optimizer_vae1,labeled_trainloader, unlabeled_trainloader, train_all_loader, net_1=True) # train net1 504 | 505 | print('\nTrain Net2') 506 | print('updating loader') 507 | labeled_trainloader, unlabeled_trainloader = loader.run('train',pred1,prob1) # co-divide 508 | train_all_loader = loader.run("all",pred1,prob1) 509 | loss_2 = train(epoch,net2,net1,optimizer2,vae_model2,vae_model1, optimizer_vae2,labeled_trainloader, unlabeled_trainloader, train_all_loader, net_1=False) # train net2 510 | test(epoch,net1,net2) 511 | pbar.update(epoch) 512 | epoch += 1 513 | torch.save({ 514 | 'epoch': epoch, 515 | 'net1_state_dict': net1.state_dict(), 516 | 'net2_state_dict': net2.state_dict(), 517 | 'optimizer1_state_dict': optimizer1.state_dict(), 518 | 'optimizer2_state_dict': optimizer2.state_dict(), 519 | 'vae_model1_state_dict': vae_model1.state_dict(), 520 | 'vae_model2_state_dict': vae_model2.state_dict(), 521 | 'optimizer_vae1_state_dict': optimizer_vae1.state_dict(), 522 | 'optimizer_vae2_state_dict': optimizer_vae2.state_dict(), 523 | }, './saved/%s/checkpoint_%s_%.2f'%(args.dataset, args.noise_mode, args.r)+'.tar') 524 | pbar.close() 525 | end = time.time() 526 | print(end - start) 527 | 528 | -------------------------------------------------------------------------------- /models/PreResNet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | def conv3x3(in_planes, out_planes, stride=1): 7 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) 8 | 9 | class BasicBlock(nn.Module): 10 | expansion = 1 11 | 12 | def __init__(self, in_planes, planes, stride=1): 13 | super(BasicBlock, self).__init__() 14 | self.conv1 = conv3x3(in_planes, planes, stride) 15 | self.bn1 = nn.BatchNorm2d(planes) 16 | self.conv2 = conv3x3(planes, planes) 17 | self.bn2 = nn.BatchNorm2d(planes) 18 | 19 | self.shortcut = nn.Sequential() 20 | if stride != 1 or in_planes != self.expansion*planes: 21 | self.shortcut = nn.Sequential( 22 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), 23 | nn.BatchNorm2d(self.expansion*planes) 24 | ) 25 | 26 | def forward(self, x): 27 | out = F.relu(self.bn1(self.conv1(x))) 28 | out = self.bn2(self.conv2(out)) 29 | out += self.shortcut(x) 30 | out = F.relu(out) 31 | return out 32 | 33 | 34 | class PreActBlock(nn.Module): 35 | '''Pre-activation version of the BasicBlock.''' 36 | expansion = 1 37 | 38 | def __init__(self, in_planes, planes, stride=1): 39 | super(PreActBlock, self).__init__() 40 | self.bn1 = nn.BatchNorm2d(in_planes) 41 | self.conv1 = conv3x3(in_planes, planes, stride) 42 | self.bn2 = nn.BatchNorm2d(planes) 43 | self.conv2 = conv3x3(planes, planes) 44 | 45 | self.shortcut = nn.Sequential() 46 | if stride != 1 or in_planes != self.expansion*planes: 47 | self.shortcut = nn.Sequential( 48 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False) 49 | ) 50 | 51 | def forward(self, x): 52 | out = F.relu(self.bn1(x)) 53 | shortcut = self.shortcut(out) 54 | out = self.conv1(out) 55 | out = self.conv2(F.relu(self.bn2(out))) 56 | out += shortcut 57 | return out 58 | 59 | 60 | class Bottleneck(nn.Module): 61 | expansion = 4 62 | 63 | def __init__(self, in_planes, planes, stride=1): 64 | super(Bottleneck, self).__init__() 65 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 66 | self.bn1 = nn.BatchNorm2d(planes) 67 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 68 | self.bn2 = nn.BatchNorm2d(planes) 69 | self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False) 70 | self.bn3 = nn.BatchNorm2d(self.expansion*planes) 71 | 72 | self.shortcut = nn.Sequential() 73 | if stride != 1 or in_planes != self.expansion*planes: 74 | self.shortcut = nn.Sequential( 75 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), 76 | nn.BatchNorm2d(self.expansion*planes) 77 | ) 78 | 79 | def forward(self, x): 80 | out = F.relu(self.bn1(self.conv1(x))) 81 | out = F.relu(self.bn2(self.conv2(out))) 82 | out = self.bn3(self.conv3(out)) 83 | out += self.shortcut(x) 84 | out = F.relu(out) 85 | return out 86 | 87 | 88 | class PreActBottleneck(nn.Module): 89 | '''Pre-activation version of the original Bottleneck module.''' 90 | expansion = 4 91 | 92 | def __init__(self, in_planes, planes, stride=1): 93 | super(PreActBottleneck, self).__init__() 94 | self.bn1 = nn.BatchNorm2d(in_planes) 95 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 96 | self.bn2 = nn.BatchNorm2d(planes) 97 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 98 | self.bn3 = nn.BatchNorm2d(planes) 99 | self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False) 100 | self.drop_layer = nn.Dropout(p=0.2) 101 | 102 | self.shortcut = nn.Sequential() 103 | if stride != 1 or in_planes != self.expansion*planes: 104 | self.shortcut = nn.Sequential( 105 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False) 106 | ) 107 | 108 | def forward(self, x): 109 | out = F.relu(self.bn1(x)) 110 | shortcut = self.shortcut(out) 111 | out = self.drop_layer(self.conv1(out)) 112 | out = self.drop_layer(self.conv2(F.relu(self.bn2(out)))) 113 | out = self.drop_layer(self.conv3(F.relu(self.bn3(out)))) 114 | out += shortcut 115 | return out 116 | 117 | 118 | class ResNet(nn.Module): 119 | def __init__(self, block, num_blocks, num_classes=10, in_c=3, z_dim=32): 120 | super(ResNet, self).__init__() 121 | self.in_planes = 64 122 | 123 | self.conv1 = conv3x3(in_c,64) 124 | self.bn1 = nn.BatchNorm2d(64) 125 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 126 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 127 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 128 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 129 | self.mean = nn.Sequential(nn.Linear(512*block.expansion, z_dim)) 130 | self.logvar = nn.Sequential(nn.Linear(512*block.expansion, z_dim)) 131 | self.fc = nn.Sequential(nn.Linear(z_dim, z_dim), 132 | nn.ReLU(), 133 | nn.Linear(z_dim, num_classes)) 134 | def _make_layer(self, block, planes, num_blocks, stride): 135 | strides = [stride] + [1]*(num_blocks-1) 136 | layers = [] 137 | for stride in strides: 138 | layers.append(block(self.in_planes, planes, stride)) 139 | self.in_planes = planes * block.expansion 140 | return nn.Sequential(*layers) 141 | 142 | def forward(self, x, lin=0, lout=5): 143 | out = x 144 | if lin < 1 and lout > -1: 145 | out = self.conv1(out) 146 | out = self.bn1(out) 147 | out = F.relu(out) 148 | if lin < 2 and lout > 0: 149 | out = self.layer1(out) 150 | if lin < 3 and lout > 1: 151 | out = self.layer2(out) 152 | if lin < 4 and lout > 2: 153 | out = self.layer3(out) 154 | if lin < 5 and lout > 3: 155 | out = self.layer4(out) 156 | if lout > 4: 157 | out = F.avg_pool2d(out, 4) 158 | out = out.view(out.size(0), -1) 159 | 160 | zc_mean = self.mean(out) 161 | zc_logvar = self.logvar(out) 162 | zc = self.reparameterize(zc_mean, zc_logvar) 163 | y = self.fc(zc_mean) 164 | 165 | return zc_mean, zc_logvar, zc, y 166 | 167 | def reparameterize(self, mean, logvar): 168 | std = torch.exp(logvar / 2) # in log-space, squareroot is divide by two 169 | epsilon = torch.randn_like(std) 170 | return epsilon * std + mean 171 | 172 | 173 | def ResNet18(num_classes=10, in_c=3, z_dim=32): 174 | return ResNet(PreActBlock, [2,2,2,2], num_classes=num_classes, in_c=in_c, z_dim=z_dim) 175 | 176 | def ResNet34(num_classes=10, z_dim=32): 177 | return ResNet(BasicBlock, [3,4,6,3], num_classes=num_classes, z_dim=z_dim) 178 | 179 | def ResNet50(num_classes=10, z_dim=32): 180 | return ResNet(Bottleneck, [3,4,6,3], num_classes=num_classes, z_dim=z_dim) 181 | 182 | def ResNet101(num_classes=10, z_dim=32): 183 | return ResNet(Bottleneck, [3,4,23,3], num_classes=num_classes, z_dim=z_dim) 184 | 185 | def ResNet152(num_classes=10, z_dim=32): 186 | return ResNet(Bottleneck, [3,8,36,3], num_classes=num_classes, z_dim=z_dim) 187 | 188 | -------------------------------------------------------------------------------- /models/encoders.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.utils.data 3 | from torch import nn 4 | from torch.nn import functional as F 5 | __all__ = ["CONV_Decoder_FMNIST", "CONV_Encoder_FMNIST", "Z_Encoder", "X_Decoder","CONV_Encoder_CIFAR","CONV_Decoder_CIFAR"] 6 | 7 | 8 | def make_hidden_layers(num_hidden_layers=1, hidden_size=5, prefix="y"): 9 | block = nn.Sequential() 10 | for i in range(num_hidden_layers): 11 | block.add_module(prefix+"_"+str(i), nn.Sequential(nn.Linear(hidden_size,hidden_size),nn.BatchNorm1d(hidden_size),nn.LeakyReLU())) 12 | return block 13 | 14 | 15 | class CONV_Encoder_FMNIST(nn.Module): 16 | def __init__(self, in_channels =1, feature_dim = 28, num_classes = 2, hidden_dims = [32, 64, 128, 256], z_dim = 2): 17 | super().__init__() 18 | self.z_dim = z_dim 19 | self.feature_dim = feature_dim 20 | modules = [] 21 | 22 | for h_dim in hidden_dims: 23 | modules.append( 24 | nn.Sequential( 25 | nn.Conv2d(in_channels, out_channels=h_dim, 26 | kernel_size= 3, stride= 2, padding = 1), 27 | nn.BatchNorm2d(h_dim), 28 | nn.LeakyReLU()) 29 | ) 30 | in_channels = h_dim 31 | 32 | self.encoder = nn.Sequential(*modules) 33 | self.fc_mu = nn.Linear(hidden_dims[-1]*4, z_dim) 34 | self.fc_logvar = nn.Linear(hidden_dims[-1]*4, z_dim) 35 | 36 | def forward(self, x): 37 | x = self.encoder(x) 38 | x = torch.flatten(x, start_dim=1) 39 | mu = self.fc_mu(x) 40 | log_var = self.fc_logvar(x) 41 | return mu, log_var 42 | 43 | 44 | class CONV_Decoder_FMNIST(nn.Module): 45 | 46 | def __init__(self, num_classes = 2, hidden_dims = [256, 128, 64, 32], z_dim = 1): 47 | super().__init__() 48 | self.decoder_input = nn.Linear(z_dim, hidden_dims[0] * 4) 49 | modules = [] 50 | for i in range(len(hidden_dims) - 1): 51 | modules.append( 52 | nn.Sequential( 53 | nn.ConvTranspose2d(hidden_dims[i], 54 | hidden_dims[i + 1], 55 | kernel_size=3, 56 | stride = 2, 57 | padding=1, 58 | output_padding=1), 59 | nn.BatchNorm2d(hidden_dims[i + 1]), 60 | nn.LeakyReLU()) 61 | ) 62 | self.decoder = nn.Sequential(*modules) 63 | 64 | self.final_layer = nn.Sequential( 65 | nn.ConvTranspose2d(hidden_dims[-1], 66 | hidden_dims[-1], 67 | kernel_size=3, 68 | stride=2, 69 | padding=1, 70 | ), 71 | nn.BatchNorm2d(hidden_dims[-1]), 72 | nn.LeakyReLU(), 73 | nn.Conv2d(hidden_dims[-1], out_channels= 1, 74 | kernel_size= 4)) 75 | 76 | 77 | def forward(self, z): 78 | out = self.decoder_input(z) 79 | out = out.view(-1, 256, 2, 2) 80 | out = self.decoder(out) 81 | out = self.final_layer(out) 82 | return out 83 | 84 | 85 | 86 | class CONV_Encoder_CIFAR(nn.Module): 87 | def __init__(self, in_channels =3, feature_dim = 32, num_classes = 2, embed_size = 10, hidden_dims = [32, 64, 128, 256], z_dim = 2): 88 | super().__init__() 89 | self.z_dim = z_dim 90 | self.feature_dim = feature_dim 91 | self.in_channels = in_channels 92 | modules = [] 93 | 94 | for h_dim in hidden_dims: 95 | modules.append( 96 | nn.Sequential( 97 | nn.Conv2d(in_channels, out_channels=h_dim, 98 | kernel_size= 3, stride= 2, padding = 1), 99 | nn.BatchNorm2d(h_dim), 100 | nn.LeakyReLU()) 101 | ) 102 | in_channels = h_dim 103 | 104 | self.encoder = nn.Sequential(*modules) 105 | self.fc_mu = nn.Sequential(nn.Linear(hidden_dims[-1]*4,hidden_dims[-1]*4), 106 | nn.LeakyReLU(), 107 | nn.Linear(hidden_dims[-1]*4, z_dim)) 108 | self.fc_logvar = nn.Sequential(nn.Linear(hidden_dims[-1]*4,hidden_dims[-1]*4), 109 | nn.LeakyReLU(), 110 | nn.Linear(hidden_dims[-1]*4, z_dim)) 111 | 112 | def forward(self, x): 113 | x = self.encoder(x) 114 | x = torch.flatten(x, start_dim=1) 115 | mu = self.fc_mu(x) 116 | log_var = self.fc_logvar(x) 117 | return mu, log_var 118 | 119 | 120 | 121 | class CONV_Decoder_CIFAR(nn.Module): 122 | 123 | def __init__(self, num_classes = 2, embed_size=10, hidden_dims = [256, 128, 64,32], z_dim = 1): 124 | super().__init__() 125 | self.decoder_input = nn.Linear(z_dim, hidden_dims[0] * 4) 126 | modules = [] 127 | for i in range(len(hidden_dims) - 1): 128 | modules.append( 129 | nn.Sequential( 130 | nn.ConvTranspose2d(hidden_dims[i], 131 | hidden_dims[i + 1], 132 | kernel_size=3, 133 | stride = 2, 134 | padding=1, 135 | output_padding=1), 136 | nn.BatchNorm2d(hidden_dims[i + 1]), 137 | nn.LeakyReLU()) 138 | ) 139 | self.decoder = nn.Sequential(*modules) 140 | 141 | 142 | self.final_layer = nn.Sequential( 143 | nn.ConvTranspose2d(hidden_dims[-1], 144 | hidden_dims[-1], 145 | kernel_size=3, 146 | stride=2, 147 | padding=1, 148 | output_padding=1), 149 | nn.BatchNorm2d(hidden_dims[-1]), 150 | nn.LeakyReLU(), 151 | nn.Conv2d(hidden_dims[-1], out_channels= 3, 152 | kernel_size= 3, stride=1, padding= 1)) 153 | 154 | 155 | def forward(self, z): 156 | out = self.decoder_input(z) 157 | out = out.view(-1, 256, 2, 2) 158 | out = self.decoder(out) 159 | out = self.final_layer(out) 160 | return out 161 | 162 | 163 | 164 | class Z_Encoder(nn.Module): 165 | def __init__(self, feature_dim = 2, num_classes = 2, embed_size=10, num_hidden_layers=1, hidden_size = 5, z_dim = 2): 166 | super().__init__() 167 | self.z_fc1 = nn.Linear(feature_dim, hidden_size) 168 | self.z_h_layers = make_hidden_layers(num_hidden_layers, hidden_size=hidden_size, prefix="z") 169 | self.z_fc_mu = nn.Linear(hidden_size, z_dim) # fc21 for mean of Z 170 | self.z_fc_logvar = nn.Linear(hidden_size, z_dim) # fc22 for log variance of Z 171 | 172 | def forward(self, x): 173 | out = F.leaky_relu(self.z_fc1(x)) 174 | out = self.z_h_layers(out) 175 | mu = F.elu(self.z_fc_mu(out)) 176 | logvar = F.elu(self.z_fc_logvar(out)) 177 | return mu, logvar 178 | 179 | 180 | class X_Decoder(nn.Module): 181 | def __init__(self, feature_dim = 2, num_classes = 2, embed_size=10, num_hidden_layers=1, hidden_size = 5, z_dim = 1): 182 | super().__init__() 183 | self.recon_fc1 = nn.Linear(z_dim, hidden_size) 184 | self.recon_h_layers = make_hidden_layers(num_hidden_layers, hidden_size=hidden_size, prefix="recon") 185 | self.recon_fc2 = nn.Linear(hidden_size, feature_dim) 186 | 187 | def forward(self, z): 188 | out = F.leaky_relu(self.recon_fc1(z)) 189 | out = self.recon_h_layers(out) 190 | x = self.recon_fc2(out) 191 | return x -------------------------------------------------------------------------------- /models/vae.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.utils.data 4 | from torch import nn, optim 5 | from torch.nn import functional as F 6 | from .encoders import * 7 | from .PreResNet import * 8 | 9 | 10 | __all__ = ["VAE_FASHIONMNIST","VAE_CIFAR10","VAE_CIFAR100"] 11 | 12 | 13 | class PositionalEncoding(nn.Module): 14 | 15 | def __init__(self, d_model: int, max_len: int = 5000): 16 | super().__init__() 17 | position = torch.arange(max_len).unsqueeze(1) 18 | div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)) 19 | pe = torch.zeros(max_len, d_model) 20 | pe[:, 0::2] = torch.sin(position * div_term) 21 | pe[:, 1::2] = torch.cos(position * div_term) 22 | self.register_buffer('pe', pe) 23 | 24 | def forward(self, index: torch.Tensor) -> torch.Tensor: 25 | """ 26 | Arguments: 27 | index: Tensor, shape ``[batch_size]`` 28 | """ 29 | return self.pe[index] 30 | 31 | 32 | class BaseVAE(nn.Module): 33 | def __init__(self, feature_dim=28, num_hidden_layers=1, hidden_size=32, z_dim =10, num_classes=100, embed_size = 10): 34 | super().__init__() 35 | self.zs_encoder = Z_Encoder(feature_dim=feature_dim, num_classes=num_classes, embed_size=embed_size, num_hidden_layers=num_hidden_layers, hidden_size = hidden_size, z_dim=z_dim) 36 | self.x_decoder = X_Decoder(feature_dim=feature_dim, num_hidden_layers=num_hidden_layers, num_classes=num_classes, embed_size=embed_size, hidden_size = hidden_size, z_dim=2*z_dim) 37 | self.kl_divergence = None 38 | self.flow = None 39 | self.z_dim = z_dim 40 | self.num_classes = num_classes 41 | self.embed_size = embed_size 42 | self.replay_num = 1000+10 43 | self.zc_prior = nn.Sequential(nn.Linear(32, 128), 44 | nn.LeakyReLU(), 45 | nn.Linear(128,128), 46 | nn.LeakyReLU(), 47 | nn.Linear(128, z_dim)) 48 | self.zs_prior = nn.Sequential(nn.Linear(self.embed_size,64), 49 | nn.LeakyReLU(), 50 | nn.Linear(64,64), 51 | nn.LeakyReLU(), 52 | nn.Linear(64, z_dim)) 53 | self.pe_uc = PositionalEncoding(32, 80000) 54 | self.pe_us = PositionalEncoding(self.embed_size, self.replay_num) 55 | 56 | def _y_hat_reparameterize(self, c_logits): 57 | return F.gumbel_softmax(c_logits, dim=1) 58 | 59 | def _z_reparameterize(self, mu, logvar): 60 | std = torch.exp(0.5*logvar) 61 | eps = torch.rand_like(std) 62 | return mu + eps*std 63 | 64 | def forward(self, x: torch.Tensor, uc: torch.Tensor, us: torch.Tensor, net): 65 | zc_mean, zc_logvar, zc, c_logits = net.forward(x) 66 | embed_uc = self.pe_uc(uc) 67 | embed_us = self.pe_us(us) 68 | p_zc_m = self.zc_prior(embed_uc) 69 | p_zs_m = self.zs_prior(embed_us) 70 | zs_mean, zs_logvar = self.zs_encoder(x) 71 | 72 | zs = self._z_reparameterize(zs_mean, zs_logvar) 73 | z = torch.cat((zc, zs),dim=1) 74 | x_hat = self.x_decoder(z) 75 | 76 | return x_hat, c_logits, zc_mean, zc_logvar, zs_mean, zs_logvar, p_zc_m, p_zs_m 77 | 78 | 79 | 80 | class VAE_FASHIONMNIST(BaseVAE): 81 | def __init__(self, feature_dim=28, input_channel=1, z_dim =10, num_classes=10, embed_size = 10): 82 | super().__init__() 83 | 84 | self.zs_encoder = CONV_Encoder_FMNIST(feature_dim=feature_dim, num_classes=num_classes, z_dim=z_dim) 85 | self.x_decoder = CONV_Decoder_FMNIST(num_classes=num_classes, z_dim=2*z_dim) 86 | self.z_dim = z_dim 87 | self.num_classes = num_classes 88 | self.embed_size = embed_size 89 | self.replay_num = 1000+10 90 | self.zc_prior = nn.Sequential(nn.Linear(32, 128), 91 | nn.LeakyReLU(), 92 | nn.Linear(128,128), 93 | nn.LeakyReLU(), 94 | nn.Linear(128, z_dim)) 95 | self.zs_prior = nn.Sequential(nn.Linear(self.embed_size,64), 96 | nn.LeakyReLU(), 97 | nn.Linear(64,64), 98 | nn.LeakyReLU(), 99 | nn.Linear(64, z_dim)) 100 | self.pe_uc = PositionalEncoding(32, 80000) 101 | self.pe_us = PositionalEncoding(self.embed_size, self.replay_num) 102 | 103 | 104 | class VAE_CIFAR100(BaseVAE): 105 | def __init__(self, feature_dim=32, input_channel=3, z_dim =32, num_classes=100, embed_size = 10): 106 | super().__init__() 107 | self.zs_encoder = CONV_Encoder_CIFAR(feature_dim=feature_dim, num_classes=num_classes, embed_size=embed_size, z_dim=z_dim) 108 | self.x_decoder = CONV_Decoder_CIFAR(num_classes=num_classes, z_dim=2*z_dim) 109 | self.z_dim = z_dim 110 | self.num_classes = num_classes 111 | self.embed_size = embed_size 112 | self.replay_num = 1000+10 113 | self.zc_prior = nn.Sequential(nn.Linear(32, 128), 114 | nn.LeakyReLU(), 115 | nn.Linear(128,128), 116 | nn.LeakyReLU(), 117 | nn.Linear(128, z_dim)) 118 | self.zs_prior = nn.Sequential(nn.Linear(self.embed_size,64), 119 | nn.LeakyReLU(), 120 | nn.Linear(64,64), 121 | nn.LeakyReLU(), 122 | nn.Linear(64, z_dim)) 123 | self.pe_uc = PositionalEncoding(32, 80000) 124 | self.pe_us = PositionalEncoding(self.embed_size, self.replay_num) 125 | 126 | 127 | class VAE_CIFAR10(BaseVAE): 128 | def __init__(self, feature_dim=32, input_channel=3, z_dim =32, num_classes=10, embed_size = 10): 129 | super().__init__() 130 | self.zs_encoder = CONV_Encoder_CIFAR(feature_dim=feature_dim, num_classes=num_classes, embed_size=embed_size, z_dim=z_dim) 131 | self.x_decoder = CONV_Decoder_CIFAR(num_classes=num_classes, z_dim=2*z_dim) 132 | self.z_dim = z_dim 133 | self.num_classes = num_classes 134 | self.embed_size = embed_size 135 | self.replay_num = 1000+10 136 | self.zc_prior = nn.Sequential(nn.Linear(32, 128), 137 | nn.LeakyReLU(), 138 | nn.Linear(128,128), 139 | nn.LeakyReLU(), 140 | nn.Linear(128, z_dim)) 141 | self.zs_prior = nn.Sequential(nn.Linear(self.embed_size,64), 142 | nn.LeakyReLU(), 143 | nn.Linear(64,64), 144 | nn.LeakyReLU(), 145 | nn.Linear(64, z_dim)) 146 | self.pe_uc = PositionalEncoding(32, 80000) 147 | self.pe_us = PositionalEncoding(self.embed_size, self.replay_num) 148 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | albumentations==1.3.1 2 | numpy==1.24.3 3 | pandas==2.0.3 4 | Pillow==9.4.0 5 | Pillow==10.1.0 6 | scikit_learn==1.3.0 7 | scipy==1.11.4 8 | torch==1.12.1 9 | torchnet==0.0.4 10 | torchvision==0.13.1 11 | tqdm==4.65.0 12 | -------------------------------------------------------------------------------- /tools.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | from math import inf 5 | from scipy import stats 6 | import torch.nn.functional as F 7 | import torch.nn as nn 8 | from numpy.testing import assert_array_almost_equal 9 | 10 | def get_instance_noisy_label(n, dataset, labels, num_classes, feature_size, norm_std, seed): 11 | # n -> noise_rate 12 | # dataset -> mnist, cifar10 # not train_loader 13 | # labels -> labels (targets) 14 | # label_num -> class number 15 | # feature_size -> the size of input images (e.g. 28*28) 16 | # norm_std -> default 0.1 17 | # seed -> random_seed 18 | print("building dataset...") 19 | label_num = num_classes 20 | np.random.seed(int(seed)) 21 | torch.manual_seed(int(seed)) 22 | torch.cuda.manual_seed(int(seed)) 23 | 24 | P = [] 25 | flip_distribution = stats.truncnorm((0 - n) / norm_std, (1 - n) / norm_std, loc=n, scale=norm_std) 26 | flip_rate = flip_distribution.rvs(labels.shape[0]) 27 | 28 | if isinstance(labels, list): 29 | labels = torch.FloatTensor(labels) 30 | labels = labels.cuda() 31 | 32 | W = np.random.randn(label_num, feature_size, label_num) 33 | 34 | W = torch.FloatTensor(W).cuda() 35 | for i, (x, y) in enumerate(dataset): 36 | # 1*m * m*10 = 1*10 37 | x = x.cuda() 38 | A = x.contiguous().view(1, -1).mm(W[y]).squeeze(0) 39 | A[y] = -inf 40 | A = flip_rate[i] * F.softmax(A, dim=0) 41 | A[y] += 1 - flip_rate[i] 42 | P.append(A) 43 | P = torch.stack(P, 0).cpu().numpy() 44 | l = [i for i in range(label_num)] 45 | new_label = [np.random.choice(l, p=P[i]) for i in range(labels.shape[0])] 46 | record = [[0 for _ in range(label_num)] for i in range(label_num)] 47 | 48 | for a, b in zip(labels, new_label): 49 | a, b = int(a), int(b) 50 | record[a][b] += 1 51 | 52 | pidx = np.random.choice(range(P.shape[0]), 1000) 53 | cnt = 0 54 | for i in range(1000): 55 | if labels[pidx[i]] == 0: 56 | a = P[pidx[i], :] 57 | cnt += 1 58 | if cnt >= 10: 59 | break 60 | return np.array(new_label) 61 | 62 | # basic function 63 | def multiclass_noisify(y, P, random_state=1): 64 | """ Flip classes according to transition probability matrix T. 65 | It expects a number between 0 and the number of classes - 1. 66 | """ 67 | # print (np.max(y), P.shape[0]) 68 | assert P.shape[0] == P.shape[1] 69 | assert np.max(y) < P.shape[0] 70 | 71 | # row stochastic matrix 72 | assert_array_almost_equal(P.sum(axis=1), np.ones(P.shape[1])) 73 | assert (P >= 0.0).all() 74 | 75 | m = y.shape[0] 76 | new_y = y.copy() 77 | flipper = np.random.RandomState(random_state) 78 | 79 | for idx in np.arange(m): 80 | i = y[idx] 81 | # draw a vector with only an 1 82 | flipped = flipper.multinomial(1, P[i, :][0], 1)[0] 83 | new_y[idx] = np.where(flipped == 1)[0] 84 | 85 | return new_y 86 | 87 | # noisify_pairflip call the function "multiclass_noisify" 88 | def noisify_pairflip(y_train, noise, random_state=1, nb_classes=10): 89 | """mistakes: 90 | flip in the pair 91 | """ 92 | P = np.eye(nb_classes) 93 | n = noise 94 | 95 | if n > 0.0: 96 | # 0 -> 1 97 | P[0, 0], P[0, 1] = 1. - n, n 98 | for i in range(1, nb_classes-1): 99 | P[i, i], P[i, i + 1] = 1. - n, n 100 | P[nb_classes-1, nb_classes-1], P[nb_classes-1, 0] = 1. - n, n 101 | 102 | y_train_noisy = multiclass_noisify(y_train, P=P, 103 | random_state=random_state) 104 | actual_noise = (y_train_noisy != y_train).mean() 105 | assert actual_noise > 0.0 106 | print('Actual noise %.2f' % actual_noise) 107 | y_train = y_train_noisy 108 | # print (P) 109 | 110 | return y_train 111 | 112 | --------------------------------------------------------------------------------