├── README.md ├── checkpoint └── erm_r100_c10_trial1.t7 ├── config.py ├── data_loader.py ├── etc ├── celebA_test_orig.txt ├── celebA_train_orig.txt ├── celebA_val_orig.txt └── celeb_loader.py ├── models ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-37.pyc │ └── resnet32.cpython-37.pyc └── resnet32.py ├── run.sh ├── train.py └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | # M2m: Imbalanced Classification via Major-to-minor Translation 2 | 3 | This repository contains code for the paper 4 | **"M2m: Imbalanced Classification via Major-to-minor Translation"** 5 | by [Jaehyung Kim](https://sites.google.com/view/jaehyungkim)\*, [Jongheon Jeong](https://sites.google.com/view/jongheonj)\* and [Jinwoo Shin](http://alinlab.kaist.ac.kr/shin.html). 6 | 7 | ## Dependencies 8 | 9 | * `python3` 10 | * `pytorch >= 1.1.0` 11 | * `torchvision` 12 | * `tqdm` 13 | 14 | ## Scripts 15 | Please check out `run.sh` for all the scripts to reproduce the CIFAR-10-LT results reported. 16 | 17 | ### Training procedure of M2m 18 | 1. Train a baseline network g for generating minority samples 19 | ``` 20 | python train.py --no_over --ratio 100 --decay 2e-4 --model resnet32 --dataset cifar10 \ 21 | --lr 0.1 --batch-size 128 --name 'ERM' --warm 200 --epoch 200 22 | ``` 23 | 2. Train another network f using M2m with the pre-trained g 24 | ``` 25 | python train.py -gen -r --ratio 100 --decay 2e-4 --model resnet32 --dataset cifar10 \ 26 | --lr 0.1 --batch-size 128 --name 'M2m' --beta 0.999 --lam 0.5 --gamma 0.9 \ 27 | --step_size 0.1 --attack_iter 10 --warm 160 --epoch 200 --net_g ./checkpoint/pre_trained_g.t7 28 | ``` 29 | We also provide a pre-trained ResNet-32 model of g at `checkpoint/erm_r100_c10_trial1.t7`, 30 | so one can directly use M2m without pre-training as follows: 31 | ``` 32 | python train.py -gen -r --ratio 100 --decay 2e-4 --model resnet32 --dataset cifar10 \ 33 | --lr 0.1 --batch-size 128 --name 'M2m' --beta 0.999 --lam 0.5 --gamma 0.9 \ 34 | --step_size 0.1 --attack_iter 10 --warm 160 --epoch 200 --net_g ./checkpoint/erm_r100_c10_trial1.t7 35 | ``` 36 | -------------------------------------------------------------------------------- /checkpoint/erm_r100_c10_trial1.t7: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alinlab/M2m/42d08a5399c1b62925044287e7ee8d134260a08a/checkpoint/erm_r100_c10_trial1.t7 -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import torch.backends.cudnn as cudnn 7 | import torchvision.transforms as transforms 8 | 9 | from data_loader import make_longtailed_imb, get_imbalanced, get_oversampled, get_smote 10 | from utils import InputNormalize, sum_t 11 | import models 12 | 13 | 14 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 15 | cudnn.benchmark = True 16 | if torch.cuda.is_available(): 17 | N_GPUS = torch.cuda.device_count() 18 | else: 19 | N_GPUS = 0 20 | 21 | 22 | def parse_args(): 23 | parser = argparse.ArgumentParser(description='PyTorch CIFAR10 Training') 24 | parser.add_argument('--lr', default=0.1, type=float, help='learning rate') 25 | parser.add_argument('--model', default='resnet32', type=str, 26 | help='model type (default: ResNet18)') 27 | parser.add_argument('--batch-size', default=128, type=int, help='batch size') 28 | parser.add_argument('--epoch', default=200, type=int, 29 | help='total epochs to run') 30 | parser.add_argument('--seed', default=None, type=int, help='random seed') 31 | parser.add_argument('--dataset', required=True, 32 | choices=['cifar10', 'cifar100'], help='Dataset') 33 | parser.add_argument('--decay', default=2e-4, type=float, help='weight decay') 34 | parser.add_argument('--no-augment', dest='augment', action='store_false', 35 | help='use standard augmentation (default: True)') 36 | 37 | parser.add_argument('--name', default='0', type=str, help='name of run') 38 | parser.add_argument('--resume', '-r', action='store_true', 39 | help='resume from checkpoint') 40 | parser.add_argument('--net_g', default=None, type=str, 41 | help='checkpoint path of network for generation') 42 | parser.add_argument('--net_g2', default=None, type=str, 43 | help='checkpoint path of network for generation') 44 | parser.add_argument('--net_t', default=None, type=str, 45 | help='checkpoint path of network for train') 46 | parser.add_argument('--net_both', default=None, type=str, 47 | help='checkpoint path of both networks') 48 | 49 | parser.add_argument('--beta', default=0.999, type=float, help='Hyper-parameter for rejection/sampling') 50 | parser.add_argument('--lam', default=0.5, type=float, help='Hyper-parameter for regularization of translation') 51 | parser.add_argument('--warm', default=160, type=int, help='Deferred strategy for re-balancing') 52 | parser.add_argument('--gamma', default=0.99, type=float, help='Threshold of the generation') 53 | 54 | parser.add_argument('--eff_beta', default=1.0, type=float, help='Hyper-parameter for effective number') 55 | parser.add_argument('--focal_gamma', default=1.0, type=float, help='Hyper-parameter for Focal Loss') 56 | 57 | parser.add_argument('--gen', '-gen', action='store_true', help='') 58 | parser.add_argument('--step_size', default=0.1, type=float, help='') 59 | parser.add_argument('--attack_iter', default=10, type=int, help='') 60 | 61 | parser.add_argument('--imb_type', default='longtail', type=str, 62 | choices=['none', 'longtail', 'step'], 63 | help='Type of artificial imbalance') 64 | parser.add_argument('--loss_type', default='CE', type=str, 65 | choices=['CE', 'Focal', 'LDAM'], 66 | help='Type of loss for imbalance') 67 | parser.add_argument('--ratio', default=100, type=int, help='max/min') 68 | parser.add_argument('--imb_start', default=5, type=int, help='start idx of step imbalance') 69 | 70 | parser.add_argument('--smote', '-s', action='store_true', help='oversampling') 71 | parser.add_argument('--cost', '-c', action='store_true', help='oversampling') 72 | parser.add_argument('--effect_over', action='store_true', help='Use effective number in oversampling') 73 | parser.add_argument('--no_over', dest='over', action='store_false', help='Do not use over-sampling') 74 | 75 | return parser.parse_args() 76 | 77 | 78 | ARGS = parse_args() 79 | if ARGS.seed is not None: 80 | SEED = ARGS.seed 81 | else: 82 | SEED = np.random.randint(10000) 83 | np.random.seed(SEED) 84 | torch.manual_seed(SEED) 85 | torch.cuda.manual_seed_all(SEED) 86 | 87 | DATASET = ARGS.dataset 88 | BATCH_SIZE = ARGS.batch_size 89 | MODEL = ARGS.model 90 | 91 | LR = ARGS.lr 92 | EPOCH = ARGS.epoch 93 | START_EPOCH = 0 94 | 95 | LOGFILE_BASE = f"S{SEED}_{ARGS.name}_" \ 96 | f"L{ARGS.lam}_W{ARGS.warm}_" \ 97 | f"E{ARGS.step_size}_I{ARGS.attack_iter}_" \ 98 | f"{DATASET}_R{ARGS.ratio}_{MODEL}_G{ARGS.gamma}_B{ARGS.beta}" 99 | 100 | # Data 101 | print('==> Preparing data: %s' % DATASET) 102 | 103 | if DATASET == 'cifar100': 104 | N_CLASSES = 100 105 | N_SAMPLES = 500 106 | mean = torch.tensor([0.5071, 0.4867, 0.4408]) 107 | std = torch.tensor([0.2675, 0.2565, 0.2761]) 108 | elif DATASET == 'cifar10': 109 | N_CLASSES = 10 110 | N_SAMPLES = 5000 111 | mean = torch.tensor([0.4914, 0.4822, 0.4465]) 112 | std = torch.tensor([0.2023, 0.1994, 0.2010]) 113 | else: 114 | raise NotImplementedError() 115 | 116 | normalizer = InputNormalize(mean, std).to(device) 117 | 118 | if 'cifar' in DATASET: 119 | if ARGS.augment: 120 | transform_train = transforms.Compose([ 121 | transforms.RandomCrop(32, padding=4), 122 | transforms.RandomHorizontalFlip(), 123 | transforms.ToTensor(), 124 | ]) 125 | else: 126 | transform_train = transforms.Compose([ 127 | transforms.ToTensor(), 128 | ]) 129 | 130 | transform_test = transforms.Compose([ 131 | transforms.ToTensor(), 132 | ]) 133 | else: 134 | raise NotImplementedError() 135 | 136 | ## Data Loader ## 137 | 138 | N_SAMPLES_PER_CLASS_BASE = [int(N_SAMPLES)] * N_CLASSES 139 | if ARGS.imb_type == 'longtail': 140 | N_SAMPLES_PER_CLASS_BASE = make_longtailed_imb(N_SAMPLES, N_CLASSES, ARGS.ratio) 141 | elif ARGS.imb_type == 'step': 142 | for i in range(ARGS.imb_start, N_CLASSES): 143 | N_SAMPLES_PER_CLASS_BASE[i] = int(N_SAMPLES * (1 / ARGS.ratio)) 144 | 145 | N_SAMPLES_PER_CLASS_BASE = tuple(N_SAMPLES_PER_CLASS_BASE) 146 | print(N_SAMPLES_PER_CLASS_BASE) 147 | 148 | train_loader, val_loader, test_loader = get_imbalanced(DATASET, N_SAMPLES_PER_CLASS_BASE, BATCH_SIZE, 149 | transform_train, transform_test) 150 | 151 | ## To apply effective number for over-sampling or cost-sensitive ## 152 | 153 | if ARGS.over and ARGS.effect_over: 154 | _beta = ARGS.eff_beta 155 | effective_num = 1.0 - np.power(_beta, N_SAMPLES_PER_CLASS_BASE) 156 | N_SAMPLES_PER_CLASS = tuple(np.array(effective_num) / (1 - _beta)) 157 | print(N_SAMPLES_PER_CLASS) 158 | else: 159 | N_SAMPLES_PER_CLASS = N_SAMPLES_PER_CLASS_BASE 160 | N_SAMPLES_PER_CLASS_T = torch.Tensor(N_SAMPLES_PER_CLASS).to(device) 161 | 162 | 163 | def adjust_learning_rate(optimizer, lr_init, epoch): 164 | """decrease the learning rate at 160 and 180 epoch ( from LDAM-DRW, NeurIPS19 )""" 165 | lr = lr_init 166 | 167 | if epoch < 5: 168 | lr = (epoch + 1) * lr_init / 5 169 | else: 170 | if epoch >= 160: 171 | lr /= 100 172 | if epoch >= 180: 173 | lr /= 100 174 | 175 | for param_group in optimizer.param_groups: 176 | param_group['lr'] = lr 177 | 178 | 179 | def evaluate(net, dataloader, logger=None): 180 | is_training = net.training 181 | net.eval() 182 | criterion = nn.CrossEntropyLoss() 183 | 184 | total_loss = 0.0 185 | correct, total = 0.0, 0.0 186 | major_correct, neutral_correct, minor_correct = 0.0, 0.0, 0.0 187 | major_total, neutral_total, minor_total = 0.0, 0.0, 0.0 188 | 189 | class_correct = torch.zeros(N_CLASSES) 190 | class_total = torch.zeros(N_CLASSES) 191 | 192 | for inputs, targets in dataloader: 193 | batch_size = inputs.size(0) 194 | inputs, targets = inputs.to(device), targets.to(device) 195 | 196 | outputs, _ = net(normalizer(inputs)) 197 | loss = criterion(outputs, targets) 198 | 199 | total_loss += loss.item() * batch_size 200 | predicted = outputs[:, :N_CLASSES].max(1)[1] 201 | total += batch_size 202 | correct_mask = (predicted == targets) 203 | correct += sum_t(correct_mask) 204 | 205 | # For accuracy of minority / majority classes. 206 | major_mask = targets < (N_CLASSES // 3) 207 | major_total += sum_t(major_mask) 208 | major_correct += sum_t(correct_mask * major_mask) 209 | 210 | minor_mask = targets >= (N_CLASSES - (N_CLASSES // 3)) 211 | minor_total += sum_t(minor_mask) 212 | minor_correct += sum_t(correct_mask * minor_mask) 213 | 214 | neutral_mask = ~(major_mask + minor_mask) 215 | neutral_total += sum_t(neutral_mask) 216 | neutral_correct += sum_t(correct_mask * neutral_mask) 217 | 218 | for i in range(N_CLASSES): 219 | class_mask = (targets == i) 220 | class_total[i] += sum_t(class_mask) 221 | class_correct[i] += sum_t(correct_mask * class_mask) 222 | 223 | results = { 224 | 'loss': total_loss / total, 225 | 'acc': 100. * correct / total, 226 | 'major_acc': 100. * major_correct / major_total, 227 | 'neutral_acc': 100. * neutral_correct / neutral_total, 228 | 'minor_acc': 100. * minor_correct / minor_total, 229 | 'class_acc': 100. * class_correct / class_total, 230 | } 231 | 232 | msg = 'Loss: %.3f | Acc: %.3f%% (%d/%d) | Major_ACC: %.3f%% | Neutral_ACC: %.3f%% | Minor ACC: %.3f%% ' % \ 233 | ( 234 | results['loss'], results['acc'], correct, total, 235 | results['major_acc'], results['neutral_acc'], results['minor_acc'] 236 | ) 237 | if logger: 238 | logger.log(msg) 239 | else: 240 | print(msg) 241 | 242 | net.train(is_training) 243 | return results 244 | -------------------------------------------------------------------------------- /data_loader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import numpy.random as nr 4 | import numpy as np 5 | import bisect 6 | from PIL import Image 7 | 8 | from torchvision import datasets, transforms 9 | from torch.utils.data import DataLoader 10 | from torch.utils.data.sampler import SubsetRandomSampler, WeightedRandomSampler 11 | from torch.utils.data import Dataset 12 | from torch.utils.data import TensorDataset 13 | from scipy import io 14 | 15 | num_test_samples_cifar10 = [1000] * 10 16 | num_test_samples_cifar100 = [100] * 100 17 | 18 | DATA_ROOT = os.path.expanduser('~/data') 19 | 20 | 21 | def make_longtailed_imb(max_num, class_num, gamma): 22 | mu = np.power(1/gamma, 1/(class_num - 1)) 23 | print(mu) 24 | class_num_list = [] 25 | for i in range(class_num): 26 | class_num_list.append(int(max_num * np.power(mu, i))) 27 | 28 | return list(class_num_list) 29 | 30 | 31 | def get_val_test_data(dataset, num_sample_per_class, shuffle=False, random_seed=0): 32 | """ 33 | Return a list of indices for validation and test from a dataset. 34 | Input: A test dataset (e.g., CIFAR-10) 35 | Output: validation_list and test_list 36 | """ 37 | length = dataset.__len__() 38 | num_sample_per_class = list(num_sample_per_class) 39 | num_samples = num_sample_per_class[0] # Suppose that all classes have the same number of test samples 40 | 41 | val_list = [] 42 | test_list = [] 43 | indices = list(range(0, length)) 44 | if shuffle: 45 | nr.shuffle(indices) 46 | for i in range(0, length): 47 | index = indices[i] 48 | _, label = dataset.__getitem__(index) 49 | if num_sample_per_class[label] > (9 * num_samples / 10): 50 | val_list.append(index) 51 | num_sample_per_class[label] -= 1 52 | else: 53 | test_list.append(index) 54 | num_sample_per_class[label] -= 1 55 | 56 | return val_list, test_list 57 | 58 | 59 | def get_oversampled_data(dataset, num_sample_per_class, random_seed=0): 60 | """ 61 | Return a list of imbalanced indices from a dataset. 62 | Input: A dataset (e.g., CIFAR-10), num_sample_per_class: list of integers 63 | Output: oversampled_list ( weights are increased ) 64 | """ 65 | length = dataset.__len__() 66 | num_sample_per_class = list(num_sample_per_class) 67 | num_samples = list(num_sample_per_class) 68 | 69 | selected_list = [] 70 | indices = list(range(0,length)) 71 | for i in range(0, length): 72 | index = indices[i] 73 | _, label = dataset.__getitem__(index) 74 | if num_sample_per_class[label] > 0: 75 | selected_list.append(1 / num_samples[label]) 76 | num_sample_per_class[label] -= 1 77 | 78 | return selected_list 79 | 80 | 81 | def get_imbalanced_data(dataset, num_sample_per_class, shuffle=False, random_seed=0): 82 | """ 83 | Return a list of imbalanced indices from a dataset. 84 | Input: A dataset (e.g., CIFAR-10), num_sample_per_class: list of integers 85 | Output: imbalanced_list 86 | """ 87 | length = dataset.__len__() 88 | num_sample_per_class = list(num_sample_per_class) 89 | selected_list = [] 90 | indices = list(range(0,length)) 91 | 92 | for i in range(0, length): 93 | index = indices[i] 94 | _, label = dataset.__getitem__(index) 95 | if num_sample_per_class[label] > 0: 96 | selected_list.append(index) 97 | num_sample_per_class[label] -= 1 98 | 99 | return selected_list 100 | 101 | 102 | def get_oversampled(dataset, num_sample_per_class, batch_size, TF_train, TF_test): 103 | print("Building {} CV data loader with {} workers".format(dataset, 8)) 104 | ds = [] 105 | 106 | if dataset == 'cifar10': 107 | dataset_ = datasets.CIFAR10 108 | num_test_samples = num_test_samples_cifar10 109 | elif dataset == 'cifar100': 110 | dataset_ = datasets.CIFAR100 111 | num_test_samples = num_test_samples_cifar100 112 | else: 113 | raise NotImplementedError() 114 | 115 | train_cifar = dataset_(root=DATA_ROOT, train=True, download=False, transform=TF_train) 116 | 117 | targets = np.array(train_cifar.targets) 118 | classes, class_counts = np.unique(targets, return_counts=True) 119 | nb_classes = len(classes) 120 | 121 | imbal_class_counts = [int(i) for i in num_sample_per_class] 122 | class_indices = [np.where(targets == i)[0] for i in range(nb_classes)] 123 | 124 | imbal_class_indices = [class_idx[:class_count] for class_idx, class_count in zip(class_indices, imbal_class_counts)] 125 | imbal_class_indices = np.hstack(imbal_class_indices) 126 | 127 | train_cifar.targets = targets[imbal_class_indices] 128 | train_cifar.data = train_cifar.data[imbal_class_indices] 129 | 130 | assert len(train_cifar.targets) == len(train_cifar.data) 131 | 132 | train_in_idx = get_oversampled_data(train_cifar, num_sample_per_class) 133 | train_in_loader = DataLoader(train_cifar, batch_size=batch_size, 134 | sampler=WeightedRandomSampler(train_in_idx, len(train_in_idx)), num_workers=8) 135 | ds.append(train_in_loader) 136 | 137 | test_cifar = dataset_(root=DATA_ROOT, train=False, download=False, transform=TF_test) 138 | val_idx, test_idx = get_val_test_data(test_cifar, num_test_samples) 139 | val_loader = DataLoader(test_cifar, batch_size=100, 140 | sampler=SubsetRandomSampler(val_idx), num_workers=8) 141 | test_loader = DataLoader(test_cifar, batch_size=100, 142 | sampler=SubsetRandomSampler(test_idx), num_workers=8) 143 | ds.append(val_loader) 144 | ds.append(test_loader) 145 | ds = ds[0] if len(ds) == 1 else ds 146 | 147 | return ds 148 | 149 | 150 | def get_imbalanced(dataset, num_sample_per_class, batch_size, TF_train, TF_test): 151 | print("Building CV {} data loader with {} workers".format(dataset, 8)) 152 | ds = [] 153 | 154 | if dataset == 'cifar10': 155 | dataset_ = datasets.CIFAR10 156 | num_test_samples = num_test_samples_cifar10 157 | elif dataset == 'cifar100': 158 | dataset_ = datasets.CIFAR100 159 | num_test_samples = num_test_samples_cifar100 160 | else: 161 | raise NotImplementedError() 162 | 163 | train_cifar = dataset_(root=DATA_ROOT, train=True, download=False, transform=TF_train) 164 | train_in_idx = get_imbalanced_data(train_cifar, num_sample_per_class) 165 | train_in_loader = torch.utils.data.DataLoader(train_cifar, batch_size=batch_size, 166 | sampler=SubsetRandomSampler(train_in_idx), num_workers=8) 167 | ds.append(train_in_loader) 168 | 169 | test_cifar = dataset_(root=DATA_ROOT, train=False, download=False, transform=TF_test) 170 | val_idx, test_idx= get_val_test_data(test_cifar, num_test_samples) 171 | val_loader = torch.utils.data.DataLoader(test_cifar, batch_size=100, 172 | sampler=SubsetRandomSampler(val_idx), num_workers=8) 173 | test_loader = torch.utils.data.DataLoader(test_cifar, batch_size=100, 174 | sampler=SubsetRandomSampler(test_idx), num_workers=8) 175 | ds.append(val_loader) 176 | ds.append(test_loader) 177 | ds = ds[0] if len(ds) == 1 else ds 178 | 179 | return ds 180 | 181 | 182 | def smote(data, targets, n_class, n_max): 183 | aug_data = [] 184 | aug_label = [] 185 | 186 | for k in range(1, n_class): 187 | indices = np.where(targets == k)[0] 188 | class_data = data[indices] 189 | class_len = len(indices) 190 | class_dist = np.zeros((class_len, class_len)) 191 | 192 | # Augmentation with SMOTE ( k-nearest ) 193 | if smote: 194 | for i in range(class_len): 195 | for j in range(class_len): 196 | class_dist[i, j] = np.linalg.norm(class_data[i] - class_data[j]) 197 | sorted_idx = np.argsort(class_dist) 198 | 199 | for i in range(n_max - class_len): 200 | lam = nr.uniform(0, 1) 201 | row_idx = i % class_len 202 | col_idx = int((i - row_idx) / class_len) % (class_len - 1) 203 | new_data = np.round( 204 | lam * class_data[row_idx] + (1 - lam) * class_data[sorted_idx[row_idx, 1 + col_idx]]) 205 | 206 | aug_data.append(new_data.astype('uint8')) 207 | aug_label.append(k) 208 | 209 | return np.array(aug_data), np.array(aug_label) 210 | 211 | 212 | def get_smote(dataset, num_sample_per_class, batch_size, TF_train, TF_test): 213 | print("Building CV {} data loader with {} workers".format(dataset, 8)) 214 | ds = [] 215 | 216 | if dataset == 'cifar10': 217 | dataset_ = datasets.CIFAR10 218 | num_test_samples = num_test_samples_cifar10 219 | elif dataset == 'cifar100': 220 | dataset_ = datasets.CIFAR100 221 | num_test_samples = num_test_samples_cifar100 222 | else: 223 | raise NotImplementedError() 224 | 225 | train_cifar = dataset_(root=DATA_ROOT, train=True, download=False, transform=TF_train) 226 | 227 | targets = np.array(train_cifar.targets) 228 | classes, class_counts = np.unique(targets, return_counts=True) 229 | nb_classes = len(classes) 230 | 231 | imbal_class_counts = [int(i) for i in num_sample_per_class] 232 | class_indices = [np.where(targets == i)[0] for i in range(nb_classes)] 233 | 234 | imbal_class_indices = [class_idx[:class_count] for class_idx, class_count in zip(class_indices, imbal_class_counts)] 235 | imbal_class_indices = np.hstack(imbal_class_indices) 236 | 237 | train_cifar.targets = targets[imbal_class_indices] 238 | train_cifar.data = train_cifar.data[imbal_class_indices] 239 | 240 | assert len(train_cifar.targets) == len(train_cifar.data) 241 | 242 | class_max = max(num_sample_per_class) 243 | aug_data, aug_label = smote(train_cifar.data, train_cifar.targets, nb_classes, class_max) 244 | 245 | train_cifar.targets = np.concatenate((train_cifar.targets, aug_label), axis=0) 246 | train_cifar.data = np.concatenate((train_cifar.data, aug_data), axis=0) 247 | 248 | print("Augmented data num = {}".format(len(aug_label))) 249 | print(train_cifar.data.shape) 250 | 251 | train_in_loader = torch.utils.data.DataLoader(train_cifar, batch_size=batch_size, shuffle=True, num_workers=8) 252 | ds.append(train_in_loader) 253 | 254 | test_cifar = dataset_(root=DATA_ROOT, train=False, download=False, transform=TF_test) 255 | val_idx, test_idx = get_val_test_data(test_cifar, num_test_samples) 256 | val_loader = torch.utils.data.DataLoader(test_cifar, batch_size=100, 257 | sampler=SubsetRandomSampler(val_idx), num_workers=8) 258 | test_loader = torch.utils.data.DataLoader(test_cifar, batch_size=100, 259 | sampler=SubsetRandomSampler(test_idx), num_workers=8) 260 | ds.append(val_loader) 261 | ds.append(test_loader) 262 | ds = ds[0] if len(ds) == 1 else ds 263 | 264 | return ds -------------------------------------------------------------------------------- /etc/celebA_test_orig.txt: -------------------------------------------------------------------------------- 1 | /home/jaehyung/data/CelebA/image64/001999.jpg 0 2 | /home/jaehyung/data/CelebA/image64/002004.jpg 0 3 | /home/jaehyung/data/CelebA/image64/002030.jpg 0 4 | /home/jaehyung/data/CelebA/image64/002081.jpg 0 5 | /home/jaehyung/data/CelebA/image64/002124.jpg 0 6 | /home/jaehyung/data/CelebA/image64/002132.jpg 0 7 | /home/jaehyung/data/CelebA/image64/002188.jpg 0 8 | /home/jaehyung/data/CelebA/image64/002191.jpg 0 9 | /home/jaehyung/data/CelebA/image64/002276.jpg 0 10 | /home/jaehyung/data/CelebA/image64/002305.jpg 0 11 | /home/jaehyung/data/CelebA/image64/002327.jpg 0 12 | /home/jaehyung/data/CelebA/image64/002341.jpg 0 13 | /home/jaehyung/data/CelebA/image64/002359.jpg 0 14 | /home/jaehyung/data/CelebA/image64/002363.jpg 0 15 | /home/jaehyung/data/CelebA/image64/002505.jpg 0 16 | /home/jaehyung/data/CelebA/image64/002516.jpg 0 17 | /home/jaehyung/data/CelebA/image64/002528.jpg 0 18 | /home/jaehyung/data/CelebA/image64/002673.jpg 0 19 | /home/jaehyung/data/CelebA/image64/002723.jpg 0 20 | /home/jaehyung/data/CelebA/image64/002744.jpg 0 21 | /home/jaehyung/data/CelebA/image64/002772.jpg 0 22 | /home/jaehyung/data/CelebA/image64/002828.jpg 0 23 | /home/jaehyung/data/CelebA/image64/002874.jpg 0 24 | /home/jaehyung/data/CelebA/image64/002935.jpg 0 25 | /home/jaehyung/data/CelebA/image64/002960.jpg 0 26 | /home/jaehyung/data/CelebA/image64/002961.jpg 0 27 | /home/jaehyung/data/CelebA/image64/002980.jpg 0 28 | /home/jaehyung/data/CelebA/image64/003027.jpg 0 29 | /home/jaehyung/data/CelebA/image64/003110.jpg 0 30 | /home/jaehyung/data/CelebA/image64/003317.jpg 0 31 | /home/jaehyung/data/CelebA/image64/003384.jpg 0 32 | /home/jaehyung/data/CelebA/image64/003431.jpg 0 33 | /home/jaehyung/data/CelebA/image64/003454.jpg 0 34 | /home/jaehyung/data/CelebA/image64/003511.jpg 0 35 | /home/jaehyung/data/CelebA/image64/003527.jpg 0 36 | /home/jaehyung/data/CelebA/image64/003576.jpg 0 37 | /home/jaehyung/data/CelebA/image64/003586.jpg 0 38 | /home/jaehyung/data/CelebA/image64/003644.jpg 0 39 | /home/jaehyung/data/CelebA/image64/003678.jpg 0 40 | /home/jaehyung/data/CelebA/image64/003686.jpg 0 41 | /home/jaehyung/data/CelebA/image64/003849.jpg 0 42 | /home/jaehyung/data/CelebA/image64/003867.jpg 0 43 | /home/jaehyung/data/CelebA/image64/003981.jpg 0 44 | /home/jaehyung/data/CelebA/image64/004023.jpg 0 45 | /home/jaehyung/data/CelebA/image64/004058.jpg 0 46 | /home/jaehyung/data/CelebA/image64/004109.jpg 0 47 | /home/jaehyung/data/CelebA/image64/004110.jpg 0 48 | /home/jaehyung/data/CelebA/image64/004146.jpg 0 49 | /home/jaehyung/data/CelebA/image64/004191.jpg 0 50 | /home/jaehyung/data/CelebA/image64/004230.jpg 0 51 | /home/jaehyung/data/CelebA/image64/004249.jpg 0 52 | /home/jaehyung/data/CelebA/image64/004250.jpg 0 53 | /home/jaehyung/data/CelebA/image64/004317.jpg 0 54 | /home/jaehyung/data/CelebA/image64/004330.jpg 0 55 | /home/jaehyung/data/CelebA/image64/004367.jpg 0 56 | /home/jaehyung/data/CelebA/image64/004382.jpg 0 57 | /home/jaehyung/data/CelebA/image64/004399.jpg 0 58 | /home/jaehyung/data/CelebA/image64/004449.jpg 0 59 | /home/jaehyung/data/CelebA/image64/004455.jpg 0 60 | /home/jaehyung/data/CelebA/image64/004492.jpg 0 61 | /home/jaehyung/data/CelebA/image64/004611.jpg 0 62 | /home/jaehyung/data/CelebA/image64/004800.jpg 0 63 | /home/jaehyung/data/CelebA/image64/004812.jpg 0 64 | /home/jaehyung/data/CelebA/image64/004814.jpg 0 65 | /home/jaehyung/data/CelebA/image64/004852.jpg 0 66 | /home/jaehyung/data/CelebA/image64/004868.jpg 0 67 | /home/jaehyung/data/CelebA/image64/004976.jpg 0 68 | /home/jaehyung/data/CelebA/image64/005002.jpg 0 69 | /home/jaehyung/data/CelebA/image64/005012.jpg 0 70 | /home/jaehyung/data/CelebA/image64/005020.jpg 0 71 | /home/jaehyung/data/CelebA/image64/005026.jpg 0 72 | /home/jaehyung/data/CelebA/image64/005042.jpg 0 73 | /home/jaehyung/data/CelebA/image64/005077.jpg 0 74 | /home/jaehyung/data/CelebA/image64/005117.jpg 0 75 | /home/jaehyung/data/CelebA/image64/005128.jpg 0 76 | /home/jaehyung/data/CelebA/image64/005140.jpg 0 77 | /home/jaehyung/data/CelebA/image64/005146.jpg 0 78 | /home/jaehyung/data/CelebA/image64/005158.jpg 0 79 | /home/jaehyung/data/CelebA/image64/005163.jpg 0 80 | /home/jaehyung/data/CelebA/image64/005165.jpg 0 81 | /home/jaehyung/data/CelebA/image64/005200.jpg 0 82 | /home/jaehyung/data/CelebA/image64/005211.jpg 0 83 | /home/jaehyung/data/CelebA/image64/005295.jpg 0 84 | /home/jaehyung/data/CelebA/image64/005318.jpg 0 85 | /home/jaehyung/data/CelebA/image64/005353.jpg 0 86 | /home/jaehyung/data/CelebA/image64/005361.jpg 0 87 | /home/jaehyung/data/CelebA/image64/005384.jpg 0 88 | /home/jaehyung/data/CelebA/image64/005497.jpg 0 89 | /home/jaehyung/data/CelebA/image64/005509.jpg 0 90 | /home/jaehyung/data/CelebA/image64/005591.jpg 0 91 | /home/jaehyung/data/CelebA/image64/005669.jpg 0 92 | /home/jaehyung/data/CelebA/image64/005702.jpg 0 93 | /home/jaehyung/data/CelebA/image64/005721.jpg 0 94 | /home/jaehyung/data/CelebA/image64/005726.jpg 0 95 | /home/jaehyung/data/CelebA/image64/005731.jpg 0 96 | /home/jaehyung/data/CelebA/image64/005789.jpg 0 97 | /home/jaehyung/data/CelebA/image64/005833.jpg 0 98 | /home/jaehyung/data/CelebA/image64/005839.jpg 0 99 | /home/jaehyung/data/CelebA/image64/005920.jpg 0 100 | /home/jaehyung/data/CelebA/image64/005959.jpg 0 101 | /home/jaehyung/data/CelebA/image64/000164.jpg 1 102 | /home/jaehyung/data/CelebA/image64/000168.jpg 1 103 | /home/jaehyung/data/CelebA/image64/000173.jpg 1 104 | /home/jaehyung/data/CelebA/image64/000177.jpg 1 105 | /home/jaehyung/data/CelebA/image64/000179.jpg 1 106 | /home/jaehyung/data/CelebA/image64/000181.jpg 1 107 | /home/jaehyung/data/CelebA/image64/000187.jpg 1 108 | /home/jaehyung/data/CelebA/image64/000193.jpg 1 109 | /home/jaehyung/data/CelebA/image64/000204.jpg 1 110 | /home/jaehyung/data/CelebA/image64/000206.jpg 1 111 | /home/jaehyung/data/CelebA/image64/000212.jpg 1 112 | /home/jaehyung/data/CelebA/image64/000214.jpg 1 113 | /home/jaehyung/data/CelebA/image64/000216.jpg 1 114 | /home/jaehyung/data/CelebA/image64/000230.jpg 1 115 | /home/jaehyung/data/CelebA/image64/000233.jpg 1 116 | /home/jaehyung/data/CelebA/image64/000242.jpg 1 117 | /home/jaehyung/data/CelebA/image64/000245.jpg 1 118 | /home/jaehyung/data/CelebA/image64/000246.jpg 1 119 | /home/jaehyung/data/CelebA/image64/000250.jpg 1 120 | /home/jaehyung/data/CelebA/image64/000256.jpg 1 121 | /home/jaehyung/data/CelebA/image64/000260.jpg 1 122 | /home/jaehyung/data/CelebA/image64/000262.jpg 1 123 | /home/jaehyung/data/CelebA/image64/000269.jpg 1 124 | /home/jaehyung/data/CelebA/image64/000272.jpg 1 125 | /home/jaehyung/data/CelebA/image64/000277.jpg 1 126 | /home/jaehyung/data/CelebA/image64/000284.jpg 1 127 | /home/jaehyung/data/CelebA/image64/000285.jpg 1 128 | /home/jaehyung/data/CelebA/image64/000286.jpg 1 129 | /home/jaehyung/data/CelebA/image64/000289.jpg 1 130 | /home/jaehyung/data/CelebA/image64/000292.jpg 1 131 | /home/jaehyung/data/CelebA/image64/000294.jpg 1 132 | /home/jaehyung/data/CelebA/image64/000297.jpg 1 133 | /home/jaehyung/data/CelebA/image64/000298.jpg 1 134 | /home/jaehyung/data/CelebA/image64/000300.jpg 1 135 | /home/jaehyung/data/CelebA/image64/000303.jpg 1 136 | /home/jaehyung/data/CelebA/image64/000315.jpg 1 137 | /home/jaehyung/data/CelebA/image64/000320.jpg 1 138 | /home/jaehyung/data/CelebA/image64/000321.jpg 1 139 | /home/jaehyung/data/CelebA/image64/000324.jpg 1 140 | /home/jaehyung/data/CelebA/image64/000327.jpg 1 141 | /home/jaehyung/data/CelebA/image64/000332.jpg 1 142 | /home/jaehyung/data/CelebA/image64/000336.jpg 1 143 | /home/jaehyung/data/CelebA/image64/000339.jpg 1 144 | /home/jaehyung/data/CelebA/image64/000340.jpg 1 145 | /home/jaehyung/data/CelebA/image64/000342.jpg 1 146 | /home/jaehyung/data/CelebA/image64/000346.jpg 1 147 | /home/jaehyung/data/CelebA/image64/000349.jpg 1 148 | /home/jaehyung/data/CelebA/image64/000350.jpg 1 149 | /home/jaehyung/data/CelebA/image64/000354.jpg 1 150 | /home/jaehyung/data/CelebA/image64/000355.jpg 1 151 | /home/jaehyung/data/CelebA/image64/000356.jpg 1 152 | /home/jaehyung/data/CelebA/image64/000362.jpg 1 153 | /home/jaehyung/data/CelebA/image64/000363.jpg 1 154 | /home/jaehyung/data/CelebA/image64/000367.jpg 1 155 | /home/jaehyung/data/CelebA/image64/000369.jpg 1 156 | /home/jaehyung/data/CelebA/image64/000370.jpg 1 157 | /home/jaehyung/data/CelebA/image64/000374.jpg 1 158 | /home/jaehyung/data/CelebA/image64/000392.jpg 1 159 | /home/jaehyung/data/CelebA/image64/000394.jpg 1 160 | /home/jaehyung/data/CelebA/image64/000401.jpg 1 161 | /home/jaehyung/data/CelebA/image64/000406.jpg 1 162 | /home/jaehyung/data/CelebA/image64/000409.jpg 1 163 | /home/jaehyung/data/CelebA/image64/000418.jpg 1 164 | /home/jaehyung/data/CelebA/image64/000422.jpg 1 165 | /home/jaehyung/data/CelebA/image64/000427.jpg 1 166 | /home/jaehyung/data/CelebA/image64/000430.jpg 1 167 | /home/jaehyung/data/CelebA/image64/000433.jpg 1 168 | /home/jaehyung/data/CelebA/image64/000435.jpg 1 169 | /home/jaehyung/data/CelebA/image64/000439.jpg 1 170 | /home/jaehyung/data/CelebA/image64/000452.jpg 1 171 | /home/jaehyung/data/CelebA/image64/000455.jpg 1 172 | /home/jaehyung/data/CelebA/image64/000456.jpg 1 173 | /home/jaehyung/data/CelebA/image64/000461.jpg 1 174 | /home/jaehyung/data/CelebA/image64/000477.jpg 1 175 | /home/jaehyung/data/CelebA/image64/000484.jpg 1 176 | /home/jaehyung/data/CelebA/image64/000486.jpg 1 177 | /home/jaehyung/data/CelebA/image64/000489.jpg 1 178 | /home/jaehyung/data/CelebA/image64/000490.jpg 1 179 | /home/jaehyung/data/CelebA/image64/000502.jpg 1 180 | /home/jaehyung/data/CelebA/image64/000505.jpg 1 181 | /home/jaehyung/data/CelebA/image64/000507.jpg 1 182 | /home/jaehyung/data/CelebA/image64/000508.jpg 1 183 | /home/jaehyung/data/CelebA/image64/000512.jpg 1 184 | /home/jaehyung/data/CelebA/image64/000515.jpg 1 185 | /home/jaehyung/data/CelebA/image64/000522.jpg 1 186 | /home/jaehyung/data/CelebA/image64/000524.jpg 1 187 | /home/jaehyung/data/CelebA/image64/000525.jpg 1 188 | /home/jaehyung/data/CelebA/image64/000526.jpg 1 189 | /home/jaehyung/data/CelebA/image64/000529.jpg 1 190 | /home/jaehyung/data/CelebA/image64/000531.jpg 1 191 | /home/jaehyung/data/CelebA/image64/000535.jpg 1 192 | /home/jaehyung/data/CelebA/image64/000536.jpg 1 193 | /home/jaehyung/data/CelebA/image64/000537.jpg 1 194 | /home/jaehyung/data/CelebA/image64/000539.jpg 1 195 | /home/jaehyung/data/CelebA/image64/000540.jpg 1 196 | /home/jaehyung/data/CelebA/image64/000546.jpg 1 197 | /home/jaehyung/data/CelebA/image64/000548.jpg 1 198 | /home/jaehyung/data/CelebA/image64/000554.jpg 1 199 | /home/jaehyung/data/CelebA/image64/000559.jpg 1 200 | /home/jaehyung/data/CelebA/image64/000561.jpg 1 201 | /home/jaehyung/data/CelebA/image64/000365.jpg 2 202 | /home/jaehyung/data/CelebA/image64/000371.jpg 2 203 | /home/jaehyung/data/CelebA/image64/000380.jpg 2 204 | /home/jaehyung/data/CelebA/image64/000382.jpg 2 205 | /home/jaehyung/data/CelebA/image64/000399.jpg 2 206 | /home/jaehyung/data/CelebA/image64/000420.jpg 2 207 | /home/jaehyung/data/CelebA/image64/000421.jpg 2 208 | /home/jaehyung/data/CelebA/image64/000431.jpg 2 209 | /home/jaehyung/data/CelebA/image64/000432.jpg 2 210 | /home/jaehyung/data/CelebA/image64/000442.jpg 2 211 | /home/jaehyung/data/CelebA/image64/000454.jpg 2 212 | /home/jaehyung/data/CelebA/image64/000468.jpg 2 213 | /home/jaehyung/data/CelebA/image64/000471.jpg 2 214 | /home/jaehyung/data/CelebA/image64/000475.jpg 2 215 | /home/jaehyung/data/CelebA/image64/000476.jpg 2 216 | /home/jaehyung/data/CelebA/image64/000478.jpg 2 217 | /home/jaehyung/data/CelebA/image64/000483.jpg 2 218 | /home/jaehyung/data/CelebA/image64/000485.jpg 2 219 | /home/jaehyung/data/CelebA/image64/000492.jpg 2 220 | /home/jaehyung/data/CelebA/image64/000497.jpg 2 221 | /home/jaehyung/data/CelebA/image64/000499.jpg 2 222 | /home/jaehyung/data/CelebA/image64/000503.jpg 2 223 | /home/jaehyung/data/CelebA/image64/000510.jpg 2 224 | /home/jaehyung/data/CelebA/image64/000513.jpg 2 225 | /home/jaehyung/data/CelebA/image64/000516.jpg 2 226 | /home/jaehyung/data/CelebA/image64/000523.jpg 2 227 | /home/jaehyung/data/CelebA/image64/000528.jpg 2 228 | /home/jaehyung/data/CelebA/image64/000562.jpg 2 229 | /home/jaehyung/data/CelebA/image64/000563.jpg 2 230 | /home/jaehyung/data/CelebA/image64/000567.jpg 2 231 | /home/jaehyung/data/CelebA/image64/000579.jpg 2 232 | /home/jaehyung/data/CelebA/image64/000582.jpg 2 233 | /home/jaehyung/data/CelebA/image64/000588.jpg 2 234 | /home/jaehyung/data/CelebA/image64/000596.jpg 2 235 | /home/jaehyung/data/CelebA/image64/000599.jpg 2 236 | /home/jaehyung/data/CelebA/image64/000602.jpg 2 237 | /home/jaehyung/data/CelebA/image64/000605.jpg 2 238 | /home/jaehyung/data/CelebA/image64/000615.jpg 2 239 | /home/jaehyung/data/CelebA/image64/000634.jpg 2 240 | /home/jaehyung/data/CelebA/image64/000642.jpg 2 241 | /home/jaehyung/data/CelebA/image64/000644.jpg 2 242 | /home/jaehyung/data/CelebA/image64/000656.jpg 2 243 | /home/jaehyung/data/CelebA/image64/000658.jpg 2 244 | /home/jaehyung/data/CelebA/image64/000666.jpg 2 245 | /home/jaehyung/data/CelebA/image64/000667.jpg 2 246 | /home/jaehyung/data/CelebA/image64/000680.jpg 2 247 | /home/jaehyung/data/CelebA/image64/000681.jpg 2 248 | /home/jaehyung/data/CelebA/image64/000691.jpg 2 249 | /home/jaehyung/data/CelebA/image64/000692.jpg 2 250 | /home/jaehyung/data/CelebA/image64/000693.jpg 2 251 | /home/jaehyung/data/CelebA/image64/000699.jpg 2 252 | /home/jaehyung/data/CelebA/image64/000702.jpg 2 253 | /home/jaehyung/data/CelebA/image64/000709.jpg 2 254 | /home/jaehyung/data/CelebA/image64/000715.jpg 2 255 | /home/jaehyung/data/CelebA/image64/000719.jpg 2 256 | /home/jaehyung/data/CelebA/image64/000721.jpg 2 257 | /home/jaehyung/data/CelebA/image64/000751.jpg 2 258 | /home/jaehyung/data/CelebA/image64/000759.jpg 2 259 | /home/jaehyung/data/CelebA/image64/000770.jpg 2 260 | /home/jaehyung/data/CelebA/image64/000787.jpg 2 261 | /home/jaehyung/data/CelebA/image64/000791.jpg 2 262 | /home/jaehyung/data/CelebA/image64/000800.jpg 2 263 | /home/jaehyung/data/CelebA/image64/000804.jpg 2 264 | /home/jaehyung/data/CelebA/image64/000809.jpg 2 265 | /home/jaehyung/data/CelebA/image64/000812.jpg 2 266 | /home/jaehyung/data/CelebA/image64/000819.jpg 2 267 | /home/jaehyung/data/CelebA/image64/000822.jpg 2 268 | /home/jaehyung/data/CelebA/image64/000824.jpg 2 269 | /home/jaehyung/data/CelebA/image64/000835.jpg 2 270 | /home/jaehyung/data/CelebA/image64/000837.jpg 2 271 | /home/jaehyung/data/CelebA/image64/000856.jpg 2 272 | /home/jaehyung/data/CelebA/image64/000857.jpg 2 273 | /home/jaehyung/data/CelebA/image64/000863.jpg 2 274 | /home/jaehyung/data/CelebA/image64/000871.jpg 2 275 | /home/jaehyung/data/CelebA/image64/000872.jpg 2 276 | /home/jaehyung/data/CelebA/image64/000894.jpg 2 277 | /home/jaehyung/data/CelebA/image64/000895.jpg 2 278 | /home/jaehyung/data/CelebA/image64/000899.jpg 2 279 | /home/jaehyung/data/CelebA/image64/000917.jpg 2 280 | /home/jaehyung/data/CelebA/image64/000918.jpg 2 281 | /home/jaehyung/data/CelebA/image64/000939.jpg 2 282 | /home/jaehyung/data/CelebA/image64/000941.jpg 2 283 | /home/jaehyung/data/CelebA/image64/000943.jpg 2 284 | /home/jaehyung/data/CelebA/image64/000945.jpg 2 285 | /home/jaehyung/data/CelebA/image64/000947.jpg 2 286 | /home/jaehyung/data/CelebA/image64/000948.jpg 2 287 | /home/jaehyung/data/CelebA/image64/000950.jpg 2 288 | /home/jaehyung/data/CelebA/image64/000955.jpg 2 289 | /home/jaehyung/data/CelebA/image64/000970.jpg 2 290 | /home/jaehyung/data/CelebA/image64/000974.jpg 2 291 | /home/jaehyung/data/CelebA/image64/000976.jpg 2 292 | /home/jaehyung/data/CelebA/image64/000980.jpg 2 293 | /home/jaehyung/data/CelebA/image64/000982.jpg 2 294 | /home/jaehyung/data/CelebA/image64/000985.jpg 2 295 | /home/jaehyung/data/CelebA/image64/000986.jpg 2 296 | /home/jaehyung/data/CelebA/image64/000996.jpg 2 297 | /home/jaehyung/data/CelebA/image64/000997.jpg 2 298 | /home/jaehyung/data/CelebA/image64/000998.jpg 2 299 | /home/jaehyung/data/CelebA/image64/001014.jpg 2 300 | /home/jaehyung/data/CelebA/image64/001015.jpg 2 301 | /home/jaehyung/data/CelebA/image64/000240.jpg 3 302 | /home/jaehyung/data/CelebA/image64/000241.jpg 3 303 | /home/jaehyung/data/CelebA/image64/000243.jpg 3 304 | /home/jaehyung/data/CelebA/image64/000259.jpg 3 305 | /home/jaehyung/data/CelebA/image64/000261.jpg 3 306 | /home/jaehyung/data/CelebA/image64/000263.jpg 3 307 | /home/jaehyung/data/CelebA/image64/000267.jpg 3 308 | /home/jaehyung/data/CelebA/image64/000268.jpg 3 309 | /home/jaehyung/data/CelebA/image64/000270.jpg 3 310 | /home/jaehyung/data/CelebA/image64/000277.jpg 3 311 | /home/jaehyung/data/CelebA/image64/000279.jpg 3 312 | /home/jaehyung/data/CelebA/image64/000281.jpg 3 313 | /home/jaehyung/data/CelebA/image64/000296.jpg 3 314 | /home/jaehyung/data/CelebA/image64/000301.jpg 3 315 | /home/jaehyung/data/CelebA/image64/000305.jpg 3 316 | /home/jaehyung/data/CelebA/image64/000307.jpg 3 317 | /home/jaehyung/data/CelebA/image64/000309.jpg 3 318 | /home/jaehyung/data/CelebA/image64/000312.jpg 3 319 | /home/jaehyung/data/CelebA/image64/000328.jpg 3 320 | /home/jaehyung/data/CelebA/image64/000329.jpg 3 321 | /home/jaehyung/data/CelebA/image64/000337.jpg 3 322 | /home/jaehyung/data/CelebA/image64/000352.jpg 3 323 | /home/jaehyung/data/CelebA/image64/000353.jpg 3 324 | /home/jaehyung/data/CelebA/image64/000359.jpg 3 325 | /home/jaehyung/data/CelebA/image64/000373.jpg 3 326 | /home/jaehyung/data/CelebA/image64/000380.jpg 3 327 | /home/jaehyung/data/CelebA/image64/000384.jpg 3 328 | /home/jaehyung/data/CelebA/image64/000387.jpg 3 329 | /home/jaehyung/data/CelebA/image64/000388.jpg 3 330 | /home/jaehyung/data/CelebA/image64/000389.jpg 3 331 | /home/jaehyung/data/CelebA/image64/000391.jpg 3 332 | /home/jaehyung/data/CelebA/image64/000396.jpg 3 333 | /home/jaehyung/data/CelebA/image64/000398.jpg 3 334 | /home/jaehyung/data/CelebA/image64/000401.jpg 3 335 | /home/jaehyung/data/CelebA/image64/000404.jpg 3 336 | /home/jaehyung/data/CelebA/image64/000407.jpg 3 337 | /home/jaehyung/data/CelebA/image64/000412.jpg 3 338 | /home/jaehyung/data/CelebA/image64/000415.jpg 3 339 | /home/jaehyung/data/CelebA/image64/000428.jpg 3 340 | /home/jaehyung/data/CelebA/image64/000429.jpg 3 341 | /home/jaehyung/data/CelebA/image64/000436.jpg 3 342 | /home/jaehyung/data/CelebA/image64/000443.jpg 3 343 | /home/jaehyung/data/CelebA/image64/000445.jpg 3 344 | /home/jaehyung/data/CelebA/image64/000451.jpg 3 345 | /home/jaehyung/data/CelebA/image64/000453.jpg 3 346 | /home/jaehyung/data/CelebA/image64/000459.jpg 3 347 | /home/jaehyung/data/CelebA/image64/000466.jpg 3 348 | /home/jaehyung/data/CelebA/image64/000469.jpg 3 349 | /home/jaehyung/data/CelebA/image64/000470.jpg 3 350 | /home/jaehyung/data/CelebA/image64/000472.jpg 3 351 | /home/jaehyung/data/CelebA/image64/000473.jpg 3 352 | /home/jaehyung/data/CelebA/image64/000474.jpg 3 353 | /home/jaehyung/data/CelebA/image64/000480.jpg 3 354 | /home/jaehyung/data/CelebA/image64/000481.jpg 3 355 | /home/jaehyung/data/CelebA/image64/000487.jpg 3 356 | /home/jaehyung/data/CelebA/image64/000495.jpg 3 357 | /home/jaehyung/data/CelebA/image64/000504.jpg 3 358 | /home/jaehyung/data/CelebA/image64/000506.jpg 3 359 | /home/jaehyung/data/CelebA/image64/000514.jpg 3 360 | /home/jaehyung/data/CelebA/image64/000519.jpg 3 361 | /home/jaehyung/data/CelebA/image64/000527.jpg 3 362 | /home/jaehyung/data/CelebA/image64/000532.jpg 3 363 | /home/jaehyung/data/CelebA/image64/000541.jpg 3 364 | /home/jaehyung/data/CelebA/image64/000542.jpg 3 365 | /home/jaehyung/data/CelebA/image64/000544.jpg 3 366 | /home/jaehyung/data/CelebA/image64/000549.jpg 3 367 | /home/jaehyung/data/CelebA/image64/000552.jpg 3 368 | /home/jaehyung/data/CelebA/image64/000553.jpg 3 369 | /home/jaehyung/data/CelebA/image64/000557.jpg 3 370 | /home/jaehyung/data/CelebA/image64/000558.jpg 3 371 | /home/jaehyung/data/CelebA/image64/000561.jpg 3 372 | /home/jaehyung/data/CelebA/image64/000565.jpg 3 373 | /home/jaehyung/data/CelebA/image64/000575.jpg 3 374 | /home/jaehyung/data/CelebA/image64/000590.jpg 3 375 | /home/jaehyung/data/CelebA/image64/000592.jpg 3 376 | /home/jaehyung/data/CelebA/image64/000594.jpg 3 377 | /home/jaehyung/data/CelebA/image64/000597.jpg 3 378 | /home/jaehyung/data/CelebA/image64/000598.jpg 3 379 | /home/jaehyung/data/CelebA/image64/000607.jpg 3 380 | /home/jaehyung/data/CelebA/image64/000609.jpg 3 381 | /home/jaehyung/data/CelebA/image64/000611.jpg 3 382 | /home/jaehyung/data/CelebA/image64/000613.jpg 3 383 | /home/jaehyung/data/CelebA/image64/000622.jpg 3 384 | /home/jaehyung/data/CelebA/image64/000643.jpg 3 385 | /home/jaehyung/data/CelebA/image64/000646.jpg 3 386 | /home/jaehyung/data/CelebA/image64/000650.jpg 3 387 | /home/jaehyung/data/CelebA/image64/000657.jpg 3 388 | /home/jaehyung/data/CelebA/image64/000662.jpg 3 389 | /home/jaehyung/data/CelebA/image64/000664.jpg 3 390 | /home/jaehyung/data/CelebA/image64/000670.jpg 3 391 | /home/jaehyung/data/CelebA/image64/000684.jpg 3 392 | /home/jaehyung/data/CelebA/image64/000685.jpg 3 393 | /home/jaehyung/data/CelebA/image64/000688.jpg 3 394 | /home/jaehyung/data/CelebA/image64/000689.jpg 3 395 | /home/jaehyung/data/CelebA/image64/000691.jpg 3 396 | /home/jaehyung/data/CelebA/image64/000705.jpg 3 397 | /home/jaehyung/data/CelebA/image64/000706.jpg 3 398 | /home/jaehyung/data/CelebA/image64/000708.jpg 3 399 | /home/jaehyung/data/CelebA/image64/000714.jpg 3 400 | /home/jaehyung/data/CelebA/image64/000723.jpg 3 401 | /home/jaehyung/data/CelebA/image64/000958.jpg 4 402 | /home/jaehyung/data/CelebA/image64/000969.jpg 4 403 | /home/jaehyung/data/CelebA/image64/000974.jpg 4 404 | /home/jaehyung/data/CelebA/image64/000993.jpg 4 405 | /home/jaehyung/data/CelebA/image64/001006.jpg 4 406 | /home/jaehyung/data/CelebA/image64/001028.jpg 4 407 | /home/jaehyung/data/CelebA/image64/001033.jpg 4 408 | /home/jaehyung/data/CelebA/image64/001083.jpg 4 409 | /home/jaehyung/data/CelebA/image64/001101.jpg 4 410 | /home/jaehyung/data/CelebA/image64/001109.jpg 4 411 | /home/jaehyung/data/CelebA/image64/001127.jpg 4 412 | /home/jaehyung/data/CelebA/image64/001189.jpg 4 413 | /home/jaehyung/data/CelebA/image64/001215.jpg 4 414 | /home/jaehyung/data/CelebA/image64/001237.jpg 4 415 | /home/jaehyung/data/CelebA/image64/001241.jpg 4 416 | /home/jaehyung/data/CelebA/image64/001243.jpg 4 417 | /home/jaehyung/data/CelebA/image64/001258.jpg 4 418 | /home/jaehyung/data/CelebA/image64/001259.jpg 4 419 | /home/jaehyung/data/CelebA/image64/001270.jpg 4 420 | /home/jaehyung/data/CelebA/image64/001310.jpg 4 421 | /home/jaehyung/data/CelebA/image64/001338.jpg 4 422 | /home/jaehyung/data/CelebA/image64/001356.jpg 4 423 | /home/jaehyung/data/CelebA/image64/001372.jpg 4 424 | /home/jaehyung/data/CelebA/image64/001393.jpg 4 425 | /home/jaehyung/data/CelebA/image64/001420.jpg 4 426 | /home/jaehyung/data/CelebA/image64/001432.jpg 4 427 | /home/jaehyung/data/CelebA/image64/001483.jpg 4 428 | /home/jaehyung/data/CelebA/image64/001531.jpg 4 429 | /home/jaehyung/data/CelebA/image64/001533.jpg 4 430 | /home/jaehyung/data/CelebA/image64/001543.jpg 4 431 | /home/jaehyung/data/CelebA/image64/001577.jpg 4 432 | /home/jaehyung/data/CelebA/image64/001604.jpg 4 433 | /home/jaehyung/data/CelebA/image64/001625.jpg 4 434 | /home/jaehyung/data/CelebA/image64/001634.jpg 4 435 | /home/jaehyung/data/CelebA/image64/001655.jpg 4 436 | /home/jaehyung/data/CelebA/image64/001657.jpg 4 437 | /home/jaehyung/data/CelebA/image64/001679.jpg 4 438 | /home/jaehyung/data/CelebA/image64/001703.jpg 4 439 | /home/jaehyung/data/CelebA/image64/001711.jpg 4 440 | /home/jaehyung/data/CelebA/image64/001717.jpg 4 441 | /home/jaehyung/data/CelebA/image64/001724.jpg 4 442 | /home/jaehyung/data/CelebA/image64/001742.jpg 4 443 | /home/jaehyung/data/CelebA/image64/001754.jpg 4 444 | /home/jaehyung/data/CelebA/image64/001757.jpg 4 445 | /home/jaehyung/data/CelebA/image64/001759.jpg 4 446 | /home/jaehyung/data/CelebA/image64/001780.jpg 4 447 | /home/jaehyung/data/CelebA/image64/001781.jpg 4 448 | /home/jaehyung/data/CelebA/image64/001833.jpg 4 449 | /home/jaehyung/data/CelebA/image64/001846.jpg 4 450 | /home/jaehyung/data/CelebA/image64/001930.jpg 4 451 | /home/jaehyung/data/CelebA/image64/001971.jpg 4 452 | /home/jaehyung/data/CelebA/image64/001978.jpg 4 453 | /home/jaehyung/data/CelebA/image64/001980.jpg 4 454 | /home/jaehyung/data/CelebA/image64/001982.jpg 4 455 | /home/jaehyung/data/CelebA/image64/001986.jpg 4 456 | /home/jaehyung/data/CelebA/image64/002078.jpg 4 457 | /home/jaehyung/data/CelebA/image64/002106.jpg 4 458 | /home/jaehyung/data/CelebA/image64/002124.jpg 4 459 | /home/jaehyung/data/CelebA/image64/002131.jpg 4 460 | /home/jaehyung/data/CelebA/image64/002132.jpg 4 461 | /home/jaehyung/data/CelebA/image64/002141.jpg 4 462 | /home/jaehyung/data/CelebA/image64/002145.jpg 4 463 | /home/jaehyung/data/CelebA/image64/002166.jpg 4 464 | /home/jaehyung/data/CelebA/image64/002187.jpg 4 465 | /home/jaehyung/data/CelebA/image64/002188.jpg 4 466 | /home/jaehyung/data/CelebA/image64/002191.jpg 4 467 | /home/jaehyung/data/CelebA/image64/002202.jpg 4 468 | /home/jaehyung/data/CelebA/image64/002205.jpg 4 469 | /home/jaehyung/data/CelebA/image64/002232.jpg 4 470 | /home/jaehyung/data/CelebA/image64/002255.jpg 4 471 | /home/jaehyung/data/CelebA/image64/002270.jpg 4 472 | /home/jaehyung/data/CelebA/image64/002327.jpg 4 473 | /home/jaehyung/data/CelebA/image64/002341.jpg 4 474 | /home/jaehyung/data/CelebA/image64/002366.jpg 4 475 | /home/jaehyung/data/CelebA/image64/002400.jpg 4 476 | /home/jaehyung/data/CelebA/image64/002459.jpg 4 477 | /home/jaehyung/data/CelebA/image64/002504.jpg 4 478 | /home/jaehyung/data/CelebA/image64/002505.jpg 4 479 | /home/jaehyung/data/CelebA/image64/002515.jpg 4 480 | /home/jaehyung/data/CelebA/image64/002523.jpg 4 481 | /home/jaehyung/data/CelebA/image64/002529.jpg 4 482 | /home/jaehyung/data/CelebA/image64/002614.jpg 4 483 | /home/jaehyung/data/CelebA/image64/002631.jpg 4 484 | /home/jaehyung/data/CelebA/image64/002632.jpg 4 485 | /home/jaehyung/data/CelebA/image64/002679.jpg 4 486 | /home/jaehyung/data/CelebA/image64/002682.jpg 4 487 | /home/jaehyung/data/CelebA/image64/002715.jpg 4 488 | /home/jaehyung/data/CelebA/image64/002743.jpg 4 489 | /home/jaehyung/data/CelebA/image64/002754.jpg 4 490 | /home/jaehyung/data/CelebA/image64/002764.jpg 4 491 | /home/jaehyung/data/CelebA/image64/002765.jpg 4 492 | /home/jaehyung/data/CelebA/image64/002782.jpg 4 493 | /home/jaehyung/data/CelebA/image64/002787.jpg 4 494 | /home/jaehyung/data/CelebA/image64/002804.jpg 4 495 | /home/jaehyung/data/CelebA/image64/002807.jpg 4 496 | /home/jaehyung/data/CelebA/image64/002835.jpg 4 497 | /home/jaehyung/data/CelebA/image64/002838.jpg 4 498 | /home/jaehyung/data/CelebA/image64/002852.jpg 4 499 | /home/jaehyung/data/CelebA/image64/002854.jpg 4 500 | /home/jaehyung/data/CelebA/image64/002872.jpg 4 501 | -------------------------------------------------------------------------------- /etc/celebA_val_orig.txt: -------------------------------------------------------------------------------- 1 | /home/jaehyung/data/CelebA/image64/000051.jpg 0 2 | /home/jaehyung/data/CelebA/image64/000079.jpg 0 3 | /home/jaehyung/data/CelebA/image64/000115.jpg 0 4 | /home/jaehyung/data/CelebA/image64/000125.jpg 0 5 | /home/jaehyung/data/CelebA/image64/000134.jpg 0 6 | /home/jaehyung/data/CelebA/image64/000182.jpg 0 7 | /home/jaehyung/data/CelebA/image64/000209.jpg 0 8 | /home/jaehyung/data/CelebA/image64/000226.jpg 0 9 | /home/jaehyung/data/CelebA/image64/000299.jpg 0 10 | /home/jaehyung/data/CelebA/image64/000306.jpg 0 11 | /home/jaehyung/data/CelebA/image64/000386.jpg 0 12 | /home/jaehyung/data/CelebA/image64/000402.jpg 0 13 | /home/jaehyung/data/CelebA/image64/000425.jpg 0 14 | /home/jaehyung/data/CelebA/image64/000623.jpg 0 15 | /home/jaehyung/data/CelebA/image64/000729.jpg 0 16 | /home/jaehyung/data/CelebA/image64/000846.jpg 0 17 | /home/jaehyung/data/CelebA/image64/000902.jpg 0 18 | /home/jaehyung/data/CelebA/image64/000905.jpg 0 19 | /home/jaehyung/data/CelebA/image64/000907.jpg 0 20 | /home/jaehyung/data/CelebA/image64/000926.jpg 0 21 | /home/jaehyung/data/CelebA/image64/000938.jpg 0 22 | /home/jaehyung/data/CelebA/image64/000969.jpg 0 23 | /home/jaehyung/data/CelebA/image64/001072.jpg 0 24 | /home/jaehyung/data/CelebA/image64/001100.jpg 0 25 | /home/jaehyung/data/CelebA/image64/001105.jpg 0 26 | /home/jaehyung/data/CelebA/image64/001117.jpg 0 27 | /home/jaehyung/data/CelebA/image64/001149.jpg 0 28 | /home/jaehyung/data/CelebA/image64/001192.jpg 0 29 | /home/jaehyung/data/CelebA/image64/001207.jpg 0 30 | /home/jaehyung/data/CelebA/image64/001208.jpg 0 31 | /home/jaehyung/data/CelebA/image64/001270.jpg 0 32 | /home/jaehyung/data/CelebA/image64/001322.jpg 0 33 | /home/jaehyung/data/CelebA/image64/001368.jpg 0 34 | /home/jaehyung/data/CelebA/image64/001383.jpg 0 35 | /home/jaehyung/data/CelebA/image64/001398.jpg 0 36 | /home/jaehyung/data/CelebA/image64/001400.jpg 0 37 | /home/jaehyung/data/CelebA/image64/001471.jpg 0 38 | /home/jaehyung/data/CelebA/image64/001490.jpg 0 39 | /home/jaehyung/data/CelebA/image64/001519.jpg 0 40 | /home/jaehyung/data/CelebA/image64/001655.jpg 0 41 | /home/jaehyung/data/CelebA/image64/001656.jpg 0 42 | /home/jaehyung/data/CelebA/image64/001669.jpg 0 43 | /home/jaehyung/data/CelebA/image64/001684.jpg 0 44 | /home/jaehyung/data/CelebA/image64/001706.jpg 0 45 | /home/jaehyung/data/CelebA/image64/001742.jpg 0 46 | /home/jaehyung/data/CelebA/image64/001759.jpg 0 47 | /home/jaehyung/data/CelebA/image64/001766.jpg 0 48 | /home/jaehyung/data/CelebA/image64/001917.jpg 0 49 | /home/jaehyung/data/CelebA/image64/001924.jpg 0 50 | /home/jaehyung/data/CelebA/image64/001978.jpg 0 51 | /home/jaehyung/data/CelebA/image64/000007.jpg 1 52 | /home/jaehyung/data/CelebA/image64/000008.jpg 1 53 | /home/jaehyung/data/CelebA/image64/000011.jpg 1 54 | /home/jaehyung/data/CelebA/image64/000012.jpg 1 55 | /home/jaehyung/data/CelebA/image64/000014.jpg 1 56 | /home/jaehyung/data/CelebA/image64/000017.jpg 1 57 | /home/jaehyung/data/CelebA/image64/000020.jpg 1 58 | /home/jaehyung/data/CelebA/image64/000027.jpg 1 59 | /home/jaehyung/data/CelebA/image64/000033.jpg 1 60 | /home/jaehyung/data/CelebA/image64/000035.jpg 1 61 | /home/jaehyung/data/CelebA/image64/000036.jpg 1 62 | /home/jaehyung/data/CelebA/image64/000040.jpg 1 63 | /home/jaehyung/data/CelebA/image64/000041.jpg 1 64 | /home/jaehyung/data/CelebA/image64/000044.jpg 1 65 | /home/jaehyung/data/CelebA/image64/000046.jpg 1 66 | /home/jaehyung/data/CelebA/image64/000047.jpg 1 67 | /home/jaehyung/data/CelebA/image64/000049.jpg 1 68 | /home/jaehyung/data/CelebA/image64/000050.jpg 1 69 | /home/jaehyung/data/CelebA/image64/000055.jpg 1 70 | /home/jaehyung/data/CelebA/image64/000056.jpg 1 71 | /home/jaehyung/data/CelebA/image64/000057.jpg 1 72 | /home/jaehyung/data/CelebA/image64/000059.jpg 1 73 | /home/jaehyung/data/CelebA/image64/000065.jpg 1 74 | /home/jaehyung/data/CelebA/image64/000070.jpg 1 75 | /home/jaehyung/data/CelebA/image64/000072.jpg 1 76 | /home/jaehyung/data/CelebA/image64/000073.jpg 1 77 | /home/jaehyung/data/CelebA/image64/000076.jpg 1 78 | /home/jaehyung/data/CelebA/image64/000078.jpg 1 79 | /home/jaehyung/data/CelebA/image64/000080.jpg 1 80 | /home/jaehyung/data/CelebA/image64/000081.jpg 1 81 | /home/jaehyung/data/CelebA/image64/000091.jpg 1 82 | /home/jaehyung/data/CelebA/image64/000098.jpg 1 83 | /home/jaehyung/data/CelebA/image64/000101.jpg 1 84 | /home/jaehyung/data/CelebA/image64/000102.jpg 1 85 | /home/jaehyung/data/CelebA/image64/000114.jpg 1 86 | /home/jaehyung/data/CelebA/image64/000117.jpg 1 87 | /home/jaehyung/data/CelebA/image64/000118.jpg 1 88 | /home/jaehyung/data/CelebA/image64/000123.jpg 1 89 | /home/jaehyung/data/CelebA/image64/000124.jpg 1 90 | /home/jaehyung/data/CelebA/image64/000129.jpg 1 91 | /home/jaehyung/data/CelebA/image64/000131.jpg 1 92 | /home/jaehyung/data/CelebA/image64/000132.jpg 1 93 | /home/jaehyung/data/CelebA/image64/000135.jpg 1 94 | /home/jaehyung/data/CelebA/image64/000144.jpg 1 95 | /home/jaehyung/data/CelebA/image64/000145.jpg 1 96 | /home/jaehyung/data/CelebA/image64/000148.jpg 1 97 | /home/jaehyung/data/CelebA/image64/000158.jpg 1 98 | /home/jaehyung/data/CelebA/image64/000160.jpg 1 99 | /home/jaehyung/data/CelebA/image64/000161.jpg 1 100 | /home/jaehyung/data/CelebA/image64/000162.jpg 1 101 | /home/jaehyung/data/CelebA/image64/000013.jpg 2 102 | /home/jaehyung/data/CelebA/image64/000018.jpg 2 103 | /home/jaehyung/data/CelebA/image64/000019.jpg 2 104 | /home/jaehyung/data/CelebA/image64/000022.jpg 2 105 | /home/jaehyung/data/CelebA/image64/000024.jpg 2 106 | /home/jaehyung/data/CelebA/image64/000029.jpg 2 107 | /home/jaehyung/data/CelebA/image64/000054.jpg 2 108 | /home/jaehyung/data/CelebA/image64/000071.jpg 2 109 | /home/jaehyung/data/CelebA/image64/000092.jpg 2 110 | /home/jaehyung/data/CelebA/image64/000094.jpg 2 111 | /home/jaehyung/data/CelebA/image64/000100.jpg 2 112 | /home/jaehyung/data/CelebA/image64/000108.jpg 2 113 | /home/jaehyung/data/CelebA/image64/000111.jpg 2 114 | /home/jaehyung/data/CelebA/image64/000112.jpg 2 115 | /home/jaehyung/data/CelebA/image64/000122.jpg 2 116 | /home/jaehyung/data/CelebA/image64/000126.jpg 2 117 | /home/jaehyung/data/CelebA/image64/000133.jpg 2 118 | /home/jaehyung/data/CelebA/image64/000140.jpg 2 119 | /home/jaehyung/data/CelebA/image64/000141.jpg 2 120 | /home/jaehyung/data/CelebA/image64/000147.jpg 2 121 | /home/jaehyung/data/CelebA/image64/000152.jpg 2 122 | /home/jaehyung/data/CelebA/image64/000156.jpg 2 123 | /home/jaehyung/data/CelebA/image64/000157.jpg 2 124 | /home/jaehyung/data/CelebA/image64/000167.jpg 2 125 | /home/jaehyung/data/CelebA/image64/000178.jpg 2 126 | /home/jaehyung/data/CelebA/image64/000191.jpg 2 127 | /home/jaehyung/data/CelebA/image64/000201.jpg 2 128 | /home/jaehyung/data/CelebA/image64/000217.jpg 2 129 | /home/jaehyung/data/CelebA/image64/000218.jpg 2 130 | /home/jaehyung/data/CelebA/image64/000223.jpg 2 131 | /home/jaehyung/data/CelebA/image64/000228.jpg 2 132 | /home/jaehyung/data/CelebA/image64/000239.jpg 2 133 | /home/jaehyung/data/CelebA/image64/000247.jpg 2 134 | /home/jaehyung/data/CelebA/image64/000271.jpg 2 135 | /home/jaehyung/data/CelebA/image64/000278.jpg 2 136 | /home/jaehyung/data/CelebA/image64/000287.jpg 2 137 | /home/jaehyung/data/CelebA/image64/000291.jpg 2 138 | /home/jaehyung/data/CelebA/image64/000302.jpg 2 139 | /home/jaehyung/data/CelebA/image64/000317.jpg 2 140 | /home/jaehyung/data/CelebA/image64/000322.jpg 2 141 | /home/jaehyung/data/CelebA/image64/000325.jpg 2 142 | /home/jaehyung/data/CelebA/image64/000326.jpg 2 143 | /home/jaehyung/data/CelebA/image64/000334.jpg 2 144 | /home/jaehyung/data/CelebA/image64/000338.jpg 2 145 | /home/jaehyung/data/CelebA/image64/000341.jpg 2 146 | /home/jaehyung/data/CelebA/image64/000343.jpg 2 147 | /home/jaehyung/data/CelebA/image64/000347.jpg 2 148 | /home/jaehyung/data/CelebA/image64/000348.jpg 2 149 | /home/jaehyung/data/CelebA/image64/000351.jpg 2 150 | /home/jaehyung/data/CelebA/image64/000357.jpg 2 151 | /home/jaehyung/data/CelebA/image64/000001.jpg 3 152 | /home/jaehyung/data/CelebA/image64/000002.jpg 3 153 | /home/jaehyung/data/CelebA/image64/000006.jpg 3 154 | /home/jaehyung/data/CelebA/image64/000023.jpg 3 155 | /home/jaehyung/data/CelebA/image64/000028.jpg 3 156 | /home/jaehyung/data/CelebA/image64/000032.jpg 3 157 | /home/jaehyung/data/CelebA/image64/000033.jpg 3 158 | /home/jaehyung/data/CelebA/image64/000034.jpg 3 159 | /home/jaehyung/data/CelebA/image64/000039.jpg 3 160 | /home/jaehyung/data/CelebA/image64/000042.jpg 3 161 | /home/jaehyung/data/CelebA/image64/000043.jpg 3 162 | /home/jaehyung/data/CelebA/image64/000045.jpg 3 163 | /home/jaehyung/data/CelebA/image64/000052.jpg 3 164 | /home/jaehyung/data/CelebA/image64/000064.jpg 3 165 | /home/jaehyung/data/CelebA/image64/000067.jpg 3 166 | /home/jaehyung/data/CelebA/image64/000073.jpg 3 167 | /home/jaehyung/data/CelebA/image64/000075.jpg 3 168 | /home/jaehyung/data/CelebA/image64/000083.jpg 3 169 | /home/jaehyung/data/CelebA/image64/000085.jpg 3 170 | /home/jaehyung/data/CelebA/image64/000088.jpg 3 171 | /home/jaehyung/data/CelebA/image64/000090.jpg 3 172 | /home/jaehyung/data/CelebA/image64/000097.jpg 3 173 | /home/jaehyung/data/CelebA/image64/000099.jpg 3 174 | /home/jaehyung/data/CelebA/image64/000103.jpg 3 175 | /home/jaehyung/data/CelebA/image64/000107.jpg 3 176 | /home/jaehyung/data/CelebA/image64/000110.jpg 3 177 | /home/jaehyung/data/CelebA/image64/000116.jpg 3 178 | /home/jaehyung/data/CelebA/image64/000136.jpg 3 179 | /home/jaehyung/data/CelebA/image64/000143.jpg 3 180 | /home/jaehyung/data/CelebA/image64/000151.jpg 3 181 | /home/jaehyung/data/CelebA/image64/000163.jpg 3 182 | /home/jaehyung/data/CelebA/image64/000165.jpg 3 183 | /home/jaehyung/data/CelebA/image64/000170.jpg 3 184 | /home/jaehyung/data/CelebA/image64/000171.jpg 3 185 | /home/jaehyung/data/CelebA/image64/000172.jpg 3 186 | /home/jaehyung/data/CelebA/image64/000175.jpg 3 187 | /home/jaehyung/data/CelebA/image64/000178.jpg 3 188 | /home/jaehyung/data/CelebA/image64/000180.jpg 3 189 | /home/jaehyung/data/CelebA/image64/000188.jpg 3 190 | /home/jaehyung/data/CelebA/image64/000189.jpg 3 191 | /home/jaehyung/data/CelebA/image64/000198.jpg 3 192 | /home/jaehyung/data/CelebA/image64/000200.jpg 3 193 | /home/jaehyung/data/CelebA/image64/000205.jpg 3 194 | /home/jaehyung/data/CelebA/image64/000208.jpg 3 195 | /home/jaehyung/data/CelebA/image64/000215.jpg 3 196 | /home/jaehyung/data/CelebA/image64/000219.jpg 3 197 | /home/jaehyung/data/CelebA/image64/000222.jpg 3 198 | /home/jaehyung/data/CelebA/image64/000225.jpg 3 199 | /home/jaehyung/data/CelebA/image64/000235.jpg 3 200 | /home/jaehyung/data/CelebA/image64/000238.jpg 3 201 | /home/jaehyung/data/CelebA/image64/000021.jpg 4 202 | /home/jaehyung/data/CelebA/image64/000051.jpg 4 203 | /home/jaehyung/data/CelebA/image64/000094.jpg 4 204 | /home/jaehyung/data/CelebA/image64/000125.jpg 4 205 | /home/jaehyung/data/CelebA/image64/000127.jpg 4 206 | /home/jaehyung/data/CelebA/image64/000202.jpg 4 207 | /home/jaehyung/data/CelebA/image64/000209.jpg 4 208 | /home/jaehyung/data/CelebA/image64/000213.jpg 4 209 | /home/jaehyung/data/CelebA/image64/000221.jpg 4 210 | /home/jaehyung/data/CelebA/image64/000224.jpg 4 211 | /home/jaehyung/data/CelebA/image64/000234.jpg 4 212 | /home/jaehyung/data/CelebA/image64/000252.jpg 4 213 | /home/jaehyung/data/CelebA/image64/000264.jpg 4 214 | /home/jaehyung/data/CelebA/image64/000295.jpg 4 215 | /home/jaehyung/data/CelebA/image64/000306.jpg 4 216 | /home/jaehyung/data/CelebA/image64/000310.jpg 4 217 | /home/jaehyung/data/CelebA/image64/000313.jpg 4 218 | /home/jaehyung/data/CelebA/image64/000316.jpg 4 219 | /home/jaehyung/data/CelebA/image64/000330.jpg 4 220 | /home/jaehyung/data/CelebA/image64/000383.jpg 4 221 | /home/jaehyung/data/CelebA/image64/000440.jpg 4 222 | /home/jaehyung/data/CelebA/image64/000444.jpg 4 223 | /home/jaehyung/data/CelebA/image64/000449.jpg 4 224 | /home/jaehyung/data/CelebA/image64/000498.jpg 4 225 | /home/jaehyung/data/CelebA/image64/000500.jpg 4 226 | /home/jaehyung/data/CelebA/image64/000518.jpg 4 227 | /home/jaehyung/data/CelebA/image64/000520.jpg 4 228 | /home/jaehyung/data/CelebA/image64/000562.jpg 4 229 | /home/jaehyung/data/CelebA/image64/000575.jpg 4 230 | /home/jaehyung/data/CelebA/image64/000578.jpg 4 231 | /home/jaehyung/data/CelebA/image64/000596.jpg 4 232 | /home/jaehyung/data/CelebA/image64/000608.jpg 4 233 | /home/jaehyung/data/CelebA/image64/000639.jpg 4 234 | /home/jaehyung/data/CelebA/image64/000677.jpg 4 235 | /home/jaehyung/data/CelebA/image64/000687.jpg 4 236 | /home/jaehyung/data/CelebA/image64/000719.jpg 4 237 | /home/jaehyung/data/CelebA/image64/000720.jpg 4 238 | /home/jaehyung/data/CelebA/image64/000779.jpg 4 239 | /home/jaehyung/data/CelebA/image64/000798.jpg 4 240 | /home/jaehyung/data/CelebA/image64/000817.jpg 4 241 | /home/jaehyung/data/CelebA/image64/000818.jpg 4 242 | /home/jaehyung/data/CelebA/image64/000846.jpg 4 243 | /home/jaehyung/data/CelebA/image64/000851.jpg 4 244 | /home/jaehyung/data/CelebA/image64/000853.jpg 4 245 | /home/jaehyung/data/CelebA/image64/000867.jpg 4 246 | /home/jaehyung/data/CelebA/image64/000891.jpg 4 247 | /home/jaehyung/data/CelebA/image64/000899.jpg 4 248 | /home/jaehyung/data/CelebA/image64/000903.jpg 4 249 | /home/jaehyung/data/CelebA/image64/000911.jpg 4 250 | /home/jaehyung/data/CelebA/image64/000915.jpg 4 251 | -------------------------------------------------------------------------------- /etc/celeb_loader.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | from PIL import Image 3 | import os 4 | import json 5 | from torchvision import transforms, utils, datasets 6 | import random 7 | import numpy as np 8 | 9 | from torch import cuda 10 | from torch.utils.data import Dataset, DataLoader, ConcatDataset 11 | from torch.utils.data.sampler import WeightedRandomSampler 12 | 13 | image_size = 32 14 | data_transforms = { 15 | 'train': transforms.Compose([ 16 | transforms.Resize(image_size), 17 | transforms.RandomCrop(image_size, padding=4), 18 | transforms.RandomHorizontalFlip(), 19 | transforms.ToTensor(), 20 | ]), 21 | 'val': transforms.Compose([ 22 | transforms.Resize(image_size), 23 | transforms.ToTensor(), 24 | ]), 25 | 'test': transforms.Compose([ 26 | transforms.Resize(image_size), 27 | transforms.ToTensor(), 28 | ]) 29 | } 30 | 31 | class LT_Dataset(Dataset): 32 | def __init__(self, root, txt, transform=None): 33 | self.img_path = [] 34 | self.labels = [] 35 | self.transform = transform 36 | with open(txt) as f: 37 | for line in f: 38 | self.img_path.append(os.path.join(root, line.split()[0])) 39 | self.labels.append(int(line.split()[1])) 40 | 41 | def __len__(self): 42 | return len(self.labels) 43 | 44 | def __getitem__(self, index): 45 | path = self.img_path[index] 46 | label = self.labels[index] 47 | 48 | with open(path, 'rb') as f: 49 | sample = Image.open(f).convert('RGB') 50 | 51 | if self.transform is not None: 52 | sample = self.transform(sample) 53 | 54 | return sample, label 55 | 56 | 57 | def default_loader(path): 58 | return Image.open(path).convert('RGB') 59 | 60 | 61 | def get_celeb_loader(batch_size, mode=False, smote=False, num_workers=16): 62 | txt_train = './CelebA/celebA_train_orig.txt' 63 | txt_val = './CelebA/celebA_val_orig.txt' 64 | txt_test = './CelebA/celebA_test_orig.txt' 65 | 66 | data_root = '/home/temp/data/CelebA/' 67 | 68 | set_train = LT_Dataset(data_root, txt_train, data_transforms['train']) 69 | set_val = LT_Dataset(data_root, txt_val, data_transforms['val']) 70 | set_test = LT_Dataset(data_root, txt_test, data_transforms['test']) 71 | 72 | train_loader = DataLoader(set_train, batch_size, shuffle=True, num_workers=num_workers,pin_memory=cuda.is_available()) 73 | val_loader = DataLoader(set_val, batch_size, shuffle=False, num_workers=num_workers, pin_memory=cuda.is_available()) 74 | test_loader = DataLoader(set_test, batch_size, shuffle=False, num_workers=num_workers, pin_memory=cuda.is_available()) 75 | 76 | return train_loader, val_loader, test_loader 77 | 78 | 79 | if __name__ == '__main__': 80 | 81 | for mode in ["train", "val", "test"]: 82 | loader = get_celeb_loader(128, mode, 4) 83 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .resnet32 import * 2 | -------------------------------------------------------------------------------- /models/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alinlab/M2m/42d08a5399c1b62925044287e7ee8d134260a08a/models/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/resnet32.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alinlab/M2m/42d08a5399c1b62925044287e7ee8d134260a08a/models/__pycache__/resnet32.cpython-37.pyc -------------------------------------------------------------------------------- /models/resnet32.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.nn import Parameter 5 | from torch.nn import init 6 | import math 7 | 8 | 9 | class DownsampleA(nn.Module): 10 | def __init__(self, nIn, nOut, stride): 11 | super(DownsampleA, self).__init__() 12 | assert stride == 2 13 | self.avg = nn.AvgPool2d(kernel_size=1, stride=stride) 14 | 15 | def forward(self, x): 16 | x = self.avg(x) 17 | return torch.cat((x, x.mul(0)), 1) 18 | 19 | 20 | class DownsampleC(nn.Module): 21 | def __init__(self, nIn, nOut, stride): 22 | super(DownsampleC, self).__init__() 23 | assert stride != 1 or nIn != nOut 24 | self.conv = nn.Conv2d(nIn, nOut, kernel_size=1, stride=stride, padding=0, bias=False) 25 | 26 | def forward(self, x): 27 | x = self.conv(x) 28 | return x 29 | 30 | 31 | class DownsampleD(nn.Module): 32 | def __init__(self, nIn, nOut, stride): 33 | super(DownsampleD, self).__init__() 34 | assert stride == 2 35 | self.conv = nn.Conv2d(nIn, nOut, kernel_size=2, stride=stride, padding=0, bias=False) 36 | self.bn = nn.BatchNorm2d(nOut) 37 | 38 | def forward(self, x): 39 | x = self.conv(x) 40 | x = self.bn(x) 41 | return x 42 | 43 | 44 | class NormedLinear(nn.Module): 45 | def __init__(self, in_features, out_features): 46 | super(NormedLinear, self).__init__() 47 | self.weight = Parameter(torch.Tensor(in_features, out_features)) 48 | self.weight.data.uniform_(-1, 1).renorm_(2, 1, 1e-5).mul_(1e5) 49 | 50 | def forward(self, x): 51 | out = F.normalize(x, dim=1).mm(F.normalize(self.weight, dim=0)) 52 | return out 53 | 54 | 55 | class ResNetBasicblock(nn.Module): 56 | expansion = 1 57 | """ 58 | RexNet basicblock (https://github.com/facebook/fb.resnet.torch/blob/master/models/resnet.lua) 59 | """ 60 | def __init__(self, inplanes, planes, stride=1, downsample=None): 61 | super(ResNetBasicblock, self).__init__() 62 | 63 | self.conv_a = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 64 | self.bn_a = nn.BatchNorm2d(planes) 65 | 66 | self.conv_b = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 67 | self.bn_b = nn.BatchNorm2d(planes) 68 | 69 | self.downsample = downsample 70 | 71 | def forward(self, x): 72 | residual = x 73 | 74 | basicblock = self.conv_a(x) 75 | basicblock = self.bn_a(basicblock) 76 | basicblock = F.relu(basicblock, inplace=True) 77 | 78 | basicblock = self.conv_b(basicblock) 79 | basicblock = self.bn_b(basicblock) 80 | 81 | if self.downsample is not None: 82 | residual = self.downsample(x) 83 | 84 | return F.relu(residual + basicblock, inplace=True) 85 | 86 | 87 | class CifarResNet(nn.Module): 88 | """ 89 | ResNet optimized for the Cifar dataset, as specified in 90 | https://arxiv.org/abs/1512.03385.pdf 91 | """ 92 | def __init__(self, block, depth, num_classes, normalized=False, gray=False): 93 | """ Constructor 94 | Args: 95 | depth: number of layers. 96 | num_classes: number of classes 97 | base_width: base width 98 | """ 99 | super(CifarResNet, self).__init__() 100 | 101 | #Model type specifies number of layers for CIFAR-10 and CIFAR-100 model 102 | assert (depth - 2) % 6 == 0, 'depth should be one of 20, 32, 44, 56, 110' 103 | layer_blocks = (depth - 2) // 6 104 | print ('CifarResNet : Depth : {} , Layers for each block : {}'.format(depth, layer_blocks)) 105 | 106 | self.num_classes = num_classes 107 | self.normalized = normalized 108 | self.gray = gray 109 | 110 | if self.gray: 111 | self.conv_1_3x3 = nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1, bias=False) 112 | else: 113 | self.conv_1_3x3 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False) 114 | self.bn_1 = nn.BatchNorm2d(16) 115 | 116 | self.inplanes = 16 117 | self.stage_1 = self._make_layer(block, 16, layer_blocks, 1) 118 | self.stage_2 = self._make_layer(block, 32, layer_blocks, 2) 119 | self.stage_3 = self._make_layer(block, 64, layer_blocks, 2) 120 | self.avgpool = nn.AdaptiveAvgPool2d((1,1)) 121 | if self.normalized: 122 | self.linear = NormedLinear(64, num_classes) 123 | else: 124 | self.linear = nn.Linear(64*block.expansion, num_classes) 125 | 126 | for m in self.modules(): 127 | if isinstance(m, nn.Conv2d): 128 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 129 | m.weight.data.normal_(0, math.sqrt(2. / n)) 130 | #m.bias.data.zero_() 131 | elif isinstance(m, nn.BatchNorm2d): 132 | m.weight.data.fill_(1) 133 | m.bias.data.zero_() 134 | elif isinstance(m, nn.Linear): 135 | init.kaiming_normal_(m.weight) 136 | m.bias.data.zero_() 137 | 138 | def _make_layer(self, block, planes, blocks, stride=1): 139 | downsample = None 140 | if stride != 1 or self.inplanes != planes * block.expansion: 141 | downsample = DownsampleA(self.inplanes, planes * block.expansion, stride) 142 | 143 | layers = [] 144 | layers.append(block(self.inplanes, planes, stride, downsample)) 145 | self.inplanes = planes * block.expansion 146 | for i in range(1, blocks): 147 | layers.append(block(self.inplanes, planes)) 148 | 149 | return nn.Sequential(*layers) 150 | 151 | def forward(self, x): 152 | x = self.conv_1_3x3(x) 153 | x1 = F.relu(self.bn_1(x), inplace=True) 154 | x2 = self.stage_1(x1) 155 | x3 = self.stage_2(x2) 156 | x = self.stage_3(x3) 157 | x = self.avgpool(x) 158 | x = x.view(x.size(0), -1) 159 | return self.linear(x), [x1, x2, x3, x] 160 | 161 | 162 | def resnet20(num_classes=10): 163 | """Constructs a ResNet-20 model for CIFAR-10 (by default) 164 | Args: 165 | num_classes (uint): number of classes 166 | """ 167 | model = CifarResNet(ResNetBasicblock, 20, num_classes) 168 | return model 169 | 170 | def resnet32(num_classes=10): 171 | """Constructs a ResNet-32 model for CIFAR-10 (by default) 172 | Args: 173 | num_classes (uint): number of classes 174 | """ 175 | model = CifarResNet(ResNetBasicblock, 32, num_classes) 176 | return model 177 | 178 | def resnet32_norm(num_classes=10): 179 | """Constructs a ResNet-32 model for CIFAR-10 (by default) 180 | Args: 181 | num_classes (uint): number of classes 182 | """ 183 | model = CifarResNet(ResNetBasicblock, 32, num_classes, True, False) 184 | return model 185 | 186 | def resnet32_gray(num_classes=10): 187 | """Constructs a ResNet-32 model for CIFAR-10 (by default) 188 | Args: 189 | num_classes (uint): number of classes 190 | """ 191 | model = CifarResNet(ResNetBasicblock, 32, num_classes, False, True) 192 | return model 193 | 194 | def resnet44(num_classes=10): 195 | """Constructs a ResNet-44 model for CIFAR-10 (by default) 196 | Args: 197 | num_classes (uint): number of classes 198 | """ 199 | model = CifarResNet(ResNetBasicblock, 44, num_classes) 200 | return model 201 | 202 | def resnet56(num_classes=397): 203 | """Constructs a ResNet-56 model for CIFAR-10 (by default) 204 | Args: 205 | num_classes (uint): number of classes 206 | """ 207 | model = CifarResNet(ResNetBasicblock, 56, num_classes) 208 | return model 209 | 210 | def resnet110(num_classes=10): 211 | """Constructs a ResNet-110 model for CIFAR-10 (by default) 212 | Args: 213 | num_classes (uint): number of classes 214 | """ 215 | model = CifarResNet(ResNetBasicblock, 110, num_classes) 216 | return model 217 | -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | # Script for running baseline (cross-entropy) 2 | 3 | CUDA_VISIBLE_DEVICES=0 python train.py --no_over --ratio 100 --decay 2e-4 --model resnet32 --dataset cifar10 --lr 0.1 --batch-size 128 --name 'ERM' --warm 200 --epoch 200 4 | 5 | # Script for running over-sampling 6 | 7 | CUDA_VISIBLE_DEVICES=0 python train.py --ratio 100 --decay 2e-4 --model resnet32 --dataset cifar10 --lr 0.1 --batch-size 128 --name 'Over' --warm 0 --epoch 200 8 | 9 | # SMOTE 10 | 11 | CUDA_VISIBLE_DEVICES=0 python train.py -s --ratio 100 --decay 2e-4 --model resnet32 --dataset cifar10 --lr 0.1 --batch-size 128 --name 'SMOTE' --warm 0 --epoch 200 12 | 13 | # Script for running re-weighting (RW) 14 | 15 | CUDA_VISIBLE_DEVICES=0 python train.py --no_over -c --eff_beta 1.0 --ratio 100 --decay 2e-4 --model resnet32 --dataset cifar10 --lr 0.1 --batch-size 128 --name 'Cost' --warm 0 --epoch 200 16 | 17 | # Script for running re-weighting with class-balanced loss (CB-RW) 18 | 19 | CUDA_VISIBLE_DEVICES=0 python train.py --no_over -c --eff_beta 0.999 --ratio 100 --decay 2e-4 --model resnet32 --dataset cifar10 --lr 0.1 --batch-size 128 --name 'CBLoss' --warm 0 --epoch 200 20 | 21 | # Script for running DRS 22 | 23 | CUDA_VISIBLE_DEVICES=0 python train.py --ratio 100 --decay 2e-4 --model resnet32 --dataset cifar10 --lr 0.1 --batch-size 128 --name 'DRS' --warm 160 --epoch 200 24 | 25 | # Script for running Focal 26 | 27 | CUDA_VISIBLE_DEVICES=0 python train.py --no_over --loss_type Focal --focal_gamma 1.0 --ratio 100 --decay 2e-4 --model resnet32 --dataset cifar10 --lr 0.1 --batch-size 128 --name 'Focal' --warm 160 --epoch 200 28 | 29 | # Script for running LDAM-DRW 30 | 31 | CUDA_VISIBLE_DEVICES=0 python train.py --no_over -c --loss_type LDAM --eff_beta 0.999 --ratio 100 --decay 2e-4 --model resnet32_norm --dataset cifar10 --lr 0.1 --batch-size 128 --name 'LDAM-DRW' --warm 160 --epoch 200 32 | 33 | # Script for running our method (M2m) 34 | CUDA_VISIBLE_DEVICES=0 python train.py -gen -r --ratio 100 --decay 2e-4 --model resnet32 --dataset cifar10 --lr 0.1 --batch-size 128 --name 'M2m' --beta 0.999 --lam 0.5 --gamma 0.9 --step_size 0.1 --attack_iter 10 --warm 160 --epoch 200 --net_g ./checkpoint/erm_r100_c10_trial1.t7 35 | 36 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 -u 2 | # Copyright (c) 2017-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the LICENSE file in 6 | # the root directory of this source tree. 7 | from __future__ import print_function 8 | 9 | import csv 10 | import os 11 | 12 | import numpy as np 13 | import torch 14 | from torch.autograd import Variable, grad 15 | import torch.nn as nn 16 | import torch.nn.functional as F 17 | import torch.optim as optim 18 | from tqdm import tqdm 19 | 20 | from utils import random_perturb, make_step, inf_data_gen, Logger 21 | from utils import soft_cross_entropy, classwise_loss, LDAMLoss, FocalLoss 22 | from config import * 23 | 24 | 25 | LOGNAME = 'Imbalance_' + LOGFILE_BASE 26 | logger = Logger(LOGNAME) 27 | LOGDIR = logger.logdir 28 | 29 | LOG_CSV = os.path.join(LOGDIR, f'log_{SEED}.csv') 30 | LOG_CSV_HEADER = [ 31 | 'epoch', 'train loss', 'gen loss', 'train acc', 'gen_acc', 'prob_orig', 'prob_targ', 32 | 'test loss', 'major test acc', 'neutral test acc', 'minor test acc', 'test acc', 'f1 score' 33 | ] 34 | if not os.path.exists(LOG_CSV): 35 | with open(LOG_CSV, 'w') as f: 36 | csv_writer = csv.writer(f, delimiter=',') 37 | csv_writer.writerow(LOG_CSV_HEADER) 38 | 39 | 40 | def save_checkpoint(acc, model, optim, epoch, index=False): 41 | # Save checkpoint. 42 | print('Saving..') 43 | 44 | if isinstance(model, nn.DataParallel): 45 | model = model.module 46 | 47 | state = { 48 | 'net': model.state_dict(), 49 | 'optimizer': optim.state_dict(), 50 | 'acc': acc, 51 | 'epoch': epoch, 52 | 'rng_state': torch.get_rng_state() 53 | } 54 | 55 | if index: 56 | ckpt_name = 'ckpt_epoch' + str(epoch) + '_' + str(SEED) + '.t7' 57 | else: 58 | ckpt_name = 'ckpt_' + str(SEED) + '.t7' 59 | 60 | ckpt_path = os.path.join(LOGDIR, ckpt_name) 61 | torch.save(state, ckpt_path) 62 | 63 | 64 | def train_epoch(model, criterion, optimizer, data_loader, logger=None): 65 | model.train() 66 | 67 | train_loss = 0 68 | correct = 0 69 | total = 0 70 | 71 | for inputs, targets in tqdm(data_loader): 72 | # For SMOTE, get the samples from smote_loader instead of usual loader 73 | if epoch >= ARGS.warm and ARGS.smote: 74 | inputs, targets = next(smote_loader_inf) 75 | 76 | inputs, targets = inputs.to(device), targets.to(device) 77 | batch_size = inputs.size(0) 78 | 79 | outputs, _ = model(normalizer(inputs)) 80 | loss = criterion(outputs, targets).mean() 81 | 82 | train_loss += loss.item() * batch_size 83 | predicted = outputs.max(1)[1] 84 | total += batch_size 85 | correct += sum_t(predicted.eq(targets)) 86 | 87 | optimizer.zero_grad() 88 | loss.backward() 89 | optimizer.step() 90 | 91 | msg = 'Loss: %.3f| Acc: %.3f%% (%d/%d)' % \ 92 | (train_loss / total, 100. * correct / total, correct, total) 93 | if logger: 94 | logger.log(msg) 95 | else: 96 | print(msg) 97 | 98 | return train_loss / total, 100. * correct / total 99 | 100 | 101 | def uniform_loss(outputs): 102 | weights = torch.ones_like(outputs) / N_CLASSES 103 | 104 | return soft_cross_entropy(outputs, weights, reduction='mean') 105 | 106 | 107 | def classwise_loss(outputs, targets): 108 | out_1hot = torch.zeros_like(outputs) 109 | out_1hot.scatter_(1, targets.view(-1, 1), 1) 110 | return (outputs * out_1hot).sum(1).mean() 111 | 112 | 113 | def generation(model_g, model_r, inputs, seed_targets, targets, p_accept, 114 | gamma, lam, step_size, random_start=True, max_iter=10): 115 | model_g.eval() 116 | model_r.eval() 117 | criterion = nn.CrossEntropyLoss() 118 | 119 | if random_start: 120 | random_noise = random_perturb(inputs, 'l2', 0.5) 121 | inputs = torch.clamp(inputs + random_noise, 0, 1) 122 | 123 | for _ in range(max_iter): 124 | inputs = inputs.clone().detach().requires_grad_(True) 125 | outputs_g, _ = model_g(normalizer(inputs)) 126 | outputs_r, _ = model_r(normalizer(inputs)) 127 | 128 | loss = criterion(outputs_g, targets) + lam * classwise_loss(outputs_r, seed_targets) 129 | grad, = torch.autograd.grad(loss, [inputs]) 130 | 131 | inputs = inputs - make_step(grad, 'l2', step_size) 132 | inputs = torch.clamp(inputs, 0, 1) 133 | 134 | inputs = inputs.detach() 135 | 136 | outputs_g, _ = model_g(normalizer(inputs)) 137 | 138 | one_hot = torch.zeros_like(outputs_g) 139 | one_hot.scatter_(1, targets.view(-1, 1), 1) 140 | probs_g = torch.softmax(outputs_g, dim=1)[one_hot.to(torch.bool)] 141 | 142 | correct = (probs_g >= gamma) * torch.bernoulli(p_accept).byte().to(device) 143 | model_r.train() 144 | 145 | return inputs, correct 146 | 147 | 148 | def train_net(model_train, model_gen, criterion, optimizer_train, inputs_orig, targets_orig, gen_idx, gen_targets): 149 | batch_size = inputs_orig.size(0) 150 | 151 | inputs = inputs_orig.clone() 152 | targets = targets_orig.clone() 153 | 154 | ######################## 155 | 156 | bs = N_SAMPLES_PER_CLASS_T[targets_orig].repeat(gen_idx.size(0), 1) 157 | gs = N_SAMPLES_PER_CLASS_T[gen_targets].view(-1, 1) 158 | 159 | delta = F.relu(bs - gs) 160 | p_accept = 1 - ARGS.beta ** delta 161 | mask_valid = (p_accept.sum(1) > 0) 162 | 163 | gen_idx = gen_idx[mask_valid] 164 | gen_targets = gen_targets[mask_valid] 165 | p_accept = p_accept[mask_valid] 166 | 167 | select_idx = torch.multinomial(p_accept, 1, replacement=True).view(-1) 168 | p_accept = p_accept.gather(1, select_idx.view(-1, 1)).view(-1) 169 | 170 | seed_targets = targets_orig[select_idx] 171 | seed_images = inputs_orig[select_idx] 172 | 173 | gen_inputs, correct_mask = generation(model_gen, model_train, seed_images, seed_targets, gen_targets, p_accept, 174 | ARGS.gamma, ARGS.lam, ARGS.step_size, True, ARGS.attack_iter) 175 | 176 | ######################## 177 | 178 | # Only change the correctly generated samples 179 | num_gen = sum_t(correct_mask) 180 | num_others = batch_size - num_gen 181 | 182 | gen_c_idx = gen_idx[correct_mask] 183 | others_mask = torch.ones(batch_size, dtype=torch.bool, device=device) 184 | others_mask[gen_c_idx] = 0 185 | others_idx = others_mask.nonzero().view(-1) 186 | 187 | if num_gen > 0: 188 | gen_inputs_c = gen_inputs[correct_mask] 189 | gen_targets_c = gen_targets[correct_mask] 190 | 191 | inputs[gen_c_idx] = gen_inputs_c 192 | targets[gen_c_idx] = gen_targets_c 193 | 194 | outputs, _ = model_train(normalizer(inputs)) 195 | loss = criterion(outputs, targets) 196 | 197 | optimizer_train.zero_grad() 198 | loss.mean().backward() 199 | optimizer_train.step() 200 | 201 | # For logging the training 202 | 203 | oth_loss_total = sum_t(loss[others_idx]) 204 | gen_loss_total = sum_t(loss[gen_c_idx]) 205 | 206 | _, predicted = torch.max(outputs[others_idx].data, 1) 207 | num_correct_oth = sum_t(predicted.eq(targets[others_idx])) 208 | 209 | num_correct_gen, p_g_orig, p_g_targ = 0, 0, 0 210 | success = torch.zeros(N_CLASSES, 2) 211 | 212 | if num_gen > 0: 213 | _, predicted_gen = torch.max(outputs[gen_c_idx].data, 1) 214 | num_correct_gen = sum_t(predicted_gen.eq(targets[gen_c_idx])) 215 | probs = torch.softmax(outputs[gen_c_idx], 1).data 216 | 217 | p_g_orig = probs.gather(1, seed_targets[correct_mask].view(-1, 1)) 218 | p_g_orig = sum_t(p_g_orig) 219 | 220 | p_g_targ = probs.gather(1, gen_targets_c.view(-1, 1)) 221 | p_g_targ = sum_t(p_g_targ) 222 | 223 | for i in range(N_CLASSES): 224 | if num_gen > 0: 225 | success[i, 0] = sum_t(gen_targets_c == i) 226 | success[i, 1] = sum_t(gen_targets == i) 227 | 228 | return oth_loss_total, gen_loss_total, num_others, num_correct_oth, num_gen, num_correct_gen, p_g_orig, p_g_targ, success 229 | 230 | 231 | def train_gen_epoch(net_t, net_g, criterion, optimizer, data_loader): 232 | net_t.train() 233 | net_g.eval() 234 | 235 | oth_loss, gen_loss = 0, 0 236 | correct_oth = 0 237 | correct_gen = 0 238 | total_oth, total_gen = 1e-6, 1e-6 239 | p_g_orig, p_g_targ = 0, 0 240 | t_success = torch.zeros(N_CLASSES, 2) 241 | 242 | for inputs, targets in tqdm(data_loader): 243 | batch_size = inputs.size(0) 244 | inputs, targets = inputs.to(device), targets.to(device) 245 | 246 | # Set a generation target for current batch with re-sampling 247 | if ARGS.imb_type != 'none': # Imbalanced 248 | # Keep the sample with this probability 249 | gen_probs = N_SAMPLES_PER_CLASS_T[targets] / N_SAMPLES_PER_CLASS_T[0] 250 | gen_index = (1 - torch.bernoulli(gen_probs)).nonzero() # Generation index 251 | gen_index = gen_index.view(-1) 252 | gen_targets = targets[gen_index] 253 | else: # Balanced 254 | gen_index = torch.arange(batch_size).view(-1) 255 | gen_targets = torch.randint(N_CLASSES, (batch_size,)).to(device).long() 256 | 257 | t_loss, g_loss, num_others, num_correct, num_gen, num_gen_correct, p_g_orig_batch, p_g_targ_batch, success \ 258 | = train_net(net_t, net_g, criterion, optimizer, inputs, targets, gen_index, gen_targets) 259 | 260 | oth_loss += t_loss 261 | gen_loss += g_loss 262 | total_oth += num_others 263 | correct_oth += num_correct 264 | total_gen += num_gen 265 | correct_gen += num_gen_correct 266 | p_g_orig += p_g_orig_batch 267 | p_g_targ += p_g_targ_batch 268 | t_success += success 269 | 270 | res = { 271 | 'train_loss': oth_loss / total_oth, 272 | 'gen_loss': gen_loss / total_gen, 273 | 'train_acc': 100. * correct_oth / total_oth, 274 | 'gen_acc': 100. * correct_gen / total_gen, 275 | 'p_g_orig': p_g_orig / total_gen, 276 | 'p_g_targ': p_g_targ / total_gen, 277 | 't_success': t_success 278 | } 279 | 280 | msg = 't_Loss: %.3f | g_Loss: %.3f | Acc: %.3f%% (%d/%d) | Acc_gen: %.3f%% (%d/%d) ' \ 281 | '| Prob_orig: %.3f | Prob_targ: %.3f' % ( 282 | res['train_loss'], res['gen_loss'], 283 | res['train_acc'], correct_oth, total_oth, 284 | res['gen_acc'], correct_gen, total_gen, 285 | res['p_g_orig'], res['p_g_targ'] 286 | ) 287 | if logger: 288 | logger.log(msg) 289 | else: 290 | print(msg) 291 | 292 | return res 293 | 294 | 295 | if __name__ == '__main__': 296 | TEST_ACC = 0 # best test accuracy 297 | BEST_VAL = 0 # best validation accuracy 298 | 299 | # Weights for virtual samples are generated 300 | logger.log('==> Building model: %s' % MODEL) 301 | net = models.__dict__[MODEL](N_CLASSES) 302 | net_seed = models.__dict__[MODEL](N_CLASSES) 303 | 304 | net, net_seed = net.to(device), net_seed.to(device) 305 | optimizer = optim.SGD(net.parameters(), lr=ARGS.lr, momentum=0.9, weight_decay=ARGS.decay) 306 | 307 | if ARGS.resume: 308 | # Load checkpoint. 309 | logger.log('==> Resuming from checkpoint..') 310 | ckpt_g = f'./checkpoint/{DATASET}/ratio{ARGS.ratio}/erm_trial1_{MODEL}.t7' 311 | 312 | if ARGS.net_both is not None: 313 | ckpt_t = torch.load(ARGS.net_both) 314 | net.load_state_dict(ckpt_t['net']) 315 | optimizer.load_state_dict(ckpt_t['optimizer']) 316 | START_EPOCH = ckpt_t['epoch'] + 1 317 | net_seed.load_state_dict(ckpt_t['net2']) 318 | else: 319 | if ARGS.net_t is not None: 320 | ckpt_t = torch.load(ARGS.net_t) 321 | net.load_state_dict(ckpt_t['net']) 322 | optimizer.load_state_dict(ckpt_t['optimizer']) 323 | START_EPOCH = ckpt_t['epoch'] + 1 324 | 325 | if ARGS.net_g is not None: 326 | ckpt_g = ARGS.net_g 327 | print(ckpt_g) 328 | ckpt_g = torch.load(ckpt_g) 329 | net_seed.load_state_dict(ckpt_g['net']) 330 | 331 | if N_GPUS > 1: 332 | logger.log('Multi-GPU mode: using %d GPUs for training.' % N_GPUS) 333 | net = nn.DataParallel(net) 334 | net_seed = nn.DataParallel(net_seed) 335 | elif N_GPUS == 1: 336 | logger.log('Single-GPU mode.') 337 | 338 | if ARGS.warm < START_EPOCH and ARGS.over: 339 | raise ValueError("warm < START_EPOCH") 340 | 341 | SUCCESS = torch.zeros(EPOCH, N_CLASSES, 2) 342 | test_stats = {} 343 | for epoch in range(START_EPOCH, EPOCH): 344 | logger.log(' * Epoch %d: %s' % (epoch, LOGDIR)) 345 | 346 | adjust_learning_rate(optimizer, LR, epoch) 347 | 348 | if epoch == ARGS.warm and ARGS.over: 349 | if ARGS.smote: 350 | logger.log("=============== Applying smote sampling ===============") 351 | smote_loader, _, _ = get_smote(DATASET, N_SAMPLES_PER_CLASS, BATCH_SIZE, transform_train, transform_test) 352 | smote_loader_inf = inf_data_gen(smote_loader) 353 | else: 354 | logger.log("=============== Applying over sampling ===============") 355 | train_loader, _, _ = get_oversampled(DATASET, N_SAMPLES_PER_CLASS, BATCH_SIZE, 356 | transform_train, transform_test) 357 | 358 | ## For Cost-Sensitive Learning ## 359 | 360 | if ARGS.cost and epoch >= ARGS.warm: 361 | beta = ARGS.eff_beta 362 | if beta < 1: 363 | effective_num = 1.0 - np.power(beta, N_SAMPLES_PER_CLASS) 364 | per_cls_weights = (1.0 - beta) / np.array(effective_num) 365 | else: 366 | per_cls_weights = 1 / np.array(N_SAMPLES_PER_CLASS) 367 | per_cls_weights = per_cls_weights / np.sum(per_cls_weights) * len(N_SAMPLES_PER_CLASS) 368 | per_cls_weights = torch.FloatTensor(per_cls_weights).to(device) 369 | else: 370 | per_cls_weights = torch.ones(N_CLASSES).to(device) 371 | 372 | ## Choos a loss function ## 373 | 374 | if ARGS.loss_type == 'CE': 375 | criterion = nn.CrossEntropyLoss(weight=per_cls_weights, reduction='none').to(device) 376 | elif ARGS.loss_type == 'Focal': 377 | criterion = FocalLoss(weight=per_cls_weights, gamma=ARGS.focal_gamma, reduction='none').to(device) 378 | elif ARGS.loss_type == 'LDAM': 379 | criterion = LDAMLoss(cls_num_list=N_SAMPLES_PER_CLASS, max_m=0.5, s=30, weight=per_cls_weights, 380 | reduction='none').to(device) 381 | else: 382 | raise ValueError("Wrong Loss Type") 383 | 384 | ## Training ( ARGS.warm is used for deferred re-balancing ) ## 385 | 386 | if epoch >= ARGS.warm and ARGS.gen: 387 | train_stats = train_gen_epoch(net, net_seed, criterion, optimizer, train_loader) 388 | SUCCESS[epoch, :, :] = train_stats['t_success'].float() 389 | logger.log(SUCCESS[epoch, -10:, :]) 390 | np.save(LOGDIR + '/success.npy', SUCCESS.cpu().numpy()) 391 | else: 392 | train_loss, train_acc = train_epoch(net, criterion, optimizer, train_loader, logger) 393 | train_stats = {'train_loss': train_loss, 'train_acc': train_acc} 394 | if epoch == 159: 395 | save_checkpoint(train_acc, net, optimizer, epoch, True) 396 | 397 | ## Evaluation ## 398 | 399 | val_eval = evaluate(net, val_loader, logger=logger) 400 | val_acc = val_eval['acc'] 401 | if val_acc >= BEST_VAL: 402 | BEST_VAL = val_acc 403 | 404 | test_stats = evaluate(net, test_loader, logger=logger) 405 | TEST_ACC = test_stats['acc'] 406 | TEST_ACC_CLASS = test_stats['class_acc'] 407 | 408 | save_checkpoint(TEST_ACC, net, optimizer, epoch) 409 | logger.log("========== Class-wise test performance ( avg : {} ) ==========".format(TEST_ACC_CLASS.mean())) 410 | np.save(LOGDIR + '/classwise_acc.npy', TEST_ACC_CLASS.cpu()) 411 | 412 | def _convert_scala(x): 413 | if hasattr(x, 'item'): 414 | x = x.item() 415 | return x 416 | 417 | log_tr = ['train_loss', 'gen_loss', 'train_acc', 'gen_acc', 'p_g_orig', 'p_g_targ'] 418 | log_te = ['loss', 'major_acc', 'neutral_acc', 'minor_acc', 'acc', 'f1_score'] 419 | 420 | log_vector = [epoch] + [train_stats.get(k, 0) for k in log_tr] + [test_stats.get(k, 0) for k in log_te] 421 | log_vector = list(map(_convert_scala, log_vector)) 422 | 423 | with open(LOG_CSV, 'a') as f: 424 | logwriter = csv.writer(f, delimiter=',') 425 | logwriter.writerow(log_vector) 426 | 427 | logger.log(' * %s' % LOGDIR) 428 | logger.log("Best Accuracy : {}".format(TEST_ACC)) 429 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | '''Some helper functions for PyTorch, including: 2 | - get_mean_and_std: calculate the mean and std value of dataset. 3 | - msr_init: net parameter initialization. 4 | - progress_bar: progress bar mimic xlua.progress. 5 | ''' 6 | import os 7 | import sys 8 | import time 9 | from datetime import datetime 10 | import shutil 11 | import math 12 | 13 | import torch 14 | import torch.nn as nn 15 | import torch.nn.functional as F 16 | import torch.nn.init as init 17 | import numpy as np 18 | import importlib 19 | 20 | def source_import(file_path): 21 | """This function imports python module directly from source code using importlib""" 22 | spec = importlib.util.spec_from_file_location('', file_path) 23 | module = importlib.util.module_from_spec(spec) 24 | spec.loader.exec_module(module) 25 | return module 26 | 27 | 28 | def sum_t(tensor): 29 | return tensor.float().sum().item() 30 | 31 | 32 | class InputNormalize(nn.Module): 33 | ''' 34 | A module (custom layer) for normalizing the input to have a fixed 35 | mean and standard deviation (user-specified). 36 | ''' 37 | def __init__(self, new_mean, new_std): 38 | super(InputNormalize, self).__init__() 39 | new_std = new_std[..., None, None].cuda() 40 | new_mean = new_mean[..., None, None].cuda() 41 | 42 | # To prevent the updates the mean, std 43 | self.register_buffer("new_mean", new_mean) 44 | self.register_buffer("new_std", new_std) 45 | 46 | def forward(self, x): 47 | x = torch.clamp(x, 0, 1) 48 | x_normalized = (x - self.new_mean)/self.new_std 49 | return x_normalized 50 | 51 | 52 | class Logger(object): 53 | """Reference: https://gist.github.com/gyglim/1f8dfb1b5c82627ae3efcfbbadb9f514""" 54 | def __init__(self, fn): 55 | if not os.path.exists("./logs/"): 56 | os.mkdir("./logs/") 57 | 58 | logdir = 'logs/' + fn 59 | if not os.path.exists(logdir): 60 | os.mkdir(logdir) 61 | if len(os.listdir(logdir)) != 0: 62 | ans = input("log_dir is not empty. All data inside log_dir will be deleted. " 63 | "Will you proceed [y/N]? ") 64 | if ans in ['y', 'Y']: 65 | shutil.rmtree(logdir) 66 | else: 67 | exit(1) 68 | self.set_dir(logdir) 69 | 70 | def set_dir(self, logdir, log_fn='log.txt'): 71 | self.logdir = logdir 72 | if not os.path.exists(logdir): 73 | os.mkdir(logdir) 74 | self.log_file = open(os.path.join(logdir, log_fn), 'a') 75 | 76 | def log(self, string): 77 | self.log_file.write('[%s] %s' % (datetime.now(), string) + '\n') 78 | self.log_file.flush() 79 | 80 | print('[%s] %s' % (datetime.now(), string)) 81 | sys.stdout.flush() 82 | 83 | def log_dirname(self, string): 84 | self.log_file.write('%s (%s)' % (string, self.logdir) + '\n') 85 | self.log_file.flush() 86 | 87 | print('%s (%s)' % (string, self.logdir)) 88 | sys.stdout.flush() 89 | 90 | 91 | ######## Loss ######## 92 | 93 | 94 | def soft_cross_entropy(input, labels, reduction='mean'): 95 | xent = (-labels * F.log_softmax(input, dim=1)).sum(1) 96 | if reduction == 'sum': 97 | return xent.sum() 98 | elif reduction == 'mean': 99 | return xent.mean() 100 | elif reduction == 'none': 101 | return xent 102 | else: 103 | raise NotImplementedError() 104 | 105 | 106 | def classwise_loss(outputs, targets): 107 | out_1hot = torch.ones_like(outputs) 108 | out_1hot.scatter_(1, targets.view(-1, 1), -1) 109 | return (outputs * out_1hot).mean() 110 | 111 | 112 | def focal_loss(input_values, gamma): 113 | """Computes the focal loss 114 | 115 | Reference: https://github.com/kaidic/LDAM-DRW/blob/master/losses.py 116 | """ 117 | p = torch.exp(-input_values) 118 | loss = (1 - p) ** gamma * input_values 119 | return loss 120 | 121 | 122 | class FocalLoss(nn.Module): 123 | """Reference: https://github.com/kaidic/LDAM-DRW/blob/master/losses.py""" 124 | def __init__(self, weight=None, gamma=0., reduction='mean'): 125 | super(FocalLoss, self).__init__() 126 | assert gamma >= 0 127 | self.gamma = gamma 128 | self.weight = weight 129 | self.reduction = reduction 130 | 131 | def forward(self, input, target): 132 | return focal_loss(F.cross_entropy(input, target, weight=self.weight, reduction=self.reduction), self.gamma) 133 | 134 | 135 | class LDAMLoss(nn.Module): 136 | """Reference: https://github.com/kaidic/LDAM-DRW/blob/master/losses.py""" 137 | def __init__(self, cls_num_list, max_m=0.5, weight=None, s=30, reduction='mean'): 138 | super(LDAMLoss, self).__init__() 139 | m_list = 1.0 / np.sqrt(np.sqrt(cls_num_list)) 140 | m_list = m_list * (max_m / np.max(m_list)) 141 | m_list = torch.cuda.FloatTensor(m_list) 142 | self.m_list = m_list 143 | self.scale = s 144 | self.weight = weight 145 | self.reduction = reduction 146 | 147 | def forward(self, x, target): 148 | index = torch.zeros_like(x, dtype=torch.uint8) 149 | index.scatter_(1, target.data.view(-1, 1), 1) 150 | 151 | index_float = index.type(torch.cuda.FloatTensor) 152 | batch_m = torch.matmul(self.m_list[None, :], index_float.transpose(0, 1)) 153 | batch_m = batch_m.view((-1, 1)) 154 | x_m = x - batch_m 155 | 156 | output = torch.where(index, x_m, x) 157 | return F.cross_entropy(self.scale * output, target, weight=self.weight, reduction=self.reduction) 158 | 159 | ######## Generation ######## 160 | 161 | 162 | def project(inputs, orig_inputs, attack, eps): 163 | diff = inputs - orig_inputs 164 | if attack == 'l2': 165 | diff = diff.renorm(p=2, dim=0, maxnorm=eps) 166 | elif attack == 'inf': 167 | diff = torch.clamp(diff, -eps, eps) 168 | return orig_inputs + diff 169 | 170 | 171 | def make_step(grad, attack, step_size): 172 | if attack == 'l2': 173 | grad_norm = torch.norm(grad.view(grad.shape[0], -1), dim=1).view(-1, 1, 1, 1) 174 | scaled_grad = grad / (grad_norm + 1e-10) 175 | step = step_size * scaled_grad 176 | elif attack == 'inf': 177 | step = step_size * torch.sign(grad) 178 | else: 179 | step = step_size * grad 180 | return step 181 | 182 | 183 | def random_perturb(inputs, attack, eps): 184 | if attack == 'inf': 185 | r_inputs = 2 * (torch.rand_like(inputs) - 0.5) * eps 186 | else: 187 | r_inputs = (torch.rand_like(inputs) - 0.5).renorm(p=2, dim=1, maxnorm=eps) 188 | return r_inputs 189 | 190 | 191 | ######## Data ######## 192 | 193 | 194 | def make_imb_data(max_num, min_num, class_num, gamma): 195 | class_idx = torch.arange(1, class_num + 1).float() 196 | ratio = max_num / min_num 197 | b = (torch.pow(class_idx[-1], gamma) - ratio) / (ratio - 1) 198 | a = max_num * (1 + b) 199 | class_num_list = [] 200 | for i in range(class_num): 201 | class_num_list.append(int(torch.round(a / (torch.pow(class_idx[i], gamma) + b)))) 202 | print(class_num_list) 203 | 204 | return list(class_num_list) 205 | 206 | 207 | def make_imb_data2(max_num, class_num, gamma): 208 | mu = np.power(1/gamma, 1/(class_num - 1)) 209 | print(mu) 210 | class_num_list = [] 211 | for i in range(class_num): 212 | class_num_list.append(int(max_num * np.power(mu, i))) 213 | 214 | return list(class_num_list) 215 | 216 | 217 | def inf_data_gen(dataloader): 218 | while True: 219 | for images, targets in dataloader: 220 | yield images, targets 221 | 222 | 223 | def source_import(file_path): 224 | """This function imports python module directly from source code using importlib""" 225 | spec = importlib.util.spec_from_file_location('', file_path) 226 | module = importlib.util.module_from_spec(spec) 227 | spec.loader.exec_module(module) 228 | return module 229 | 230 | 231 | def get_mean_and_std(dataset): 232 | '''Compute the mean and std value of dataset.''' 233 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=2) 234 | mean = torch.zeros(3) 235 | std = torch.zeros(3) 236 | print('==> Computing mean and std..') 237 | for inputs, targets in dataloader: 238 | for i in range(3): 239 | mean[i] += inputs[:,i,:,:].mean() 240 | std[i] += inputs[:,i,:,:].std() 241 | mean.div_(len(dataset)) 242 | std.div_(len(dataset)) 243 | return mean, std 244 | 245 | 246 | def init_params(net): 247 | '''Init layer parameters.''' 248 | for m in net.modules(): 249 | if isinstance(m, nn.Conv2d): 250 | init.kaiming_normal(m.weight, mode='fan_out') 251 | if m.bias: 252 | init.constant(m.bias, 0) 253 | elif isinstance(m, nn.BatchNorm2d): 254 | init.constant(m.weight, 1) 255 | init.constant(m.bias, 0) 256 | elif isinstance(m, nn.Linear): 257 | init.normal(m.weight, std=1e-3) 258 | if m.bias: 259 | init.constant(m.bias, 0) 260 | 261 | 262 | _, term_width = os.popen('stty size', 'r').read().split() 263 | term_width = int(term_width) 264 | 265 | TOTAL_BAR_LENGTH = 50. 266 | last_time = time.time() 267 | begin_time = last_time 268 | 269 | 270 | def progress_bar(current, total, msg=None): 271 | global last_time, begin_time 272 | if current == 0: 273 | begin_time = time.time() # Reset for new bar. 274 | 275 | cur_len = int(TOTAL_BAR_LENGTH*current/total) 276 | rest_len = int(TOTAL_BAR_LENGTH - cur_len) - 1 277 | 278 | sys.stdout.write(' [') 279 | for i in range(cur_len): 280 | sys.stdout.write('=') 281 | sys.stdout.write('>') 282 | for i in range(rest_len): 283 | sys.stdout.write('.') 284 | sys.stdout.write(']') 285 | 286 | cur_time = time.time() 287 | step_time = cur_time - last_time 288 | last_time = cur_time 289 | tot_time = cur_time - begin_time 290 | 291 | L = [] 292 | L.append(' Step: %s' % format_time(step_time)) 293 | L.append(' | Tot: %s' % format_time(tot_time)) 294 | if msg: 295 | L.append(' | ' + msg) 296 | 297 | msg = ''.join(L) 298 | sys.stdout.write(msg) 299 | for i in range(term_width-int(TOTAL_BAR_LENGTH)-len(msg)-3): 300 | sys.stdout.write(' ') 301 | 302 | # Go back to the center of the bar. 303 | for i in range(term_width-int(TOTAL_BAR_LENGTH/2)): 304 | sys.stdout.write('\b') 305 | sys.stdout.write(' %d/%d ' % (current+1, total)) 306 | 307 | if current < total-1: 308 | sys.stdout.write('\r') 309 | else: 310 | sys.stdout.write('\n') 311 | sys.stdout.flush() 312 | 313 | 314 | def format_time(seconds): 315 | days = int(seconds / 3600/24) 316 | seconds = seconds - days*3600*24 317 | hours = int(seconds / 3600) 318 | seconds = seconds - hours*3600 319 | minutes = int(seconds / 60) 320 | seconds = seconds - minutes*60 321 | secondsf = int(seconds) 322 | seconds = seconds - secondsf 323 | millis = int(seconds*1000) 324 | 325 | f = '' 326 | i = 1 327 | if days > 0: 328 | f += str(days) + 'D' 329 | i += 1 330 | if hours > 0 and i <= 2: 331 | f += str(hours) + 'h' 332 | i += 1 333 | if minutes > 0 and i <= 2: 334 | f += str(minutes) + 'm' 335 | i += 1 336 | if secondsf > 0 and i <= 2: 337 | f += str(secondsf) + 's' 338 | i += 1 339 | if millis > 0 and i <= 2: 340 | f += str(millis) + 'ms' 341 | i += 1 342 | if f == '': 343 | f = '0ms' 344 | return f 345 | --------------------------------------------------------------------------------