├── CIFARFS ├── LoadUnlableCIFAR.py ├── MAMLMeta.py ├── MAML_TrainStd_CIFARFS.ipynb ├── MetaFT.py ├── attack.py ├── learner.py ├── metafgsm.py ├── resnet.py ├── train_trades_cifar.ipynb └── trainfgsmrs_cifar.ipynb ├── LICENSE ├── LoadDataST.py ├── LoadUnlableData.py ├── MAMLMeta.py ├── MAML_TrainStd.ipynb ├── MODELMETA.py ├── MetaFT.py ├── MiniImagenet.py ├── Omniglot ├── MAMLMeta.py ├── MAML_TrainStd_Omniglot.ipynb ├── MetaFTOmni.py ├── attack.py ├── learner.py ├── train_trades_omniglot.ipynb └── trainfgsm_omniglot.ipynb ├── README.md ├── StandardTrans.py ├── StandardTransAdv.ipynb ├── Visualization.py ├── attack.py ├── learner.py ├── mamlfgsmeps2.pt ├── mamltradesrseps2self.pt ├── metafgsm.py ├── metafgsminout.py ├── robust_vis_neuron.ipynb ├── train_trade.ipynb ├── trainfgsmrs.ipynb └── vis_tool.py /CIFARFS/LoadUnlableCIFAR.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | 4 | # In[ ]: 5 | 6 | 7 | import os 8 | import torch 9 | # from torch.utils.data import Dataset 10 | # from torchvision.transforms import transforms 11 | import numpy as np 12 | import collections 13 | import random 14 | 15 | 16 | class UnlabData(object): 17 | def __init__(self, seed=None): 18 | tinyimg = np.array(np.load("stl_select.npy", allow_pickle=True)) 19 | 20 | train_indices = np.arange(500) 21 | self.train_data = [] 22 | 23 | for i in range(64): 24 | temp_tiny = [tinyimg[i][j] for j in train_indices] 25 | temp_tiny = np.array(temp_tiny) 26 | self.train_data.append(DataSubset(temp_tiny)) 27 | 28 | class DataSubset(object): 29 | def __init__(self, xs, num_examples=None, seed=None): 30 | 31 | if seed is not None: 32 | np.random.seed(99) 33 | self.xs = xs 34 | self.n = len(xs) 35 | self.batch_start = 0 36 | self.cur_order = np.random.permutation(self.n) 37 | 38 | def get_next_batch(self, batch_size, multiple_passes=False, reshuffle_after_pass=True): 39 | # np.random.seed(99) 40 | if self.n < batch_size: 41 | raise ValueError('Batch size can be at most the dataset size') 42 | if not multiple_passes: 43 | actual_batch_size = min(batch_size, self.n - self.batch_start) 44 | if actual_batch_size <= 0: 45 | raise ValueError('Pass through the dataset is complete.') 46 | batch_end = self.batch_start + actual_batch_size 47 | batch_xs = self.xs[self.cur_order[self.batch_start : batch_end], ...] 48 | self.batch_start += actual_batch_size 49 | return batch_xs 50 | actual_batch_size = min(batch_size, self.n - self.batch_start) 51 | if actual_batch_size < batch_size: 52 | if reshuffle_after_pass: 53 | self.cur_order = np.random.permutation(self.n) 54 | self.batch_start = 0 55 | batch_end = self.batch_start + batch_size 56 | 57 | 58 | batch_xs = self.xs[self.cur_order[self.batch_start : batch_end], ...] 59 | 60 | 61 | self.batch_start += actual_batch_size 62 | return batch_xs 63 | 64 | 65 | 66 | 67 | 68 | 69 | if __name__ == '__main__': 70 | # the following episode is to view one set of images via tensorboard. 71 | from torchvision.utils import make_grid 72 | from matplotlib import pyplot as plt 73 | from tensorboardX import SummaryWriter 74 | import time 75 | 76 | -------------------------------------------------------------------------------- /CIFARFS/MAML_TrainStd_CIFARFS.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": { 7 | "scrolled": true 8 | }, 9 | "outputs": [], 10 | "source": [ 11 | "import torch, os\n", 12 | "import numpy as np\n", 13 | "#from MiniImagenet import MiniImagenet\n", 14 | "import scipy.stats\n", 15 | "from torch.utils.data import DataLoader\n", 16 | "from torch.optim import lr_scheduler\n", 17 | "import random, sys, pickle\n", 18 | "import argparse\n", 19 | "\n", 20 | "\n", 21 | "from torchmeta.datasets import CIFARFS\n", 22 | "from torchmeta.transforms import Categorical, ClassSplitter, Rotation\n", 23 | "from torchvision.transforms import Compose, Resize, ToTensor\n", 24 | "from torchmeta.utils.data import BatchMetaDataLoader\n", 25 | "\n", 26 | "\n", 27 | "from MAMLMeta import Meta\n", 28 | "\n", 29 | "\n", 30 | "\n", 31 | "\n", 32 | "\n", 33 | "def main():\n", 34 | "\n", 35 | " torch.manual_seed(222)\n", 36 | " torch.cuda.manual_seed_all(222)\n", 37 | " np.random.seed(222)\n", 38 | "\n", 39 | " print(args)\n", 40 | "\n", 41 | " config = [\n", 42 | " ('conv2d', [16, 3, 3, 3, 2, 1]),\n", 43 | " ('relu', [True]),\n", 44 | " ('bn', [16]),\n", 45 | " #('max_pool2d', [2, 2, 0]),\n", 46 | " ('conv2d', [16, 16, 3, 3, 1, 1]),\n", 47 | " ('relu', [True]),\n", 48 | " ('bn', [16]),\n", 49 | " ('max_pool2d', [2, 2, 0]),\n", 50 | " ('conv2d', [32, 16, 3, 3, 1, 1]),\n", 51 | " ('relu', [True]),\n", 52 | " ('bn', [32]),\n", 53 | " #('max_pool2d', [2, 2, 0]),\n", 54 | " ('conv2d', [32, 32, 3, 3, 1, 1]),\n", 55 | " ('relu', [True]),\n", 56 | " ('bn', [32]),\n", 57 | " ('max_pool2d', [2, 2, 0]),\n", 58 | " ('flatten', []),\n", 59 | " ('linear', [args.n_way, 32 * 4 * 4])\n", 60 | " ]\n", 61 | "\n", 62 | " device = torch.device('cuda:0')\n", 63 | " \n", 64 | " start_epoch = 0\n", 65 | " start_step = 0\n", 66 | " filename = 'maml_eps8_cifar_5shot.pt'\n", 67 | " maml = Meta(args, config).to(device)\n", 68 | " if os.path.isfile(filename):\n", 69 | " print(\"=> loading checkpoint '{}'\".format(filename))\n", 70 | " checkpoint = torch.load(filename)\n", 71 | " start_epoch = checkpoint['epoch']\n", 72 | " start_step = checkpoint['step']\n", 73 | " maml.net.load_state_dict(checkpoint['state_dict'])\n", 74 | " #maml = maml.to(device)\n", 75 | " print(\"=> loaded checkpoint '{}' (epoch {})\"\n", 76 | " .format(filename, checkpoint['epoch']))\n", 77 | " else:\n", 78 | " print(\"=> no checkpoint found at '{}'\".format(filename))\n", 79 | "\n", 80 | " tmp = filter(lambda x: x.requires_grad, maml.parameters())\n", 81 | " num = sum(map(lambda x: np.prod(x.shape), tmp))\n", 82 | " print(maml)\n", 83 | " print('Total trainable tensors:', num)\n", 84 | "\n", 85 | " # batchsz here means total episode number\n", 86 | " \n", 87 | " \n", 88 | " \n", 89 | " data_train = CIFARFS(\"data\",\n", 90 | " # Number of ways\n", 91 | " num_classes_per_task=args.n_way,\n", 92 | " meta_train=True,\n", 93 | " meta_val=False,\n", 94 | " meta_test=False,\n", 95 | " # Resize the images to 28x28 and converts them to PyTorch tensors (from Torchvision)\n", 96 | " transform=Compose([Resize(32), ToTensor()]),\n", 97 | " # Transform the labels to integers (e.g. (\"Glagolitic/character01\", \"Sanskrit/character14\", ...) to (0, 1, ...))\n", 98 | " target_transform=Categorical(num_classes=args.n_way),\n", 99 | " # Creates new virtual classes with rotated versions of the images (from Santoro et al., 2016)\n", 100 | " #class_augmentations=[Rotation([90, 180, 270])],\n", 101 | " download=True)\n", 102 | " \n", 103 | " data_test = CIFARFS(\"data\",\n", 104 | " # Number of ways\n", 105 | " num_classes_per_task=args.n_way,\n", 106 | " meta_train=False,\n", 107 | " meta_val=False,\n", 108 | " meta_test=True,\n", 109 | " # Resize the images to 28x28 and converts them to PyTorch tensors (from Torchvision)\n", 110 | " transform=Compose([Resize(32), ToTensor()]),\n", 111 | " # Transform the labels to integers (e.g. (\"Glagolitic/character01\", \"Sanskrit/character14\", ...) to (0, 1, ...))\n", 112 | " target_transform=Categorical(num_classes=args.n_way),\n", 113 | " # Creates new virtual classes with rotated versions of the images (from Santoro et al., 2016)\n", 114 | " #class_augmentations=[Rotation([90, 180, 270])],\n", 115 | " download=True)\n", 116 | " data_train = ClassSplitter(data_train, shuffle=True, num_train_per_class=args.k_spt, num_test_per_class=args.k_qry)\n", 117 | " data_test = ClassSplitter(data_test, shuffle=True, num_train_per_class=args.k_spt, num_test_per_class=args.k_qry)\n", 118 | " \n", 119 | " best_cl = 0\n", 120 | " for epoch in range(args.epoch//10000):\n", 121 | " # fetch meta_batchsz num of episode each time\n", 122 | " db = BatchMetaDataLoader(data_train, batch_size=args.task_num, num_workers=0)\n", 123 | " #db = DataLoader(data_train, args.task_num, shuffle=True, num_workers=0, pin_memory=True)\n", 124 | "\n", 125 | " for step, batch_train in enumerate(db):\n", 126 | " #for step, (x_spt, y_spt, x_qry, y_qry) in enumerate(db):\n", 127 | "\n", 128 | " x_spt, y_spt = batch_train[\"train\"]\n", 129 | " x_qry, y_qry = batch_train[\"test\"]\n", 130 | " x_spt, y_spt, x_qry, y_qry = x_spt.to(device), y_spt.to(device), x_qry.to(device), y_qry.to(device)\n", 131 | "\n", 132 | " accs, accs_adv = maml(x_spt, y_spt, x_qry, y_qry)\n", 133 | " \n", 134 | " if step % 50 == 0:\n", 135 | " print('step:', step, '\\ttraining acc:', accs)\n", 136 | " print('step:', step, '\\ttraining acc_adv:', accs_adv)\n", 137 | " state = {'epoch': epoch, 'step': step, 'state_dict': maml.net.state_dict()}\n", 138 | " #torch.save(state, 'mamleps2cifar5shot.pt')\n", 139 | "\n", 140 | " if step % 10000 == 0: # evaluation\n", 141 | " #db_test = DataLoader(data_test, 1, shuffle=True, num_workers=0, pin_memory=True)\n", 142 | " db_test = BatchMetaDataLoader(data_test, batch_size=1, num_workers=0)\n", 143 | " accs_all_test = []\n", 144 | " accsadv_all_test = []\n", 145 | " accsadvpr_all_test = []\n", 146 | "\n", 147 | " for step_t, batch_test in enumerate(db_test):\n", 148 | " x_spt, y_spt = batch_test[\"train\"]\n", 149 | " x_qry, y_qry = batch_test[\"test\"]\n", 150 | "\n", 151 | "\n", 152 | " x_spt, y_spt, x_qry, y_qry = x_spt.squeeze(0).to(device), y_spt.squeeze(0).to(device), \\\n", 153 | " x_qry.squeeze(0).to(device), y_qry.squeeze(0).to(device)\n", 154 | "\n", 155 | " accs, accs_adv, accs_adv_prior = maml.finetunning(x_spt, y_spt, x_qry, y_qry)\n", 156 | " accs_all_test.append(accs)\n", 157 | " accsadv_all_test.append(accs_adv)\n", 158 | " accsadvpr_all_test.append(accs_adv_prior)\n", 159 | " if step_t == 20000:\n", 160 | " break\n", 161 | "\n", 162 | " # [b, update_step+1]\n", 163 | " accs = np.array(accs_all_test).mean(axis=0).astype(np.float16)\n", 164 | " accs_adv = np.array(accsadv_all_test).mean(axis=0).astype(np.float16)\n", 165 | " accs_adv_prior = np.array(accsadvpr_all_test).mean(axis=0).astype(np.float16)\n", 166 | " print('Test acc:', accs)\n", 167 | " print('Test acc_adv:', accs_adv)\n", 168 | " print('Test acc_adv_prior:', accs_adv_prior)\n", 169 | " if best_cl < accs[-1]:\n", 170 | " torch.save(state, 'maml_eps8_cifar_5shot.pt')\n", 171 | " best_cl = accs[-1]\n", 172 | " print(best_cl)\n", 173 | "\n", 174 | "\n", 175 | "\n", 176 | "if __name__ == '__main__':\n", 177 | "\n", 178 | " argparser = argparse.ArgumentParser()\n", 179 | " argparser.add_argument('--epoch', type=int, help='epoch number', default=60000)\n", 180 | " argparser.add_argument('--n_way', type=int, help='n way', default=5)\n", 181 | " argparser.add_argument('--k_spt', type=int, help='k shot for support set', default=5)\n", 182 | " argparser.add_argument('--k_qry', type=int, help='k shot for query set', default=15)\n", 183 | " argparser.add_argument('--imgsz', type=int, help='imgsz', default=32)\n", 184 | " argparser.add_argument('--imgc', type=int, help='imgc', default=3)\n", 185 | " argparser.add_argument('--task_num', type=int, help='meta batch size, namely task num', default=4)\n", 186 | " argparser.add_argument('--meta_lr', type=float, help='meta-level outer learning rate', default=1e-3)\n", 187 | " argparser.add_argument('--update_lr', type=float, help='task-level inner update learning rate', default=0.01)\n", 188 | " argparser.add_argument('--update_step', type=int, help='task-level inner update steps', default=5)\n", 189 | " argparser.add_argument('--update_step_test', type=int, help='update steps for finetunning', default=10)\n", 190 | " \n", 191 | " #argparser.add_argument('--fast', action=\"store_true\", help='whether to use fgsm')\n", 192 | "\n", 193 | " args = argparser.parse_args(args=[])\n", 194 | "\n", 195 | " main()" 196 | ] 197 | }, 198 | { 199 | "cell_type": "code", 200 | "execution_count": null, 201 | "metadata": {}, 202 | "outputs": [], 203 | "source": [] 204 | } 205 | ], 206 | "metadata": { 207 | "kernelspec": { 208 | "display_name": "Python 3", 209 | "language": "python", 210 | "name": "python3" 211 | }, 212 | "language_info": { 213 | "codemirror_mode": { 214 | "name": "ipython", 215 | "version": 3 216 | }, 217 | "file_extension": ".py", 218 | "mimetype": "text/x-python", 219 | "name": "python", 220 | "nbconvert_exporter": "python", 221 | "pygments_lexer": "ipython3", 222 | "version": "3.7.6" 223 | } 224 | }, 225 | "nbformat": 4, 226 | "nbformat_minor": 4 227 | } 228 | -------------------------------------------------------------------------------- /CIFARFS/attack.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from abc import ABCMeta, abstractmethod, abstractproperty 4 | from torch.nn import functional as F 5 | 6 | class AttackBase(metaclass=ABCMeta): 7 | @abstractmethod 8 | def attack(self, net, inp, label, target = None): 9 | ''' 10 | 11 | :param inp: batched images 12 | :param target: specify the indexes of target class, None represents untargeted attack 13 | :return: batched adversaril images 14 | ''' 15 | pass 16 | 17 | @abstractmethod 18 | def to(self, device): 19 | pass 20 | 21 | 22 | 23 | def clip_eta(eta, norm, eps, DEVICE = torch.device('cuda:2')): 24 | ''' 25 | helper functions to project eta into epsilon norm ball 26 | :param eta: Perturbation tensor (should be of size(N, C, H, W)) 27 | :param norm: which norm. should be in [1, 2, np.inf] 28 | :param eps: epsilon, bound of the perturbation 29 | :return: Projected perturbation 30 | ''' 31 | 32 | assert norm in [1, 2, np.inf], "norm should be in [1, 2, np.inf]" 33 | 34 | with torch.no_grad(): 35 | avoid_zero_div = torch.tensor(1e-12).to(DEVICE) 36 | eps = torch.tensor(eps).to(DEVICE) 37 | one = torch.tensor(1.0).to(DEVICE) 38 | 39 | if norm == np.inf: 40 | eta = torch.clamp(eta, -eps, eps) 41 | else: 42 | normalize = torch.norm(eta.reshape(eta.size(0), -1), p = norm, dim = -1, keepdim = False) 43 | normalize = torch.max(normalize, avoid_zero_div) 44 | 45 | normalize.unsqueeze_(dim = -1) 46 | normalize.unsqueeze_(dim=-1) 47 | normalize.unsqueeze_(dim=-1) 48 | 49 | factor = torch.min(one, eps / normalize) 50 | eta = eta * factor 51 | return eta 52 | 53 | 54 | 55 | class PGD(AttackBase): 56 | # ImageNet pre-trained mean and std 57 | # _mean = torch.tensor(np.array([0.485, 0.456, 0.406]).astype(np.float32)[np.newaxis, :, np.newaxis, np.newaxis]) 58 | # _std = torch.tensor(np.array([0.229, 0.224, 0.225]).astype(np.float32)[np.newaxis, :, np.newaxis, np.newaxis]) 59 | 60 | # _mean = torch.tensor(np.array([0]).astype(np.float32)[np.newaxis, :, np.newaxis, np.newaxis]) 61 | # _std = torch.tensor(np.array([1.0]).astype(np.float32)[np.newaxis, :, np.newaxis, np.newaxis]) 62 | def __init__(self, eps = 6 / 255.0, sigma = 3 / 255.0, nb_iter = 20, 63 | norm = np.inf, DEVICE = torch.device('cuda:2'), 64 | mean = torch.tensor(np.array([0]).astype(np.float32)[np.newaxis, :, np.newaxis, np.newaxis]), 65 | std = torch.tensor(np.array([1.0]).astype(np.float32)[np.newaxis, :, np.newaxis, np.newaxis]), random_start = True): 66 | ''' 67 | :param eps: maximum distortion of adversarial examples 68 | :param sigma: single step size 69 | :param nb_iter: number of attack iterations 70 | :param norm: which norm to bound the perturbations 71 | ''' 72 | self.eps = eps 73 | self.sigma = sigma 74 | self.nb_iter = nb_iter 75 | self.norm = norm 76 | self.criterion = torch.nn.CrossEntropyLoss().to(DEVICE) 77 | self.DEVICE = DEVICE 78 | self._mean = mean.to(DEVICE) 79 | self._std = std.to(DEVICE) 80 | self.random_start = random_start 81 | 82 | def single_attack(self, net, para, inp, label, eta, target = None): 83 | ''' 84 | Given the original image and the perturbation computed so far, computes 85 | a new perturbation. 86 | :param net: 87 | :param inp: original image 88 | :param label: 89 | :param eta: perturbation computed so far 90 | :return: a new perturbation 91 | ''' 92 | 93 | adv_inp = inp + eta 94 | 95 | #net.zero_grad() 96 | 97 | pred = net(adv_inp, para) 98 | 99 | 100 | loss = self.criterion(pred, label) 101 | grad_sign = torch.autograd.grad(loss, adv_inp, 102 | only_inputs=True, retain_graph = False)[0].sign() 103 | 104 | adv_inp = adv_inp + grad_sign * (self.sigma / self._std) 105 | tmp_adv_inp = adv_inp * self._std + self._mean 106 | 107 | tmp_inp = inp * self._std + self._mean 108 | tmp_adv_inp = torch.clamp(tmp_adv_inp, 0, 1) 109 | tmp_eta = tmp_adv_inp - tmp_inp 110 | tmp_eta = clip_eta(tmp_eta, norm=self.norm, eps=self.eps, DEVICE=self.DEVICE) 111 | 112 | eta = tmp_eta/ self._std 113 | 114 | 115 | # adv_inp = adv_inp + grad_sign * self.eps 116 | # adv_inp = torch.clamp(adv_inp, 0, 1) 117 | # eta = adv_inp - inp 118 | # eta = clip_eta(eta, norm=self.norm, eps=self.eps, DEVICE=self.DEVICE) 119 | 120 | return eta 121 | 122 | def attack(self, net, para, inp, label, target = None): 123 | 124 | if self.random_start: 125 | eta = torch.FloatTensor(*inp.shape).uniform_(-self.eps, self.eps) 126 | else: 127 | eta = torch.zeros_like(inp) 128 | eta = eta.to(self.DEVICE) 129 | eta = (eta - self._mean) / self._std 130 | net.eval() 131 | #print(torch.min(torch.min(torch.min(inp[0])))) 132 | 133 | inp.requires_grad = True 134 | eta.requires_grad = True 135 | for i in range(self.nb_iter): 136 | eta = self.single_attack(net, para, inp, label, eta, target) 137 | #print(i) 138 | 139 | #print(eta.max()) 140 | adv_inp = inp + eta 141 | tmp_adv_inp = adv_inp * self._std + self._mean 142 | tmp_adv_inp = torch.clamp(tmp_adv_inp, 0, 1) 143 | adv_inp = (tmp_adv_inp - self._mean) / self._std 144 | 145 | return adv_inp 146 | 147 | def to(self, device): 148 | self.DEVICE = device 149 | self._mean = self._mean.to(device) 150 | self._std = self._std.to(device) 151 | self.criterion = self.criterion.to(device) -------------------------------------------------------------------------------- /CIFARFS/learner.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | 4 | # In[ ]: 5 | 6 | 7 | import torch 8 | from torch import nn 9 | from torch.nn import functional as F 10 | import numpy as np 11 | 12 | 13 | 14 | class Learner(nn.Module): 15 | """ 16 | """ 17 | 18 | def __init__(self, config, imgc, imgsz): 19 | """ 20 | :param config: network config file, type:list of (string, list) 21 | :param imgc: 1 or 3 22 | :param imgsz: 28 or 84 23 | """ 24 | super(Learner, self).__init__() 25 | 26 | 27 | self.config = config 28 | 29 | # this dict contains all tensors needed to be optimized 30 | self.vars = nn.ParameterList() 31 | # running_mean and running_var 32 | self.vars_bn = nn.ParameterList() 33 | 34 | for i, (name, param) in enumerate(self.config): 35 | if name is 'conv2d': 36 | # [ch_out, ch_in, kernelsz, kernelsz] 37 | w = nn.Parameter(torch.ones(*param[:4])) 38 | # gain=1 according to cbfin's implementation 39 | torch.nn.init.kaiming_normal_(w) 40 | self.vars.append(w) 41 | # [ch_out] 42 | self.vars.append(nn.Parameter(torch.zeros(param[0]))) 43 | 44 | elif name is 'convt2d': 45 | # [ch_in, ch_out, kernelsz, kernelsz, stride, padding] 46 | w = nn.Parameter(torch.ones(*param[:4])) 47 | # gain=1 according to cbfin's implementation 48 | torch.nn.init.kaiming_normal_(w) 49 | self.vars.append(w) 50 | # [ch_in, ch_out] 51 | self.vars.append(nn.Parameter(torch.zeros(param[1]))) 52 | 53 | elif name is 'linear': 54 | # [ch_out, ch_in] 55 | w = nn.Parameter(torch.ones(*param)) 56 | # gain=1 according to cbfinn's implementation 57 | torch.nn.init.kaiming_normal_(w) 58 | self.vars.append(w) 59 | # [ch_out] 60 | self.vars.append(nn.Parameter(torch.zeros(param[0]))) 61 | 62 | elif name is 'bn': 63 | # [ch_out] 64 | w = nn.Parameter(torch.ones(param[0])) 65 | self.vars.append(w) 66 | # [ch_out] 67 | self.vars.append(nn.Parameter(torch.zeros(param[0]))) 68 | 69 | # must set requires_grad=False 70 | running_mean = nn.Parameter(torch.zeros(param[0]), requires_grad=False) 71 | running_var = nn.Parameter(torch.ones(param[0]), requires_grad=False) 72 | self.vars_bn.extend([running_mean, running_var]) 73 | 74 | 75 | elif name in ['tanh', 'relu', 'upsample', 'avg_pool2d', 'max_pool2d', 76 | 'flatten', 'reshape', 'leakyrelu', 'sigmoid']: 77 | continue 78 | else: 79 | raise NotImplementedError 80 | 81 | 82 | 83 | 84 | 85 | 86 | def extra_repr(self): 87 | info = '' 88 | 89 | for name, param in self.config: 90 | if name is 'conv2d': 91 | tmp = 'conv2d:(ch_in:%d, ch_out:%d, k:%dx%d, stride:%d, padding:%d)' %(param[1], param[0], param[2], param[3], param[4], param[5],) 92 | info += tmp + '\n' 93 | 94 | elif name is 'convt2d': 95 | tmp = 'convTranspose2d:(ch_in:%d, ch_out:%d, k:%dx%d, stride:%d, padding:%d)' %(param[0], param[1], param[2], param[3], param[4], param[5],) 96 | info += tmp + '\n' 97 | 98 | elif name is 'linear': 99 | tmp = 'linear:(in:%d, out:%d)'%(param[1], param[0]) 100 | info += tmp + '\n' 101 | 102 | elif name is 'leakyrelu': 103 | tmp = 'leakyrelu:(slope:%f)'%(param[0]) 104 | info += tmp + '\n' 105 | 106 | 107 | elif name is 'avg_pool2d': 108 | tmp = 'avg_pool2d:(k:%d, stride:%d, padding:%d)'%(param[0], param[1], param[2]) 109 | info += tmp + '\n' 110 | elif name is 'max_pool2d': 111 | tmp = 'max_pool2d:(k:%d, stride:%d, padding:%d)'%(param[0], param[1], param[2]) 112 | info += tmp + '\n' 113 | elif name in ['flatten', 'tanh', 'relu', 'upsample', 'reshape', 'sigmoid', 'use_logits', 'bn']: 114 | tmp = name + ':' + str(tuple(param)) 115 | info += tmp + '\n' 116 | else: 117 | raise NotImplementedError 118 | 119 | return info 120 | 121 | 122 | 123 | def forward(self, x, vars=None, bn_training=True): 124 | """ 125 | This function can be called by finetunning, however, in finetunning, we dont wish to update 126 | running_mean/running_var. Thought weights/bias of bn is updated, it has been separated by fast_weights. 127 | Indeed, to not update running_mean/running_var, we need set update_bn_statistics=False 128 | but weight/bias will be updated and not dirty initial theta parameters via fast_weiths. 129 | :param x: [b, 1, 28, 28] 130 | :param vars: 131 | :param bn_training: set False to not update 132 | :return: x, loss, likelihood, kld 133 | """ 134 | 135 | if vars is None: 136 | vars = self.vars 137 | 138 | idx = 0 139 | bn_idx = 0 140 | 141 | for name, param in self.config: 142 | if name is 'conv2d': 143 | w, b = vars[idx], vars[idx + 1] 144 | # remember to keep synchrozied of forward_encoder and forward_decoder! 145 | x = F.conv2d(x, w, b, stride=param[4], padding=param[5]) 146 | idx += 2 147 | # print(name, param, '\tout:', x.shape) 148 | elif name is 'convt2d': 149 | w, b = vars[idx], vars[idx + 1] 150 | # remember to keep synchrozied of forward_encoder and forward_decoder! 151 | x = F.conv_transpose2d(x, w, b, stride=param[4], padding=param[5]) 152 | idx += 2 153 | # print(name, param, '\tout:', x.shape) 154 | elif name is 'linear': 155 | w, b = vars[idx], vars[idx + 1] 156 | x = F.linear(x, w, b) 157 | idx += 2 158 | # print('forward:', idx, x.norm().item()) 159 | elif name is 'bn': 160 | w, b = vars[idx], vars[idx + 1] 161 | running_mean, running_var = self.vars_bn[bn_idx], self.vars_bn[bn_idx+1] 162 | x = F.batch_norm(x, running_mean, running_var, weight=w, bias=b, training=bn_training) 163 | idx += 2 164 | bn_idx += 2 165 | 166 | elif name is 'flatten': 167 | # print(x.shape) 168 | x = x.view(x.size(0), -1) 169 | elif name is 'reshape': 170 | # [b, 8] => [b, 2, 2, 2] 171 | x = x.view(x.size(0), *param) 172 | elif name is 'relu': 173 | x = F.relu(x, inplace=param[0]) 174 | elif name is 'leakyrelu': 175 | x = F.leaky_relu(x, negative_slope=param[0], inplace=param[1]) 176 | elif name is 'tanh': 177 | x = F.tanh(x) 178 | elif name is 'sigmoid': 179 | x = torch.sigmoid(x) 180 | elif name is 'upsample': 181 | x = F.upsample_nearest(x, scale_factor=param[0]) 182 | elif name is 'max_pool2d': 183 | x = F.max_pool2d(x, param[0], param[1], param[2]) 184 | elif name is 'avg_pool2d': 185 | x = F.avg_pool2d(x, param[0], param[1], param[2]) 186 | 187 | else: 188 | raise NotImplementedError 189 | 190 | # make sure variable is used properly 191 | assert idx == len(vars) 192 | assert bn_idx == len(self.vars_bn) 193 | 194 | 195 | return x 196 | 197 | 198 | def zero_grad(self, vars=None): 199 | """ 200 | :param vars: 201 | :return: 202 | """ 203 | with torch.no_grad(): 204 | if vars is None: 205 | for p in self.vars: 206 | if p.grad is not None: 207 | p.grad.zero_() 208 | else: 209 | for p in vars: 210 | if p.grad is not None: 211 | p.grad.zero_() 212 | 213 | def parameters(self): 214 | """ 215 | override this function since initial parameters will return with a generator. 216 | :return: 217 | """ 218 | return self.vars 219 | 220 | -------------------------------------------------------------------------------- /CIFARFS/resnet.py: -------------------------------------------------------------------------------- 1 | ''' 2 | 3 | Reference: 4 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 5 | Deep Residual Learning for Image Recognition. arXiv:1512.03385 6 | ''' 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | 12 | class BasicBlock(nn.Module): 13 | expansion = 1 14 | 15 | def __init__(self, in_planes, planes, stride=1): 16 | super(BasicBlock, self).__init__() 17 | self.conv1 = nn.Conv2d( 18 | in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 19 | self.bn1 = nn.BatchNorm2d(planes) 20 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, 21 | stride=1, padding=1, bias=False) 22 | self.bn2 = nn.BatchNorm2d(planes) 23 | 24 | self.shortcut = nn.Sequential() 25 | if stride != 1 or in_planes != self.expansion*planes: 26 | self.shortcut = nn.Sequential( 27 | nn.Conv2d(in_planes, self.expansion*planes, 28 | kernel_size=1, stride=stride, bias=False), 29 | nn.BatchNorm2d(self.expansion*planes) 30 | ) 31 | 32 | def forward(self, x): 33 | out = F.relu(self.bn1(self.conv1(x))) 34 | out = self.bn2(self.conv2(out)) 35 | out += self.shortcut(x) 36 | out = F.relu(out) 37 | return out 38 | 39 | 40 | class Bottleneck(nn.Module): 41 | expansion = 4 42 | 43 | def __init__(self, in_planes, planes, stride=1): 44 | super(Bottleneck, self).__init__() 45 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 46 | self.bn1 = nn.BatchNorm2d(planes) 47 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, 48 | stride=stride, padding=1, bias=False) 49 | self.bn2 = nn.BatchNorm2d(planes) 50 | self.conv3 = nn.Conv2d(planes, self.expansion * 51 | planes, kernel_size=1, bias=False) 52 | self.bn3 = nn.BatchNorm2d(self.expansion*planes) 53 | 54 | self.shortcut = nn.Sequential() 55 | if stride != 1 or in_planes != self.expansion*planes: 56 | self.shortcut = nn.Sequential( 57 | nn.Conv2d(in_planes, self.expansion*planes, 58 | kernel_size=1, stride=stride, bias=False), 59 | nn.BatchNorm2d(self.expansion*planes) 60 | ) 61 | 62 | def forward(self, x): 63 | out = F.relu(self.bn1(self.conv1(x))) 64 | out = F.relu(self.bn2(self.conv2(out))) 65 | out = self.bn3(self.conv3(out)) 66 | out += self.shortcut(x) 67 | out = F.relu(out) 68 | return out 69 | 70 | 71 | class ResNet(nn.Module): 72 | def __init__(self, block, num_blocks, num_classes=64): 73 | super(ResNet, self).__init__() 74 | self.in_planes = 64 75 | 76 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, 77 | stride=1, padding=1, bias=False) 78 | self.bn1 = nn.BatchNorm2d(64) 79 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 80 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 81 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 82 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 83 | self.linear = nn.Linear(512*block.expansion, num_classes) 84 | 85 | def _make_layer(self, block, planes, num_blocks, stride): 86 | strides = [stride] + [1]*(num_blocks-1) 87 | layers = [] 88 | for stride in strides: 89 | layers.append(block(self.in_planes, planes, stride)) 90 | self.in_planes = planes * block.expansion 91 | return nn.Sequential(*layers) 92 | 93 | def forward(self, x): 94 | out0 = F.relu(self.bn1(self.conv1(x))) 95 | out1 = self.layer1(out0) 96 | out2 = self.layer2(out1) 97 | out3 = self.layer3(out2) 98 | out3d5 = self.layer4(out3) 99 | out4 = F.avg_pool2d(out3d5, 4) 100 | out5 = out4.view(out4.size(0), -1) 101 | out = self.linear(out5) 102 | return out1, out2, out3, out4, out 103 | 104 | 105 | def ResNet18(): 106 | return ResNet(BasicBlock, [2, 2, 2, 2]) 107 | 108 | 109 | def ResNet34(): 110 | return ResNet(BasicBlock, [3, 4, 6, 3]) 111 | 112 | 113 | def ResNet50(): 114 | return ResNet(Bottleneck, [3, 4, 6, 3]) 115 | 116 | 117 | def ResNet101(): 118 | return ResNet(Bottleneck, [3, 4, 23, 3]) 119 | 120 | 121 | def ResNet152(): 122 | return ResNet(Bottleneck, [3, 8, 36, 3]) 123 | 124 | 125 | def test(): 126 | net = ResNet18() 127 | y = net(torch.randn(1, 3, 32, 32)) 128 | print(y.size()) 129 | 130 | # test() 131 | -------------------------------------------------------------------------------- /CIFARFS/train_trades_cifar.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": { 7 | "scrolled": true 8 | }, 9 | "outputs": [], 10 | "source": [ 11 | "import torch, os\n", 12 | "import numpy as np\n", 13 | "import scipy.stats\n", 14 | "from torch.utils.data import DataLoader\n", 15 | "from torch.optim import lr_scheduler\n", 16 | "import random, sys, pickle\n", 17 | "import argparse\n", 18 | "\n", 19 | "\n", 20 | "from LoadUnlableCIFAR import UnlabData\n", 21 | "from torchmeta.datasets import CIFARFS\n", 22 | "from torchmeta.transforms import Categorical, ClassSplitter, Rotation\n", 23 | "from torchvision.transforms import Compose, Resize, ToTensor\n", 24 | "from torchmeta.utils.data import BatchMetaDataLoader\n", 25 | "\n", 26 | "from resnet import ResNet18\n", 27 | "from MetaFT import Meta\n", 28 | "\n", 29 | "\n", 30 | "\n", 31 | "def mean_confidence_interval(accs, confidence=0.95):\n", 32 | " n = accs.shape[0]\n", 33 | " m, se = np.mean(accs), scipy.stats.sem(accs)\n", 34 | " h = se * scipy.stats.t._ppf((1 + confidence) / 2, n - 1)\n", 35 | " return m, h\n", 36 | "\n", 37 | "\n", 38 | "def main():\n", 39 | "\n", 40 | " torch.manual_seed(222)\n", 41 | " torch.cuda.manual_seed_all(222)\n", 42 | " np.random.seed(222)\n", 43 | "\n", 44 | " print(args)\n", 45 | "\n", 46 | " config = [\n", 47 | " ('conv2d', [16, 3, 3, 3, 2, 1]),\n", 48 | " ('relu', [True]),\n", 49 | " ('bn', [16]),\n", 50 | " #('max_pool2d', [2, 2, 0]),\n", 51 | " ('conv2d', [16, 16, 3, 3, 1, 1]),\n", 52 | " ('relu', [True]),\n", 53 | " ('bn', [16]),\n", 54 | " ('max_pool2d', [2, 2, 0]),\n", 55 | " ('conv2d', [32, 16, 3, 3, 1, 1]),\n", 56 | " ('relu', [True]),\n", 57 | " ('bn', [32]),\n", 58 | " #('max_pool2d', [2, 2, 0]),\n", 59 | " ('conv2d', [32, 32, 3, 3, 1, 1]),\n", 60 | " ('relu', [True]),\n", 61 | " ('bn', [32]),\n", 62 | " ('max_pool2d', [2, 2, 0]),\n", 63 | " ('flatten', []),\n", 64 | " ('linear', [args.n_way, 32 * 4 * 4])\n", 65 | " ]\n", 66 | "\n", 67 | " device = torch.device('cuda')\n", 68 | " \n", 69 | " best_cl = 0\n", 70 | " best_rb = 0\n", 71 | " filename = 'mamltrades_unlab_eps8_cifar_5shot.pt'\n", 72 | " maml = Meta(args, config, device).to(device)\n", 73 | " if os.path.isfile(filename):\n", 74 | " print(\"=> loading checkpoint '{}'\".format(filename))\n", 75 | " checkpoint = torch.load(filename)\n", 76 | " best_cl = checkpoint['cl']\n", 77 | " best_rb = checkpoint['rb']\n", 78 | " maml.net.load_state_dict(checkpoint['state_dict'])\n", 79 | " #maml = maml.to(device)\n", 80 | "# print(\"=> loaded checkpoint '{}' (epoch {})\"\n", 81 | "# .format(filename, checkpoint['epoch']))\n", 82 | " else:\n", 83 | " print(\"=> no checkpoint found at '{}'\".format(filename))\n", 84 | "\n", 85 | " tmp = filter(lambda x: x.requires_grad, maml.parameters())\n", 86 | " num = sum(map(lambda x: np.prod(x.shape), tmp))\n", 87 | " print(maml)\n", 88 | " print('Total trainable tensors:', num)\n", 89 | "\n", 90 | " # batchsz here means total episode number\n", 91 | " \n", 92 | " \n", 93 | " \n", 94 | "\n", 95 | " data_train = CIFARFS(\"data\",\n", 96 | " # Number of ways\n", 97 | " num_classes_per_task=args.n_way,\n", 98 | " meta_train=True,\n", 99 | " meta_val=False,\n", 100 | " meta_test=False,\n", 101 | " # Resize the images to 28x28 and converts them to PyTorch tensors (from Torchvision)\n", 102 | " transform=Compose([Resize(32), ToTensor()]),\n", 103 | " # Transform the labels to integers (e.g. (\"Glagolitic/character01\", \"Sanskrit/character14\", ...) to (0, 1, ...))\n", 104 | " target_transform=Categorical(num_classes=args.n_way),\n", 105 | " # Creates new virtual classes with rotated versions of the images (from Santoro et al., 2016)\n", 106 | " #class_augmentations=[Rotation([90, 180, 270])],\n", 107 | " download=True)\n", 108 | " \n", 109 | " data_test = CIFARFS(\"data\",\n", 110 | " # Number of ways\n", 111 | " num_classes_per_task=args.n_way,\n", 112 | " meta_train=False,\n", 113 | " meta_val=False,\n", 114 | " meta_test=True,\n", 115 | " # Resize the images to 28x28 and converts them to PyTorch tensors (from Torchvision)\n", 116 | " transform=Compose([Resize(32), ToTensor()]),\n", 117 | " # Transform the labels to integers (e.g. (\"Glagolitic/character01\", \"Sanskrit/character14\", ...) to (0, 1, ...))\n", 118 | " target_transform=Categorical(num_classes=args.n_way),\n", 119 | " download=True)\n", 120 | " data_train = ClassSplitter(data_train, shuffle=True, num_train_per_class=args.k_spt, num_test_per_class=args.k_qry)\n", 121 | " data_test = ClassSplitter(data_test, shuffle=True, num_train_per_class=args.k_spt, num_test_per_class=args.k_qry)\n", 122 | " \n", 123 | " LST = UnlabData()\n", 124 | " batchsiz = 5\n", 125 | " \n", 126 | " \n", 127 | " #load a pretrained resnet\n", 128 | " filename = 'cifarfsresnet18.pt'\n", 129 | " resmodel = ResNet18().to(device)\n", 130 | " criterion = torch.nn.CrossEntropyLoss()\n", 131 | "\n", 132 | " if os.path.isfile(filename):\n", 133 | " print(\"=> loading checkpoint '{}'\".format(filename))\n", 134 | " checkpoint = torch.load(filename)\n", 135 | " resmodel.load_state_dict(checkpoint['state_dict'])\n", 136 | " else:\n", 137 | " print(\"=> no checkpoint found at '{}'\".format(filename))\n", 138 | " resmodel.eval()\n", 139 | "\n", 140 | " for epoch in range(args.epoch//10000):\n", 141 | " # fetch meta_batchsz num of episode each time\n", 142 | " db = BatchMetaDataLoader(data_train, batch_size=args.task_num, num_workers=0)\n", 143 | " #db = DataLoader(data_train, args.task_num, shuffle=True, num_workers=0, pin_memory=True)\n", 144 | "\n", 145 | " for step, batch_train in enumerate(db):\n", 146 | " #for step, (x_spt, y_spt, x_qry, y_qry) in enumerate(db):\n", 147 | " \n", 148 | "\n", 149 | "\n", 150 | " x_spt, y_spt = batch_train[\"train\"]\n", 151 | " x_qry, y_qry = batch_train[\"test\"]\n", 152 | "\n", 153 | " x_spt, y_spt, x_qry, y_qry = x_spt.to(device), y_spt.to(device), x_qry.to(device), y_qry.to(device)\n", 154 | " \n", 155 | "\n", 156 | " \n", 157 | " if LST:\n", 158 | " x_unlab = torch.zeros((args.task_num, args.n_way, batchsiz, 3, 32, 32))\n", 159 | " for i in range(args.task_num): \n", 160 | " with torch.no_grad():\n", 161 | " *_, outputs = resmodel(x_spt[i])\n", 162 | " _, y_unlab = outputs.max(1)\n", 163 | " index = y_unlab.cpu().numpy()\n", 164 | " for j in range(args.n_way):\n", 165 | " temp_train = LST.train_data[index[j]].get_next_batch(batchsiz,multiple_passes=True)\n", 166 | " x_unlab[i][j]= torch.from_numpy(temp_train.astype(np.float32))\n", 167 | " x_unlab = x_unlab.to(device)\n", 168 | " else:\n", 169 | " x_unlab = None\n", 170 | " \n", 171 | "\n", 172 | " accs, accs_adv = maml(x_spt, y_spt, x_qry, y_qry, x_unlab)\n", 173 | " \n", 174 | " if step % 500 == 0:\n", 175 | " print('step:', step, '\\ttraining acc:', accs)\n", 176 | " print('step:', step, '\\ttraining acc_adv:', accs_adv)\n", 177 | " \n", 178 | "\n", 179 | " if step % 10000 == 0: # evaluation\n", 180 | " #db_test = DataLoader(data_test, 1, shuffle=True, num_workers=0, pin_memory=True)\n", 181 | " db_test = BatchMetaDataLoader(data_test, batch_size=1, num_workers=0)\n", 182 | " accs_all_test = []\n", 183 | " accsadv_all_test = []\n", 184 | " accsadvpr_all_test = []\n", 185 | "\n", 186 | " for step_t, batch_test in enumerate(db_test):\n", 187 | " x_spt, y_spt = batch_test[\"train\"]\n", 188 | " x_qry, y_qry = batch_test[\"test\"]\n", 189 | "\n", 190 | "\n", 191 | " x_spt, y_spt, x_qry, y_qry = x_spt.squeeze(0).to(device), y_spt.squeeze(0).to(device), \\\n", 192 | " x_qry.squeeze(0).to(device), y_qry.squeeze(0).to(device)\n", 193 | "\n", 194 | " accs, accs_adv, accs_adv_prior = maml.finetunning(x_spt, y_spt, x_qry, y_qry)\n", 195 | " accs_all_test.append(accs)\n", 196 | " accsadv_all_test.append(accs_adv)\n", 197 | " accsadvpr_all_test.append(accs_adv_prior)\n", 198 | " if step_t == 2000:\n", 199 | " break\n", 200 | "\n", 201 | " # [b, update_step+1]\n", 202 | " accs = np.array(accs_all_test).mean(axis=0).astype(np.float16)\n", 203 | " accs_adv = np.array(accsadv_all_test).mean(axis=0).astype(np.float16)\n", 204 | " accs_adv_prior = np.array(accsadvpr_all_test).mean(axis=0).astype(np.float16)\n", 205 | " print('Test acc:', accs)\n", 206 | " print('Test acc_adv:', accs_adv)\n", 207 | " print('Test acc_adv_prior:', accs_adv_prior)\n", 208 | "\n", 209 | " if best_cl < accs[-1] and best_rb < accs_adv[-1]:\n", 210 | " best_cl = accs[-1]\n", 211 | " best_rb = accs_adv[-1]\n", 212 | " state = {'cl': best_cl, 'rb': best_rb, 'state_dict': maml.net.state_dict()}\n", 213 | " torch.save(state, 'mamltrades_unlab_eps8_cifar_5shot.pt')\n", 214 | " print(best_cl)\n", 215 | " print(best_rb)\n", 216 | "\n", 217 | "\n", 218 | "\n", 219 | "if __name__ == '__main__':\n", 220 | "\n", 221 | " argparser = argparse.ArgumentParser()\n", 222 | " argparser.add_argument('--epoch', type=int, help='epoch number', default=60000)\n", 223 | " argparser.add_argument('--n_way', type=int, help='n way', default=5)\n", 224 | " argparser.add_argument('--k_spt', type=int, help='k shot for support set', default=5)\n", 225 | " argparser.add_argument('--k_qry', type=int, help='k shot for query set', default=15)\n", 226 | " argparser.add_argument('--imgsz', type=int, help='imgsz', default=32)\n", 227 | " argparser.add_argument('--imgc', type=int, help='imgc', default=3)\n", 228 | " argparser.add_argument('--task_num', type=int, help='meta batch size, namely task num', default=4)\n", 229 | " argparser.add_argument('--meta_lr', type=float, help='meta-level outer learning rate', default=1e-3)\n", 230 | " argparser.add_argument('--update_lr', type=float, help='task-level inner update learning rate', default=0.01)\n", 231 | " argparser.add_argument('--update_step', type=int, help='task-level inner update steps', default=5)\n", 232 | " argparser.add_argument('--update_step_test', type=int, help='update steps for finetunning', default=10)\n", 233 | " \n", 234 | " #argparser.add_argument('--fast', action=\"store_true\", help='whether to use fgsm')\n", 235 | "\n", 236 | " args = argparser.parse_args(args=[])\n", 237 | "\n", 238 | " main()" 239 | ] 240 | }, 241 | { 242 | "cell_type": "code", 243 | "execution_count": null, 244 | "metadata": {}, 245 | "outputs": [], 246 | "source": [] 247 | } 248 | ], 249 | "metadata": { 250 | "kernelspec": { 251 | "display_name": "Python 3", 252 | "language": "python", 253 | "name": "python3" 254 | }, 255 | "language_info": { 256 | "codemirror_mode": { 257 | "name": "ipython", 258 | "version": 3 259 | }, 260 | "file_extension": ".py", 261 | "mimetype": "text/x-python", 262 | "name": "python", 263 | "nbconvert_exporter": "python", 264 | "pygments_lexer": "ipython3", 265 | "version": "3.7.6" 266 | } 267 | }, 268 | "nbformat": 4, 269 | "nbformat_minor": 4 270 | } 271 | -------------------------------------------------------------------------------- /CIFARFS/trainfgsmrs_cifar.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import torch, os\n", 10 | "import numpy as np\n", 11 | "#from MiniImagenet import MiniImagenet\n", 12 | "import scipy.stats\n", 13 | "from torch.utils.data import DataLoader\n", 14 | "from torch.optim import lr_scheduler\n", 15 | "import random, sys, pickle\n", 16 | "import argparse\n", 17 | "\n", 18 | "\n", 19 | "from torchmeta.datasets import CIFARFS\n", 20 | "from torchmeta.transforms import Categorical, ClassSplitter, Rotation\n", 21 | "from torchvision.transforms import Compose, Resize, ToTensor\n", 22 | "from torchmeta.utils.data import BatchMetaDataLoader\n", 23 | "\n", 24 | "\n", 25 | "from metafgsm import Meta\n", 26 | "\n", 27 | "\n", 28 | "\n", 29 | "def mean_confidence_interval(accs, confidence=0.95):\n", 30 | " n = accs.shape[0]\n", 31 | " m, se = np.mean(accs), scipy.stats.sem(accs)\n", 32 | " h = se * scipy.stats.t._ppf((1 + confidence) / 2, n - 1)\n", 33 | " return m, h\n", 34 | "\n", 35 | "\n", 36 | "def main():\n", 37 | "\n", 38 | " torch.manual_seed(222)\n", 39 | " torch.cuda.manual_seed_all(222)\n", 40 | " np.random.seed(222)\n", 41 | "\n", 42 | " print(args)\n", 43 | "\n", 44 | " config = [\n", 45 | " ('conv2d', [16, 3, 3, 3, 2, 1]),\n", 46 | " ('relu', [True]),\n", 47 | " ('bn', [16]),\n", 48 | " #('max_pool2d', [2, 2, 0]),\n", 49 | " ('conv2d', [16, 16, 3, 3, 1, 1]),\n", 50 | " ('relu', [True]),\n", 51 | " ('bn', [16]),\n", 52 | " ('max_pool2d', [2, 2, 0]),\n", 53 | " ('conv2d', [32, 16, 3, 3, 1, 1]),\n", 54 | " ('relu', [True]),\n", 55 | " ('bn', [32]),\n", 56 | " #('max_pool2d', [2, 2, 0]),\n", 57 | " ('conv2d', [32, 32, 3, 3, 1, 1]),\n", 58 | " ('relu', [True]),\n", 59 | " ('bn', [32]),\n", 60 | " ('max_pool2d', [2, 2, 0]),\n", 61 | " ('flatten', []),\n", 62 | " ('linear', [args.n_way, 32 * 4 * 4])\n", 63 | " ]\n", 64 | "\n", 65 | " device = torch.device('cuda')\n", 66 | " \n", 67 | " best_cl = 0\n", 68 | " best_rb = 0\n", 69 | " filename = 'mamlfgsm_eps8_cifar_1shot.pt'\n", 70 | " maml = Meta(args, config, device).to(device)\n", 71 | " if os.path.isfile(filename):\n", 72 | " print(\"=> loading checkpoint '{}'\".format(filename))\n", 73 | " checkpoint = torch.load(filename)\n", 74 | " best_cl = checkpoint['cl']\n", 75 | " best_rb = checkpoint['rb']\n", 76 | " maml.net.load_state_dict(checkpoint['state_dict'])\n", 77 | " #maml = maml.to(device)\n", 78 | " print(\"=> loaded checkpoint '{}' (epoch {})\"\n", 79 | " .format(filename, checkpoint['epoch']))\n", 80 | " else:\n", 81 | " print(\"=> no checkpoint found at '{}'\".format(filename))\n", 82 | "\n", 83 | " tmp = filter(lambda x: x.requires_grad, maml.parameters())\n", 84 | " num = sum(map(lambda x: np.prod(x.shape), tmp))\n", 85 | " print(maml)\n", 86 | " print('Total trainable tensors:', num)\n", 87 | "\n", 88 | " # batchsz here means total episode number\n", 89 | " \n", 90 | " \n", 91 | " \n", 92 | "\n", 93 | " data_train = CIFARFS(\"data\",\n", 94 | " # Number of ways\n", 95 | " num_classes_per_task=args.n_way,\n", 96 | " meta_train=True,\n", 97 | " meta_val=False,\n", 98 | " meta_test=False,\n", 99 | " # Resize the images to 28x28 and converts them to PyTorch tensors (from Torchvision)\n", 100 | " transform=Compose([Resize(32), ToTensor()]),\n", 101 | " # Transform the labels to integers (e.g. (\"Glagolitic/character01\", \"Sanskrit/character14\", ...) to (0, 1, ...))\n", 102 | " target_transform=Categorical(num_classes=args.n_way),\n", 103 | " # Creates new virtual classes with rotated versions of the images (from Santoro et al., 2016)\n", 104 | " #class_augmentations=[Rotation([90, 180, 270])],\n", 105 | " download=True)\n", 106 | " \n", 107 | " data_test = CIFARFS(\"data\",\n", 108 | " # Number of ways\n", 109 | " num_classes_per_task=args.n_way,\n", 110 | " meta_train=False,\n", 111 | " meta_val=False,\n", 112 | " meta_test=True,\n", 113 | " # Resize the images to 28x28 and converts them to PyTorch tensors (from Torchvision)\n", 114 | " transform=Compose([Resize(32), ToTensor()]),\n", 115 | " # Transform the labels to integers (e.g. (\"Glagolitic/character01\", \"Sanskrit/character14\", ...) to (0, 1, ...))\n", 116 | " target_transform=Categorical(num_classes=args.n_way),\n", 117 | " # Creates new virtual classes with rotated versions of the images (from Santoro et al., 2016)\n", 118 | " #class_augmentations=[Rotation([90, 180, 270])],\n", 119 | " download=True)\n", 120 | " data_train = ClassSplitter(data_train, shuffle=True, num_train_per_class=args.k_spt, num_test_per_class=args.k_qry)\n", 121 | " data_test = ClassSplitter(data_test, shuffle=True, num_train_per_class=args.k_spt, num_test_per_class=args.k_qry)\n", 122 | " \n", 123 | "\n", 124 | " for epoch in range(args.epoch//10000):\n", 125 | " # fetch meta_batchsz num of episode each time\n", 126 | " db = BatchMetaDataLoader(data_train, batch_size=args.task_num, num_workers=0)\n", 127 | " #db = DataLoader(data_train, args.task_num, shuffle=True, num_workers=0, pin_memory=True)\n", 128 | "\n", 129 | " for step, batch_train in enumerate(db):\n", 130 | " #for step, (x_spt, y_spt, x_qry, y_qry) in enumerate(db):\n", 131 | "\n", 132 | " x_spt, y_spt = batch_train[\"train\"]\n", 133 | " x_qry, y_qry = batch_train[\"test\"]\n", 134 | " x_spt, y_spt, x_qry, y_qry = x_spt.to(device), y_spt.to(device), x_qry.to(device), y_qry.to(device)\n", 135 | "\n", 136 | " accs, accs_adv = maml(x_spt, y_spt, x_qry, y_qry)\n", 137 | " \n", 138 | " if step % 50 == 0:\n", 139 | " print('step:', step, '\\ttraining acc:', accs)\n", 140 | " print('step:', step, '\\ttraining acc_adv:', accs_adv)\n", 141 | " \n", 142 | "\n", 143 | " if step % 2000 == 0: # evaluation\n", 144 | " #db_test = DataLoader(data_test, 1, shuffle=True, num_workers=0, pin_memory=True)\n", 145 | " db_test = BatchMetaDataLoader(data_test, batch_size=1, num_workers=0)\n", 146 | " accs_all_test = []\n", 147 | " accsadv_all_test = []\n", 148 | " accsadvpr_all_test = []\n", 149 | "\n", 150 | " for batch_test in db_test:\n", 151 | " x_spt, y_spt = batch_test[\"train\"]\n", 152 | " x_qry, y_qry = batch_test[\"test\"]\n", 153 | "\n", 154 | "\n", 155 | " x_spt, y_spt, x_qry, y_qry = x_spt.squeeze(0).to(device), y_spt.squeeze(0).to(device), \\\n", 156 | " x_qry.squeeze(0).to(device), y_qry.squeeze(0).to(device)\n", 157 | "\n", 158 | " accs, accs_adv, accs_adv_prior = maml.finetunning(x_spt, y_spt, x_qry, y_qry)\n", 159 | " accs_all_test.append(accs)\n", 160 | " accsadv_all_test.append(accs_adv)\n", 161 | " accsadvpr_all_test.append(accs_adv_prior)\n", 162 | "\n", 163 | " # [b, update_step+1]\n", 164 | " accs = np.array(accs_all_test).mean(axis=0).astype(np.float16)\n", 165 | " accs_adv = np.array(accsadv_all_test).mean(axis=0).astype(np.float16)\n", 166 | " accs_adv_prior = np.array(accsadvpr_all_test).mean(axis=0).astype(np.float16)\n", 167 | " print('Test acc:', accs)\n", 168 | " print('Test acc_adv:', accs_adv)\n", 169 | " print('Test acc_adv_prior:', accs_adv_prior)\n", 170 | "\n", 171 | " if best_cl < accs[-1] and best_rb < accs_adv[-1]:\n", 172 | " best_cl = accs[-1]\n", 173 | " best_rb = accs_adv[-1]\n", 174 | " state = {'cl': best_cl, 'rb': best_rb, 'state_dict': maml.net.state_dict()}\n", 175 | " torch.save(state, 'mamlfgsm_eps8_cifar_1shot.pt')\n", 176 | " print(best_cl)\n", 177 | " print(best_rb)\n", 178 | "\n", 179 | "\n", 180 | "\n", 181 | "if __name__ == '__main__':\n", 182 | "\n", 183 | " argparser = argparse.ArgumentParser()\n", 184 | " argparser.add_argument('--epoch', type=int, help='epoch number', default=60000)\n", 185 | " argparser.add_argument('--n_way', type=int, help='n way', default=5)\n", 186 | " argparser.add_argument('--k_spt', type=int, help='k shot for support set', default=1)\n", 187 | " argparser.add_argument('--k_qry', type=int, help='k shot for query set', default=15)\n", 188 | " argparser.add_argument('--imgsz', type=int, help='imgsz', default=32)\n", 189 | " argparser.add_argument('--imgc', type=int, help='imgc', default=3)\n", 190 | " argparser.add_argument('--task_num', type=int, help='meta batch size, namely task num', default=4)\n", 191 | " argparser.add_argument('--meta_lr', type=float, help='meta-level outer learning rate', default=1e-3)\n", 192 | " argparser.add_argument('--update_lr', type=float, help='task-level inner update learning rate', default=0.01)\n", 193 | " argparser.add_argument('--update_step', type=int, help='task-level inner update steps', default=5)\n", 194 | " argparser.add_argument('--update_step_test', type=int, help='update steps for finetunning', default=10)\n", 195 | " \n", 196 | " #argparser.add_argument('--fast', action=\"store_true\", help='whether to use fgsm')\n", 197 | "\n", 198 | " args = argparser.parse_args(args=[])\n", 199 | "\n", 200 | " main()" 201 | ] 202 | }, 203 | { 204 | "cell_type": "code", 205 | "execution_count": null, 206 | "metadata": {}, 207 | "outputs": [], 208 | "source": [] 209 | } 210 | ], 211 | "metadata": { 212 | "kernelspec": { 213 | "display_name": "Python 3", 214 | "language": "python", 215 | "name": "python3" 216 | }, 217 | "language_info": { 218 | "codemirror_mode": { 219 | "name": "ipython", 220 | "version": 3 221 | }, 222 | "file_extension": ".py", 223 | "mimetype": "text/x-python", 224 | "name": "python", 225 | "nbconvert_exporter": "python", 226 | "pygments_lexer": "ipython3", 227 | "version": "3.7.6" 228 | } 229 | }, 230 | "nbformat": 4, 231 | "nbformat_minor": 4 232 | } 233 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Ren Wang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /LoadDataST.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | 4 | # In[ ]: 5 | 6 | 7 | import os 8 | import torch 9 | from torch.utils.data import Dataset 10 | from torchvision.transforms import transforms 11 | import numpy as np 12 | import collections 13 | from PIL import Image 14 | import csv 15 | import random 16 | 17 | 18 | class ImagenetMini(Dataset): 19 | """ 20 | put mini-imagenet files as : 21 | root : 22 | |- images/*.jpg includes all imgeas 23 | |- train.csv 24 | |- test.csv 25 | |- val.csv 26 | NOTICE: meta-learning is different from general supervised learning, especially the concept of batch and set. 27 | batch: contains several sets 28 | sets: conains n_way * k_shot for meta-train set, n_way * n_query for meta-test set. 29 | """ 30 | 31 | def __init__(self, root, mode, batchsz, resize, startidx=0): 32 | """ 33 | :param root: root path of mini-imagenet 34 | :param mode: train, val or test 35 | :param batchsz: batch size of sets, not batch of imgs 36 | :param n_way: 37 | :param k_shot: 38 | :param k_query: num of qeruy imgs per class 39 | :param resize: resize to 40 | :param startidx: start to index label from startidx 41 | """ 42 | 43 | self.batchsz = batchsz #batch of imgs 44 | self.resize = resize # resize to 45 | self.startidx = startidx # index label not from 0, but from startidx 46 | print('shuffle DB :%s, b:%d, resize:%d' % ( 47 | mode, batchsz, resize)) 48 | 49 | if mode == 'train': 50 | # self.transform = transforms.Compose([lambda x: Image.open(x).convert('RGB'), 51 | # transforms.Resize((self.resize, self.resize)), 52 | # # transforms.RandomHorizontalFlip(), 53 | # # transforms.RandomRotation(5), 54 | # transforms.ToTensor(), 55 | # transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) 56 | # ]) 57 | self.transform = transforms.Compose([lambda x: Image.open(x).convert('RGB'), 58 | transforms.Resize((self.resize, self.resize)), 59 | # transforms.RandomHorizontalFlip(), 60 | # transforms.RandomRotation(5), 61 | transforms.ToTensor(), 62 | transforms.Normalize((0, 0, 0), (1, 1, 1)) 63 | ]) 64 | else: 65 | # self.transform = transforms.Compose([lambda x: Image.open(x).convert('RGB'), 66 | # transforms.Resize((self.resize, self.resize)), 67 | # transforms.ToTensor(), 68 | # transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) 69 | # ]) 70 | self.transform = transforms.Compose([lambda x: Image.open(x).convert('RGB'), 71 | transforms.Resize((self.resize, self.resize)), 72 | transforms.ToTensor(), 73 | transforms.Normalize((0, 0, 0), (1, 1, 1)) 74 | ]) 75 | 76 | self.path = os.path.join(root, 'images') # image path 77 | csvdata = self.loadCSV(os.path.join(root, mode + '.csv')) # csv path 78 | self.data = [] 79 | self.img2label = {} 80 | for i, (k, v) in enumerate(csvdata.items()): 81 | self.data.extend(v) # [[img1, img2, ...], [img111, ...]] 82 | self.img2label[k] = i + self.startidx # {"img_name[:9]":label} 83 | self.cls_num = len(self.data) 84 | 85 | # self.create_batch(self.batchsz) 86 | 87 | 88 | support_x = torch.FloatTensor(self.cls_num, 3, self.resize, self.resize) 89 | support_x_temp = np.array(self.data)#.tolist() 90 | #support_y = np.zeros((self.cls_num), dtype=np.int) 91 | flatten_x = [os.path.join(self.path, item) for item in support_x_temp] 92 | temp = [self.img2label[item[:9]] for item in support_x_temp] 93 | support_y = np.array(temp).astype(np.int32) 94 | 95 | for i, path in enumerate(flatten_x): 96 | support_x[i] = self.transform(path) 97 | 98 | self.loading_data = DataSubset(support_x, torch.LongTensor(support_y)) 99 | 100 | def loadCSV(self, csvf): 101 | """ 102 | return a dict saving the information of csv 103 | :param splitFile: csv file name 104 | :return: {label:[file1, file2 ...]} 105 | """ 106 | dictLabels = {} 107 | with open(csvf) as csvfile: 108 | csvreader = csv.reader(csvfile, delimiter=',') 109 | next(csvreader, None) # skip (filename, label) 110 | for i, row in enumerate(csvreader): 111 | filename = row[0] 112 | label = row[1] 113 | # append filename to current label 114 | if label in dictLabels.keys(): 115 | dictLabels[label].append(filename) 116 | else: 117 | dictLabels[label] = [filename] 118 | return dictLabels 119 | 120 | # def create_batch(self, batchsz): 121 | # """ 122 | # create batch for meta-learning. 123 | # ×episode× here means batch, and it means how many sets we want to retain. 124 | # :param episodes: batch size 125 | # :return: 126 | # """ 127 | # self.x_batch = [] # support set batch 128 | # for b in range(batchsz): # for each batch 129 | # # 1.select n_way classes randomly 130 | # selected_cls = np.random.choice(self.cls_num, self.n_way, False) # no duplicate 131 | # np.random.shuffle(selected_cls) 132 | # support_x = [] 133 | # query_x = [] 134 | # for cls in selected_cls: 135 | # # 2. select k_shot + k_query for each class 136 | # selected_imgs_idx = np.random.choice(len(self.data[cls]), self.k_shot + self.k_query, False) 137 | # np.random.shuffle(selected_imgs_idx) 138 | # indexDtrain = np.array(selected_imgs_idx[:self.k_shot]) # idx for Dtrain 139 | # indexDtest = np.array(selected_imgs_idx[self.k_shot:]) # idx for Dtest 140 | # support_x.append( 141 | # np.array(self.data[cls])[indexDtrain].tolist()) # get all images filename for current Dtrain 142 | # query_x.append(np.array(self.data[cls])[indexDtest].tolist()) 143 | 144 | # # shuffle the correponding relation between support set and query set 145 | # random.shuffle(support_x) 146 | # random.shuffle(query_x) 147 | 148 | # self.support_x_batch.append(support_x) # append set to current sets 149 | # self.query_x_batch.append(query_x) # append sets to current sets 150 | 151 | # def __getitem__(self, index): 152 | # """ 153 | # index means index of sets, 0<= index <= batchsz-1 154 | # :param index: 155 | # :return: 156 | # """ 157 | # # [setsz, 3, resize, resize] 158 | # support_x = torch.FloatTensor(self.setsz, 3, self.resize, self.resize) 159 | # # [setsz] 160 | # support_y = np.zeros((self.setsz), dtype=np.int) 161 | # # [querysz, 3, resize, resize] 162 | # query_x = torch.FloatTensor(self.querysz, 3, self.resize, self.resize) 163 | # # [querysz] 164 | # query_y = np.zeros((self.querysz), dtype=np.int) 165 | 166 | # flatten_support_x = [os.path.join(self.path, item) 167 | # for sublist in self.support_x_batch[index] for item in sublist] 168 | # support_y = np.array( 169 | # [self.img2label[item[:9]] # filename:n0153282900000005.jpg, the first 9 characters treated as label 170 | # for sublist in self.support_x_batch[index] for item in sublist]).astype(np.int32) 171 | 172 | # flatten_query_x = [os.path.join(self.path, item) 173 | # for sublist in self.query_x_batch[index] for item in sublist] 174 | # query_y = np.array([self.img2label[item[:9]] 175 | # for sublist in self.query_x_batch[index] for item in sublist]).astype(np.int32) 176 | 177 | # # print('global:', support_y, query_y) 178 | # # support_y: [setsz] 179 | # # query_y: [querysz] 180 | # # unique: [n-way], sorted 181 | # unique = np.unique(support_y) 182 | # random.shuffle(unique) 183 | # # relative means the label ranges from 0 to n-way 184 | # support_y_relative = np.zeros(self.setsz) 185 | # query_y_relative = np.zeros(self.querysz) 186 | # for idx, l in enumerate(unique): 187 | # support_y_relative[support_y == l] = idx 188 | # query_y_relative[query_y == l] = idx 189 | 190 | # # print('relative:', support_y_relative, query_y_relative) 191 | 192 | # for i, path in enumerate(flatten_support_x): 193 | # support_x[i] = self.transform(path) 194 | 195 | # for i, path in enumerate(flatten_query_x): 196 | # query_x[i] = self.transform(path) 197 | # # print(support_set_y) 198 | # # return support_x, torch.LongTensor(support_y), query_x, torch.LongTensor(query_y) 199 | 200 | # return support_x, torch.LongTensor(support_y_relative), query_x, torch.LongTensor(query_y_relative) 201 | 202 | # def __len__(self): 203 | # # as we have built up to batchsz of sets, you can sample some small batch size of sets. 204 | # return self.batchsz 205 | 206 | 207 | 208 | class DataSubset(object): 209 | def __init__(self, xs, ys, num_examples=None, seed=None): 210 | 211 | if seed is not None: 212 | np.random.seed(99) 213 | self.xs = xs 214 | self.n = len(xs) 215 | self.ys = ys 216 | self.batch_start = 0 217 | self.cur_order = np.random.permutation(self.n) 218 | 219 | def get_next_batch(self, batch_size, multiple_passes=False, reshuffle_after_pass=True): 220 | # np.random.seed(99) 221 | if self.n < batch_size: 222 | raise ValueError('Batch size can be at most the dataset size') 223 | if not multiple_passes: 224 | actual_batch_size = min(batch_size, self.n - self.batch_start) 225 | if actual_batch_size <= 0: 226 | raise ValueError('Pass through the dataset is complete.') 227 | batch_end = self.batch_start + actual_batch_size 228 | batch_xs = self.xs[self.cur_order[self.batch_start : batch_end], ...] 229 | batch_ys = self.ys[self.cur_order[self.batch_start : batch_end], ...] 230 | self.batch_start += actual_batch_size 231 | return batch_xs, batch_ys 232 | actual_batch_size = min(batch_size, self.n - self.batch_start) 233 | if actual_batch_size < batch_size: 234 | if reshuffle_after_pass: 235 | self.cur_order = np.random.permutation(self.n) 236 | self.batch_start = 0 237 | batch_end = self.batch_start + batch_size 238 | # batch_xs = self.xs[self.cur_order[self.batch_start : batch_end], ...] 239 | # batch_ys = self.ys[self.cur_order[self.batch_start : batch_end], ...] 240 | 241 | batch_xs = self.xs[self.cur_order[self.batch_start : batch_end], ...] 242 | batch_ys = self.ys[self.cur_order[self.batch_start : batch_end], ...] 243 | 244 | 245 | self.batch_start += actual_batch_size 246 | return batch_xs, batch_ys 247 | 248 | 249 | if __name__ == '__main__': 250 | # the following episode is to view one set of images via tensorboard. 251 | from torchvision.utils import make_grid 252 | from matplotlib import pyplot as plt 253 | from tensorboardX import SummaryWriter 254 | import time 255 | 256 | plt.ion() 257 | 258 | tb = SummaryWriter('runs', 'mini-imagenet') 259 | mini = MiniImagenet('../../../dataset/', mode='train', batchsz=1000, resize=168) 260 | 261 | for i, set_ in enumerate(mini): 262 | # support_x: [k_shot*n_way, 3, 84, 84] 263 | support_x, support_y, query_x, query_y = set_ 264 | 265 | support_x = make_grid(support_x, nrow=2) 266 | query_x = make_grid(query_x, nrow=2) 267 | 268 | plt.figure(1) 269 | plt.imshow(support_x.transpose(2, 0).numpy()) 270 | plt.pause(0.5) 271 | plt.figure(2) 272 | plt.imshow(query_x.transpose(2, 0).numpy()) 273 | plt.pause(0.5) 274 | 275 | tb.add_image('support_x', support_x) 276 | tb.add_image('query_x', query_x) 277 | 278 | time.sleep(5) 279 | 280 | tb.close() 281 | -------------------------------------------------------------------------------- /LoadUnlableData.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | 4 | # In[ ]: 5 | 6 | 7 | import os 8 | import torch 9 | # from torch.utils.data import Dataset 10 | # from torchvision.transforms import transforms 11 | import numpy as np 12 | import collections 13 | # from PIL import Image 14 | # import csv 15 | import random 16 | 17 | 18 | class UnlabData(object): 19 | def __init__(self, seed=None): 20 | tinyimg = np.array(np.load("Img_select.npy", allow_pickle=True)) 21 | 22 | train_indices = np.arange(500) 23 | self.train_data = [] 24 | 25 | for i in range(64): 26 | temp_tiny = [tinyimg[i][j] for j in train_indices] 27 | temp_tiny = np.array(temp_tiny) 28 | self.train_data.append(DataSubset(temp_tiny)) 29 | 30 | class DataSubset(object): 31 | def __init__(self, xs, num_examples=None, seed=None): 32 | 33 | if seed is not None: 34 | np.random.seed(99) 35 | self.xs = xs 36 | self.n = len(xs) 37 | self.batch_start = 0 38 | self.cur_order = np.random.permutation(self.n) 39 | 40 | def get_next_batch(self, batch_size, multiple_passes=False, reshuffle_after_pass=True): 41 | # np.random.seed(99) 42 | if self.n < batch_size: 43 | raise ValueError('Batch size can be at most the dataset size') 44 | if not multiple_passes: 45 | actual_batch_size = min(batch_size, self.n - self.batch_start) 46 | if actual_batch_size <= 0: 47 | raise ValueError('Pass through the dataset is complete.') 48 | batch_end = self.batch_start + actual_batch_size 49 | batch_xs = self.xs[self.cur_order[self.batch_start : batch_end], ...] 50 | self.batch_start += actual_batch_size 51 | return batch_xs 52 | actual_batch_size = min(batch_size, self.n - self.batch_start) 53 | if actual_batch_size < batch_size: 54 | if reshuffle_after_pass: 55 | self.cur_order = np.random.permutation(self.n) 56 | self.batch_start = 0 57 | batch_end = self.batch_start + batch_size 58 | 59 | 60 | batch_xs = self.xs[self.cur_order[self.batch_start : batch_end], ...] 61 | 62 | 63 | self.batch_start += actual_batch_size 64 | return batch_xs 65 | 66 | 67 | 68 | 69 | 70 | 71 | if __name__ == '__main__': 72 | # the following episode is to view one set of images via tensorboard. 73 | from torchvision.utils import make_grid 74 | from matplotlib import pyplot as plt 75 | from tensorboardX import SummaryWriter 76 | import time 77 | 78 | # plt.ion() 79 | 80 | # tb = SummaryWriter('runs', 'mini-imagenet') 81 | # mini = MiniImagenet('../../../dataset/', mode='train', batchsz=1000, resize=168) 82 | 83 | # for i, set_ in enumerate(mini): 84 | # # support_x: [k_shot*n_way, 3, 84, 84] 85 | # support_x, support_y, query_x, query_y = set_ 86 | 87 | # support_x = make_grid(support_x, nrow=2) 88 | # query_x = make_grid(query_x, nrow=2) 89 | 90 | # plt.figure(1) 91 | # plt.imshow(support_x.transpose(2, 0).numpy()) 92 | # plt.pause(0.5) 93 | # plt.figure(2) 94 | # plt.imshow(query_x.transpose(2, 0).numpy()) 95 | # plt.pause(0.5) 96 | 97 | # tb.add_image('support_x', support_x) 98 | # tb.add_image('query_x', query_x) 99 | 100 | # time.sleep(5) 101 | 102 | # tb.close() 103 | -------------------------------------------------------------------------------- /MODELMETA.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | 4 | # In[ ]: 5 | 6 | #normal MAML, also contain the adv accuracy calculation 7 | import torch 8 | from torch import nn 9 | from torch import optim 10 | from torch.nn import functional as F 11 | from torch.utils.data import TensorDataset, DataLoader 12 | from torch import optim 13 | from torch.autograd import Variable 14 | import numpy as np 15 | 16 | from learner import Learner 17 | from copy import deepcopy 18 | from attack import PGD 19 | from Visualization import RobustVis 20 | 21 | 22 | 23 | 24 | class Meta(nn.Module): 25 | """ 26 | Meta Learner 27 | """ 28 | def __init__(self, args, config, device): 29 | """ 30 | :param args: 31 | """ 32 | super(Meta, self).__init__() 33 | 34 | self.update_lr = args.update_lr 35 | self.meta_lr = args.meta_lr 36 | self.n_way = args.n_way 37 | self.k_spt = args.k_spt 38 | self.k_qry = args.k_qry 39 | self.task_num = args.task_num 40 | self.update_step = args.update_step 41 | self.update_step_test = args.update_step_test 42 | self.device = device 43 | 44 | 45 | self.net = Learner(config, args.imgc, args.imgsz) 46 | #self.netadv = Learner(config, args.imgc, args.imgsz) 47 | self.meta_optim = optim.Adam(self.net.parameters(), lr=self.meta_lr) 48 | #self.meta_optimadv = optim.Adam(self.netadv.parameters(), lr=self.meta_lr) 49 | 50 | 51 | 52 | 53 | def clip_grad_by_norm_(self, grad, max_norm): 54 | """ 55 | in-place gradient clipping. 56 | :param grad: list of gradients 57 | :param max_norm: maximum norm allowable 58 | :return: 59 | """ 60 | 61 | total_norm = 0 62 | counter = 0 63 | for g in grad: 64 | param_norm = g.data.norm(2) 65 | total_norm += param_norm.item() ** 2 66 | counter += 1 67 | total_norm = total_norm ** (1. / 2) 68 | 69 | clip_coef = max_norm / (total_norm + 1e-6) 70 | if clip_coef < 1: 71 | for g in grad: 72 | g.data.mul_(clip_coef) 73 | 74 | return total_norm/counter 75 | 76 | 77 | def forward(self, x_spt, y_spt, x_qry, y_qry, x_nat): 78 | """ 79 | :param x_spt: [b, setsz, c_, h, w] 80 | :param y_spt: [b, setsz] 81 | :param x_qry: [b, querysz, c_, h, w] 82 | :param y_qry: [b, querysz] 83 | :return: 84 | """ 85 | task_num, setsz, c_, h, w = x_spt.size() 86 | querysz = x_qry.size(1) 87 | 88 | losses_q = [0 for _ in range(self.update_step + 1)] # losses_q[i] is the loss on step i 89 | corrects = [0 for _ in range(self.update_step + 1)] 90 | 91 | need_adv = False 92 | #AT 93 | optimizer = torch.optim.SGD(self.net.parameters(), lr=self.update_lr, momentum=0.9, weight_decay=5e-4) 94 | eps, step = (2.0,10) 95 | losses_q_adv = [0 for _ in range(self.update_step + 1)] 96 | corrects_adv = [0 for _ in range(self.update_step + 1)] 97 | 98 | 99 | 100 | 101 | for i in range(task_num): 102 | 103 | # 1. run the i-th task and compute loss for k=0 104 | logits = self.net(x_spt[i], vars=None, bn_training=True) 105 | loss = F.cross_entropy(logits, y_spt[i]) 106 | grad = torch.autograd.grad(loss, self.net.parameters()) 107 | fast_weights = list(map(lambda p: p[1] - self.update_lr * p[0], zip(grad, self.net.parameters()))) 108 | 109 | #PGD AT 110 | if need_adv: 111 | at = PGD(eps=eps / 255.0, sigma=2 / 255.0, nb_iter=step) 112 | data = x_qry[i] 113 | label = y_qry[i] 114 | optimizer.zero_grad() 115 | adv_inp_adv = at.attack(self.net, fast_weights, data, label) 116 | optimizer.zero_grad() 117 | self.net.train() 118 | 119 | # this is the loss and accuracy before first update 120 | with torch.no_grad(): 121 | # [setsz, nway] 122 | logits_q = self.net(x_qry[i], self.net.parameters(), bn_training=True) 123 | loss_q = F.cross_entropy(logits_q, y_qry[i]) 124 | losses_q[0] += loss_q 125 | 126 | pred_q = F.softmax(logits_q, dim=1).argmax(dim=1) 127 | correct = torch.eq(pred_q, y_qry[i]).sum().item() 128 | corrects[0] = corrects[0] + correct 129 | 130 | 131 | #PGD AT 132 | if need_adv: 133 | data = x_qry[i] 134 | label = y_qry[i] 135 | optimizer.zero_grad() 136 | adv_inp = at.attack(self.net, self.net.parameters(), data, label) 137 | optimizer.zero_grad() 138 | self.net.train() 139 | with torch.no_grad(): 140 | logits_q_adv = self.net(adv_inp, self.net.parameters(), bn_training=True) 141 | loss_q_adv = F.cross_entropy(logits_q_adv, label) 142 | losses_q_adv[0] += loss_q_adv 143 | 144 | pred_q_adv = F.softmax(logits_q_adv, dim=1).argmax(dim=1) 145 | correct_adv = torch.eq(pred_q_adv, label).sum().item() 146 | corrects_adv[0] = corrects_adv[0] + correct_adv 147 | 148 | # this is the loss and accuracy after the first update 149 | with torch.no_grad(): 150 | # [setsz, nway] 151 | logits_q = self.net(x_qry[i], fast_weights, bn_training=True) 152 | loss_q = F.cross_entropy(logits_q, y_qry[i]) 153 | losses_q[1] += loss_q 154 | # [setsz] 155 | pred_q = F.softmax(logits_q, dim=1).argmax(dim=1) 156 | correct = torch.eq(pred_q, y_qry[i]).sum().item() 157 | corrects[1] = corrects[1] + correct 158 | 159 | 160 | 161 | #PGD AT 162 | if need_adv: 163 | logits_q_adv = self.net(adv_inp_adv, fast_weights, bn_training=True) 164 | loss_q_adv = F.cross_entropy(logits_q_adv, label) 165 | losses_q_adv[1] += loss_q_adv 166 | 167 | pred_q_adv = F.softmax(logits_q_adv, dim=1).argmax(dim=1) 168 | correct_adv = torch.eq(pred_q_adv, label).sum().item() 169 | corrects_adv[1] = corrects_adv[1] + correct_adv 170 | 171 | 172 | for k in range(1, self.update_step): 173 | # 1. run the i-th task and compute loss for k=1~K-1 174 | logits = self.net(x_spt[i], fast_weights, bn_training=True) 175 | loss = F.cross_entropy(logits, y_spt[i]) 176 | # 2. compute grad on theta_pi 177 | grad = torch.autograd.grad(loss, fast_weights) 178 | # 3. theta_pi = theta_pi - train_lr * grad 179 | fast_weights = list(map(lambda p: p[1] - self.update_lr * p[0], zip(grad, fast_weights))) 180 | 181 | logits_q = self.net(x_qry[i], fast_weights, bn_training=True) 182 | # loss_q will be overwritten and just keep the loss_q on last update step. 183 | loss_q = F.cross_entropy(logits_q, y_qry[i]) 184 | losses_q[k + 1] += loss_q 185 | 186 | 187 | #PGD AT 188 | if need_adv: 189 | at = PGD(eps=eps / 255.0, sigma=2 / 255.0, nb_iter=step) 190 | data = x_qry[i] 191 | label = y_qry[i] 192 | optimizer.zero_grad() 193 | adv_inp_adv = at.attack(self.net, fast_weights, data, label) 194 | optimizer.zero_grad() 195 | 196 | logits_q_adv = self.net(adv_inp_adv, fast_weights, bn_training=True) 197 | loss_q_adv = F.cross_entropy(logits_q_adv, label) 198 | losses_q_adv[k + 1] += loss_q_adv 199 | 200 | 201 | with torch.no_grad(): 202 | pred_q = F.softmax(logits_q, dim=1).argmax(dim=1) 203 | correct = torch.eq(pred_q, y_qry[i]).sum().item() # convert to numpy 204 | corrects[k + 1] = corrects[k + 1] + correct 205 | 206 | #PGD AT 207 | if need_adv: 208 | pred_q_adv = F.softmax(logits_q_adv, dim=1).argmax(dim=1) 209 | correct_adv = torch.eq(pred_q_adv, label).sum().item() 210 | corrects_adv[k + 1] = corrects_adv[k + 1] + correct_adv 211 | 212 | 213 | 214 | # end of all tasks 215 | # sum over all losses on query set across all tasks 216 | loss_q = losses_q[-1] / task_num 217 | 218 | loss_q_adv = losses_q_adv[-1] / task_num 219 | 220 | # optimize theta parameters 221 | self.meta_optim.zero_grad() 222 | loss_q.backward() 223 | 224 | self.meta_optim.step() 225 | 226 | 227 | 228 | accs = np.array(corrects) / (querysz * task_num) 229 | accs_adv = np.array(corrects_adv) / (querysz * task_num) 230 | 231 | return accs, accs_adv 232 | 233 | 234 | def finetunning(self, x_spt, y_spt, x_qry, y_qry, x_ss, num_ind, kwargs): 235 | """ 236 | :param x_spt: [setsz, c_, h, w] 237 | :param y_spt: [setsz] 238 | :param x_qry: [querysz, c_, h, w] 239 | :param y_qry: [querysz] 240 | :return: 241 | """ 242 | assert len(x_spt.shape) == 4 243 | 244 | querysz = x_qry.size(0) 245 | 246 | corrects = [0 for _ in range(self.update_step_test + 1)] 247 | 248 | need_adv = True 249 | beta = 0 250 | optimizer = torch.optim.SGD(self.net.parameters(), lr=self.update_lr, momentum=0.9, weight_decay=5e-4) 251 | eps, step = (2,10) 252 | corrects_adv = [0 for _ in range(self.update_step_test + 1)] 253 | corrects_adv_prior = [0 for _ in range(self.update_step_test + 1)] 254 | 255 | 256 | # in order to not ruin the state of running_mean/variance and bn_weight/bias 257 | # we finetunning on the copied model instead of self.net 258 | net = deepcopy(self.net) 259 | 260 | # 1. run the i-th task and compute loss for k=0 261 | logits = net(x_spt) 262 | loss = F.cross_entropy(logits, y_spt) 263 | grad = torch.autograd.grad(loss, net.parameters()) 264 | fast_weights = list(map(lambda p: p[1] - self.update_lr * p[0], zip(grad, net.parameters()))) 265 | 266 | 267 | 268 | #PGD AT 269 | if need_adv: 270 | at = PGD(eps=eps / 255.0, sigma=2 / 255.0, nb_iter=step) 271 | data = x_qry 272 | label = y_qry 273 | optimizer.zero_grad() 274 | adv_inp_adv = at.attack(net, fast_weights, data, label) 275 | 276 | 277 | 278 | 279 | # this is the loss and accuracy before first update 280 | with torch.no_grad(): 281 | # [setsz, nway] 282 | logits_q = net(x_qry, net.parameters(), bn_training=True) 283 | # [setsz] 284 | pred_q = F.softmax(logits_q, dim=1).argmax(dim=1) 285 | 286 | #find the correct index 287 | corr_ind = (torch.eq(pred_q, y_qry) == True).nonzero() 288 | # scalar 289 | correct = torch.eq(pred_q, y_qry).sum().item() 290 | corrects[0] = corrects[0] + correct 291 | 292 | 293 | #PGD AT 294 | if need_adv: 295 | data = x_qry 296 | label = y_qry 297 | optimizer.zero_grad() 298 | adv_inp = at.attack(net, net.parameters(), data, label) 299 | with torch.no_grad(): 300 | logits_q_adv = net(adv_inp, net.parameters(), bn_training=True) 301 | pred_q_adv = F.softmax(logits_q_adv, dim=1).argmax(dim=1) 302 | correct_adv = torch.eq(pred_q_adv, label).sum().item() 303 | correct_adv_prior = torch.eq(pred_q_adv[corr_ind], label[corr_ind]).sum().item() 304 | corrects_adv[0] = corrects_adv[0] + correct_adv 305 | corrects_adv_prior[0] = corrects_adv_prior[0] + correct_adv_prior/len(corr_ind) 306 | 307 | # this is the loss and accuracy after the first update 308 | with torch.no_grad(): 309 | # [setsz, nway] 310 | logits_q = net(x_qry, fast_weights, bn_training=True) 311 | # [setsz] 312 | pred_q = F.softmax(logits_q, dim=1).argmax(dim=1) 313 | #find the correct index 314 | corr_ind = (torch.eq(pred_q, y_qry) == True).nonzero() 315 | # scalar 316 | correct = torch.eq(pred_q, y_qry).sum().item() 317 | corrects[1] = corrects[1] + correct 318 | 319 | 320 | #PGD AT 321 | if need_adv: 322 | logits_q_adv = net(adv_inp_adv, fast_weights, bn_training=True) 323 | pred_q_adv = F.softmax(logits_q_adv, dim=1).argmax(dim=1) 324 | correct_adv = torch.eq(pred_q_adv, label).sum().item() 325 | correct_adv_prior = torch.eq(pred_q_adv[corr_ind], label[corr_ind]).sum().item() 326 | corrects_adv[1] = corrects_adv[1] + correct_adv 327 | corrects_adv_prior[1] = corrects_adv_prior[1] + correct_adv_prior/len(corr_ind) 328 | 329 | 330 | for k in range(1, self.update_step_test): 331 | # 1. run the i-th task and compute loss for k=1~K-1 332 | logits = net(x_spt, fast_weights, bn_training=True) 333 | loss = F.cross_entropy(logits, y_spt) 334 | # 2. compute grad on theta_pi 335 | grad = torch.autograd.grad(loss, fast_weights) 336 | # 3. theta_pi = theta_pi - train_lr * grad 337 | fast_weights = list(map(lambda p: p[1] - self.update_lr * p[0], zip(grad, fast_weights))) 338 | 339 | logits_q = net(x_qry, fast_weights, bn_training=True) 340 | # loss_q will be overwritten and just keep the loss_q on last update step. 341 | loss_q = F.cross_entropy(logits_q, y_qry) 342 | 343 | 344 | 345 | #PGD AT 346 | if need_adv: 347 | at = PGD(eps=eps / 255.0, sigma=2 / 255.0, nb_iter=step) 348 | data = x_qry 349 | label = y_qry 350 | optimizer.zero_grad() 351 | adv_inp_adv = at.attack(net, fast_weights, data, label) 352 | 353 | logits_q_adv = net(adv_inp_adv, fast_weights, bn_training=True) 354 | loss_q_adv = F.cross_entropy(logits_q_adv, label) 355 | 356 | 357 | 358 | with torch.no_grad(): 359 | pred_q = F.softmax(logits_q, dim=1).argmax(dim=1) 360 | #find the correct index 361 | corr_ind = (torch.eq(pred_q, y_qry) == True).nonzero() 362 | correct = torch.eq(pred_q, y_qry).sum().item() # convert to numpy 363 | corrects[k + 1] = corrects[k + 1] + correct 364 | 365 | 366 | #PGD AT 367 | if need_adv: 368 | pred_q_adv = F.softmax(logits_q_adv, dim=1).argmax(dim=1) 369 | correct_adv = torch.eq(pred_q_adv, label).sum().item() 370 | correct_adv_prior = torch.eq(pred_q_adv[corr_ind], label[corr_ind]).sum().item() 371 | corrects_adv[k + 1] = corrects_adv[k + 1] + correct_adv 372 | corrects_adv_prior[k + 1] = corrects_adv_prior[k + 1] + correct_adv_prior/len(corr_ind) 373 | 374 | 375 | robustmod = RobustVis(net, self.device) 376 | loss_p, x_p = robustmod.forward(x_ss.to(self.device), num_ind, **kwargs) 377 | 378 | 379 | del net 380 | 381 | accs = np.array(corrects) / querysz 382 | 383 | accs_adv = np.array(corrects_adv) / querysz 384 | 385 | accs_adv_prior = np.array(corrects_adv_prior) 386 | 387 | return loss_p, x_p 388 | 389 | 390 | 391 | def main(): 392 | pass 393 | 394 | 395 | if __name__ == '__main__': 396 | main() 397 | 398 | 399 | -------------------------------------------------------------------------------- /MiniImagenet.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | 4 | # In[ ]: 5 | 6 | 7 | import os 8 | import torch 9 | from torch.utils.data import Dataset 10 | from torchvision.transforms import transforms 11 | import numpy as np 12 | import collections 13 | from PIL import Image 14 | import csv 15 | import random 16 | 17 | 18 | class MiniImagenet(Dataset): 19 | """ 20 | put mini-imagenet files as : 21 | root : 22 | |- images/*.jpg includes all imgeas 23 | |- train.csv 24 | |- test.csv 25 | |- val.csv 26 | NOTICE: meta-learning is different from general supervised learning, especially the concept of batch and set. 27 | batch: contains several sets 28 | sets: conains n_way * k_shot for meta-train set, n_way * n_query for meta-test set. 29 | """ 30 | 31 | def __init__(self, root, mode, batchsz, n_way, k_shot, k_query, resize, startidx=0): 32 | """ 33 | :param root: root path of mini-imagenet 34 | :param mode: train, val or test 35 | :param batchsz: batch size of sets, not batch of imgs 36 | :param n_way: 37 | :param k_shot: 38 | :param k_query: num of qeruy imgs per class 39 | :param resize: resize to 40 | :param startidx: start to index label from startidx 41 | """ 42 | 43 | self.batchsz = batchsz # batch of set, not batch of imgs 44 | self.n_way = n_way # n-way 45 | self.k_shot = k_shot # k-shot 46 | self.k_query = k_query # for evaluation 47 | self.setsz = self.n_way * self.k_shot # num of samples per set 48 | self.querysz = self.n_way * self.k_query # number of samples per set for evaluation 49 | self.resize = resize # resize to 50 | self.startidx = startidx # index label not from 0, but from startidx 51 | print('shuffle DB :%s, b:%d, %d-way, %d-shot, %d-query, resize:%d' % ( 52 | mode, batchsz, n_way, k_shot, k_query, resize)) 53 | 54 | if mode == 'train': 55 | # self.transform = transforms.Compose([lambda x: Image.open(x).convert('RGB'), 56 | # transforms.Resize((self.resize, self.resize)), 57 | # # transforms.RandomHorizontalFlip(), 58 | # # transforms.RandomRotation(5), 59 | # transforms.ToTensor(), 60 | # transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) 61 | # ]) 62 | self.transform = transforms.Compose([lambda x: Image.open(x).convert('RGB'), 63 | transforms.Resize((self.resize, self.resize)), 64 | # transforms.RandomHorizontalFlip(), 65 | # transforms.RandomRotation(5), 66 | transforms.ToTensor(), 67 | transforms.Normalize((0, 0, 0), (1, 1, 1)) 68 | ]) 69 | else: 70 | # self.transform = transforms.Compose([lambda x: Image.open(x).convert('RGB'), 71 | # transforms.Resize((self.resize, self.resize)), 72 | # transforms.ToTensor(), 73 | # transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) 74 | # ]) 75 | self.transform = transforms.Compose([lambda x: Image.open(x).convert('RGB'), 76 | transforms.Resize((self.resize, self.resize)), 77 | transforms.ToTensor(), 78 | transforms.Normalize((0, 0, 0), (1, 1, 1)) 79 | ]) 80 | 81 | self.path = os.path.join(root, 'images') # image path 82 | csvdata = self.loadCSV(os.path.join(root, mode + '.csv')) # csv path 83 | self.data = [] 84 | self.img2label = {} 85 | for i, (k, v) in enumerate(csvdata.items()): 86 | self.data.append(v) # [[img1, img2, ...], [img111, ...]] 87 | self.img2label[k] = i + self.startidx # {"img_name[:9]":label} 88 | self.cls_num = len(self.data) 89 | 90 | self.create_batch(self.batchsz) 91 | 92 | def loadCSV(self, csvf): 93 | """ 94 | return a dict saving the information of csv 95 | :param splitFile: csv file name 96 | :return: {label:[file1, file2 ...]} 97 | """ 98 | dictLabels = {} 99 | with open(csvf) as csvfile: 100 | csvreader = csv.reader(csvfile, delimiter=',') 101 | next(csvreader, None) # skip (filename, label) 102 | for i, row in enumerate(csvreader): 103 | filename = row[0] 104 | label = row[1] 105 | # append filename to current label 106 | if label in dictLabels.keys(): 107 | dictLabels[label].append(filename) 108 | else: 109 | dictLabels[label] = [filename] 110 | return dictLabels 111 | 112 | def create_batch(self, batchsz): 113 | """ 114 | create batch for meta-learning. 115 | ×episode× here means batch, and it means how many sets we want to retain. 116 | :param episodes: batch size 117 | :return: 118 | """ 119 | self.support_x_batch = [] # support set batch 120 | self.query_x_batch = [] # query set batch 121 | for b in range(batchsz): # for each batch 122 | # 1.select n_way classes randomly 123 | selected_cls = np.random.choice(self.cls_num, self.n_way, False) # no duplicate 124 | np.random.shuffle(selected_cls) 125 | support_x = [] 126 | query_x = [] 127 | for cls in selected_cls: 128 | # 2. select k_shot + k_query for each class 129 | selected_imgs_idx = np.random.choice(len(self.data[cls]), self.k_shot + self.k_query, False) 130 | np.random.shuffle(selected_imgs_idx) 131 | indexDtrain = np.array(selected_imgs_idx[:self.k_shot]) # idx for Dtrain 132 | indexDtest = np.array(selected_imgs_idx[self.k_shot:]) # idx for Dtest 133 | support_x.append( 134 | np.array(self.data[cls])[indexDtrain].tolist()) # get all images filename for current Dtrain 135 | query_x.append(np.array(self.data[cls])[indexDtest].tolist()) 136 | 137 | # shuffle the correponding relation between support set and query set 138 | random.shuffle(support_x) 139 | random.shuffle(query_x) 140 | 141 | self.support_x_batch.append(support_x) # append set to current sets 142 | self.query_x_batch.append(query_x) # append sets to current sets 143 | 144 | def __getitem__(self, index): 145 | """ 146 | index means index of sets, 0<= index <= batchsz-1 147 | :param index: 148 | :return: 149 | """ 150 | # [setsz, 3, resize, resize] 151 | support_x = torch.FloatTensor(self.setsz, 3, self.resize, self.resize) 152 | # [setsz] 153 | support_y = np.zeros((self.setsz), dtype=np.int) 154 | # [querysz, 3, resize, resize] 155 | query_x = torch.FloatTensor(self.querysz, 3, self.resize, self.resize) 156 | # [querysz] 157 | query_y = np.zeros((self.querysz), dtype=np.int) 158 | 159 | flatten_support_x = [os.path.join(self.path, item) 160 | for sublist in self.support_x_batch[index] for item in sublist] 161 | support_y = np.array( 162 | [self.img2label[item[:9]] # filename:n0153282900000005.jpg, the first 9 characters treated as label 163 | for sublist in self.support_x_batch[index] for item in sublist]).astype(np.int32) 164 | 165 | flatten_query_x = [os.path.join(self.path, item) 166 | for sublist in self.query_x_batch[index] for item in sublist] 167 | query_y = np.array([self.img2label[item[:9]] 168 | for sublist in self.query_x_batch[index] for item in sublist]).astype(np.int32) 169 | 170 | # print('global:', support_y, query_y) 171 | # support_y: [setsz] 172 | # query_y: [querysz] 173 | # unique: [n-way], sorted 174 | unique = np.unique(support_y) 175 | random.shuffle(unique) 176 | # relative means the label ranges from 0 to n-way 177 | support_y_relative = np.zeros(self.setsz) 178 | query_y_relative = np.zeros(self.querysz) 179 | for idx, l in enumerate(unique): 180 | support_y_relative[support_y == l] = idx 181 | query_y_relative[query_y == l] = idx 182 | 183 | # print('relative:', support_y_relative, query_y_relative) 184 | 185 | for i, path in enumerate(flatten_support_x): 186 | support_x[i] = self.transform(path) 187 | 188 | for i, path in enumerate(flatten_query_x): 189 | query_x[i] = self.transform(path) 190 | # print(support_set_y) 191 | # return support_x, torch.LongTensor(support_y), query_x, torch.LongTensor(query_y) 192 | 193 | return support_x, torch.LongTensor(support_y_relative), query_x, torch.LongTensor(query_y_relative) 194 | 195 | def __len__(self): 196 | # as we have built up to batchsz of sets, you can sample some small batch size of sets. 197 | return self.batchsz 198 | 199 | 200 | if __name__ == '__main__': 201 | # the following episode is to view one set of images via tensorboard. 202 | from torchvision.utils import make_grid 203 | from matplotlib import pyplot as plt 204 | from tensorboardX import SummaryWriter 205 | import time 206 | 207 | plt.ion() 208 | 209 | tb = SummaryWriter('runs', 'mini-imagenet') 210 | mini = MiniImagenet('../../../dataset/', mode='train', n_way=5, k_shot=1, k_query=1, batchsz=1000, resize=168) 211 | 212 | for i, set_ in enumerate(mini): 213 | # support_x: [k_shot*n_way, 3, 84, 84] 214 | support_x, support_y, query_x, query_y = set_ 215 | 216 | support_x = make_grid(support_x, nrow=2) 217 | query_x = make_grid(query_x, nrow=2) 218 | 219 | plt.figure(1) 220 | plt.imshow(support_x.transpose(2, 0).numpy()) 221 | plt.pause(0.5) 222 | plt.figure(2) 223 | plt.imshow(query_x.transpose(2, 0).numpy()) 224 | plt.pause(0.5) 225 | 226 | tb.add_image('support_x', support_x) 227 | tb.add_image('query_x', query_x) 228 | 229 | time.sleep(5) 230 | 231 | tb.close() 232 | 233 | -------------------------------------------------------------------------------- /Omniglot/MAML_TrainStd_Omniglot.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": { 7 | "scrolled": true 8 | }, 9 | "outputs": [], 10 | "source": [ 11 | "import torch, os\n", 12 | "import numpy as np\n", 13 | "#from MiniImagenet import MiniImagenet\n", 14 | "import scipy.stats\n", 15 | "from torch.utils.data import DataLoader\n", 16 | "from torch.optim import lr_scheduler\n", 17 | "import random, sys, pickle\n", 18 | "import argparse\n", 19 | "\n", 20 | "\n", 21 | "from torchmeta.datasets import Omniglot\n", 22 | "from torchmeta.transforms import Categorical, ClassSplitter, Rotation\n", 23 | "from torchvision.transforms import Compose, Resize, ToTensor\n", 24 | "from torchmeta.utils.data import BatchMetaDataLoader\n", 25 | "\n", 26 | "\n", 27 | "from MAMLMeta import Meta\n", 28 | "\n", 29 | "\n", 30 | "\n", 31 | "def mean_confidence_interval(accs, confidence=0.95):\n", 32 | " n = accs.shape[0]\n", 33 | " m, se = np.mean(accs), scipy.stats.sem(accs)\n", 34 | " h = se * scipy.stats.t._ppf((1 + confidence) / 2, n - 1)\n", 35 | " return m, h\n", 36 | "\n", 37 | "\n", 38 | "def main():\n", 39 | "\n", 40 | " torch.manual_seed(222)\n", 41 | " torch.cuda.manual_seed_all(222)\n", 42 | " np.random.seed(222)\n", 43 | "\n", 44 | " print(args)\n", 45 | "\n", 46 | " config = [\n", 47 | " ('conv2d', [128, 1, 3, 3, 2, 0]),\n", 48 | " ('relu', [True]),\n", 49 | " ('bn', [128]),\n", 50 | " ('conv2d', [128, 128, 3, 3, 2, 0]),\n", 51 | " ('relu', [True]),\n", 52 | " ('bn', [128]),\n", 53 | " ('conv2d', [128, 128, 3, 3, 2, 0]),\n", 54 | " ('relu', [True]),\n", 55 | " ('bn', [128]),\n", 56 | " ('conv2d', [128, 128, 2, 2, 1, 0]),\n", 57 | " ('relu', [True]),\n", 58 | " ('bn', [128]),\n", 59 | " ('flatten', []),\n", 60 | " ('linear', [args.n_way, 128])\n", 61 | " ]\n", 62 | "\n", 63 | " device = torch.device('cuda:0')\n", 64 | " \n", 65 | " best_cl = 0\n", 66 | " best_rb = 0\n", 67 | " filename = 'maml_eps10_omniglot_5way.pt'\n", 68 | " maml = Meta(args, config).to(device)\n", 69 | " if os.path.isfile(filename):\n", 70 | " print(\"=> loading checkpoint '{}'\".format(filename))\n", 71 | " checkpoint = torch.load(filename)\n", 72 | " best_cl = checkpoint['cl']\n", 73 | " best_rb = checkpoint['rb']\n", 74 | " maml.net.load_state_dict(checkpoint['state_dict'])\n", 75 | " #maml = maml.to(device)\n", 76 | "# print(\"=> loaded checkpoint '{}' (epoch {})\"\n", 77 | "# .format(filename, checkpoint['epoch']))\n", 78 | " else:\n", 79 | " print(\"=> no checkpoint found at '{}'\".format(filename))\n", 80 | "\n", 81 | " tmp = filter(lambda x: x.requires_grad, maml.parameters())\n", 82 | " num = sum(map(lambda x: np.prod(x.shape), tmp))\n", 83 | " print(maml)\n", 84 | " print('Total trainable tensors:', num)\n", 85 | "\n", 86 | " \n", 87 | " data_train = Omniglot(\"data\",\n", 88 | " # Number of ways\n", 89 | " num_classes_per_task=args.n_way,\n", 90 | " meta_train=True,\n", 91 | " meta_val=False,\n", 92 | " meta_test=False,\n", 93 | " # Resize the images to 28x28 and converts them to PyTorch tensors (from Torchvision)\n", 94 | " transform=Compose([Resize(28), ToTensor()]),\n", 95 | " # Transform the labels to integers (e.g. (\"Glagolitic/character01\", \"Sanskrit/character14\", ...) to (0, 1, ...))\n", 96 | " target_transform=Categorical(num_classes=args.n_way),\n", 97 | " # Creates new virtual classes with rotated versions of the images (from Santoro et al., 2016)\n", 98 | " #class_augmentations=[Rotation([90, 180, 270])],\n", 99 | " download=True)\n", 100 | " \n", 101 | " data_test = Omniglot(\"data\",\n", 102 | " # Number of ways\n", 103 | " num_classes_per_task=args.n_way,\n", 104 | " meta_train=False,\n", 105 | " meta_val=False,\n", 106 | " meta_test=True,\n", 107 | " # Resize the images to 28x28 and converts them to PyTorch tensors (from Torchvision)\n", 108 | " transform=Compose([Resize(28), ToTensor()]),\n", 109 | " # Transform the labels to integers (e.g. (\"Glagolitic/character01\", \"Sanskrit/character14\", ...) to (0, 1, ...))\n", 110 | " target_transform=Categorical(num_classes=args.n_way),\n", 111 | " # Creates new virtual classes with rotated versions of the images (from Santoro et al., 2016)\n", 112 | " #class_augmentations=[Rotation([90, 180, 270])],\n", 113 | " download=True)\n", 114 | " data_train = ClassSplitter(data_train, shuffle=True, num_train_per_class=args.k_spt, num_test_per_class=args.k_qry)\n", 115 | " data_test = ClassSplitter(data_test, shuffle=True, num_train_per_class=args.k_spt, num_test_per_class=args.k_qry)\n", 116 | "\n", 117 | " \n", 118 | "\n", 119 | " for epoch in range(args.epoch//10000):\n", 120 | " # fetch meta_batchsz num of episode each time\n", 121 | " db = BatchMetaDataLoader(data_train, batch_size=args.task_num, num_workers=0)\n", 122 | " #db = DataLoader(data_train, args.task_num, shuffle=True, num_workers=0, pin_memory=True)\n", 123 | "\n", 124 | " for step, batch_train in enumerate(db):\n", 125 | " #for step, (x_spt, y_spt, x_qry, y_qry) in enumerate(db):\n", 126 | "\n", 127 | " x_spt, y_spt = batch_train[\"train\"]\n", 128 | " x_qry, y_qry = batch_train[\"test\"]\n", 129 | " x_spt, y_spt, x_qry, y_qry = x_spt.to(device), y_spt.to(device), x_qry.to(device), y_qry.to(device)\n", 130 | "\n", 131 | " accs, accs_adv = maml(x_spt, y_spt, x_qry, y_qry)\n", 132 | "\n", 133 | " \n", 134 | " if step % 5000 == 0:\n", 135 | " print('step:', step, '\\ttraining acc:', accs)\n", 136 | " print('step:', step, '\\ttraining acc_adv:', accs_adv)\n", 137 | "# state = {'epoch': epoch, 'step': step, 'state_dict': maml.net.state_dict()}\n", 138 | "# torch.save(state, 'mamleps8Omniglot1shot.pt')\n", 139 | "\n", 140 | " if step % 20000 == 0: # evaluation\n", 141 | " #db_test = DataLoader(data_test, 1, shuffle=True, num_workers=0, pin_memory=True)\n", 142 | " db_test = BatchMetaDataLoader(data_test, batch_size=1, num_workers=0)\n", 143 | " accs_all_test = []\n", 144 | " accsadv_all_test = []\n", 145 | " accsadvpr_all_test = []\n", 146 | "\n", 147 | " for step_t, batch_test in enumerate(db_test):\n", 148 | " x_spt, y_spt = batch_test[\"train\"]\n", 149 | " x_qry, y_qry = batch_test[\"test\"]\n", 150 | " \n", 151 | "\n", 152 | "\n", 153 | " x_spt, y_spt, x_qry, y_qry = x_spt.squeeze(0).to(device), y_spt.squeeze(0).to(device), \\\n", 154 | " x_qry.squeeze(0).to(device), y_qry.squeeze(0).to(device)\n", 155 | "\n", 156 | " accs, accs_adv, accs_adv_prior = maml.finetunning(x_spt, y_spt, x_qry, y_qry)\n", 157 | " accs_all_test.append(accs)\n", 158 | " accsadv_all_test.append(accs_adv)\n", 159 | " accsadvpr_all_test.append(accs_adv_prior)\n", 160 | " if step_t == 10000:\n", 161 | " break\n", 162 | "\n", 163 | " accs = np.array(accs_all_test).mean(axis=0).astype(np.float16)\n", 164 | " accs_adv = np.array(accsadv_all_test).mean(axis=0).astype(np.float16)\n", 165 | " accs_adv_prior = np.array(accsadvpr_all_test).mean(axis=0).astype(np.float16)\n", 166 | " print('Test acc:', accs)\n", 167 | " print('Test acc_adv:', accs_adv)\n", 168 | " print('Test acc_adv_prior:', accs_adv_prior)\n", 169 | " if best_cl < accs[-1]:\n", 170 | " best_cl = accs[-1]\n", 171 | " state = {'cl': best_cl, 'rb': best_rb, 'state_dict': maml.net.state_dict()}\n", 172 | " torch.save(state, 'maml_eps10_omniglot_5way.pt')\n", 173 | " print(best_cl)\n", 174 | "\n", 175 | "\n", 176 | "\n", 177 | "if __name__ == '__main__':\n", 178 | "\n", 179 | " argparser = argparse.ArgumentParser()\n", 180 | " argparser.add_argument('--epoch', type=int, help='epoch number', default=60000)\n", 181 | " argparser.add_argument('--n_way', type=int, help='n way', default=5)\n", 182 | " argparser.add_argument('--k_spt', type=int, help='k shot for support set', default=1)\n", 183 | " argparser.add_argument('--k_qry', type=int, help='k shot for query set', default=15)\n", 184 | " argparser.add_argument('--imgsz', type=int, help='imgsz', default=32)\n", 185 | " argparser.add_argument('--imgc', type=int, help='imgc', default=3)\n", 186 | " argparser.add_argument('--task_num', type=int, help='meta batch size, namely task num', default=4)\n", 187 | " argparser.add_argument('--meta_lr', type=float, help='meta-level outer learning rate', default=1e-3)\n", 188 | " argparser.add_argument('--update_lr', type=float, help='task-level inner update learning rate', default=0.01)\n", 189 | " argparser.add_argument('--update_step', type=int, help='task-level inner update steps', default=5)\n", 190 | " argparser.add_argument('--update_step_test', type=int, help='update steps for finetunning', default=10)\n", 191 | " \n", 192 | " #argparser.add_argument('--fast', action=\"store_true\", help='whether to use fgsm')\n", 193 | "\n", 194 | " args = argparser.parse_args(args=[])\n", 195 | "\n", 196 | " main()" 197 | ] 198 | }, 199 | { 200 | "cell_type": "code", 201 | "execution_count": null, 202 | "metadata": {}, 203 | "outputs": [], 204 | "source": [] 205 | } 206 | ], 207 | "metadata": { 208 | "kernelspec": { 209 | "display_name": "Python 3", 210 | "language": "python", 211 | "name": "python3" 212 | }, 213 | "language_info": { 214 | "codemirror_mode": { 215 | "name": "ipython", 216 | "version": 3 217 | }, 218 | "file_extension": ".py", 219 | "mimetype": "text/x-python", 220 | "name": "python", 221 | "nbconvert_exporter": "python", 222 | "pygments_lexer": "ipython3", 223 | "version": "3.7.6" 224 | } 225 | }, 226 | "nbformat": 4, 227 | "nbformat_minor": 4 228 | } 229 | -------------------------------------------------------------------------------- /Omniglot/attack.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from abc import ABCMeta, abstractmethod, abstractproperty 4 | from torch.nn import functional as F 5 | 6 | class AttackBase(metaclass=ABCMeta): 7 | @abstractmethod 8 | def attack(self, net, inp, label, target = None): 9 | ''' 10 | 11 | :param inp: batched images 12 | :param target: specify the indexes of target class, None represents untargeted attack 13 | :return: batched adversaril images 14 | ''' 15 | pass 16 | 17 | @abstractmethod 18 | def to(self, device): 19 | pass 20 | 21 | 22 | 23 | def clip_eta(eta, norm, eps, DEVICE = torch.device('cuda:2')): 24 | ''' 25 | helper functions to project eta into epsilon norm ball 26 | :param eta: Perturbation tensor (should be of size(N, C, H, W)) 27 | :param norm: which norm. should be in [1, 2, np.inf] 28 | :param eps: epsilon, bound of the perturbation 29 | :return: Projected perturbation 30 | ''' 31 | 32 | assert norm in [1, 2, np.inf], "norm should be in [1, 2, np.inf]" 33 | 34 | with torch.no_grad(): 35 | avoid_zero_div = torch.tensor(1e-12).to(DEVICE) 36 | eps = torch.tensor(eps).to(DEVICE) 37 | one = torch.tensor(1.0).to(DEVICE) 38 | 39 | if norm == np.inf: 40 | eta = torch.clamp(eta, -eps, eps) 41 | else: 42 | normalize = torch.norm(eta.reshape(eta.size(0), -1), p = norm, dim = -1, keepdim = False) 43 | normalize = torch.max(normalize, avoid_zero_div) 44 | 45 | normalize.unsqueeze_(dim = -1) 46 | normalize.unsqueeze_(dim=-1) 47 | normalize.unsqueeze_(dim=-1) 48 | 49 | factor = torch.min(one, eps / normalize) 50 | eta = eta * factor 51 | return eta 52 | 53 | 54 | 55 | class PGD(AttackBase): 56 | # ImageNet pre-trained mean and std 57 | # _mean = torch.tensor(np.array([0.485, 0.456, 0.406]).astype(np.float32)[np.newaxis, :, np.newaxis, np.newaxis]) 58 | # _std = torch.tensor(np.array([0.229, 0.224, 0.225]).astype(np.float32)[np.newaxis, :, np.newaxis, np.newaxis]) 59 | 60 | # _mean = torch.tensor(np.array([0]).astype(np.float32)[np.newaxis, :, np.newaxis, np.newaxis]) 61 | # _std = torch.tensor(np.array([1.0]).astype(np.float32)[np.newaxis, :, np.newaxis, np.newaxis]) 62 | def __init__(self, eps = 6 / 255.0, sigma = 3 / 255.0, nb_iter = 20, 63 | norm = np.inf, DEVICE = torch.device('cuda:2'), 64 | mean = torch.tensor(np.array([0]).astype(np.float32)[np.newaxis, :, np.newaxis, np.newaxis]), 65 | std = torch.tensor(np.array([1.0]).astype(np.float32)[np.newaxis, :, np.newaxis, np.newaxis]), random_start = True): 66 | ''' 67 | :param eps: maximum distortion of adversarial examples 68 | :param sigma: single step size 69 | :param nb_iter: number of attack iterations 70 | :param norm: which norm to bound the perturbations 71 | ''' 72 | self.eps = eps 73 | self.sigma = sigma 74 | self.nb_iter = nb_iter 75 | self.norm = norm 76 | self.criterion = torch.nn.CrossEntropyLoss().to(DEVICE) 77 | self.DEVICE = DEVICE 78 | self._mean = mean.to(DEVICE) 79 | self._std = std.to(DEVICE) 80 | self.random_start = random_start 81 | 82 | def single_attack(self, net, para, inp, label, eta, target = None): 83 | ''' 84 | Given the original image and the perturbation computed so far, computes 85 | a new perturbation. 86 | :param net: 87 | :param inp: original image 88 | :param label: 89 | :param eta: perturbation computed so far 90 | :return: a new perturbation 91 | ''' 92 | 93 | adv_inp = inp + eta 94 | 95 | #net.zero_grad() 96 | 97 | pred = net(adv_inp, para) 98 | 99 | 100 | loss = self.criterion(pred, label) 101 | grad_sign = torch.autograd.grad(loss, adv_inp, 102 | only_inputs=True, retain_graph = False)[0].sign() 103 | 104 | adv_inp = adv_inp + grad_sign * (self.sigma / self._std) 105 | tmp_adv_inp = adv_inp * self._std + self._mean 106 | 107 | tmp_inp = inp * self._std + self._mean 108 | tmp_adv_inp = torch.clamp(tmp_adv_inp, 0, 1) 109 | tmp_eta = tmp_adv_inp - tmp_inp 110 | tmp_eta = clip_eta(tmp_eta, norm=self.norm, eps=self.eps, DEVICE=self.DEVICE) 111 | 112 | eta = tmp_eta/ self._std 113 | 114 | 115 | # adv_inp = adv_inp + grad_sign * self.eps 116 | # adv_inp = torch.clamp(adv_inp, 0, 1) 117 | # eta = adv_inp - inp 118 | # eta = clip_eta(eta, norm=self.norm, eps=self.eps, DEVICE=self.DEVICE) 119 | 120 | return eta 121 | 122 | def attack(self, net, para, inp, label, target = None): 123 | 124 | if self.random_start: 125 | eta = torch.FloatTensor(*inp.shape).uniform_(-self.eps, self.eps) 126 | else: 127 | eta = torch.zeros_like(inp) 128 | eta = eta.to(self.DEVICE) 129 | eta = (eta - self._mean) / self._std 130 | net.eval() 131 | #print(torch.min(torch.min(torch.min(inp[0])))) 132 | 133 | inp.requires_grad = True 134 | eta.requires_grad = True 135 | for i in range(self.nb_iter): 136 | eta = self.single_attack(net, para, inp, label, eta, target) 137 | #print(i) 138 | 139 | #print(eta.max()) 140 | adv_inp = inp + eta 141 | tmp_adv_inp = adv_inp * self._std + self._mean 142 | tmp_adv_inp = torch.clamp(tmp_adv_inp, 0, 1) 143 | adv_inp = (tmp_adv_inp - self._mean) / self._std 144 | 145 | return adv_inp 146 | 147 | def to(self, device): 148 | self.DEVICE = device 149 | self._mean = self._mean.to(device) 150 | self._std = self._std.to(device) 151 | self.criterion = self.criterion.to(device) -------------------------------------------------------------------------------- /Omniglot/learner.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | 4 | # In[ ]: 5 | 6 | 7 | import torch 8 | from torch import nn 9 | from torch.nn import functional as F 10 | import numpy as np 11 | 12 | 13 | 14 | class Learner(nn.Module): 15 | """ 16 | """ 17 | 18 | def __init__(self, config, imgc, imgsz): 19 | """ 20 | :param config: network config file, type:list of (string, list) 21 | :param imgc: 1 or 3 22 | :param imgsz: 28 or 84 23 | """ 24 | super(Learner, self).__init__() 25 | 26 | 27 | self.config = config 28 | 29 | # this dict contains all tensors needed to be optimized 30 | self.vars = nn.ParameterList() 31 | # running_mean and running_var 32 | self.vars_bn = nn.ParameterList() 33 | 34 | for i, (name, param) in enumerate(self.config): 35 | if name is 'conv2d': 36 | # [ch_out, ch_in, kernelsz, kernelsz] 37 | w = nn.Parameter(torch.ones(*param[:4])) 38 | # gain=1 according to cbfin's implementation 39 | torch.nn.init.kaiming_normal_(w) 40 | self.vars.append(w) 41 | # [ch_out] 42 | self.vars.append(nn.Parameter(torch.zeros(param[0]))) 43 | 44 | elif name is 'convt2d': 45 | # [ch_in, ch_out, kernelsz, kernelsz, stride, padding] 46 | w = nn.Parameter(torch.ones(*param[:4])) 47 | # gain=1 according to cbfin's implementation 48 | torch.nn.init.kaiming_normal_(w) 49 | self.vars.append(w) 50 | # [ch_in, ch_out] 51 | self.vars.append(nn.Parameter(torch.zeros(param[1]))) 52 | 53 | elif name is 'linear': 54 | # [ch_out, ch_in] 55 | w = nn.Parameter(torch.ones(*param)) 56 | # gain=1 according to cbfinn's implementation 57 | torch.nn.init.kaiming_normal_(w) 58 | self.vars.append(w) 59 | # [ch_out] 60 | self.vars.append(nn.Parameter(torch.zeros(param[0]))) 61 | 62 | elif name is 'bn': 63 | # [ch_out] 64 | w = nn.Parameter(torch.ones(param[0])) 65 | self.vars.append(w) 66 | # [ch_out] 67 | self.vars.append(nn.Parameter(torch.zeros(param[0]))) 68 | 69 | # must set requires_grad=False 70 | running_mean = nn.Parameter(torch.zeros(param[0]), requires_grad=False) 71 | running_var = nn.Parameter(torch.ones(param[0]), requires_grad=False) 72 | self.vars_bn.extend([running_mean, running_var]) 73 | 74 | 75 | elif name in ['tanh', 'relu', 'upsample', 'avg_pool2d', 'max_pool2d', 76 | 'flatten', 'reshape', 'leakyrelu', 'sigmoid']: 77 | continue 78 | else: 79 | raise NotImplementedError 80 | 81 | 82 | 83 | 84 | 85 | 86 | def extra_repr(self): 87 | info = '' 88 | 89 | for name, param in self.config: 90 | if name is 'conv2d': 91 | tmp = 'conv2d:(ch_in:%d, ch_out:%d, k:%dx%d, stride:%d, padding:%d)' %(param[1], param[0], param[2], param[3], param[4], param[5],) 92 | info += tmp + '\n' 93 | 94 | elif name is 'convt2d': 95 | tmp = 'convTranspose2d:(ch_in:%d, ch_out:%d, k:%dx%d, stride:%d, padding:%d)' %(param[0], param[1], param[2], param[3], param[4], param[5],) 96 | info += tmp + '\n' 97 | 98 | elif name is 'linear': 99 | tmp = 'linear:(in:%d, out:%d)'%(param[1], param[0]) 100 | info += tmp + '\n' 101 | 102 | elif name is 'leakyrelu': 103 | tmp = 'leakyrelu:(slope:%f)'%(param[0]) 104 | info += tmp + '\n' 105 | 106 | 107 | elif name is 'avg_pool2d': 108 | tmp = 'avg_pool2d:(k:%d, stride:%d, padding:%d)'%(param[0], param[1], param[2]) 109 | info += tmp + '\n' 110 | elif name is 'max_pool2d': 111 | tmp = 'max_pool2d:(k:%d, stride:%d, padding:%d)'%(param[0], param[1], param[2]) 112 | info += tmp + '\n' 113 | elif name in ['flatten', 'tanh', 'relu', 'upsample', 'reshape', 'sigmoid', 'use_logits', 'bn']: 114 | tmp = name + ':' + str(tuple(param)) 115 | info += tmp + '\n' 116 | else: 117 | raise NotImplementedError 118 | 119 | return info 120 | 121 | 122 | 123 | def forward(self, x, vars=None, bn_training=True): 124 | """ 125 | This function can be called by finetunning, however, in finetunning, we dont wish to update 126 | running_mean/running_var. Thought weights/bias of bn is updated, it has been separated by fast_weights. 127 | Indeed, to not update running_mean/running_var, we need set update_bn_statistics=False 128 | but weight/bias will be updated and not dirty initial theta parameters via fast_weiths. 129 | :param x: [b, 1, 28, 28] 130 | :param vars: 131 | :param bn_training: set False to not update 132 | :return: x, loss, likelihood, kld 133 | """ 134 | 135 | if vars is None: 136 | vars = self.vars 137 | 138 | idx = 0 139 | bn_idx = 0 140 | 141 | for name, param in self.config: 142 | if name is 'conv2d': 143 | w, b = vars[idx], vars[idx + 1] 144 | # remember to keep synchrozied of forward_encoder and forward_decoder! 145 | x = F.conv2d(x, w, b, stride=param[4], padding=param[5]) 146 | idx += 2 147 | # print(name, param, '\tout:', x.shape) 148 | elif name is 'convt2d': 149 | w, b = vars[idx], vars[idx + 1] 150 | # remember to keep synchrozied of forward_encoder and forward_decoder! 151 | x = F.conv_transpose2d(x, w, b, stride=param[4], padding=param[5]) 152 | idx += 2 153 | # print(name, param, '\tout:', x.shape) 154 | elif name is 'linear': 155 | w, b = vars[idx], vars[idx + 1] 156 | x = F.linear(x, w, b) 157 | idx += 2 158 | # print('forward:', idx, x.norm().item()) 159 | elif name is 'bn': 160 | w, b = vars[idx], vars[idx + 1] 161 | running_mean, running_var = self.vars_bn[bn_idx], self.vars_bn[bn_idx+1] 162 | x = F.batch_norm(x, running_mean, running_var, weight=w, bias=b, training=bn_training) 163 | idx += 2 164 | bn_idx += 2 165 | 166 | elif name is 'flatten': 167 | # print(x.shape) 168 | x = x.view(x.size(0), -1) 169 | elif name is 'reshape': 170 | # [b, 8] => [b, 2, 2, 2] 171 | x = x.view(x.size(0), *param) 172 | elif name is 'relu': 173 | x = F.relu(x, inplace=param[0]) 174 | elif name is 'leakyrelu': 175 | x = F.leaky_relu(x, negative_slope=param[0], inplace=param[1]) 176 | elif name is 'tanh': 177 | x = F.tanh(x) 178 | elif name is 'sigmoid': 179 | x = torch.sigmoid(x) 180 | elif name is 'upsample': 181 | x = F.upsample_nearest(x, scale_factor=param[0]) 182 | elif name is 'max_pool2d': 183 | x = F.max_pool2d(x, param[0], param[1], param[2]) 184 | elif name is 'avg_pool2d': 185 | x = F.avg_pool2d(x, param[0], param[1], param[2]) 186 | 187 | else: 188 | raise NotImplementedError 189 | 190 | # make sure variable is used properly 191 | assert idx == len(vars) 192 | assert bn_idx == len(self.vars_bn) 193 | 194 | 195 | return x 196 | 197 | 198 | def zero_grad(self, vars=None): 199 | """ 200 | :param vars: 201 | :return: 202 | """ 203 | with torch.no_grad(): 204 | if vars is None: 205 | for p in self.vars: 206 | if p.grad is not None: 207 | p.grad.zero_() 208 | else: 209 | for p in vars: 210 | if p.grad is not None: 211 | p.grad.zero_() 212 | 213 | def parameters(self): 214 | """ 215 | override this function since initial parameters will return with a generator. 216 | :return: 217 | """ 218 | return self.vars 219 | 220 | -------------------------------------------------------------------------------- /Omniglot/train_trades_omniglot.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import torch, os\n", 10 | "import numpy as np\n", 11 | "import scipy.stats\n", 12 | "from torch.utils.data import DataLoader\n", 13 | "from torch.optim import lr_scheduler\n", 14 | "import random, sys, pickle\n", 15 | "import argparse\n", 16 | "\n", 17 | "\n", 18 | "from torchmeta.datasets import Omniglot\n", 19 | "from torchmeta.transforms import Categorical, ClassSplitter, Rotation\n", 20 | "from torchvision.transforms import Compose, Resize, ToTensor\n", 21 | "from torchmeta.utils.data import BatchMetaDataLoader\n", 22 | "\n", 23 | "from MetaFTOmni import Meta\n", 24 | "\n", 25 | "\n", 26 | "\n", 27 | "def mean_confidence_interval(accs, confidence=0.95):\n", 28 | " n = accs.shape[0]\n", 29 | " m, se = np.mean(accs), scipy.stats.sem(accs)\n", 30 | " h = se * scipy.stats.t._ppf((1 + confidence) / 2, n - 1)\n", 31 | " return m, h\n", 32 | "\n", 33 | "\n", 34 | "def main():\n", 35 | "\n", 36 | " torch.manual_seed(222)\n", 37 | " torch.cuda.manual_seed_all(222)\n", 38 | " np.random.seed(222)\n", 39 | "\n", 40 | " print(args)\n", 41 | "\n", 42 | " config = [\n", 43 | " ('conv2d', [64, 1, 3, 3, 2, 0]),\n", 44 | " ('relu', [True]),\n", 45 | " ('bn', [64]),\n", 46 | " ('conv2d', [64, 64, 3, 3, 2, 0]),\n", 47 | " ('relu', [True]),\n", 48 | " ('bn', [64]),\n", 49 | " ('conv2d', [64, 64, 3, 3, 2, 0]),\n", 50 | " ('relu', [True]),\n", 51 | " ('bn', [64]),\n", 52 | " ('conv2d', [64, 64, 2, 2, 1, 0]),\n", 53 | " ('relu', [True]),\n", 54 | " ('bn', [64]),\n", 55 | " ('flatten', []),\n", 56 | " ('linear', [args.n_way, 64])\n", 57 | " ]\n", 58 | "\n", 59 | " device = torch.device('cuda')\n", 60 | " \n", 61 | " best_cl = 0\n", 62 | " best_rb = 0\n", 63 | " filename = 'mamltrades_eps10_omniglot_5way.pt'\n", 64 | " maml = Meta(args, config, device).to(device)\n", 65 | " if os.path.isfile(filename):\n", 66 | " print(\"=> loading checkpoint '{}'\".format(filename))\n", 67 | " checkpoint = torch.load(filename)\n", 68 | " best_cl = checkpoint['cl']\n", 69 | " best_rb = checkpoint['rb']\n", 70 | " maml.net.load_state_dict(checkpoint['state_dict'])\n", 71 | " #maml = maml.to(device)\n", 72 | "# print(\"=> loaded checkpoint '{}' (epoch {})\"\n", 73 | "# .format(filename, checkpoint['epoch']))\n", 74 | " else:\n", 75 | " print(\"=> no checkpoint found at '{}'\".format(filename))\n", 76 | "\n", 77 | " tmp = filter(lambda x: x.requires_grad, maml.parameters())\n", 78 | " num = sum(map(lambda x: np.prod(x.shape), tmp))\n", 79 | " print(maml)\n", 80 | " print('Total trainable tensors:', num)\n", 81 | "\n", 82 | " # batchsz here means total episode number\n", 83 | " \n", 84 | " print(best_cl)\n", 85 | " print(best_rb)\n", 86 | " \n", 87 | "\n", 88 | " data_train = Omniglot(\"data\",\n", 89 | " # Number of ways\n", 90 | " num_classes_per_task=args.n_way,\n", 91 | " meta_train=True,\n", 92 | " meta_val=False,\n", 93 | " meta_test=False,\n", 94 | " # Resize the images to 28x28 and converts them to PyTorch tensors (from Torchvision)\n", 95 | " transform=Compose([Resize(28), ToTensor()]),\n", 96 | " # Transform the labels to integers (e.g. (\"Glagolitic/character01\", \"Sanskrit/character14\", ...) to (0, 1, ...))\n", 97 | " target_transform=Categorical(num_classes=args.n_way),\n", 98 | " # Creates new virtual classes with rotated versions of the images (from Santoro et al., 2016)\n", 99 | " #class_augmentations=[Rotation([90, 180, 270])],\n", 100 | " download=True)\n", 101 | " \n", 102 | " data_test = Omniglot(\"data\",\n", 103 | " # Number of ways\n", 104 | " num_classes_per_task=args.n_way,\n", 105 | " meta_train=False,\n", 106 | " meta_val=False,\n", 107 | " meta_test=True,\n", 108 | " # Resize the images to 28x28 and converts them to PyTorch tensors (from Torchvision)\n", 109 | " transform=Compose([Resize(28), ToTensor()]),\n", 110 | " # Transform the labels to integers (e.g. (\"Glagolitic/character01\", \"Sanskrit/character14\", ...) to (0, 1, ...))\n", 111 | " target_transform=Categorical(num_classes=args.n_way),\n", 112 | " # Creates new virtual classes with rotated versions of the images (from Santoro et al., 2016)\n", 113 | " #class_augmentations=[Rotation([90, 180, 270])],\n", 114 | " download=True)\n", 115 | " data_train = ClassSplitter(data_train, shuffle=True, num_train_per_class=args.k_spt, num_test_per_class=args.k_qry)\n", 116 | " data_test = ClassSplitter(data_test, shuffle=True, num_train_per_class=args.k_spt, num_test_per_class=args.k_qry)\n", 117 | " \n", 118 | " LST = None #UnlabData()\n", 119 | " batchsiz = 20\n", 120 | " \n", 121 | " \n", 122 | "\n", 123 | "\n", 124 | " for epoch in range(args.epoch//10000):\n", 125 | " # fetch meta_batchsz num of episode each time\n", 126 | " db = BatchMetaDataLoader(data_train, batch_size=args.task_num, num_workers=0)\n", 127 | " #db = DataLoader(data_train, args.task_num, shuffle=True, num_workers=0, pin_memory=True)\n", 128 | "\n", 129 | " for step, batch_train in enumerate(db):\n", 130 | " #for step, (x_spt, y_spt, x_qry, y_qry) in enumerate(db):\n", 131 | " \n", 132 | "\n", 133 | "\n", 134 | " x_spt, y_spt = batch_train[\"train\"]\n", 135 | " x_qry, y_qry = batch_train[\"test\"]\n", 136 | "\n", 137 | " x_spt, y_spt, x_qry, y_qry = x_spt.to(device), y_spt.to(device), x_qry.to(device), y_qry.to(device)\n", 138 | " \n", 139 | "\n", 140 | " \n", 141 | " if LST:\n", 142 | " x_unlab = torch.zeros((args.task_num, args.n_way, batchsiz, 3, 32, 32))\n", 143 | " for i in range(args.task_num): \n", 144 | " with torch.no_grad():\n", 145 | " *_, outputs = resmodel(x_spt[i])\n", 146 | " _, y_unlab = outputs.max(1)\n", 147 | " index = y_unlab.cpu().numpy()\n", 148 | " for j in range(args.n_way):\n", 149 | " temp_train = LST.train_data[index[j]].get_next_batch(batchsiz,multiple_passes=True)\n", 150 | " x_unlab[i][j]= torch.from_numpy(temp_train.astype(np.float32))\n", 151 | " x_unlab = x_unlab.to(device)\n", 152 | " else:\n", 153 | " x_unlab = None\n", 154 | " \n", 155 | "\n", 156 | " accs, accs_adv = maml(x_spt, y_spt, x_qry, y_qry, x_unlab)\n", 157 | " \n", 158 | " if step % 5000 == 0:\n", 159 | " print('step:', step, '\\ttraining acc:', accs)\n", 160 | " print('step:', step, '\\ttraining acc_adv:', accs_adv)\n", 161 | " \n", 162 | "\n", 163 | " if step % 10000== 0: # evaluation\n", 164 | " #db_test = DataLoader(data_test, 1, shuffle=True, num_workers=0, pin_memory=True)\n", 165 | " db_test = BatchMetaDataLoader(data_test, batch_size=1, num_workers=0)\n", 166 | " accs_all_test = []\n", 167 | " accsadv_all_test = []\n", 168 | " accsadvpr_all_test = []\n", 169 | "\n", 170 | " for step_t, batch_test in enumerate(db_test):\n", 171 | " x_spt, y_spt = batch_test[\"train\"]\n", 172 | " x_qry, y_qry = batch_test[\"test\"]\n", 173 | "\n", 174 | "\n", 175 | " x_spt, y_spt, x_qry, y_qry = x_spt.squeeze(0).to(device), y_spt.squeeze(0).to(device), \\\n", 176 | " x_qry.squeeze(0).to(device), y_qry.squeeze(0).to(device)\n", 177 | "\n", 178 | " accs, accs_adv, accs_adv_prior = maml.finetunning(x_spt, y_spt, x_qry, y_qry)\n", 179 | " accs_all_test.append(accs)\n", 180 | " accsadv_all_test.append(accs_adv)\n", 181 | " accsadvpr_all_test.append(accs_adv_prior)\n", 182 | " if step_t == 10000:\n", 183 | " break\n", 184 | "\n", 185 | " # [b, update_step+1]\n", 186 | " accs = np.array(accs_all_test).mean(axis=0).astype(np.float16)\n", 187 | " accs_adv = np.array(accsadv_all_test).mean(axis=0).astype(np.float16)\n", 188 | " accs_adv_prior = np.array(accsadvpr_all_test).mean(axis=0).astype(np.float16)\n", 189 | " print('Test acc:', accs)\n", 190 | " print('Test acc_adv:', accs_adv)\n", 191 | " print('Test acc_adv_prior:', accs_adv_prior)\n", 192 | "\n", 193 | " if best_cl < accs[-1] and best_rb < accs_adv[-1]:\n", 194 | " best_cl = accs[-1]\n", 195 | " best_rb = accs_adv[-1]\n", 196 | " state = {'cl': best_cl, 'rb': best_rb, 'state_dict': maml.net.state_dict()}\n", 197 | " torch.save(state, 'mamltrades_eps10_omniglot_5way.pt')\n", 198 | " print(best_cl)\n", 199 | " print(best_rb)\n", 200 | "\n", 201 | "\n", 202 | "\n", 203 | "if __name__ == '__main__':\n", 204 | "\n", 205 | " argparser = argparse.ArgumentParser()\n", 206 | " argparser.add_argument('--epoch', type=int, help='epoch number', default=30000)\n", 207 | " argparser.add_argument('--n_way', type=int, help='n way', default=5)\n", 208 | " argparser.add_argument('--k_spt', type=int, help='k shot for support set', default=1)\n", 209 | " argparser.add_argument('--k_qry', type=int, help='k shot for query set', default=15)\n", 210 | " argparser.add_argument('--imgsz', type=int, help='imgsz', default=32)\n", 211 | " argparser.add_argument('--imgc', type=int, help='imgc', default=3)\n", 212 | " argparser.add_argument('--task_num', type=int, help='meta batch size, namely task num', default=4)\n", 213 | " argparser.add_argument('--meta_lr', type=float, help='meta-level outer learning rate', default=1e-3)\n", 214 | " argparser.add_argument('--update_lr', type=float, help='task-level inner update learning rate', default=0.01)\n", 215 | " argparser.add_argument('--update_step', type=int, help='task-level inner update steps', default=5)\n", 216 | " argparser.add_argument('--update_step_test', type=int, help='update steps for finetunning', default=10)\n", 217 | " \n", 218 | " #argparser.add_argument('--fast', action=\"store_true\", help='whether to use fgsm')\n", 219 | "\n", 220 | " args = argparser.parse_args(args=[])\n", 221 | "\n", 222 | " main()" 223 | ] 224 | }, 225 | { 226 | "cell_type": "code", 227 | "execution_count": null, 228 | "metadata": {}, 229 | "outputs": [], 230 | "source": [] 231 | } 232 | ], 233 | "metadata": { 234 | "kernelspec": { 235 | "display_name": "Python 3", 236 | "language": "python", 237 | "name": "python3" 238 | }, 239 | "language_info": { 240 | "codemirror_mode": { 241 | "name": "ipython", 242 | "version": 3 243 | }, 244 | "file_extension": ".py", 245 | "mimetype": "text/x-python", 246 | "name": "python", 247 | "nbconvert_exporter": "python", 248 | "pygments_lexer": "ipython3", 249 | "version": "3.7.6" 250 | } 251 | }, 252 | "nbformat": 4, 253 | "nbformat_minor": 4 254 | } 255 | -------------------------------------------------------------------------------- /Omniglot/trainfgsm_omniglot.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import torch, os\n", 10 | "import numpy as np\n", 11 | "#from MiniImagenet import MiniImagenet\n", 12 | "import scipy.stats\n", 13 | "from torch.utils.data import DataLoader\n", 14 | "from torch.optim import lr_scheduler\n", 15 | "import random, sys, pickle\n", 16 | "import argparse\n", 17 | "\n", 18 | "\n", 19 | "from torchmeta.datasets import Omniglot\n", 20 | "from torchmeta.transforms import Categorical, ClassSplitter, Rotation\n", 21 | "from torchvision.transforms import Compose, Resize, ToTensor\n", 22 | "from torchmeta.utils.data import BatchMetaDataLoader\n", 23 | "\n", 24 | "\n", 25 | "from fgsmmeta import Meta\n", 26 | "\n", 27 | "\n", 28 | "\n", 29 | "def mean_confidence_interval(accs, confidence=0.95):\n", 30 | " n = accs.shape[0]\n", 31 | " m, se = np.mean(accs), scipy.stats.sem(accs)\n", 32 | " h = se * scipy.stats.t._ppf((1 + confidence) / 2, n - 1)\n", 33 | " return m, h\n", 34 | "\n", 35 | "\n", 36 | "def main():\n", 37 | "\n", 38 | " torch.manual_seed(222)\n", 39 | " torch.cuda.manual_seed_all(222)\n", 40 | " np.random.seed(222)\n", 41 | "\n", 42 | " print(args)\n", 43 | "\n", 44 | " config = [\n", 45 | " ('conv2d', [64, 1, 3, 3, 2, 0]),\n", 46 | " ('relu', [True]),\n", 47 | " ('bn', [64]),\n", 48 | " ('conv2d', [64, 64, 3, 3, 2, 0]),\n", 49 | " ('relu', [True]),\n", 50 | " ('bn', [64]),\n", 51 | " ('conv2d', [64, 64, 3, 3, 2, 0]),\n", 52 | " ('relu', [True]),\n", 53 | " ('bn', [64]),\n", 54 | " ('conv2d', [64, 64, 2, 2, 1, 0]),\n", 55 | " ('relu', [True]),\n", 56 | " ('bn', [64]),\n", 57 | " ('flatten', []),\n", 58 | " ('linear', [args.n_way, 64])\n", 59 | " ]\n", 60 | " \n", 61 | " \n", 62 | "\n", 63 | " device = torch.device('cuda')\n", 64 | " \n", 65 | " best_cl = 0\n", 66 | " best_rb = 0\n", 67 | " filename = 'mamlAQ_eps10_omniglot_5way.pt'\n", 68 | " maml = Meta(args, config, device).to(device)\n", 69 | " if os.path.isfile(filename):\n", 70 | " print(\"=> loading checkpoint '{}'\".format(filename))\n", 71 | " checkpoint = torch.load(filename)\n", 72 | " best_cl = checkpoint['cl']\n", 73 | " best_rb = checkpoint['rb']\n", 74 | " maml.net.load_state_dict(checkpoint['state_dict'])\n", 75 | " #maml = maml.to(device)\n", 76 | "# print(\"=> loaded checkpoint '{}' (epoch {})\"\n", 77 | "# .format(filename, checkpoint['epoch']))\n", 78 | " else:\n", 79 | " print(\"=> no checkpoint found at '{}'\".format(filename))\n", 80 | "\n", 81 | " tmp = filter(lambda x: x.requires_grad, maml.parameters())\n", 82 | " num = sum(map(lambda x: np.prod(x.shape), tmp))\n", 83 | " print(maml)\n", 84 | " print('Total trainable tensors:', num)\n", 85 | "\n", 86 | " # batchsz here means total episode number\n", 87 | " print(best_cl)\n", 88 | " print(best_rb)\n", 89 | " \n", 90 | " \n", 91 | "\n", 92 | " data_train = Omniglot(\"data\",\n", 93 | " # Number of ways\n", 94 | " num_classes_per_task=args.n_way,\n", 95 | " meta_train=True,\n", 96 | " meta_val=False,\n", 97 | " meta_test=False,\n", 98 | " # Resize the images to 28x28 and converts them to PyTorch tensors (from Torchvision)\n", 99 | " transform=Compose([Resize(28), ToTensor()]),\n", 100 | " # Transform the labels to integers (e.g. (\"Glagolitic/character01\", \"Sanskrit/character14\", ...) to (0, 1, ...))\n", 101 | " target_transform=Categorical(num_classes=args.n_way),\n", 102 | " # Creates new virtual classes with rotated versions of the images (from Santoro et al., 2016)\n", 103 | " #class_augmentations=[Rotation([90, 180, 270])],\n", 104 | " download=True)\n", 105 | " \n", 106 | " data_test = Omniglot(\"data\",\n", 107 | " # Number of ways\n", 108 | " num_classes_per_task=args.n_way,\n", 109 | " meta_train=False,\n", 110 | " meta_val=False,\n", 111 | " meta_test=True,\n", 112 | " # Resize the images to 28x28 and converts them to PyTorch tensors (from Torchvision)\n", 113 | " transform=Compose([Resize(28), ToTensor()]),\n", 114 | " # Transform the labels to integers (e.g. (\"Glagolitic/character01\", \"Sanskrit/character14\", ...) to (0, 1, ...))\n", 115 | " target_transform=Categorical(num_classes=args.n_way),\n", 116 | " # Creates new virtual classes with rotated versions of the images (from Santoro et al., 2016)\n", 117 | " #class_augmentations=[Rotation([90, 180, 270])],\n", 118 | " download=True)\n", 119 | " data_train = ClassSplitter(data_train, shuffle=True, num_train_per_class=args.k_spt, num_test_per_class=args.k_qry)\n", 120 | " data_test = ClassSplitter(data_test, shuffle=True, num_train_per_class=args.k_spt, num_test_per_class=args.k_qry)\n", 121 | " \n", 122 | "\n", 123 | " for epoch in range(args.epoch//10000):\n", 124 | " # fetch meta_batchsz num of episode each time\n", 125 | " db = BatchMetaDataLoader(data_train, batch_size=args.task_num, num_workers=0)\n", 126 | " #db = DataLoader(data_train, args.task_num, shuffle=True, num_workers=0, pin_memory=True)\n", 127 | "\n", 128 | " for step, batch_train in enumerate(db):\n", 129 | " #for step, (x_spt, y_spt, x_qry, y_qry) in enumerate(db):\n", 130 | "\n", 131 | " x_spt, y_spt = batch_train[\"train\"]\n", 132 | " x_qry, y_qry = batch_train[\"test\"]\n", 133 | " x_spt, y_spt, x_qry, y_qry = x_spt.to(device), y_spt.to(device), x_qry.to(device), y_qry.to(device)\n", 134 | "\n", 135 | " accs, accs_adv = maml(x_spt, y_spt, x_qry, y_qry)\n", 136 | " \n", 137 | " if step % 5000 == 0:\n", 138 | " print('step:', step, '\\ttraining acc:', accs)\n", 139 | " print('step:', step, '\\ttraining acc_adv:', accs_adv)\n", 140 | " \n", 141 | "\n", 142 | " if (step % 10000 and epoch == 2) or (step % 10000 and epoch == 0): # evaluation\n", 143 | " #db_test = DataLoader(data_test, 1, shuffle=True, num_workers=0, pin_memory=True)\n", 144 | " db_test = BatchMetaDataLoader(data_test, batch_size=1, num_workers=0)\n", 145 | " accs_all_test = []\n", 146 | " accsadv_all_test = []\n", 147 | " accsadvpr_all_test = []\n", 148 | "\n", 149 | " for step_t, batch_test in enumerate(db_test):\n", 150 | " x_spt, y_spt = batch_test[\"train\"]\n", 151 | " x_qry, y_qry = batch_test[\"test\"]\n", 152 | "\n", 153 | "\n", 154 | " x_spt, y_spt, x_qry, y_qry = x_spt.squeeze(0).to(device), y_spt.squeeze(0).to(device), \\\n", 155 | " x_qry.squeeze(0).to(device), y_qry.squeeze(0).to(device)\n", 156 | "\n", 157 | " accs, accs_adv, accs_adv_prior = maml.finetunning(x_spt, y_spt, x_qry, y_qry)\n", 158 | " accs_all_test.append(accs)\n", 159 | " accsadv_all_test.append(accs_adv)\n", 160 | " accsadvpr_all_test.append(accs_adv_prior)\n", 161 | " if step_t == 1000:\n", 162 | " break\n", 163 | "\n", 164 | " # [b, update_step+1]\n", 165 | " accs = np.array(accs_all_test).mean(axis=0).astype(np.float16)\n", 166 | " accs_adv = np.array(accsadv_all_test).mean(axis=0).astype(np.float16)\n", 167 | " accs_adv_prior = np.array(accsadvpr_all_test).mean(axis=0).astype(np.float16)\n", 168 | " print('Test acc:', accs)\n", 169 | " print('Test acc_adv:', accs_adv)\n", 170 | " print('Test acc_adv_prior:', accs_adv_prior)\n", 171 | "\n", 172 | " if best_cl < accs[-1] and best_rb < accs_adv[-1]:\n", 173 | " best_cl = accs[-1]\n", 174 | " best_rb = accs_adv[-1]\n", 175 | " state = {'cl': best_cl, 'rb': best_rb, 'state_dict': maml.net.state_dict()}\n", 176 | " torch.save(state, 'mamlAQ__eps10_omniglot_5way.pt')\n", 177 | " print(best_cl)\n", 178 | " print(best_rb)\n", 179 | "\n", 180 | "\n", 181 | "\n", 182 | "if __name__ == '__main__':\n", 183 | "\n", 184 | " argparser = argparse.ArgumentParser()\n", 185 | " argparser.add_argument('--epoch', type=int, help='epoch number', default=60000)\n", 186 | " argparser.add_argument('--n_way', type=int, help='n way', default=5)\n", 187 | " argparser.add_argument('--k_spt', type=int, help='k shot for support set', default=1)\n", 188 | " argparser.add_argument('--k_qry', type=int, help='k shot for query set', default=15)\n", 189 | " argparser.add_argument('--imgsz', type=int, help='imgsz', default=32)\n", 190 | " argparser.add_argument('--imgc', type=int, help='imgc', default=3)\n", 191 | " argparser.add_argument('--task_num', type=int, help='meta batch size, namely task num', default=4)\n", 192 | " argparser.add_argument('--meta_lr', type=float, help='meta-level outer learning rate', default=1e-3)\n", 193 | " argparser.add_argument('--update_lr', type=float, help='task-level inner update learning rate', default=0.01)\n", 194 | " argparser.add_argument('--update_step', type=int, help='task-level inner update steps', default=5)\n", 195 | " argparser.add_argument('--update_step_test', type=int, help='update steps for finetunning', default=10)\n", 196 | " \n", 197 | " #argparser.add_argument('--fast', action=\"store_true\", help='whether to use fgsm')\n", 198 | "\n", 199 | " args = argparser.parse_args(args=[])\n", 200 | "\n", 201 | " main()" 202 | ] 203 | }, 204 | { 205 | "cell_type": "code", 206 | "execution_count": null, 207 | "metadata": {}, 208 | "outputs": [], 209 | "source": [] 210 | } 211 | ], 212 | "metadata": { 213 | "kernelspec": { 214 | "display_name": "Python 3", 215 | "language": "python", 216 | "name": "python3" 217 | }, 218 | "language_info": { 219 | "codemirror_mode": { 220 | "name": "ipython", 221 | "version": 3 222 | }, 223 | "file_extension": ".py", 224 | "mimetype": "text/x-python", 225 | "name": "python", 226 | "nbconvert_exporter": "python", 227 | "pygments_lexer": "ipython3", 228 | "version": "3.7.6" 229 | } 230 | }, 231 | "nbformat": 4, 232 | "nbformat_minor": 4 233 | } 234 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MetaAdv 2 | ### Platform 3 | * Python: 3.7 4 | * PyTorch: 1.5.0 5 | ### Dataset 6 | We use the benchmark dataset MiniImageNet, which can be download [here](https://drive.google.com/file/d/1HkgrkAwukzEZA0TpO7010PkAOREb2Nuk/view) and [here](https://github.com/twitter/meta-learning-lstm/tree/master/data/miniImagenet). CIFARFS and Omniglot can be found in the package torchmeta [here](https://github.com/tristandeleu/pytorch-meta) 7 | ### Model 8 | We use a four-layer-conv NN model 9 | ## Standard MAML training 10 | Run MAML_TrainStd.ipynb, associate files include MAMLMeta.py, attack.py, learner.py 11 | * Attack power level has to be changed in MAMLMeta.py 12 | * The device in MAML_TrainStd.ipynb and attack.py should set to be the same. (same in the following adversarial training) 13 | ## MAML + FGSM-RS (random start) 14 | Run trainfgsmrs.ipynb, associate files include metafgsm.py, attack.py, learner.py. To incorporate adversarial training in the inner-loop, please replace metafgsm.py with metafgsminout.py 15 | ## MAML + TRADES-RS 16 | Run train_trade.ipynb, associate files include MetaFT.py, LoadUnlableData.py. The unlabled data can be downloaded from [here](https://drive.google.com/file/d/1QpEQFDC8SGoek6k20YFCksKbWEh6j5ei/view?usp=sharing) 17 | ## Standard training + few-short fine-tuning (Meta-tesing) 18 | Run StandardTransNew.ipynb, associate files include LoadDataST.py, StandardTrans.py. StandardTransAdv.ipynb contains adversarial training in the model training process. 19 | ## Unlabeled data selection 20 | Run figureselection.ipynb, associate files include 21 | 22 | ## Visualization 23 | Run robust_vis_neuron.ipynb, associate files include Visualization.py, vis_tool.py, MODELMETA.py. 24 | * By maximizing the output of a nueron with a perturbation in th input, the feature is shown in the input under a robust model, while "random noise" is shown in the input under a standard MAML model. 25 | * The fine-tuned model has the similar feature to the original model in the same neuron. This suggests that the robustness is kept in the fine-tuned model even without adding the adversarial training in the fine-tuning. 26 | 27 | ## CIFAR-FS and Omniglot 28 | Run .ipynb files in the two folders "CIFARFS" and "Omniglot" 29 | 30 | 31 | ## Refer to this Rep. 32 | If you use this code, please cite the following reference 33 | 34 | ``` 35 | @article{wangfast, 36 | title={ON FAST ADVERSARIAL ROBUSTNESS ADAPTATION IN MODEL-AGNOSTIC META-LEARNING}, 37 | author={Wang, Ren and Xu, Kaidi and Liu, Sijia and Chen, Pin-Yu and Weng, Tsui-Wei and Gan, Chuang and Wang, Meng}, 38 | booktitle={International Conference on Learning Representations (ICLR)}, 39 | pages={}, 40 | year={2021} 41 | } 42 | ``` 43 | -------------------------------------------------------------------------------- /StandardTrans.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | 4 | # In[ ]: 5 | 6 | #normal MAML, also contain the adv accuracy calculation 7 | import torch 8 | from torch import nn 9 | from torch import optim 10 | from torch.nn import functional as F 11 | from torch.utils.data import TensorDataset, DataLoader 12 | from torch import optim 13 | import numpy as np 14 | 15 | from learner import Learner 16 | from copy import deepcopy 17 | from attack import PGD 18 | 19 | 20 | 21 | class Transfer(nn.Module): 22 | """ 23 | Meta Learner 24 | """ 25 | def __init__(self, args, config): 26 | """ 27 | :param args: 28 | """ 29 | super(Transfer, self).__init__() 30 | 31 | self.update_lr = args.update_lr 32 | self.meta_lr = args.meta_lr 33 | self.n_way = args.n_way 34 | self.k_spt = args.k_spt 35 | self.k_qry = args.k_qry 36 | #self.task_num = args.task_num 37 | self.update_step = args.update_step 38 | self.update_step_test = args.update_step_test 39 | 40 | 41 | # self.net = Learner(config, args.imgc, args.imgsz) 42 | #self.netadv = Learner(config, args.imgc, args.imgsz) 43 | # self.trans_optim = optim.Adam(self.net.parameters(), lr=self.meta_lr) 44 | #self.meta_optimadv = optim.Adam(self.netadv.parameters(), lr=self.meta_lr) 45 | 46 | 47 | 48 | 49 | def clip_grad_by_norm_(self, grad, max_norm): 50 | """ 51 | in-place gradient clipping. 52 | :param grad: list of gradients 53 | :param max_norm: maximum norm allowable 54 | :return: 55 | """ 56 | 57 | total_norm = 0 58 | counter = 0 59 | for g in grad: 60 | param_norm = g.data.norm(2) 61 | total_norm += param_norm.item() ** 2 62 | counter += 1 63 | total_norm = total_norm ** (1. / 2) 64 | 65 | clip_coef = max_norm / (total_norm + 1e-6) 66 | if clip_coef < 1: 67 | for g in grad: 68 | g.data.mul_(clip_coef) 69 | 70 | return total_norm/counter 71 | 72 | 73 | # def forward(self, x_sq, y_sq): 74 | # """ 75 | # :param x_spt: [b, setsz, c_, h, w] 76 | # :param y_spt: [b, setsz] 77 | # :param x_qry: [b, querysz, c_, h, w] 78 | # :param y_qry: [b, querysz] 79 | # :return: 80 | # """ 81 | # setsz, c_, h, w = x_sq.size() 82 | # querysz = x_sq.size(0) 83 | 84 | # # losses_q = [0 for _ in range(self.update_step + 1)] # losses_q[i] is the loss on step i 85 | # # corrects = [0 for _ in range(self.update_step + 1)] 86 | 87 | 88 | 89 | # # 1. run the i-th task and compute loss for k=0 90 | # logits = self.net(x_sq, vars=None, bn_training=True) 91 | # loss = F.cross_entropy(logits, y_sq) 92 | 93 | # pred = F.softmax(logits, dim=1).argmax(dim=1) 94 | # correct = torch.eq(pred, y_sq).sum().item() 95 | 96 | 97 | 98 | # loss = loss 99 | 100 | 101 | # self.trans_optim.zero_grad() 102 | # loss.backward() 103 | # self.trans_optim.step() 104 | 105 | 106 | 107 | # accs = np.array(correct) / (querysz) 108 | # # accs_adv = np.array(corrects_adv) / (querysz * task_num) 109 | 110 | # return accs 111 | 112 | 113 | def finetunning(self, x_spt, y_spt, x_qry, y_qry, net): 114 | """ 115 | :param x_spt: [setsz, c_, h, w] 116 | :param y_spt: [setsz] 117 | :param x_qry: [querysz, c_, h, w] 118 | :param y_qry: [querysz] 119 | :return: 120 | """ 121 | assert len(x_spt.shape) == 4 122 | 123 | 124 | configtest = [ 125 | ('conv2d', [32, 3, 3, 3, 1, 0]), 126 | ('relu', [True]), 127 | ('bn', [32]), 128 | ('max_pool2d', [2, 2, 0]), 129 | ('conv2d', [32, 32, 3, 3, 1, 0]), 130 | ('relu', [True]), 131 | ('bn', [32]), 132 | ('max_pool2d', [2, 2, 0]), 133 | ('conv2d', [32, 32, 3, 3, 1, 0]), 134 | ('relu', [True]), 135 | ('bn', [32]), 136 | ('max_pool2d', [2, 2, 0]), 137 | ('conv2d', [32, 32, 3, 3, 1, 0]), 138 | ('relu', [True]), 139 | ('bn', [32]), 140 | ('max_pool2d', [2, 1, 0]), 141 | ('flatten', []), 142 | ('linear', [5, 32 * 5 * 5]) 143 | ] 144 | 145 | studentnet = Learner(configtest, 3, 84)#.to('cuda:3') 146 | for i in range(0,16): 147 | studentnet.parameters()[i] = net.parameters()[i] 148 | studentnet.to('cuda:3') 149 | 150 | 151 | 152 | querysz = x_qry.size(0) 153 | 154 | corrects = [0 for _ in range(self.update_step_test + 1)] 155 | 156 | need_adv = True 157 | optimizer = torch.optim.SGD(studentnet.parameters(), lr=self.update_lr, momentum=0.9, weight_decay=5e-4) 158 | eps, step = (2,10) 159 | corrects_adv = [0 for _ in range(self.update_step_test + 1)] 160 | corrects_adv_prior = [0 for _ in range(self.update_step_test + 1)] 161 | 162 | 163 | # in order to not ruin the state of running_mean/variance and bn_weight/bias 164 | # we finetunning on the copied model instead of self.net 165 | #net = deepcopy(self.net) 166 | 167 | # 1. run the i-th task and compute loss for k=0 168 | logits = studentnet(x_spt) 169 | loss = F.cross_entropy(logits, y_spt) 170 | grad = torch.autograd.grad(loss, studentnet.parameters()) 171 | fast_weights = list(map(lambda p: p[1] - self.update_lr * p[0], zip(grad, studentnet.parameters()))) 172 | 173 | 174 | 175 | #PGD AT 176 | if need_adv: 177 | at = PGD(eps=eps / 255.0, sigma=2 / 255.0, nb_iter=step) 178 | # data = x_spt 179 | # label = y_spt 180 | # optimizer.zero_grad() 181 | # adv_inp = at.attack(self.net, self.net.parameters(), data, label) 182 | # logits = self.net(adv_inp, self.net.parameters(), bn_training=True) 183 | # loss = F.cross_entropy(logits, label) 184 | # grad = torch.autograd.grad(loss, self.net.parameters()) 185 | # fast_weights_adv = list(map(lambda p: p[1] - self.update_lr * p[0], zip(grad, self.net.parameters()))) 186 | data = x_qry 187 | label = y_qry 188 | optimizer.zero_grad() 189 | adv_inp_adv = at.attack(studentnet, fast_weights, data, label) 190 | 191 | 192 | 193 | 194 | # this is the loss and accuracy before first update 195 | with torch.no_grad(): 196 | # [setsz, nway] 197 | logits_q = studentnet(x_qry, studentnet.parameters(), bn_training=True) 198 | # [setsz] 199 | pred_q = F.softmax(logits_q, dim=1).argmax(dim=1) 200 | 201 | #find the correct index 202 | corr_ind = (torch.eq(pred_q, y_qry) == True).nonzero() 203 | # scalar 204 | correct = torch.eq(pred_q, y_qry).sum().item() 205 | corrects[0] = corrects[0] + correct 206 | 207 | 208 | #PGD AT 209 | if need_adv: 210 | data = x_qry 211 | label = y_qry 212 | optimizer.zero_grad() 213 | adv_inp = at.attack(studentnet, studentnet.parameters(), data, label) 214 | with torch.no_grad(): 215 | logits_q_adv = studentnet(adv_inp, studentnet.parameters(), bn_training=True) 216 | pred_q_adv = F.softmax(logits_q_adv, dim=1).argmax(dim=1) 217 | correct_adv = torch.eq(pred_q_adv, label).sum().item() 218 | correct_adv_prior = torch.eq(pred_q_adv[corr_ind], label[corr_ind]).sum().item() 219 | corrects_adv[0] = corrects_adv[0] + correct_adv 220 | corrects_adv_prior[0] = corrects_adv_prior[0] + correct_adv_prior/len(corr_ind) 221 | 222 | # this is the loss and accuracy after the first update 223 | with torch.no_grad(): 224 | # [setsz, nway] 225 | logits_q = studentnet(x_qry, fast_weights, bn_training=True) 226 | # [setsz] 227 | pred_q = F.softmax(logits_q, dim=1).argmax(dim=1) 228 | #find the correct index 229 | corr_ind = (torch.eq(pred_q, y_qry) == True).nonzero() 230 | # scalar 231 | correct = torch.eq(pred_q, y_qry).sum().item() 232 | corrects[1] = corrects[1] + correct 233 | 234 | 235 | #PGD AT 236 | if need_adv: 237 | logits_q_adv = studentnet(adv_inp_adv, fast_weights, bn_training=True) 238 | pred_q_adv = F.softmax(logits_q_adv, dim=1).argmax(dim=1) 239 | correct_adv = torch.eq(pred_q_adv, label).sum().item() 240 | correct_adv_prior = torch.eq(pred_q_adv[corr_ind], label[corr_ind]).sum().item() 241 | corrects_adv[1] = corrects_adv[1] + correct_adv 242 | corrects_adv_prior[1] = corrects_adv_prior[1] + correct_adv_prior/len(corr_ind) 243 | 244 | 245 | for k in range(1, self.update_step_test): 246 | # 1. run the i-th task and compute loss for k=1~K-1 247 | logits = studentnet(x_spt, fast_weights, bn_training=True) 248 | loss = F.cross_entropy(logits, y_spt) 249 | # 2. compute grad on theta_pi 250 | grad = torch.autograd.grad(loss, fast_weights) 251 | # 3. theta_pi = theta_pi - train_lr * grad 252 | fast_weights = list(map(lambda p: p[1] - self.update_lr * p[0], zip(grad, fast_weights))) 253 | 254 | logits_q = studentnet(x_qry, fast_weights, bn_training=True) 255 | # loss_q will be overwritten and just keep the loss_q on last update step. 256 | loss_q = F.cross_entropy(logits_q, y_qry) 257 | 258 | 259 | 260 | #PGD AT 261 | if need_adv: 262 | at = PGD(eps=eps / 255.0, sigma=2 / 255.0, nb_iter=step) 263 | # data = x_spt 264 | # label = y_spt 265 | # optimizer.zero_grad() 266 | # adv_inp = at.attack(self.net, fast_weights_adv, data, label) 267 | # logits = self.net(adv_inp, fast_weights_adv, bn_training=True) 268 | # loss = F.cross_entropy(logits, label) 269 | # grad = torch.autograd.grad(loss, fast_weights_adv) 270 | # fast_weights_adv = list(map(lambda p: p[1] - self.update_lr * p[0], zip(grad, fast_weights_adv))) 271 | data = x_qry 272 | label = y_qry 273 | optimizer.zero_grad() 274 | adv_inp_adv = at.attack(studentnet, fast_weights, data, label) 275 | 276 | logits_q_adv = studentnet(adv_inp_adv, fast_weights, bn_training=True) 277 | loss_q_adv = F.cross_entropy(logits_q_adv, label) 278 | 279 | 280 | 281 | with torch.no_grad(): 282 | 283 | pred_q = F.softmax(logits_q, dim=1).argmax(dim=1) 284 | #find the correct index 285 | corr_ind = (torch.eq(pred_q, y_qry) == True).nonzero() 286 | correct = torch.eq(pred_q, y_qry).sum().item() # convert to numpy 287 | corrects[k + 1] = corrects[k + 1] + correct 288 | 289 | 290 | #PGD AT 291 | if need_adv: 292 | pred_q_adv = F.softmax(logits_q_adv, dim=1).argmax(dim=1) 293 | correct_adv = torch.eq(pred_q_adv, label).sum().item() 294 | correct_adv_prior = torch.eq(pred_q_adv[corr_ind], label[corr_ind]).sum().item() 295 | corrects_adv[k + 1] = corrects_adv[k + 1] + correct_adv 296 | corrects_adv_prior[k + 1] = corrects_adv_prior[k + 1] + correct_adv_prior/len(corr_ind) 297 | 298 | 299 | del studentnet 300 | 301 | accs = np.array(corrects) / querysz 302 | 303 | accs_adv = np.array(corrects_adv) / querysz 304 | 305 | accs_adv_prior = np.array(corrects_adv_prior) 306 | 307 | return accs, accs_adv, accs_adv_prior 308 | 309 | 310 | 311 | 312 | def main(): 313 | pass 314 | 315 | 316 | if __name__ == '__main__': 317 | main() 318 | 319 | -------------------------------------------------------------------------------- /StandardTransAdv.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import torch, os\n", 10 | "import numpy as np\n", 11 | "from MiniImagenet import MiniImagenet\n", 12 | "import scipy.stats\n", 13 | "from torch.utils.data import DataLoader\n", 14 | "from torch.optim import lr_scheduler\n", 15 | "import random, sys, pickle\n", 16 | "import argparse\n", 17 | "\n", 18 | "from torch import nn\n", 19 | "from torch import optim\n", 20 | "from torch.nn import functional as F\n", 21 | "\n", 22 | "from learner import Learner\n", 23 | "# from copy import deepcopy\n", 24 | "from attack import PGD\n", 25 | "\n", 26 | "\n", 27 | "from StandardTrans import Transfer\n", 28 | "from LoadDataST import ImagenetMini\n", 29 | "\n", 30 | "\n", 31 | "\n", 32 | "def mean_confidence_interval(accs, confidence=0.95):\n", 33 | " n = accs.shape[0]\n", 34 | " m, se = np.mean(accs), scipy.stats.sem(accs)\n", 35 | " h = se * scipy.stats.t._ppf((1 + confidence) / 2, n - 1)\n", 36 | " return m, h\n", 37 | "\n", 38 | "\n", 39 | "def main():\n", 40 | "\n", 41 | "# torch.manual_seed(222)\n", 42 | "# torch.cuda.manual_seed_all(222)\n", 43 | "# np.random.seed(222)\n", 44 | "\n", 45 | " print(args)\n", 46 | "\n", 47 | " config = [\n", 48 | " ('conv2d', [32, 3, 3, 3, 1, 0]),\n", 49 | " ('relu', [True]),\n", 50 | " ('bn', [32]),\n", 51 | " ('max_pool2d', [2, 2, 0]),\n", 52 | " ('conv2d', [32, 32, 3, 3, 1, 0]),\n", 53 | " ('relu', [True]),\n", 54 | " ('bn', [32]),\n", 55 | " ('max_pool2d', [2, 2, 0]),\n", 56 | " ('conv2d', [32, 32, 3, 3, 1, 0]),\n", 57 | " ('relu', [True]),\n", 58 | " ('bn', [32]),\n", 59 | " ('max_pool2d', [2, 2, 0]),\n", 60 | " ('conv2d', [32, 32, 3, 3, 1, 0]),\n", 61 | " ('relu', [True]),\n", 62 | " ('bn', [32]),\n", 63 | " ('max_pool2d', [2, 1, 0]),\n", 64 | " ('flatten', []),\n", 65 | " ('linear', [64, 32 * 5 * 5])\n", 66 | " ]\n", 67 | "\n", 68 | " device = torch.device('cuda:4')\n", 69 | " \n", 70 | " start_epoch = 0\n", 71 | " start_step = 0\n", 72 | " filename = 'standardtrainadv.pt'\n", 73 | " transferlearn = Learner(config, args.imgc, args.imgsz).to(device)\n", 74 | "# trans_optim = optim.Adam(transferlearn.parameters(), lr=args.meta_lr)\n", 75 | " criterion = torch.nn.CrossEntropyLoss()\n", 76 | " \n", 77 | " if os.path.isfile(filename):\n", 78 | " print(\"=> loading checkpoint '{}'\".format(filename))\n", 79 | " checkpoint = torch.load(filename)\n", 80 | " start_step = checkpoint['step']\n", 81 | " transferlearn.load_state_dict(checkpoint['state_dict'])\n", 82 | " #maml = maml.to(device)\n", 83 | " print(\"=> loaded checkpoint '{}' (step {})\"\n", 84 | " .format(filename, checkpoint['step']))\n", 85 | " else:\n", 86 | " print(\"=> no checkpoint found at '{}'\".format(filename))\n", 87 | "\n", 88 | " trans_optim = torch.optim.SGD(transferlearn.parameters(), lr=0.01, momentum=0.9, weight_decay=1e-4)\n", 89 | " lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(trans_optim, milestones=[30, 85, 108], gamma=0.1)\n", 90 | " #trans_optim = Lamb(transferlearn.parameters(), lr=0.01, weight_decay=1e-4, betas=(.9, .999), adam=False)\n", 91 | " tmp = filter(lambda x: x.requires_grad, transferlearn.parameters())\n", 92 | " num = sum(map(lambda x: np.prod(x.shape), tmp))\n", 93 | " print(transferlearn)\n", 94 | " print(type(transferlearn))\n", 95 | " print('Total trainable tensors:', num)\n", 96 | "\n", 97 | " # batchsz here means total episode number\n", 98 | "# mini = MiniImagenet('../../../dataset/', mode='train', n_way=args.n_way, k_shot=5,\n", 99 | "# k_query=5,\n", 100 | "# batchsz=10000, resize=args.imgsz)\n", 101 | " mini = ImagenetMini('../../../dataset/', mode='train',\n", 102 | " batchsz=10000, resize=args.imgsz)\n", 103 | " mini_test = MiniImagenet('../../../dataset/', mode='test', n_way=5, k_shot=1,\n", 104 | " k_query=args.k_qry,\n", 105 | " batchsz=100, resize=args.imgsz)\n", 106 | " \n", 107 | " \n", 108 | "\n", 109 | "\n", 110 | "# for epoch in range(args.epoch//10000):\n", 111 | "# # fetch meta_batchsz num of episode each time\n", 112 | "# db = DataLoader(mini, args.task_num, shuffle=True, num_workers=0, pin_memory=True)\n", 113 | " at = PGD(eps=2 / 255.0, sigma=2 / 255.0, nb_iter=10)\n", 114 | "\n", 115 | " for step in range(start_step, 30000):\n", 116 | "\n", 117 | "\n", 118 | " x, y = mini.loading_data.get_next_batch(64, multiple_passes=True, reshuffle_after_pass=True)\n", 119 | "\n", 120 | "\n", 121 | "\n", 122 | "\n", 123 | "\n", 124 | " x_sq = x.to(device)\n", 125 | " y_sq = y.to(device)\n", 126 | "\n", 127 | " querysz = x_sq.size(0)\n", 128 | " \n", 129 | " \n", 130 | " \n", 131 | " trans_optim.zero_grad()\n", 132 | "\n", 133 | "\n", 134 | " adv_inp = at.attack(transferlearn, transferlearn.parameters(), x_sq, y_sq)\n", 135 | " trans_optim.zero_grad()\n", 136 | " transferlearn.train()\n", 137 | " logits = transferlearn(adv_inp, vars=None, bn_training=True)\n", 138 | " loss = criterion(logits, y_sq)\n", 139 | "\n", 140 | "# acc = torch_accuracy(pred, label, (1,))\n", 141 | "# advacc = acc[0].item()\n", 142 | "# advloss = loss.item()\n", 143 | " (loss * 1.0).backward()\n", 144 | " trans_optim.step()\n", 145 | " \n", 146 | " trans_optim.zero_grad()\n", 147 | "\n", 148 | "\n", 149 | "\n", 150 | "\n", 151 | " logits = transferlearn(x_sq, vars=None, bn_training=True)\n", 152 | " loss = criterion(logits, y_sq)\n", 153 | " with torch.no_grad():\n", 154 | " pred = F.softmax(logits, dim=1).argmax(dim=1)\n", 155 | " correct = torch.eq(pred, y_sq).sum().item()\n", 156 | "\n", 157 | "\n", 158 | " trans_optim.zero_grad()\n", 159 | " loss.backward()\n", 160 | " trans_optim.step()\n", 161 | " lr_scheduler.step()\n", 162 | "\n", 163 | "\n", 164 | "\n", 165 | " accs = np.array(correct) / (querysz)\n", 166 | "\n", 167 | " if step % 30 == 0: \n", 168 | " print('step:', step, '\\ttraining acc:', accs)\n", 169 | "# print('step:', step, '\\ttraining acc_adv:', accs_adv)\n", 170 | " state = {'step': step, 'state_dict': transferlearn.state_dict()}\n", 171 | " torch.save(state, 'standardtrainadv.pt')\n", 172 | "\n", 173 | " if step % 500 == 0: # evaluation\n", 174 | "\n", 175 | " db_test = DataLoader(mini_test, 1, shuffle=True, num_workers=0, pin_memory=True)\n", 176 | " accs_all_test = []\n", 177 | " accsadv_all_test = []\n", 178 | " accsadvpr_all_test = []\n", 179 | " transfertst = Transfer(args, config).to(device)\n", 180 | "\n", 181 | "\n", 182 | "\n", 183 | "\n", 184 | "\n", 185 | "\n", 186 | " for x_spt, y_spt, x_qry, y_qry in db_test:\n", 187 | " x_spt, y_spt, x_qry, y_qry = x_spt.squeeze(0).to(device), y_spt.squeeze(0).to(device), \\\n", 188 | " x_qry.squeeze(0).to(device), y_qry.squeeze(0).to(device)\n", 189 | "\n", 190 | " accs, accs_adv, accs_adv_prior = transfertst.finetunning(x_spt, y_spt, x_qry, y_qry, transferlearn)\n", 191 | " accs_all_test.append(accs)\n", 192 | " accsadv_all_test.append(accs_adv)\n", 193 | " accsadvpr_all_test.append(accs_adv_prior)\n", 194 | "\n", 195 | " # [b, update_step+1]\n", 196 | " accs = np.array(accs_all_test).mean(axis=0).astype(np.float16)\n", 197 | " accs_adv = np.array(accsadv_all_test).mean(axis=0).astype(np.float16)\n", 198 | " accs_adv_prior = np.array(accsadvpr_all_test).mean(axis=0).astype(np.float16)\n", 199 | " print('Test acc:', accs)\n", 200 | " print('Test acc_adv:', accs_adv)\n", 201 | " print('Test acc_adv_prior:', accs_adv_prior)\n", 202 | "\n", 203 | "\n", 204 | "if __name__ == '__main__':\n", 205 | "\n", 206 | " argparser = argparse.ArgumentParser()\n", 207 | " argparser.add_argument('--epoch', type=int, help='epoch number', default=100000)\n", 208 | " argparser.add_argument('--n_way', type=int, help='n way', default=5)\n", 209 | " argparser.add_argument('--k_spt', type=int, help='k shot for support set', default=1)\n", 210 | " argparser.add_argument('--k_qry', type=int, help='k shot for query set', default=15)\n", 211 | " argparser.add_argument('--imgsz', type=int, help='imgsz', default=84)\n", 212 | " argparser.add_argument('--imgc', type=int, help='imgc', default=3)\n", 213 | " argparser.add_argument('--task_num', type=int, help='meta batch size, namely task num', default=1)\n", 214 | " argparser.add_argument('--meta_lr', type=float, help='meta-level outer learning rate', default=1e-3)\n", 215 | " argparser.add_argument('--update_lr', type=float, help='task-level inner update learning rate', default=0.01)\n", 216 | " argparser.add_argument('--update_step', type=int, help='task-level inner update steps', default=5)\n", 217 | " argparser.add_argument('--update_step_test', type=int, help='update steps for finetunning', default=10)\n", 218 | " \n", 219 | " #argparser.add_argument('--fast', action=\"store_true\", help='whether to use fgsm')\n", 220 | "\n", 221 | " args = argparser.parse_args(args=[])\n", 222 | "\n", 223 | " main()" 224 | ] 225 | }, 226 | { 227 | "cell_type": "code", 228 | "execution_count": null, 229 | "metadata": {}, 230 | "outputs": [], 231 | "source": [] 232 | } 233 | ], 234 | "metadata": { 235 | "kernelspec": { 236 | "display_name": "Python 3", 237 | "language": "python", 238 | "name": "python3" 239 | }, 240 | "language_info": { 241 | "codemirror_mode": { 242 | "name": "ipython", 243 | "version": 3 244 | }, 245 | "file_extension": ".py", 246 | "mimetype": "text/x-python", 247 | "name": "python", 248 | "nbconvert_exporter": "python", 249 | "pygments_lexer": "ipython3", 250 | "version": "3.7.7" 251 | } 252 | }, 253 | "nbformat": 4, 254 | "nbformat_minor": 4 255 | } 256 | -------------------------------------------------------------------------------- /Visualization.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | 4 | # In[ ]: 5 | 6 | 7 | import torch 8 | # import torch.nn as nn 9 | import torch.nn.functional as F 10 | import numpy as np 11 | 12 | # import dill 13 | import os 14 | if int(os.environ.get("NOTEBOOK_MODE", 0)) == 1: 15 | from tqdm import tqdm_notebook as tqdm 16 | else: 17 | from tqdm import tqdm 18 | 19 | # from . import helpers 20 | # from . import attack_steps 21 | from learner import Learner 22 | 23 | class RobustVis(torch.nn.Module): 24 | 25 | def __init__(self, model, device): 26 | 27 | super(RobustVis, self).__init__() 28 | #self.normalize = helpers.InputNormalize(dataset.mean, dataset.std) 29 | configtest = [ 30 | ('conv2d', [32, 3, 3, 3, 1, 0]), 31 | ('relu', [True]), 32 | ('bn', [32]), 33 | ('max_pool2d', [2, 2, 0]), 34 | ('conv2d', [32, 32, 3, 3, 1, 0]), 35 | ('relu', [True]), 36 | ('bn', [32]), 37 | ('max_pool2d', [2, 2, 0]), 38 | ('conv2d', [32, 32, 3, 3, 1, 0]), 39 | ('relu', [True]), 40 | ('bn', [32]), 41 | ('max_pool2d', [2, 2, 0]), 42 | ('conv2d', [32, 32, 3, 3, 1, 0]), 43 | ('relu', [True]), 44 | ('bn', [32]), 45 | ('max_pool2d', [2, 1, 0]), 46 | ('flatten', []) 47 | ] 48 | 49 | copymod = Learner(configtest, 3, 84)#.to('cuda:3') 50 | for i in range(0,16): 51 | copymod.parameters()[i] = model.parameters()[i] 52 | 53 | 54 | self.model = copymod.to(device) 55 | self.model.eval() 56 | self.device = device 57 | 58 | def forward(self, x, target, *_, constraint, eps, step_size, iterations, criterion, 59 | random_start=False, random_restarts=False, do_tqdm=False, 60 | targeted=False, custom_loss=None, should_normalize=False, 61 | orig_input=None, use_best=False, sigma=0.000001): 62 | 63 | 64 | # Can provide a different input to make the feasible set around 65 | # instead of the initial point 66 | 67 | if orig_input is None: orig_input = x.detach() 68 | orig_input = orig_input.to(self.device) 69 | 70 | # Multiplier for gradient ascent [untargeted] or descent [targeted] 71 | m = -1 if targeted else 1 72 | 73 | # Initialize step class 74 | # step = STEPS[constraint](eps=eps, orig_input=orig_input, step_size=step_size) 75 | 76 | def calc_loss(inp, index): 77 | 78 | # if should_normalize: 79 | # inp = self.normalize(inp) 80 | output_vec = self.model(inp) 81 | output = [output_vec[i][index] for i in range(output_vec.size()[0])] 82 | # if custom_loss: 83 | # return custom_loss(self.model, inp, target) 84 | 85 | return output 86 | 87 | 88 | def get_pert_examples(x): 89 | 90 | pert = torch.empty(x.shape).normal_(mean=0,std=sigma).to(self.device) 91 | #print(torch.max(pert).item()) 92 | # Random start (to escape certain types of gradient masking) 93 | if random_start: 94 | x = torch.clamp(x + step.random_perturb(x), 0, 1) 95 | 96 | iterator = range(iterations) 97 | if do_tqdm: iterator = tqdm(iterator) 98 | 99 | # Keep track of the "best" (worst-case) loss and its 100 | # corresponding input 101 | best_loss = None 102 | best_x = None 103 | 104 | # A function that updates the best loss and best input 105 | def replace_best(loss, bloss, x, bx): 106 | if bloss is None: 107 | bx = x.clone().detach() 108 | bloss = losses.clone().detach() 109 | else: 110 | replace = m * bloss < m * loss 111 | bx[replace] = x[replace].clone().detach() 112 | bloss[replace] = loss[replace] 113 | 114 | return bloss, bx 115 | 116 | delta = torch.zeros_like(x, requires_grad=True).requires_grad_(True) 117 | 118 | 119 | 120 | # W = torch.zeros(2048, requires_grad=False) 121 | # W[1858] = 1 122 | 123 | 124 | # 125 | x0 = x 126 | 127 | # step_d = STEPS[constraint](eps=eps, orig_input=delta, step_size=step_size) 128 | # step_m = attack_steps.LinfStep1(eps=eps, orig_input=M, step_size=step_size) 129 | 130 | # 131 | for _ in iterator: 132 | delta = delta.clone().detach().requires_grad_(True).to(self.device) 133 | 134 | x = x0 + pert + delta 135 | 136 | x = torch.clamp(x, 0, 1) 137 | losses = calc_loss(x, target) 138 | 139 | 140 | # W1 = W.unsqueeze(0).expand(10, -1) 141 | 142 | #W1 = W 143 | # losses = losses * W1 144 | 145 | loss = losses#torch.mean(losses) 146 | 147 | grad_d = torch.autograd.grad(loss, delta) 148 | # 149 | # print(type(grad_d)) 150 | # print(len(grad_d)) 151 | # print(len(grad_d[0])) 152 | # print(len(grad_d[0][0])) 153 | 154 | 155 | with torch.no_grad(): 156 | args = [losses, best_loss, x, best_x] 157 | best_loss, best_x = replace_best(*args) if use_best else (losses, x) 158 | 159 | delta = grad_d[0] * step_size + delta 160 | # delta = step_d.project(delta) 161 | 162 | # #additional inf_norm constraint (for clean label attack) 163 | # max_d = x0+20.0/255#torch.min(20/M.cpu().detach().numpy() + x0.cpu().detach().numpy(), eps) 164 | # min_d = x0-20.0/255#torch.max(-20/M.cpu().detach().numpy() + x0.cpu().detach().numpy(), 0) 165 | # delta = torch.where(delta > min_d, min_d, delta) 166 | # delta = torch.where(delta < max_d, max_d, delta) 167 | 168 | # M = step_m.make_step(grad_m) * m + M 169 | # M = step_m.project(M, gamma) 170 | 171 | # #weight method 172 | # W = step_w.make_step(grad_w) * m + W 173 | # W = step_w.project(W) 174 | # # 175 | 176 | if do_tqdm: iterator.set_description("Current loss: {l}".format(l=loss)) 177 | 178 | 179 | 180 | 181 | 182 | x = x0 + delta 183 | # loss_ave = loss.mean(0) 184 | 185 | if not use_best: return losses, torch.clamp(x,0,1).clone().detach() 186 | 187 | 188 | losses = calc_loss(x, target) 189 | args = [losses, best_loss, x, best_x] 190 | best_loss, best_x = replace_best(*args) 191 | return best_loss, best_x 192 | 193 | 194 | 195 | # Random restarts: repeat the attack and find the worst-case 196 | # example for each input in the batch 197 | if random_restarts: 198 | to_ret = None 199 | 200 | orig_cpy = x.clone().detach() 201 | for _ in range(random_restarts): 202 | pert_loss, pertimg = get_pert_examples(orig_cpy) 203 | 204 | if to_ret is None: 205 | to_ret = pertimg.detach() 206 | 207 | output = calc_loss(pertimg, target) 208 | # corr, = helpers.accuracy(output, target, topk=(1,), exact=True) 209 | # corr = corr.byte() 210 | # misclass = ~corr 211 | # to_ret[misclass] = adv[misclass] 212 | 213 | pert_ret = to_ret 214 | else: 215 | pert_loss, pert_ret = get_pert_examples(x) 216 | 217 | return pert_loss, pert_ret 218 | 219 | # class AttackerModel(torch.nn.Module): 220 | # def __init__(self, model, dataset): 221 | # super(AttackerModel, self).__init__() 222 | # self.normalizer = helpers.InputNormalize(dataset.mean, dataset.std) 223 | # self.model = model 224 | # self.attacker = Attacker(model, dataset) 225 | 226 | # def forward(self, inp, target=None, make_adv=False, with_latent=False, 227 | # fake_relu=False, with_image=True, **attacker_kwargs): 228 | # if make_adv: 229 | # assert target is not None 230 | # prev_training = bool(self.training) 231 | # self.eval() 232 | # adv = self.attacker(inp, target, **attacker_kwargs) 233 | # if prev_training: 234 | # self.train() 235 | 236 | # inp = adv 237 | 238 | # if with_image: 239 | # normalized_inp = self.normalizer(inp) 240 | # output = self.model(normalized_inp, with_latent=with_latent, 241 | # fake_relu=fake_relu) 242 | # else: 243 | # output = None 244 | 245 | # return (output, inp) 246 | -------------------------------------------------------------------------------- /attack.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from abc import ABCMeta, abstractmethod, abstractproperty 4 | from torch.nn import functional as F 5 | 6 | class AttackBase(metaclass=ABCMeta): 7 | @abstractmethod 8 | def attack(self, net, inp, label, target = None): 9 | ''' 10 | 11 | :param inp: batched images 12 | :param target: specify the indexes of target class, None represents untargeted attack 13 | :return: batched adversaril images 14 | ''' 15 | pass 16 | 17 | @abstractmethod 18 | def to(self, device): 19 | pass 20 | 21 | 22 | 23 | def clip_eta(eta, norm, eps, DEVICE = torch.device('cuda:2')): 24 | ''' 25 | helper functions to project eta into epsilon norm ball 26 | :param eta: Perturbation tensor (should be of size(N, C, H, W)) 27 | :param norm: which norm. should be in [1, 2, np.inf] 28 | :param eps: epsilon, bound of the perturbation 29 | :return: Projected perturbation 30 | ''' 31 | 32 | assert norm in [1, 2, np.inf], "norm should be in [1, 2, np.inf]" 33 | 34 | with torch.no_grad(): 35 | avoid_zero_div = torch.tensor(1e-12).to(DEVICE) 36 | eps = torch.tensor(eps).to(DEVICE) 37 | one = torch.tensor(1.0).to(DEVICE) 38 | 39 | if norm == np.inf: 40 | eta = torch.clamp(eta, -eps, eps) 41 | else: 42 | normalize = torch.norm(eta.reshape(eta.size(0), -1), p = norm, dim = -1, keepdim = False) 43 | normalize = torch.max(normalize, avoid_zero_div) 44 | 45 | normalize.unsqueeze_(dim = -1) 46 | normalize.unsqueeze_(dim=-1) 47 | normalize.unsqueeze_(dim=-1) 48 | 49 | factor = torch.min(one, eps / normalize) 50 | eta = eta * factor 51 | return eta 52 | 53 | 54 | 55 | class PGD(AttackBase): 56 | # ImageNet pre-trained mean and std 57 | # _mean = torch.tensor(np.array([0.485, 0.456, 0.406]).astype(np.float32)[np.newaxis, :, np.newaxis, np.newaxis]) 58 | # _std = torch.tensor(np.array([0.229, 0.224, 0.225]).astype(np.float32)[np.newaxis, :, np.newaxis, np.newaxis]) 59 | 60 | # _mean = torch.tensor(np.array([0]).astype(np.float32)[np.newaxis, :, np.newaxis, np.newaxis]) 61 | # _std = torch.tensor(np.array([1.0]).astype(np.float32)[np.newaxis, :, np.newaxis, np.newaxis]) 62 | def __init__(self, eps = 6 / 255.0, sigma = 3 / 255.0, nb_iter = 20, 63 | norm = np.inf, DEVICE = torch.device('cuda:2'), 64 | mean = torch.tensor(np.array([0]).astype(np.float32)[np.newaxis, :, np.newaxis, np.newaxis]), 65 | std = torch.tensor(np.array([1.0]).astype(np.float32)[np.newaxis, :, np.newaxis, np.newaxis]), random_start = True): 66 | ''' 67 | :param eps: maximum distortion of adversarial examples 68 | :param sigma: single step size 69 | :param nb_iter: number of attack iterations 70 | :param norm: which norm to bound the perturbations 71 | ''' 72 | self.eps = eps 73 | self.sigma = sigma 74 | self.nb_iter = nb_iter 75 | self.norm = norm 76 | self.criterion = torch.nn.CrossEntropyLoss().to(DEVICE) 77 | self.DEVICE = DEVICE 78 | self._mean = mean.to(DEVICE) 79 | self._std = std.to(DEVICE) 80 | self.random_start = random_start 81 | 82 | def single_attack(self, net, para, inp, label, eta, target = None): 83 | ''' 84 | Given the original image and the perturbation computed so far, computes 85 | a new perturbation. 86 | :param net: 87 | :param inp: original image 88 | :param label: 89 | :param eta: perturbation computed so far 90 | :return: a new perturbation 91 | ''' 92 | 93 | adv_inp = inp + eta 94 | 95 | #net.zero_grad() 96 | 97 | pred = net(adv_inp, para) 98 | 99 | 100 | loss = self.criterion(pred, label) 101 | grad_sign = torch.autograd.grad(loss, adv_inp, 102 | only_inputs=True, retain_graph = False)[0].sign() 103 | 104 | adv_inp = adv_inp + grad_sign * (self.sigma / self._std) 105 | tmp_adv_inp = adv_inp * self._std + self._mean 106 | 107 | tmp_inp = inp * self._std + self._mean 108 | tmp_adv_inp = torch.clamp(tmp_adv_inp, 0, 1) 109 | tmp_eta = tmp_adv_inp - tmp_inp 110 | tmp_eta = clip_eta(tmp_eta, norm=self.norm, eps=self.eps, DEVICE=self.DEVICE) 111 | 112 | eta = tmp_eta/ self._std 113 | 114 | 115 | # adv_inp = adv_inp + grad_sign * self.eps 116 | # adv_inp = torch.clamp(adv_inp, 0, 1) 117 | # eta = adv_inp - inp 118 | # eta = clip_eta(eta, norm=self.norm, eps=self.eps, DEVICE=self.DEVICE) 119 | 120 | return eta 121 | 122 | def attack(self, net, para, inp, label, target = None): 123 | 124 | if self.random_start: 125 | eta = torch.FloatTensor(*inp.shape).uniform_(-self.eps, self.eps) 126 | else: 127 | eta = torch.zeros_like(inp) 128 | eta = eta.to(self.DEVICE) 129 | eta = (eta - self._mean) / self._std 130 | net.eval() 131 | #print(torch.min(torch.min(torch.min(inp[0])))) 132 | 133 | inp.requires_grad = True 134 | eta.requires_grad = True 135 | for i in range(self.nb_iter): 136 | eta = self.single_attack(net, para, inp, label, eta, target) 137 | #print(i) 138 | 139 | #print(eta.max()) 140 | adv_inp = inp + eta 141 | tmp_adv_inp = adv_inp * self._std + self._mean 142 | tmp_adv_inp = torch.clamp(tmp_adv_inp, 0, 1) 143 | adv_inp = (tmp_adv_inp - self._mean) / self._std 144 | 145 | return adv_inp 146 | 147 | def to(self, device): 148 | self.DEVICE = device 149 | self._mean = self._mean.to(device) 150 | self._std = self._std.to(device) 151 | self.criterion = self.criterion.to(device) -------------------------------------------------------------------------------- /learner.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | 4 | # In[ ]: 5 | 6 | 7 | import torch 8 | from torch import nn 9 | from torch.nn import functional as F 10 | import numpy as np 11 | 12 | 13 | 14 | class Learner(nn.Module): 15 | """ 16 | """ 17 | 18 | def __init__(self, config, imgc, imgsz): 19 | """ 20 | :param config: network config file, type:list of (string, list) 21 | :param imgc: 1 or 3 22 | :param imgsz: 28 or 84 23 | """ 24 | super(Learner, self).__init__() 25 | 26 | 27 | self.config = config 28 | 29 | # this dict contains all tensors needed to be optimized 30 | self.vars = nn.ParameterList() 31 | # running_mean and running_var 32 | self.vars_bn = nn.ParameterList() 33 | 34 | for i, (name, param) in enumerate(self.config): 35 | if name is 'conv2d': 36 | # [ch_out, ch_in, kernelsz, kernelsz] 37 | w = nn.Parameter(torch.ones(*param[:4])) 38 | # gain=1 according to cbfin's implementation 39 | torch.nn.init.kaiming_normal_(w) 40 | self.vars.append(w) 41 | # [ch_out] 42 | self.vars.append(nn.Parameter(torch.zeros(param[0]))) 43 | 44 | elif name is 'convt2d': 45 | # [ch_in, ch_out, kernelsz, kernelsz, stride, padding] 46 | w = nn.Parameter(torch.ones(*param[:4])) 47 | # gain=1 according to cbfin's implementation 48 | torch.nn.init.kaiming_normal_(w) 49 | self.vars.append(w) 50 | # [ch_in, ch_out] 51 | self.vars.append(nn.Parameter(torch.zeros(param[1]))) 52 | 53 | elif name is 'linear': 54 | # [ch_out, ch_in] 55 | w = nn.Parameter(torch.ones(*param)) 56 | # gain=1 according to cbfinn's implementation 57 | torch.nn.init.kaiming_normal_(w) 58 | self.vars.append(w) 59 | # [ch_out] 60 | self.vars.append(nn.Parameter(torch.zeros(param[0]))) 61 | 62 | elif name is 'bn': 63 | # [ch_out] 64 | w = nn.Parameter(torch.ones(param[0])) 65 | self.vars.append(w) 66 | # [ch_out] 67 | self.vars.append(nn.Parameter(torch.zeros(param[0]))) 68 | 69 | # must set requires_grad=False 70 | running_mean = nn.Parameter(torch.zeros(param[0]), requires_grad=False) 71 | running_var = nn.Parameter(torch.ones(param[0]), requires_grad=False) 72 | self.vars_bn.extend([running_mean, running_var]) 73 | 74 | 75 | elif name in ['tanh', 'relu', 'upsample', 'avg_pool2d', 'max_pool2d', 76 | 'flatten', 'reshape', 'leakyrelu', 'sigmoid']: 77 | continue 78 | else: 79 | raise NotImplementedError 80 | 81 | 82 | 83 | 84 | 85 | 86 | def extra_repr(self): 87 | info = '' 88 | 89 | for name, param in self.config: 90 | if name is 'conv2d': 91 | tmp = 'conv2d:(ch_in:%d, ch_out:%d, k:%dx%d, stride:%d, padding:%d)' %(param[1], param[0], param[2], param[3], param[4], param[5],) 92 | info += tmp + '\n' 93 | 94 | elif name is 'convt2d': 95 | tmp = 'convTranspose2d:(ch_in:%d, ch_out:%d, k:%dx%d, stride:%d, padding:%d)' %(param[0], param[1], param[2], param[3], param[4], param[5],) 96 | info += tmp + '\n' 97 | 98 | elif name is 'linear': 99 | tmp = 'linear:(in:%d, out:%d)'%(param[1], param[0]) 100 | info += tmp + '\n' 101 | 102 | elif name is 'leakyrelu': 103 | tmp = 'leakyrelu:(slope:%f)'%(param[0]) 104 | info += tmp + '\n' 105 | 106 | 107 | elif name is 'avg_pool2d': 108 | tmp = 'avg_pool2d:(k:%d, stride:%d, padding:%d)'%(param[0], param[1], param[2]) 109 | info += tmp + '\n' 110 | elif name is 'max_pool2d': 111 | tmp = 'max_pool2d:(k:%d, stride:%d, padding:%d)'%(param[0], param[1], param[2]) 112 | info += tmp + '\n' 113 | elif name in ['flatten', 'tanh', 'relu', 'upsample', 'reshape', 'sigmoid', 'use_logits', 'bn']: 114 | tmp = name + ':' + str(tuple(param)) 115 | info += tmp + '\n' 116 | else: 117 | raise NotImplementedError 118 | 119 | return info 120 | 121 | 122 | 123 | def forward(self, x, vars=None, bn_training=True): 124 | """ 125 | This function can be called by finetunning, however, in finetunning, we dont wish to update 126 | running_mean/running_var. Thought weights/bias of bn is updated, it has been separated by fast_weights. 127 | Indeed, to not update running_mean/running_var, we need set update_bn_statistics=False 128 | but weight/bias will be updated and not dirty initial theta parameters via fast_weiths. 129 | :param x: [b, 1, 28, 28] 130 | :param vars: 131 | :param bn_training: set False to not update 132 | :return: x, loss, likelihood, kld 133 | """ 134 | 135 | if vars is None: 136 | vars = self.vars 137 | 138 | idx = 0 139 | bn_idx = 0 140 | 141 | for name, param in self.config: 142 | if name is 'conv2d': 143 | w, b = vars[idx], vars[idx + 1] 144 | # remember to keep synchrozied of forward_encoder and forward_decoder! 145 | x = F.conv2d(x, w, b, stride=param[4], padding=param[5]) 146 | idx += 2 147 | # print(name, param, '\tout:', x.shape) 148 | elif name is 'convt2d': 149 | w, b = vars[idx], vars[idx + 1] 150 | # remember to keep synchrozied of forward_encoder and forward_decoder! 151 | x = F.conv_transpose2d(x, w, b, stride=param[4], padding=param[5]) 152 | idx += 2 153 | # print(name, param, '\tout:', x.shape) 154 | elif name is 'linear': 155 | w, b = vars[idx], vars[idx + 1] 156 | x = F.linear(x, w, b) 157 | idx += 2 158 | # print('forward:', idx, x.norm().item()) 159 | elif name is 'bn': 160 | w, b = vars[idx], vars[idx + 1] 161 | running_mean, running_var = self.vars_bn[bn_idx], self.vars_bn[bn_idx+1] 162 | x = F.batch_norm(x, running_mean, running_var, weight=w, bias=b, training=bn_training) 163 | idx += 2 164 | bn_idx += 2 165 | 166 | elif name is 'flatten': 167 | # print(x.shape) 168 | x = x.view(x.size(0), -1) 169 | elif name is 'reshape': 170 | # [b, 8] => [b, 2, 2, 2] 171 | x = x.view(x.size(0), *param) 172 | elif name is 'relu': 173 | x = F.relu(x, inplace=param[0]) 174 | elif name is 'leakyrelu': 175 | x = F.leaky_relu(x, negative_slope=param[0], inplace=param[1]) 176 | elif name is 'tanh': 177 | x = F.tanh(x) 178 | elif name is 'sigmoid': 179 | x = torch.sigmoid(x) 180 | elif name is 'upsample': 181 | x = F.upsample_nearest(x, scale_factor=param[0]) 182 | elif name is 'max_pool2d': 183 | x = F.max_pool2d(x, param[0], param[1], param[2]) 184 | elif name is 'avg_pool2d': 185 | x = F.avg_pool2d(x, param[0], param[1], param[2]) 186 | 187 | else: 188 | raise NotImplementedError 189 | 190 | # make sure variable is used properly 191 | assert idx == len(vars) 192 | assert bn_idx == len(self.vars_bn) 193 | 194 | 195 | return x 196 | 197 | 198 | def zero_grad(self, vars=None): 199 | """ 200 | :param vars: 201 | :return: 202 | """ 203 | with torch.no_grad(): 204 | if vars is None: 205 | for p in self.vars: 206 | if p.grad is not None: 207 | p.grad.zero_() 208 | else: 209 | for p in vars: 210 | if p.grad is not None: 211 | p.grad.zero_() 212 | 213 | def parameters(self): 214 | """ 215 | override this function since initial parameters will return with a generator. 216 | :return: 217 | """ 218 | return self.vars 219 | 220 | -------------------------------------------------------------------------------- /mamlfgsmeps2.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangren09/MetaAdv/ca2e3583a69f6d92a3ece9813fe8bb1590a9d724/mamlfgsmeps2.pt -------------------------------------------------------------------------------- /mamltradesrseps2self.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangren09/MetaAdv/ca2e3583a69f6d92a3ece9813fe8bb1590a9d724/mamltradesrseps2self.pt -------------------------------------------------------------------------------- /trainfgsmrs.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": { 7 | "scrolled": false 8 | }, 9 | "outputs": [], 10 | "source": [ 11 | "import torch, os\n", 12 | "import numpy as np\n", 13 | "from MiniImagenet import MiniImagenet\n", 14 | "import scipy.stats\n", 15 | "from torch.utils.data import DataLoader\n", 16 | "from torch.optim import lr_scheduler\n", 17 | "import random, sys, pickle\n", 18 | "import argparse\n", 19 | "import time\n", 20 | "\n", 21 | "#from ANIP import Meta\n", 22 | "#from metafgsmanil import Meta\n", 23 | "#from metafgsm import Meta\n", 24 | "#from MAMLMeta import Meta\n", 25 | "#from meta import Meta\n", 26 | "#from Adv_Quer import Meta\n", 27 | "#from metafgsmnewnew import Meta\n", 28 | "from metafgsm import Meta\n", 29 | "\n", 30 | "\n", 31 | "def mean_confidence_interval(accs, confidence=0.95):\n", 32 | " n = accs.shape[0]\n", 33 | " m, se = np.mean(accs), scipy.stats.sem(accs)\n", 34 | " h = se * scipy.stats.t._ppf((1 + confidence) / 2, n - 1)\n", 35 | " return m, h\n", 36 | "\n", 37 | "\n", 38 | "def main():\n", 39 | "\n", 40 | " torch.manual_seed(222)\n", 41 | " torch.cuda.manual_seed_all(222)\n", 42 | " np.random.seed(222)\n", 43 | "\n", 44 | " print(args)\n", 45 | "\n", 46 | " config = [\n", 47 | " ('conv2d', [32, 3, 3, 3, 1, 0]),\n", 48 | " ('relu', [True]),\n", 49 | " ('bn', [32]),\n", 50 | " ('max_pool2d', [2, 2, 0]),\n", 51 | " ('conv2d', [32, 32, 3, 3, 1, 0]),\n", 52 | " ('relu', [True]),\n", 53 | " ('bn', [32]),\n", 54 | " ('max_pool2d', [2, 2, 0]),\n", 55 | " ('conv2d', [32, 32, 3, 3, 1, 0]),\n", 56 | " ('relu', [True]),\n", 57 | " ('bn', [32]),\n", 58 | " ('max_pool2d', [2, 2, 0]),\n", 59 | " ('conv2d', [32, 32, 3, 3, 1, 0]),\n", 60 | " ('relu', [True]),\n", 61 | " ('bn', [32]),\n", 62 | " ('max_pool2d', [2, 1, 0]),\n", 63 | " ('flatten', []),\n", 64 | " ('linear', [args.n_way, 32 * 5 * 5])\n", 65 | " ]\n", 66 | "\n", 67 | " device = torch.device('cuda:2')\n", 68 | " maml = Meta(args, config, device).to(device)\n", 69 | " \n", 70 | " \n", 71 | " start_epoch = 0\n", 72 | " start_step = 0\n", 73 | " filename = 'mamlfgsmeps4_2.pt'\n", 74 | " #maml = Meta(args, config).to(device)\n", 75 | " if os.path.isfile(filename):\n", 76 | " print(\"=> loading checkpoint '{}'\".format(filename))\n", 77 | " checkpoint = torch.load(filename)\n", 78 | " start_epoch = checkpoint['epoch']\n", 79 | " start_step = checkpoint['step']\n", 80 | " maml.net.load_state_dict(checkpoint['state_dict'])\n", 81 | " #maml = maml.to(device)\n", 82 | " print(\"=> loaded checkpoint '{}' (epoch {})\"\n", 83 | " .format(filename, checkpoint['epoch']))\n", 84 | " else:\n", 85 | " print(\"=> no checkpoint found at '{}'\".format(filename))\n", 86 | " \n", 87 | "\n", 88 | " tmp = filter(lambda x: x.requires_grad, maml.parameters())\n", 89 | " num = sum(map(lambda x: np.prod(x.shape), tmp))\n", 90 | " print(maml)\n", 91 | " print('Total trainable tensors:', num)\n", 92 | "\n", 93 | " # batchsz here means total episode number\n", 94 | " mini = MiniImagenet('../../../dataset/', mode='train', n_way=args.n_way, k_shot=args.k_spt,\n", 95 | " k_query=args.k_qry,\n", 96 | " batchsz=10000, resize=args.imgsz)\n", 97 | " mini_test = MiniImagenet('../../../dataset/', mode='test', n_way=args.n_way, k_shot=args.k_spt,\n", 98 | " k_query=args.k_qry,\n", 99 | " batchsz=100, resize=args.imgsz)\n", 100 | "\n", 101 | " for epoch in range(args.epoch//10000):\n", 102 | " # fetch meta_batchsz num of episode each time\n", 103 | " db = DataLoader(mini, args.task_num, shuffle=True, num_workers=0, pin_memory=True)\n", 104 | "\n", 105 | " for step, (x_spt, y_spt, x_qry, y_qry) in enumerate(db):\n", 106 | " if step == 1:\n", 107 | " t = time.perf_counter()\n", 108 | " if step == 499:\n", 109 | " ExecTime = time.perf_counter() - t\n", 110 | " print(ExecTime)\n", 111 | " if step == 501:\n", 112 | " t = time.perf_counter()\n", 113 | " if step == 999:\n", 114 | " ExecTime = time.perf_counter() - t\n", 115 | " print(ExecTime)\n", 116 | "\n", 117 | " x_spt, y_spt, x_qry, y_qry = x_spt.to(device), y_spt.to(device), x_qry.to(device), y_qry.to(device)\n", 118 | "\n", 119 | " accs, accs_adv = maml(x_spt, y_spt, x_qry, y_qry)\n", 120 | "\n", 121 | " if step % 30 == 0:\n", 122 | " print('step:', step, '\\ttraining acc:', accs)\n", 123 | " print('step:', step, '\\ttraining acc_adv:', accs_adv)\n", 124 | " state = {'epoch': epoch, 'step': step, 'state_dict': maml.net.state_dict()}\n", 125 | " torch.save(state, 'mamlfgsmeps4_2.pt')\n", 126 | "\n", 127 | " if step % 1000 == 0: # evaluation\n", 128 | " db_test = DataLoader(mini_test, 1, shuffle=True, num_workers=0, pin_memory=True)\n", 129 | " accs_all_test = []\n", 130 | " accsadv_all_test = []\n", 131 | " accsadvpr_all_test = []\n", 132 | "\n", 133 | " for x_spt, y_spt, x_qry, y_qry in db_test:\n", 134 | " x_spt, y_spt, x_qry, y_qry = x_spt.squeeze(0).to(device), y_spt.squeeze(0).to(device), \\\n", 135 | " x_qry.squeeze(0).to(device), y_qry.squeeze(0).to(device)\n", 136 | "\n", 137 | " accs, accs_adv, accs_adv_prior = maml.finetunning(x_spt, y_spt, x_qry, y_qry)\n", 138 | " accs_all_test.append(accs)\n", 139 | " accsadv_all_test.append(accs_adv)\n", 140 | " accsadvpr_all_test.append(accs_adv_prior)\n", 141 | "\n", 142 | " # [b, update_step+1]\n", 143 | " accs = np.array(accs_all_test).mean(axis=0).astype(np.float16)\n", 144 | " accs_adv = np.array(accsadv_all_test).mean(axis=0).astype(np.float16)\n", 145 | " accs_adv_prior = np.array(accsadvpr_all_test).mean(axis=0).astype(np.float16)\n", 146 | " print('Test acc:', accs)\n", 147 | " print('Test acc_adv:', accs_adv)\n", 148 | " print('Test acc_adv_prior:', accs_adv_prior)\n", 149 | "\n", 150 | "\n", 151 | "if __name__ == '__main__':\n", 152 | "\n", 153 | " argparser = argparse.ArgumentParser()\n", 154 | " argparser.add_argument('--epoch', type=int, help='epoch number', default=60000)\n", 155 | " argparser.add_argument('--n_way', type=int, help='n way', default=5)\n", 156 | " argparser.add_argument('--k_spt', type=int, help='k shot for support set', default=1)\n", 157 | " argparser.add_argument('--k_qry', type=int, help='k shot for query set', default=15)\n", 158 | " argparser.add_argument('--imgsz', type=int, help='imgsz', default=84)\n", 159 | " argparser.add_argument('--imgc', type=int, help='imgc', default=3)\n", 160 | " argparser.add_argument('--task_num', type=int, help='meta batch size, namely task num', default=4)\n", 161 | " argparser.add_argument('--meta_lr', type=float, help='meta-level outer learning rate', default=1e-3)\n", 162 | " argparser.add_argument('--update_lr', type=float, help='task-level inner update learning rate', default=0.01)\n", 163 | " argparser.add_argument('--update_step', type=int, help='task-level inner update steps', default=5)\n", 164 | " argparser.add_argument('--update_step_test', type=int, help='update steps for finetunning', default=10)\n", 165 | " \n", 166 | " #argparser.add_argument('--fast', action=\"store_true\", help='whether to use fgsm')\n", 167 | "\n", 168 | " args = argparser.parse_args(args=[])\n", 169 | "\n", 170 | " main()" 171 | ] 172 | }, 173 | { 174 | "cell_type": "code", 175 | "execution_count": null, 176 | "metadata": {}, 177 | "outputs": [], 178 | "source": [] 179 | } 180 | ], 181 | "metadata": { 182 | "kernelspec": { 183 | "display_name": "Python 3", 184 | "language": "python", 185 | "name": "python3" 186 | }, 187 | "language_info": { 188 | "codemirror_mode": { 189 | "name": "ipython", 190 | "version": 3 191 | }, 192 | "file_extension": ".py", 193 | "mimetype": "text/x-python", 194 | "name": "python", 195 | "nbconvert_exporter": "python", 196 | "pygments_lexer": "ipython3", 197 | "version": "3.7.7" 198 | } 199 | }, 200 | "nbformat": 4, 201 | "nbformat_minor": 4 202 | } 203 | -------------------------------------------------------------------------------- /vis_tool.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | from sklearn.decomposition import PCA 4 | from sklearn import manifold 5 | #import seaborn as sns 6 | 7 | def get_axis(axarr, H, W, i, j): 8 | H, W = H - 1, W - 1 9 | if not (H or W): 10 | ax = axarr 11 | elif not (H and W): 12 | ax = axarr[max(i, j)] 13 | else: 14 | ax = axarr[i][j] 15 | return ax 16 | 17 | def show_image_row(xlist, ylist=None, fontsize=12, tlist=None, filename=None, baseline=None): 18 | H, W = len(xlist), len(xlist[0]) 19 | fig, axarr = plt.subplots(H * 3, W, figsize=(2.5 * W, 2.5 * H * 3)) 20 | for w in range(W): 21 | for h in range(H): 22 | ax = get_axis(axarr, H * 3, W, 3 * h, w) 23 | #clean input 24 | # ax.imshow(baseline[0][w].permute(1, 2, 0)) 25 | #noise input 26 | ax.imshow(1 - baseline[0][w].permute(1, 2, 0)) 27 | ax.xaxis.set_ticks([]) 28 | ax.yaxis.set_ticks([]) 29 | ax.xaxis.set_ticklabels([]) 30 | ax.yaxis.set_ticklabels([]) 31 | 32 | ax = get_axis(axarr, H * 3, W, 3 * h + 1, w) 33 | #clean input 34 | # ax.imshow(xlist[h][w].permute(1, 2, 0)) 35 | #noise input 36 | ax.imshow(1-xlist[h][w].permute(1, 2, 0)) 37 | ax.xaxis.set_ticks([]) 38 | ax.yaxis.set_ticks([]) 39 | ax.xaxis.set_ticklabels([]) 40 | ax.yaxis.set_ticklabels([]) 41 | if ylist and w == 0: 42 | ax.set_ylabel(ylist[h], fontsize=fontsize) 43 | if tlist: 44 | ax.set_title(tlist[h][w], fontsize=fontsize) 45 | 46 | ax = get_axis(axarr, H * 3, W, 3 * h + 2, w) 47 | ax.imshow(1 - abs(xlist[h][w].permute(1, 2, 0) - baseline[0][w].permute(1, 2, 0))) 48 | #ax.imshow(abs(xlist[h][w].permute(1, 2, 0) - baseline[0][w].permute(1, 2, 0))) 49 | ax.xaxis.set_ticks([]) 50 | ax.yaxis.set_ticks([]) 51 | ax.xaxis.set_ticklabels([]) 52 | ax.yaxis.set_ticklabels([]) 53 | if filename is not None: 54 | plt.savefig(filename, bbox_inches='tight') 55 | plt.show() 56 | 57 | 58 | def show_image_column(xlist, ylist=None, fontsize=12, tlist=None, filename=None): 59 | W, H = len(xlist), len(xlist[0]) 60 | fig, axarr = plt.subplots(H * 2, W, figsize=(2.5 * W, 2.5 * H * 2)) 61 | for w in range(W): 62 | for h in range(H): 63 | ax = get_axis(axarr, H * 2, W, 2 * h, w) 64 | ax.imshow(xlist[w][h].permute(1, 2, 0)) 65 | ax.xaxis.set_ticks([]) 66 | ax.yaxis.set_ticks([]) 67 | ax.xaxis.set_ticklabels([]) 68 | ax.yaxis.set_ticklabels([]) 69 | if ylist and h == 0: 70 | ax.set_title(ylist[w], fontsize=fontsize) 71 | if tlist: 72 | ax.set_title(tlist[w][h], fontsize=fontsize) 73 | 74 | ax = get_axis(axarr, H * 2, W, 2 * h + 1, w) 75 | ax.imshow(xlist[w][h].permute(1, 2, 0) - xlist[0][h].permute(1, 2, 0)) 76 | ax.xaxis.set_ticks([]) 77 | ax.yaxis.set_ticks([]) 78 | ax.xaxis.set_ticklabels([]) 79 | ax.yaxis.set_ticklabels([]) 80 | if filename is not None: 81 | plt.savefig(filename, bbox_inches='tight') 82 | plt.show() 83 | 84 | def filter_data(metadata, criteria, value): 85 | crit = [True] * len(metadata) 86 | for c, v in zip(criteria, value): 87 | v = [v] if not isinstance(v, list) else v 88 | crit &= metadata[c].isin(v) 89 | metadata_int = metadata[crit] 90 | exp_ids = metadata_int['exp_id'].tolist() 91 | return exp_ids 92 | 93 | def plot_axis(ax, x, y, xlabel, ylabel, **kwargs): 94 | ax.plot(x, y, **kwargs) 95 | ax.set_xlabel(xlabel, fontsize=14) 96 | ax.set_ylabel(ylabel, fontsize=14) 97 | 98 | 99 | def plot_tsne(x, y, npca=50, markersize=10): 100 | Xlow = PCA(n_components=npca).fit_transform(x) 101 | Y = manifold.TSNE(n_components=2).fit_transform(Xlow) 102 | palette = sns.color_palette("Paired", len(np.unique(y))) 103 | color_dict = {l: c for l, c in zip(range(len(np.unique(y))), palette)} 104 | colors = [color_dict[l] for l in y] 105 | plt.scatter(Y[:, 0], Y[:, 1], markersize, colors, 'o') 106 | plt.show() --------------------------------------------------------------------------------