├── .gitignore ├── README.md ├── data_loader.py ├── experiments └── base_model │ └── params.json ├── model.py ├── run.sh ├── search_hyperparams.py ├── show_results.py ├── train_semi.py ├── utils.py └── vat.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__* 2 | data/* 3 | data 4 | .vscode* 5 | *.pyc 6 | 7 | experiments/* 8 | !experiments/base_model 9 | !experiments/embedding_omni 10 | !experiments/embedding_mini 11 | experiments/base_model/* 12 | experiments/embedding_omni/* 13 | experiments/embedding_mini/* 14 | !experiments/base_model/params.json 15 | !experiments/embedding_omni/params.json 16 | !experiments/embedding_mini/params.json 17 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # VAT-pytorch 2 | 3 | ## Hyper-parameters tuning 4 | We explore the range of epsilon by [2.5, 5] and use same ones for the others. 5 | The number of labeled data we use is 100. 6 | 7 | | epsilon | 2 | 2.5 | 3 | 3.5 | 4 | 4.5 | 5 | 8 | |:-------------:|:-----:|:-----:|:-----:|:-----:|:-----:|:-----:|:-----:| 9 | | Test accuracy | 97.35 | 97.18 | 97.45 | 97.21 | 97.36 | 98.34 | 97.84 | 10 | 11 | In the paper, the test accuracy is reported as 98.64 (+-0.03). 12 | Maybe this is because of rough hyper-parameter tuning or subtle implementation differences. 13 | If the implementation or experimental settings are wrong compared to the ones in original paper, 14 | please let me know. 15 | 16 | ## TODO 17 | - Improve performance. -------------------------------------------------------------------------------- /data_loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torchvision 4 | import numpy as np 5 | import torchvision.transforms as transforms 6 | from torch.utils.data import DataLoader 7 | from utils import Params 8 | 9 | 10 | def split_datasets(train_dataset, n_labels, n_val): 11 | """ 12 | Split train dataset into labeled one, unlabeled one, and validation set. 13 | """ 14 | n_classes = 10 15 | n_labels_per_class = n_labels / n_classes 16 | n_val_per_class = n_val / n_classes 17 | labels_indices = {c: [] for c in range(n_classes)} 18 | val_indices = {c: [] for c in range(n_classes)} 19 | 20 | rand_indices = [i for i in range(len(train_dataset))] 21 | # NOTE need seed (fixed) 22 | np.random.seed(1) 23 | np.random.shuffle(rand_indices) 24 | for idx in rand_indices: 25 | target = int(train_dataset[idx][1]) 26 | if len(labels_indices[target]) < n_labels_per_class: 27 | labels_indices[target].append(idx) 28 | elif len(val_indices[target]) < n_val_per_class: 29 | val_indices[target].append(idx) 30 | else: 31 | continue 32 | 33 | labels_set, val_set = [], [] 34 | for indices in labels_indices.values(): 35 | labels_set.extend(indices) 36 | for indices in val_indices.values(): 37 | val_set.extend(indices) 38 | assert len(labels_set) == n_labels 39 | assert len(val_set) == n_val 40 | 41 | return labels_set, val_set 42 | 43 | 44 | def fetch_dataloaders_MNIST(data_dir, params): 45 | """ 46 | Fetches the DataLoader objects for MNIST. 47 | """ 48 | 49 | # TODO "transform" for pertutation invariant MNIST 50 | train_dataset = torchvision.datasets.MNIST( 51 | data_dir, train=True, transform=transforms.ToTensor(), download=True) 52 | test_dataset = torchvision.datasets.MNIST( 53 | data_dir, train=False, transform=transforms.ToTensor(), download=True) 54 | labels_set, val_set = split_datasets(train_dataset, params.n_labels, 55 | params.n_val) 56 | unlabels_set = list(set(range(len(train_dataset))) - set(val_set)) 57 | labeled_dataset = torch.utils.data.Subset(train_dataset, labels_set) 58 | unlabeled_dataset = torch.utils.data.Subset(train_dataset, unlabels_set) 59 | val_dataset = torch.utils.data.Subset(train_dataset, val_set) 60 | 61 | dataloaders = {} 62 | dataloaders['label'] = DataLoader( 63 | labeled_dataset, batch_size=params.nll_batch_size, shuffle=True) 64 | dataloaders['unlabel'] = DataLoader( 65 | unlabeled_dataset, batch_size=params.vat_batch_size, shuffle=True) 66 | dataloaders['val'] = DataLoader( 67 | val_dataset, batch_size=params.vat_batch_size, shuffle=False) 68 | dataloaders['test'] = DataLoader( 69 | test_dataset, batch_size=params.vat_batch_size, shuffle=False) 70 | 71 | return dataloaders 72 | 73 | 74 | if __name__ == '__main__': 75 | data_dir = 'data/' 76 | json_path = os.path.join('experiments/base_model', 'params.json') 77 | params = Params(json_path) 78 | dataloaders = fetch_dataloaders_MNIST(data_dir, params) 79 | dl = dataloaders['label'] 80 | # x, y = dl.__iter__().next() 81 | # print(x, y) 82 | # x, y = dl.__iter__().next() 83 | # print(x, y) 84 | -------------------------------------------------------------------------------- /experiments/base_model/params.json: -------------------------------------------------------------------------------- 1 | { 2 | "SEED": 1, 3 | "alpha": 1, 4 | "epsilon": 3, 5 | "XI": 1e-6, 6 | "n_power": 1, 7 | "lr": 0.002, 8 | "n_iters": 100000, 9 | "n_labels": 100, 10 | "n_val": 1000, 11 | "nll_batch_size": 64, 12 | "vat_batch_size": 256, 13 | "n_summary_steps": 1000, 14 | "decay_iter": 50000, 15 | "decay_step_size": 10000, 16 | "decay_gamma": 0.5 17 | } -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class FFNN(nn.Module): 7 | """ 8 | Feed-Forward Neural Network (FFNN) for MNIST. 9 | Total 4 hidden layers are used as 28*28 -> (1200, 600, 300, 150) -> 10. 10 | We apply batchnorm and ReLU. 11 | We add isotropic noise to every hidden layer to stablize training. 12 | """ 13 | 14 | def __init__(self, params): 15 | super(FFNN, self).__init__() 16 | self.params = params 17 | self.fc1 = nn.Linear(28 * 28, 1200) 18 | self.fc2 = nn.Linear(1200, 600) 19 | self.fc3 = nn.Linear(600, 300) 20 | self.fc4 = nn.Linear(300, 150) 21 | self.fc5 = nn.Linear(150, 10) 22 | self.bn1 = nn.BatchNorm1d(1200) 23 | self.bn2 = nn.BatchNorm1d(600) 24 | self.bn3 = nn.BatchNorm1d(300) 25 | self.bn4 = nn.BatchNorm1d(150) 26 | 27 | def forward(self, X): 28 | out = X.view(X.size(0), -1) 29 | out = F.relu(self.bn1(self.fc1(out))) 30 | if self.training: out = out + out.clone().normal_(0, 0.5) 31 | out = F.relu(self.bn2(self.fc2(out))) 32 | if self.training: out = out + out.clone().normal_(0, 0.5) 33 | out = F.relu(self.bn3(self.fc3(out))) 34 | if self.training: out = out + out.clone().normal_(0, 0.5) 35 | out = F.relu(self.bn4(self.fc4(out))) 36 | if self.training: out = out + out.clone().normal_(0, 0.5) 37 | out = self.fc5(out) 38 | 39 | return out -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | python train_semi.py --model_dir experiments/epsilon1 4 | python train_semi.py --model_dir experiments/epsilon2 5 | python train_semi.py --model_dir experiments/epsilon3 6 | python train_semi.py --model_dir experiments/epsilon4 7 | python train_semi.py --model_dir experiments/epsilon5 -------------------------------------------------------------------------------- /search_hyperparams.py: -------------------------------------------------------------------------------- 1 | # Base code is from https://github.com/cs230-stanford/cs230-code-examples 2 | """Peform hyperparemeters search""" 3 | 4 | import argparse 5 | import os 6 | from subprocess import check_call 7 | from multiprocessing import Process 8 | import sys 9 | 10 | import utils 11 | 12 | PYTHON = sys.executable 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument( 15 | '--parent_dir', 16 | default='experiments/epsilon', 17 | help='Directory containing params.json') 18 | parser.add_argument( 19 | '--data_dir', default='data', help="Directory containing the dataset") 20 | 21 | 22 | def launch_training_job(parent_dir, data_dir, job_name, params, gpu_num): 23 | """Launch training of the model with a set of hyperparameters in parent_dir/job_name 24 | Args: 25 | model_dir: (string) directory containing config, weights and log 26 | data_dir: (string) directory containing the dataset 27 | params: (dict) containing hyperparameters 28 | """ 29 | # Create a new folder in parent_dir with unique_name "job_name" 30 | model_dir = os.path.join(parent_dir, job_name) 31 | if not os.path.exists(model_dir): 32 | os.makedirs(model_dir) 33 | 34 | # Write parameters in json file 35 | json_path = os.path.join(model_dir, 'params.json') 36 | params.save(json_path) 37 | 38 | # Launch training with this config 39 | cmd = "CUDA_VISIBLE_DEVICES={gpu_num} {python} train_semi.py --model_dir={model_dir} --data_dir {data_dir}".format( 40 | gpu_num=gpu_num, python=PYTHON, model_dir=model_dir, data_dir=data_dir) 41 | print(cmd) 42 | check_call(cmd, shell=True) 43 | 44 | 45 | if __name__ == "__main__": 46 | # Load the "reference" parameters from parent_dir json file 47 | args = parser.parse_args() 48 | json_path = os.path.join(args.parent_dir, 'params.json') 49 | assert os.path.isfile( 50 | json_path), "No json configuration file found at {}".format(json_path) 51 | params = utils.Params(json_path) 52 | 53 | # Perform hypersearch over one parameter 54 | epsilons = [2, 2.5, 3, 3.5, 4, 4.5, 5] 55 | seeds = [1, 2, 3] 56 | 57 | proc_args = [] 58 | for epsilon in epsilons: 59 | for seed in seeds: 60 | # Modify the relevant parameter in params 61 | params = utils.Params(json_path) 62 | params.epsilon = epsilon 63 | params.SEED = seed 64 | 65 | # Launch job (name has to be unique) 66 | job_name = "epsilon_{}_SEED_{}".format(epsilon, seed) 67 | proc_args.append( 68 | [args.parent_dir, args.data_dir, job_name, params]) 69 | 70 | num_workers = 1 71 | num_proc_per_worker = 3 72 | max_proc = num_workers * num_proc_per_worker 73 | procs = [] 74 | for count, proc_arg in enumerate(proc_args): 75 | gpu_num = count % num_workers 76 | proc = Process(target=launch_training_job, args=(*proc_arg, gpu_num)) 77 | procs.append(proc) 78 | proc.start() 79 | 80 | if (count + 1) % max_proc == 0 or (count + 1) == len(proc_args): 81 | for proc in procs: 82 | proc.join() 83 | procs = [] 84 | -------------------------------------------------------------------------------- /show_results.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import argparse 4 | from collections import OrderedDict 5 | 6 | parser = argparse.ArgumentParser() 7 | parser.add_argument( 8 | '--parent_dir', 9 | default='experiments/embedding_omni', 10 | help='Directory containing params.json') 11 | parser.add_argument('--tag', default='val', help='val or best') 12 | 13 | 14 | def iterate_parent_dir(parent_dir, tag): 15 | results = OrderedDict() 16 | for dirname in os.listdir(parent_dir): 17 | child_dir = os.path.join(parent_dir, dirname) 18 | if os.path.isdir(child_dir): 19 | for filename in os.listdir(child_dir): 20 | if filename == 'results.json': 21 | jsonname = os.path.join(child_dir, filename) 22 | results[dirname] = read_acc_from_json(jsonname) 23 | for key in sorted(results.keys()): 24 | print(key, results[key]) 25 | 26 | 27 | def read_acc_from_json(filename): 28 | with open(filename) as f: 29 | data = json.load(f) 30 | return float(data['Best ' + args.tag + ' score']) 31 | 32 | 33 | if __name__ == '__main__': 34 | args = parser.parse_args() 35 | iterate_parent_dir(args.parent_dir, args.tag) 36 | -------------------------------------------------------------------------------- /train_semi.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import argparse 4 | import logging 5 | import torch 6 | import torch.nn as nn 7 | from tqdm import tqdm 8 | 9 | import utils 10 | from model import FFNN 11 | from vat import VAT 12 | from data_loader import fetch_dataloaders_MNIST 13 | 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument( 16 | '--data_dir', default='data', help="Directory containing the dataset") 17 | parser.add_argument( 18 | '--model_dir', 19 | default='experiments/base_model', 20 | help="Directory containing params.json") 21 | 22 | 23 | def train_single_iter(model, optimizer, loss_fn, reg_fn, dl_label, dl_unlabel, 24 | params): 25 | model.train() 26 | 27 | label_X, label_y = dl_label.__iter__().next() 28 | unlabel_X, _ = dl_unlabel.__iter__().next() 29 | if params.cuda: 30 | label_X, label_y = label_X.cuda(async=True), label_y.cuda(async=True) 31 | unlabel_X = unlabel_X.cuda(async=True) 32 | 33 | label_logit = model(label_X) 34 | unlabel_logit = model(unlabel_X) 35 | nll = loss_fn(label_logit, label_y) 36 | vat = reg_fn(unlabel_X, unlabel_logit) 37 | loss = nll + vat 38 | optimizer.zero_grad() 39 | loss.backward() 40 | optimizer.step() 41 | 42 | return nll.item(), vat.item() 43 | 44 | 45 | def evalutate(model, dl, params): 46 | model.eval() 47 | 48 | total, correct = 0, 0 49 | for test_X, test_y in dl: 50 | if params.cuda: 51 | test_X, test_y = test_X.cuda(async=True), test_y.cuda(async=True) 52 | logit = model(test_X) 53 | preds = torch.argmax(logit, dim=1) 54 | correct += torch.sum(preds == test_y).item() 55 | total += preds.size(0) 56 | return float(correct / total) 57 | 58 | 59 | def train_and_evaluate(model, optimizer, scheduler, loss_fn, reg_fn, 60 | dataloaders, params): 61 | dl_label = dataloaders['label'] 62 | dl_unlabel = dataloaders['unlabel'] 63 | dl_val = dataloaders['val'] 64 | dl_test = dataloaders['test'] 65 | 66 | # training steps 67 | is_best = False 68 | best_val_score = -float('inf') 69 | best_test_score = -float('inf') 70 | plot_history = {'val_acc': [], 'test_acc': []} 71 | for step in tqdm(range(params.n_iters)): 72 | if step >= params.decay_iter: 73 | scheduler.step() 74 | nll, vat = train_single_iter(model, optimizer, loss_fn, reg_fn, 75 | dl_label, dl_unlabel, params) 76 | # report logs for each iter (mini-batch) 77 | logging.info( 78 | "Iteration {}/{} ; LOSS {:05.3f} ; NLL {:05.3f} ; VAT {:05.3f}". 79 | format(step + 1, params.n_iters, nll + vat, nll, vat)) 80 | if (step + 1) % params.n_summary_steps == 0: 81 | val_score = evalutate(model, dl_val, params) 82 | test_score = evalutate(model, dl_test, params) 83 | plot_history['val_acc'].append(val_score) 84 | plot_history['test_acc'].append(test_score) 85 | logging.info("Val_score {:05.3f} ; Test_score {:05.3f}".format( 86 | val_score, test_score)) 87 | is_best = val_score > best_val_score 88 | if is_best: 89 | best_val_score = val_score 90 | best_test_score = test_score 91 | logging.info("Found new best accuray") 92 | print('[{}] Val score was {}'.format(step + 1, val_score)) 93 | print('[{}] Test score was {}'.format(step + 1, test_score)) 94 | print('Best val score was {}'.format(best_val_score)) 95 | print('Best test score was {}'.format(best_test_score)) 96 | 97 | # Store results 98 | results = { 99 | 'Best val score': best_val_score, 100 | 'Best test score': best_test_score 101 | } 102 | utils.save_dict_to_json(results, 103 | os.path.join(args.model_dir, 'results.json')) 104 | utils.plot_training_results(args.model_dir, plot_history) 105 | 106 | 107 | if __name__ == '__main__': 108 | args = parser.parse_args() 109 | json_path = os.path.join(args.model_dir, 'params.json') 110 | assert os.path.isfile( 111 | json_path), "No json configuration file found at {}".format(json_path) 112 | params = utils.Params(json_path) 113 | 114 | # Use GPU if available 115 | params.cuda = torch.cuda.is_available() 116 | 117 | # Set the random seed for reproducible experiments 118 | torch.manual_seed(params.SEED) 119 | if params.cuda: torch.cuda.manual_seed(params.SEED) 120 | 121 | # Set the logger 122 | utils.set_logger(os.path.join(args.model_dir, 'train.log')) 123 | 124 | # Define the model and optimizer 125 | if params.cuda: 126 | model = FFNN(params).cuda() 127 | else: 128 | model = FFNN(params) 129 | # TODO learning rate decay linearly 130 | optimizer = torch.optim.Adam(model.parameters(), lr=params.lr) 131 | scheduler = torch.optim.lr_scheduler.StepLR( 132 | optimizer, params.decay_step_size, params.decay_gamma) 133 | 134 | # fetch loss function and metrics 135 | loss_fn = nn.CrossEntropyLoss() 136 | # define reg_fn 137 | reg_fn = VAT(model, params) 138 | 139 | # fetch MNIST dataloaders 140 | dataloaders = fetch_dataloaders_MNIST(args.data_dir, params) 141 | 142 | # Train the model 143 | train_and_evaluate(model, optimizer, scheduler, loss_fn, reg_fn, 144 | dataloaders, params) 145 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | # Base code is from https://github.com/cs230-stanford/cs230-code-examples 2 | import json 3 | import logging 4 | import os 5 | import shutil 6 | 7 | import torch 8 | import matplotlib.pyplot as plt 9 | 10 | 11 | class Params(): 12 | """Class that loads hyperparameters from a json file. 13 | Example: 14 | ``` 15 | params = Params(json_path) 16 | print(params.learning_rate) 17 | params.learning_rate = 0.5 # change the value of learning_rate in params 18 | ``` 19 | """ 20 | 21 | def __init__(self, json_path): 22 | with open(json_path) as f: 23 | params = json.load(f) 24 | self.__dict__.update(params) 25 | 26 | def save(self, json_path): 27 | with open(json_path, 'w') as f: 28 | json.dump(self.__dict__, f, indent=4) 29 | 30 | def update(self, json_path): 31 | """Loads parameters from json file""" 32 | with open(json_path) as f: 33 | params = json.load(f) 34 | self.__dict__.update(params) 35 | 36 | @property 37 | def dict(self): 38 | """Gives dict-like access to Params instance by `params.dict['learning_rate']""" 39 | return self.__dict__ 40 | 41 | 42 | class RunningAverage(): 43 | """A simple class that maintains the running average of a quantity 44 | 45 | Example: 46 | ``` 47 | loss_avg = RunningAverage() 48 | loss_avg.update(2) 49 | loss_avg.update(4) 50 | loss_avg() = 3 51 | ``` 52 | """ 53 | 54 | def __init__(self): 55 | self.steps = 0 56 | self.total = 0 57 | 58 | def update(self, val): 59 | self.total += val 60 | self.steps += 1 61 | 62 | def __call__(self): 63 | return self.total / float(self.steps) 64 | 65 | 66 | def set_logger(log_path): 67 | """Set the logger to log info in terminal and file `log_path`. 68 | In general, it is useful to have a logger so that every output to the terminal is saved 69 | in a permanent file. Here we save it to `model_dir/train.log`. 70 | Example: 71 | ``` 72 | logging.info("Starting training...") 73 | ``` 74 | Args: 75 | log_path: (string) where to log 76 | """ 77 | logger = logging.getLogger() 78 | logger.setLevel(logging.INFO) 79 | 80 | if not logger.handlers: 81 | # Logging to a file 82 | file_handler = logging.FileHandler(log_path) 83 | file_handler.setFormatter( 84 | logging.Formatter('%(asctime)s:%(levelname)s: %(message)s')) 85 | logger.addHandler(file_handler) 86 | 87 | # Logging to console NOTE 88 | # stream_handler = logging.StreamHandler() 89 | # stream_handler.setFormatter(logging.Formatter('%(message)s')) 90 | # logger.addHandler(stream_handler) 91 | 92 | 93 | def save_dict_to_json(d, json_path): 94 | """Saves dict of floats in json file 95 | Args: 96 | d: (dict) of float-castable values (np.float, int, float, etc.) 97 | json_path: (string) path to json file 98 | """ 99 | with open(json_path, 'w') as f: 100 | # We need to convert the values to float for json (it doesn't accept np.array, np.float, ) 101 | d = {k: float(v) for k, v in d.items()} 102 | json.dump(d, f, indent=4) 103 | 104 | 105 | def save_checkpoint(state, is_best, checkpoint): 106 | """Saves model and training parameters at checkpoint + 'last.pth.tar'. If is_best==True, also saves 107 | checkpoint + 'best.pth.tar' 108 | Args: 109 | state: (dict) contains model's state_dict, may contain other keys such as epoch, optimizer state_dict 110 | is_best: (bool) True if it is the best model seen till now 111 | checkpoint: (string) folder where parameters are to be saved 112 | """ 113 | filepath = os.path.join(checkpoint, 'last.pth.tar') 114 | if not os.path.exists(checkpoint): 115 | print("Checkpoint Directory does not exist! Making directory {}". 116 | format(checkpoint)) 117 | os.mkdir(checkpoint) 118 | else: 119 | # print("Checkpoint Directory exists! ") 120 | pass 121 | torch.save(state, filepath) 122 | if is_best: 123 | shutil.copyfile(filepath, os.path.join(checkpoint, 'best.pth.tar')) 124 | 125 | 126 | def load_checkpoint(checkpoint, model, optimizer=None): 127 | """Loads model parameters (state_dict) from file_path. If optimizer is provided, loads state_dict of 128 | optimizer assuming it is present in checkpoint. 129 | Args: 130 | checkpoint: (string) filename which needs to be loaded 131 | model: (torch.nn.Module) model for which the parameters are loaded 132 | optimizer: (torch.optim) optional: resume optimizer from checkpoint 133 | """ 134 | if not os.path.exists(checkpoint): 135 | raise ("File doesn't exist {}".format(checkpoint)) 136 | checkpoint = torch.load(checkpoint) 137 | model.load_state_dict(checkpoint['state_dict']) 138 | model.task_lr = checkpoint['task_lr_dict'] 139 | 140 | if optimizer: 141 | optimizer.load_state_dict(checkpoint['optim_dict']) 142 | 143 | return checkpoint 144 | 145 | 146 | def plot_training_results(model_dir, plot_history): 147 | """ 148 | Plot training results (procedure) during training. 149 | 150 | Args: 151 | plot_history: (dict) a dictionary containing historical values of what 152 | we want to plot 153 | """ 154 | # tr_losses = plot_history['train_loss'] 155 | # val_losses = plot_history['val_loss'] 156 | # te_losses = plot_history['test_loss'] 157 | # tr_accs = plot_history['train_acc'] 158 | val_accs = plot_history['val_acc'] 159 | te_accs = plot_history['test_acc'] 160 | 161 | # plt.figure(0) 162 | # plt.plot(list(range(len(tr_losses))), tr_losses, label='train_loss') 163 | # plt.plot(list(range(len(val_losses))), val_losses, label='val_loss') 164 | # plt.plot(list(range(len(te_losses))), te_losses, label='test_loss') 165 | # plt.title('Loss trend') 166 | # plt.xlabel('episode') 167 | # plt.ylabel('ce loss') 168 | # plt.legend() 169 | # plt.savefig(os.path.join(model_dir, 'loss_trend'), dpi=200) 170 | # plt.clf() 171 | 172 | plt.figure(1) 173 | # plt.plot(list(range(len(tr_accs))), tr_accs, label='train_acc') 174 | plt.plot(list(range(len(val_accs))), val_accs, label='val_acc') 175 | plt.plot(list(range(len(te_accs))), te_accs, label='test_acc') 176 | plt.title('Accuracy trend') 177 | plt.xlabel('iter / 1000') 178 | plt.ylabel('accuracy') 179 | plt.legend() 180 | plt.savefig(os.path.join(model_dir, 'accuracy_trend'), dpi=200) 181 | plt.clf() 182 | -------------------------------------------------------------------------------- /vat.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | # NOTE hyper-parameters we use in VAT 6 | # n_power: a number of power iteration for approximation of r_vadv 7 | # XI: a small float for the approx. of the finite difference method 8 | # epsilon: the value for how much deviate from original data point X 9 | 10 | 11 | class VAT(nn.Module): 12 | """ 13 | We define a function of regularization, specifically VAT. 14 | """ 15 | 16 | def __init__(self, model, params): 17 | super(VAT, self).__init__() 18 | self.model = model 19 | self.n_power = params.n_power 20 | self.XI = params.XI 21 | self.epsilon = params.epsilon 22 | 23 | def forward(self, X, logit): 24 | vat_loss = virtual_adversarial_loss(X, logit, self.model, self.n_power, 25 | self.XI, self.epsilon) 26 | return vat_loss # already averaged 27 | 28 | 29 | def kl_divergence_with_logit(q_logit, p_logit): 30 | q = F.softmax(q_logit, dim=1) 31 | qlogq = torch.mean(torch.sum(q * F.log_softmax(q_logit, dim=1), dim=1)) 32 | qlogp = torch.mean(torch.sum(q * F.log_softmax(p_logit, dim=1), dim=1)) 33 | return qlogq - qlogp 34 | 35 | 36 | def get_normalized_vector(d): 37 | d_abs_max = torch.max( 38 | torch.abs(d.view(d.size(0), -1)), 1, keepdim=True)[0].view( 39 | d.size(0), 1, 1, 1) 40 | # print(d_abs_max.size()) 41 | d /= (1e-12 + d_abs_max) 42 | d /= torch.sqrt(1e-6 + torch.sum( 43 | torch.pow(d, 2.0), tuple(range(1, len(d.size()))), keepdim=True)) 44 | # print(torch.norm(d.view(d.size(0), -1), dim=1)) 45 | return d 46 | 47 | 48 | def generate_virtual_adversarial_perturbation(x, logit, model, n_power, XI, 49 | epsilon): 50 | d = torch.randn_like(x) 51 | 52 | for _ in range(n_power): 53 | d = XI * get_normalized_vector(d).requires_grad_() 54 | logit_m = model(x + d) 55 | dist = kl_divergence_with_logit(logit, logit_m) 56 | grad = torch.autograd.grad(dist, [d])[0] 57 | d = grad.detach() 58 | 59 | return epsilon * get_normalized_vector(d) 60 | 61 | 62 | def virtual_adversarial_loss(x, logit, model, n_power, XI, epsilon): 63 | r_vadv = generate_virtual_adversarial_perturbation(x, logit, model, 64 | n_power, XI, epsilon) 65 | logit_p = logit.detach() 66 | logit_m = model(x + r_vadv) 67 | loss = kl_divergence_with_logit(logit_p, logit_m) 68 | return loss --------------------------------------------------------------------------------