├── scripts ├── train_5s_5c.sh └── eval_5s_5c.sh ├── learner.py ├── README.md ├── dataloader.py ├── metalearner.py ├── utils.py └── main.py /scripts/train_5s_5c.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # 3 | # For 5-shot, 5-class training 4 | # Hyper-parameters follow https://github.com/twitter/meta-learning-lstm 5 | 6 | python main.py --mode train \ 7 | --n-shot 5 \ 8 | --n-eval 15 \ 9 | --n-class 5 \ 10 | --input-size 4 \ 11 | --hidden-size 20 \ 12 | --lr 1e-3 \ 13 | --episode 50000 \ 14 | --episode-val 100 \ 15 | --epoch 8 \ 16 | --batch-size 25 \ 17 | --image-size 84 \ 18 | --grad-clip 0.25 \ 19 | --bn-momentum 0.95 \ 20 | --bn-eps 1e-3 \ 21 | --data miniimagenet \ 22 | --data-root data/miniImagenet/ \ 23 | --pin-mem True \ 24 | --log-freq 50 \ 25 | --val-freq 1000 26 | -------------------------------------------------------------------------------- /scripts/eval_5s_5c.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # 3 | # For 5-shot, 5-class evaluation, hyper-parameters follow github.com/twitter/meta-learning-lstm 4 | 5 | python main.py --mode test \ 6 | --resume logs-719/ckpts/meta-learner-42000.pth.tar \ 7 | --n-shot 5 \ 8 | --n-eval 15 \ 9 | --n-class 5 \ 10 | --input-size 4 \ 11 | --hidden-size 20 \ 12 | --lr 1e-3 \ 13 | --episode 50000 \ 14 | --episode-val 100 \ 15 | --epoch 8 \ 16 | --batch-size 25 \ 17 | --image-size 84 \ 18 | --grad-clip 0.25 \ 19 | --bn-momentum 0.95 \ 20 | --bn-eps 1e-3 \ 21 | --data miniimagenet \ 22 | --data-root data/miniImagenet/ \ 23 | --pin-mem True \ 24 | --log-freq 100 25 | -------------------------------------------------------------------------------- /learner.py: -------------------------------------------------------------------------------- 1 | from __future__ import division, print_function, absolute_import 2 | 3 | import pdb 4 | import copy 5 | from collections import OrderedDict 6 | 7 | import torch 8 | import torch.nn as nn 9 | import numpy as np 10 | 11 | class Learner(nn.Module): 12 | 13 | def __init__(self, image_size, bn_eps, bn_momentum, n_classes): 14 | super(Learner, self).__init__() 15 | self.model = nn.ModuleDict({'features': nn.Sequential(OrderedDict([ 16 | ('conv1', nn.Conv2d(3, 32, 3, padding=1)), 17 | ('norm1', nn.BatchNorm2d(32, bn_eps, bn_momentum)), 18 | ('relu1', nn.ReLU(inplace=False)), 19 | ('pool1', nn.MaxPool2d(2)), 20 | 21 | ('conv2', nn.Conv2d(32, 32, 3, padding=1)), 22 | ('norm2', nn.BatchNorm2d(32, bn_eps, bn_momentum)), 23 | ('relu2', nn.ReLU(inplace=False)), 24 | ('pool2', nn.MaxPool2d(2)), 25 | 26 | ('conv3', nn.Conv2d(32, 32, 3, padding=1)), 27 | ('norm3', nn.BatchNorm2d(32, bn_eps, bn_momentum)), 28 | ('relu3', nn.ReLU(inplace=False)), 29 | ('pool3', nn.MaxPool2d(2)), 30 | 31 | ('conv4', nn.Conv2d(32, 32, 3, padding=1)), 32 | ('norm4', nn.BatchNorm2d(32, bn_eps, bn_momentum)), 33 | ('relu4', nn.ReLU(inplace=False)), 34 | ('pool4', nn.MaxPool2d(2))])) 35 | }) 36 | 37 | clr_in = image_size // 2**4 38 | self.model.update({'cls': nn.Linear(32 * clr_in * clr_in, n_classes)}) 39 | self.criterion = nn.CrossEntropyLoss() 40 | 41 | def forward(self, x): 42 | x = self.model.features(x) 43 | x = torch.reshape(x, [x.size(0), -1]) 44 | outputs = self.model.cls(x) 45 | return outputs 46 | 47 | def get_flat_params(self): 48 | return torch.cat([p.view(-1) for p in self.model.parameters()], 0) 49 | 50 | def copy_flat_params(self, cI): 51 | idx = 0 52 | for p in self.model.parameters(): 53 | plen = p.view(-1).size(0) 54 | p.data.copy_(cI[idx: idx+plen].view_as(p)) 55 | idx += plen 56 | 57 | def transfer_params(self, learner_w_grad, cI): 58 | # Use load_state_dict only to copy the running mean/var in batchnorm, the values of the parameters 59 | # are going to be replaced by cI 60 | self.load_state_dict(learner_w_grad.state_dict()) 61 | # replace nn.Parameters with tensors from cI (NOT nn.Parameters anymore). 62 | idx = 0 63 | for m in self.model.modules(): 64 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.Linear): 65 | wlen = m._parameters['weight'].view(-1).size(0) 66 | m._parameters['weight'] = cI[idx: idx+wlen].view_as(m._parameters['weight']).clone() 67 | idx += wlen 68 | if m._parameters['bias'] is not None: 69 | blen = m._parameters['bias'].view(-1).size(0) 70 | m._parameters['bias'] = cI[idx: idx+blen].view_as(m._parameters['bias']).clone() 71 | idx += blen 72 | 73 | def reset_batch_stats(self): 74 | for m in self.modules(): 75 | if isinstance(m, nn.BatchNorm2d): 76 | m.reset_running_stats() 77 | 78 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Optimization as a Model for Few-shot Learning 2 | Pytorch implementation of [Optimization as a Model for Few-shot Learning](https://openreview.net/forum?id=rJY0-Kcll) in ICLR 2017 (Oral) 3 | 4 | ![Model Architecture](https://i.imgur.com/lydKeUc.png) 5 | 6 | ## Prerequisites 7 | - python 3+ 8 | - pytorch 0.4+ (developed on 1.0.1 with cuda 9.0) 9 | - [pillow](https://pillow.readthedocs.io/en/stable/installation.html) 10 | - [tqdm](https://tqdm.github.io/) (a nice progress bar) 11 | 12 | ## Data 13 | - Mini-Imagenet as described [here](https://github.com/twitter/meta-learning-lstm/tree/master/data/miniImagenet) 14 | - You can download it from [here](https://drive.google.com/file/d/1rV3aj_hgfNTfCakffpPm7Vhpr1in87CR/view?usp=sharing) (~2.7GB, google drive link) 15 | 16 | ## Preparation 17 | - Make sure Mini-Imagenet is split properly. For example: 18 | ``` 19 | - data/ 20 | - miniImagenet/ 21 | - train/ 22 | - n01532829/ 23 | - n0153282900000005.jpg 24 | - ... 25 | - n01558993/ 26 | - ... 27 | - val/ 28 | - n01855672/ 29 | - ... 30 | - test/ 31 | - ... 32 | - main.py 33 | - ... 34 | ``` 35 | - It'd be set if you download and extract Mini-Imagenet from the link above 36 | - Check out `scripts/train_5s_5c.sh`, make sure `--data-root` is properly set 37 | 38 | ## Run 39 | For 5-shot, 5-class training, run 40 | ```bash 41 | bash scripts/train_5s_5c.sh 42 | ``` 43 | Hyper-parameters are referred to the [author's repo](https://github.com/twitter/meta-learning-lstm). 44 | 45 | For 5-shot, 5-class evaluation, run *(remember to change `--resume` and `--seed` arguments)* 46 | ```bash 47 | bash scripts/eval_5s_5c.sh 48 | ``` 49 | 50 | ## Notes 51 | - Results (This repo is developed following the [pytorch reproducibility guideline](https://pytorch.org/docs/stable/notes/randomness.html)): 52 | 53 | |seed|train episodes|val episodes|val acc mean|val acc std|test episodes|test acc mean|test acc std| 54 | |-|-|-|-|-|-|-|-| 55 | |719|41000|100|59.08|9.9|100|56.59|8.4| 56 | | -| -| -| -| -|250|57.85|8.6| 57 | | -| -| -| -| -|600|57.76|8.6| 58 | | 53|44000|100|58.04|9.1|100|57.85|7.7| 59 | | -| -| -| -| -|250|57.83|8.3| 60 | | -| -| -| -| -|600|58.14|8.5| 61 | 62 | - The results I get from directly running the author's repo can be found [here](https://i.imgur.com/rtagm2c.png), I have slightly better performance (~5%) but neither results match the number in the paper (60%) *(Discussion and help are welcome!)*. 63 | - Training with the default settings takes ~2.5 hours on a single Titan Xp while occupying ~2GB GPU memory. 64 | - The implementation replicates two learners similar to the author's repo: 65 | - `learner_w_grad` functions as a regular model, get gradients and loss as inputs to meta learner. 66 | - `learner_wo_grad` constructs the graph for meta learner: 67 | - All the parameters in `learner_wo_grad` are replaced by `cI` output by meta learner. 68 | - `nn.Parameters` in this model are casted to `torch.Tensor` to connect the graph to meta learner. 69 | - Several ways to **copy** a parameters from meta learner to learner depends on the scenario: 70 | - `copy_flat_params`: we only need the parameter values and keep the original `grad_fn`. 71 | - `transfer_params`: we want the values as well as the `grad_fn` (from `cI` to `learner_wo_grad`). 72 | - `.data.copy_` v.s. `clone()` -> the latter retains all the properties of a tensor including `grad_fn`. 73 | - To maintain the batch statistics, `load_state_dict` is used (from `learner_w_grad` to `learner_wo_grad`). 74 | 75 | ## References 76 | - [CloserLookFewShot](https://github.com/wyharveychen/CloserLookFewShot) (Data loader) 77 | - [pytorch-meta-optimizer](https://github.com/ikostrikov/pytorch-meta-optimizer) (Casting `nn.Parameters` to `torch.Tensor` inspired from here) 78 | - [meta-learning-lstm](https://github.com/twitter/meta-learning-lstm) (Author's repo in Lua Torch) 79 | 80 | -------------------------------------------------------------------------------- /dataloader.py: -------------------------------------------------------------------------------- 1 | from __future__ import division, print_function, absolute_import 2 | 3 | import os 4 | import re 5 | import pdb 6 | import glob 7 | import pickle 8 | 9 | import torch 10 | import torch.utils.data as data 11 | import torchvision.datasets as datasets 12 | import torchvision.transforms as transforms 13 | import PIL.Image as PILI 14 | import numpy as np 15 | 16 | from tqdm import tqdm 17 | 18 | 19 | class EpisodeDataset(data.Dataset): 20 | 21 | def __init__(self, root, phase='train', n_shot=5, n_eval=15, transform=None): 22 | """Args: 23 | root (str): path to data 24 | phase (str): train, val or test 25 | n_shot (int): how many examples per class for training (k/n_support) 26 | n_eval (int): how many examples per class for evaluation 27 | - n_shot + n_eval = batch_size for data.DataLoader of ClassDataset 28 | transform (torchvision.transforms): data augmentation 29 | """ 30 | root = os.path.join(root, phase) 31 | self.labels = sorted(os.listdir(root)) 32 | images = [glob.glob(os.path.join(root, label, '*')) for label in self.labels] 33 | 34 | self.episode_loader = [data.DataLoader( 35 | ClassDataset(images=images[idx], label=idx, transform=transform), 36 | batch_size=n_shot+n_eval, shuffle=True, num_workers=0) for idx, _ in enumerate(self.labels)] 37 | 38 | def __getitem__(self, idx): 39 | return next(iter(self.episode_loader[idx])) 40 | 41 | def __len__(self): 42 | return len(self.labels) 43 | 44 | 45 | class ClassDataset(data.Dataset): 46 | 47 | def __init__(self, images, label, transform=None): 48 | """Args: 49 | images (list of str): each item is a path to an image of the same label 50 | label (int): the label of all the images 51 | """ 52 | self.images = images 53 | self.label = label 54 | self.transform = transform 55 | 56 | def __getitem__(self, idx): 57 | image = PILI.open(self.images[idx]).convert('RGB') 58 | if self.transform is not None: 59 | image = self.transform(image) 60 | 61 | return image, self.label 62 | 63 | def __len__(self): 64 | return len(self.images) 65 | 66 | 67 | class EpisodicSampler(data.Sampler): 68 | 69 | def __init__(self, total_classes, n_class, n_episode): 70 | self.total_classes = total_classes 71 | self.n_class = n_class 72 | self.n_episode = n_episode 73 | 74 | def __iter__(self): 75 | for i in range(self.n_episode): 76 | yield torch.randperm(self.total_classes)[:self.n_class] 77 | 78 | def __len__(self): 79 | return self.n_episode 80 | 81 | 82 | def prepare_data(args): 83 | 84 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 85 | 86 | train_set = EpisodeDataset(args.data_root, 'train', args.n_shot, args.n_eval, 87 | transform=transforms.Compose([ 88 | transforms.RandomResizedCrop(args.image_size), 89 | transforms.RandomHorizontalFlip(), 90 | transforms.ColorJitter( 91 | brightness=0.4, 92 | contrast=0.4, 93 | saturation=0.4, 94 | hue=0.2), 95 | transforms.ToTensor(), 96 | normalize])) 97 | 98 | val_set = EpisodeDataset(args.data_root, 'val', args.n_shot, args.n_eval, 99 | transform=transforms.Compose([ 100 | transforms.Resize(args.image_size * 8 // 7), 101 | transforms.CenterCrop(args.image_size), 102 | transforms.ToTensor(), 103 | normalize])) 104 | 105 | test_set = EpisodeDataset(args.data_root, 'test', args.n_shot, args.n_eval, 106 | transform=transforms.Compose([ 107 | transforms.Resize(args.image_size * 8 // 7), 108 | transforms.CenterCrop(args.image_size), 109 | transforms.ToTensor(), 110 | normalize])) 111 | 112 | train_loader = data.DataLoader(train_set, num_workers=args.n_workers, pin_memory=args.pin_mem, 113 | batch_sampler=EpisodicSampler(len(train_set), args.n_class, args.episode)) 114 | 115 | val_loader = data.DataLoader(val_set, num_workers=2, pin_memory=False, 116 | batch_sampler=EpisodicSampler(len(val_set), args.n_class, args.episode_val)) 117 | 118 | test_loader = data.DataLoader(test_set, num_workers=2, pin_memory=False, 119 | batch_sampler=EpisodicSampler(len(test_set), args.n_class, args.episode_val)) 120 | 121 | return train_loader, val_loader, test_loader 122 | -------------------------------------------------------------------------------- /metalearner.py: -------------------------------------------------------------------------------- 1 | from __future__ import division, print_function, absolute_import 2 | 3 | import pdb 4 | import math 5 | import torch 6 | import torch.nn as nn 7 | 8 | 9 | class MetaLSTMCell(nn.Module): 10 | """C_t = f_t * C_{t-1} + i_t * \tilde{C_t}""" 11 | def __init__(self, input_size, hidden_size, n_learner_params): 12 | super(MetaLSTMCell, self).__init__() 13 | """Args: 14 | input_size (int): cell input size, default = 20 15 | hidden_size (int): should be 1 16 | n_learner_params (int): number of learner's parameters 17 | """ 18 | self.input_size = input_size 19 | self.hidden_size = hidden_size 20 | self.n_learner_params = n_learner_params 21 | self.WF = nn.Parameter(torch.Tensor(input_size + 2, hidden_size)) 22 | self.WI = nn.Parameter(torch.Tensor(input_size + 2, hidden_size)) 23 | self.cI = nn.Parameter(torch.Tensor(n_learner_params, 1)) 24 | self.bI = nn.Parameter(torch.Tensor(1, hidden_size)) 25 | self.bF = nn.Parameter(torch.Tensor(1, hidden_size)) 26 | 27 | self.reset_parameters() 28 | 29 | def reset_parameters(self): 30 | for weight in self.parameters(): 31 | nn.init.uniform_(weight, -0.01, 0.01) 32 | 33 | # want initial forget value to be high and input value to be low so that 34 | # model starts with gradient descent 35 | nn.init.uniform_(self.bF, 4, 6) 36 | nn.init.uniform_(self.bI, -5, -4) 37 | 38 | def init_cI(self, flat_params): 39 | self.cI.data.copy_(flat_params.unsqueeze(1)) 40 | 41 | def forward(self, inputs, hx=None): 42 | """Args: 43 | inputs = [x_all, grad]: 44 | x_all (torch.Tensor of size [n_learner_params, input_size]): outputs from previous LSTM 45 | grad (torch.Tensor of size [n_learner_params]): gradients from learner 46 | hx = [f_prev, i_prev, c_prev]: 47 | f (torch.Tensor of size [n_learner_params, 1]): forget gate 48 | i (torch.Tensor of size [n_learner_params, 1]): input gate 49 | c (torch.Tensor of size [n_learner_params, 1]): flattened learner parameters 50 | """ 51 | x_all, grad = inputs 52 | batch, _ = x_all.size() 53 | 54 | if hx is None: 55 | f_prev = torch.zeros((batch, self.hidden_size)).to(self.WF.device) 56 | i_prev = torch.zeros((batch, self.hidden_size)).to(self.WI.device) 57 | c_prev = self.cI 58 | hx = [f_prev, i_prev, c_prev] 59 | 60 | f_prev, i_prev, c_prev = hx 61 | 62 | # f_t = sigmoid(W_f * [grad_t, loss_t, theta_{t-1}, f_{t-1}] + b_f) 63 | f_next = torch.mm(torch.cat((x_all, c_prev, f_prev), 1), self.WF) + self.bF.expand_as(f_prev) 64 | # i_t = sigmoid(W_i * [grad_t, loss_t, theta_{t-1}, i_{t-1}] + b_i) 65 | i_next = torch.mm(torch.cat((x_all, c_prev, i_prev), 1), self.WI) + self.bI.expand_as(i_prev) 66 | # next cell/params 67 | c_next = torch.sigmoid(f_next).mul(c_prev) - torch.sigmoid(i_next).mul(grad) 68 | 69 | return c_next, [f_next, i_next, c_next] 70 | 71 | def extra_repr(self): 72 | s = '{input_size}, {hidden_size}, {n_learner_params}' 73 | return s.format(**self.__dict__) 74 | 75 | 76 | class MetaLearner(nn.Module): 77 | 78 | def __init__(self, input_size, hidden_size, n_learner_params): 79 | super(MetaLearner, self).__init__() 80 | """Args: 81 | input_size (int): for the first LSTM layer, default = 4 82 | hidden_size (int): for the first LSTM layer, default = 20 83 | n_learner_params (int): number of learner's parameters 84 | """ 85 | self.lstm = nn.LSTMCell(input_size=input_size, hidden_size=hidden_size) 86 | self.metalstm = MetaLSTMCell(input_size=hidden_size, hidden_size=1, n_learner_params=n_learner_params) 87 | 88 | def forward(self, inputs, hs=None): 89 | """Args: 90 | inputs = [loss, grad_prep, grad] 91 | loss (torch.Tensor of size [1, 2]) 92 | grad_prep (torch.Tensor of size [n_learner_params, 2]) 93 | grad (torch.Tensor of size [n_learner_params]) 94 | 95 | hs = [(lstm_hn, lstm_cn), [metalstm_fn, metalstm_in, metalstm_cn]] 96 | """ 97 | loss, grad_prep, grad = inputs 98 | loss = loss.expand_as(grad_prep) 99 | inputs = torch.cat((loss, grad_prep), 1) # [n_learner_params, 4] 100 | 101 | if hs is None: 102 | hs = [None, None] 103 | 104 | lstmhx, lstmcx = self.lstm(inputs, hs[0]) 105 | flat_learner_unsqzd, metalstm_hs = self.metalstm([lstmhx, grad], hs[1]) 106 | 107 | return flat_learner_unsqzd.squeeze(), [(lstmhx, lstmcx), metalstm_hs] 108 | 109 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import division, print_function, absolute_import 2 | 3 | import os 4 | import pdb 5 | import logging 6 | 7 | import torch 8 | import numpy as np 9 | 10 | 11 | class GOATLogger: 12 | 13 | def __init__(self, args): 14 | args.save = args.save + '-{}'.format(args.seed) 15 | 16 | self.mode = args.mode 17 | self.save_root = args.save 18 | self.log_freq = args.log_freq 19 | 20 | if self.mode == 'train': 21 | if not os.path.exists(self.save_root): 22 | os.mkdir(self.save_root) 23 | filename = os.path.join(self.save_root, 'console.log') 24 | logging.basicConfig(level=logging.DEBUG, 25 | format='%(asctime)s.%(msecs)03d - %(message)s', 26 | datefmt='%b-%d %H:%M:%S', 27 | filename=filename, 28 | filemode='w') 29 | console = logging.StreamHandler() 30 | console.setLevel(logging.INFO) 31 | console.setFormatter(logging.Formatter('%(message)s')) 32 | logging.getLogger('').addHandler(console) 33 | 34 | logging.info("Logger created at {}".format(filename)) 35 | else: 36 | logging.basicConfig(level=logging.INFO, 37 | format='%(asctime)s.%(msecs)03d - %(message)s', 38 | datefmt='%b-%d %H:%M:%S') 39 | 40 | logging.info("Random Seed: {}".format(args.seed)) 41 | self.reset_stats() 42 | 43 | def reset_stats(self): 44 | if self.mode == 'train': 45 | self.stats = {'train': {'loss': [], 'acc': []}, 46 | 'eval': {'loss': [], 'acc': []}} 47 | else: 48 | self.stats = {'eval': {'loss': [], 'acc': []}} 49 | 50 | def batch_info(self, **kwargs): 51 | if kwargs['phase'] == 'train': 52 | self.stats['train']['loss'].append(kwargs['loss']) 53 | self.stats['train']['acc'].append(kwargs['acc']) 54 | 55 | if kwargs['eps'] % self.log_freq == 0 and kwargs['eps'] != 0: 56 | loss_mean = np.mean(self.stats['train']['loss']) 57 | acc_mean = np.mean(self.stats['train']['acc']) 58 | #self.draw_stats() 59 | self.loginfo("[{:5d}/{:5d}] loss: {:6.4f} ({:6.4f}), acc: {:6.3f}% ({:6.3f}%)".format(\ 60 | kwargs['eps'], kwargs['totaleps'], kwargs['loss'], loss_mean, kwargs['acc'], acc_mean)) 61 | 62 | elif kwargs['phase'] == 'eval': 63 | self.stats['eval']['loss'].append(kwargs['loss']) 64 | self.stats['eval']['acc'].append(kwargs['acc']) 65 | 66 | elif kwargs['phase'] == 'evaldone': 67 | loss_mean = np.mean(self.stats['eval']['loss']) 68 | loss_std = np.std(self.stats['eval']['loss']) 69 | acc_mean = np.mean(self.stats['eval']['acc']) 70 | acc_std = np.std(self.stats['eval']['acc']) 71 | self.loginfo("[{:5d}] Eval ({:3d} episode) - loss: {:6.4f} +- {:6.4f}, acc: {:6.3f} +- {:5.3f}%".format(\ 72 | kwargs['eps'], kwargs['totaleps'], loss_mean, loss_std, acc_mean, acc_std)) 73 | 74 | self.reset_stats() 75 | return acc_mean 76 | 77 | else: 78 | raise ValueError("phase {} not supported".format(kwargs['phase'])) 79 | 80 | def logdebug(self, strout): 81 | logging.debug(strout) 82 | def loginfo(self, strout): 83 | logging.info(strout) 84 | 85 | 86 | def accuracy(output, target, topk=(1,)): 87 | with torch.no_grad(): 88 | maxk = max(topk) 89 | batch_size = target.size(0) 90 | 91 | _, pred = output.topk(maxk, 1, True, True) 92 | pred = pred.t() 93 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 94 | 95 | res = [] 96 | for k in topk: 97 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) 98 | res.append(correct_k.mul_(100.0 / batch_size)) 99 | return res[0].item() if len(res) == 1 else [r.item() for r in res] 100 | 101 | 102 | def save_ckpt(episode, metalearner, optim, save): 103 | if not os.path.exists(os.path.join(save, 'ckpts')): 104 | os.mkdir(os.path.join(save, 'ckpts')) 105 | 106 | torch.save({ 107 | 'episode': episode, 108 | 'metalearner': metalearner.state_dict(), 109 | 'optim': optim.state_dict() 110 | }, os.path.join(save, 'ckpts', 'meta-learner-{}.pth.tar'.format(episode))) 111 | 112 | 113 | def resume_ckpt(metalearner, optim, resume, device): 114 | ckpt = torch.load(resume, map_location=device) 115 | last_episode = ckpt['episode'] 116 | metalearner.load_state_dict(ckpt['metalearner']) 117 | optim.load_state_dict(ckpt['optim']) 118 | return last_episode, metalearner, optim 119 | 120 | 121 | def preprocess_grad_loss(x): 122 | p = 10 123 | indicator = (x.abs() >= np.exp(-p)).to(torch.float32) 124 | 125 | # preproc1 126 | x_proc1 = indicator * torch.log(x.abs() + 1e-8) / p + (1 - indicator) * -1 127 | # preproc2 128 | x_proc2 = indicator * torch.sign(x) + (1 - indicator) * np.exp(p) * x 129 | return torch.stack((x_proc1, x_proc2), 1) 130 | 131 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from __future__ import division, print_function, absolute_import 2 | 3 | import os 4 | import pdb 5 | import copy 6 | import random 7 | import argparse 8 | 9 | import torch 10 | import torch.nn as nn 11 | import numpy as np 12 | from tqdm import tqdm 13 | 14 | from learner import Learner 15 | from metalearner import MetaLearner 16 | from dataloader import prepare_data 17 | from utils import * 18 | 19 | 20 | FLAGS = argparse.ArgumentParser() 21 | FLAGS.add_argument('--mode', choices=['train', 'test']) 22 | # Hyper-parameters 23 | FLAGS.add_argument('--n-shot', type=int, 24 | help="How many examples per class for training (k, n_support)") 25 | FLAGS.add_argument('--n-eval', type=int, 26 | help="How many examples per class for evaluation (n_query)") 27 | FLAGS.add_argument('--n-class', type=int, 28 | help="How many classes (N, n_way)") 29 | FLAGS.add_argument('--input-size', type=int, 30 | help="Input size for the first LSTM") 31 | FLAGS.add_argument('--hidden-size', type=int, 32 | help="Hidden size for the first LSTM") 33 | FLAGS.add_argument('--lr', type=float, 34 | help="Learning rate") 35 | FLAGS.add_argument('--episode', type=int, 36 | help="Episodes to train") 37 | FLAGS.add_argument('--episode-val', type=int, 38 | help="Episodes to eval") 39 | FLAGS.add_argument('--epoch', type=int, 40 | help="Epoch to train for an episode") 41 | FLAGS.add_argument('--batch-size', type=int, 42 | help="Batch size when training an episode") 43 | FLAGS.add_argument('--image-size', type=int, 44 | help="Resize image to this size") 45 | FLAGS.add_argument('--grad-clip', type=float, 46 | help="Clip gradients larger than this number") 47 | FLAGS.add_argument('--bn-momentum', type=float, 48 | help="Momentum parameter in BatchNorm2d") 49 | FLAGS.add_argument('--bn-eps', type=float, 50 | help="Eps parameter in BatchNorm2d") 51 | 52 | # Paths 53 | FLAGS.add_argument('--data', choices=['miniimagenet'], 54 | help="Name of dataset") 55 | FLAGS.add_argument('--data-root', type=str, 56 | help="Location of data") 57 | FLAGS.add_argument('--resume', type=str, 58 | help="Location to pth.tar") 59 | FLAGS.add_argument('--save', type=str, default='logs', 60 | help="Location to logs and ckpts") 61 | # Others 62 | FLAGS.add_argument('--cpu', action='store_true', 63 | help="Set this to use CPU, default use CUDA") 64 | FLAGS.add_argument('--n-workers', type=int, default=4, 65 | help="How many processes for preprocessing") 66 | FLAGS.add_argument('--pin-mem', type=bool, default=False, 67 | help="DataLoader pin_memory") 68 | FLAGS.add_argument('--log-freq', type=int, default=100, 69 | help="Logging frequency") 70 | FLAGS.add_argument('--val-freq', type=int, default=1000, 71 | help="Validation frequency") 72 | FLAGS.add_argument('--seed', type=int, 73 | help="Random seed") 74 | 75 | 76 | def meta_test(eps, eval_loader, learner_w_grad, learner_wo_grad, metalearner, args, logger): 77 | for subeps, (episode_x, episode_y) in enumerate(tqdm(eval_loader, ascii=True)): 78 | train_input = episode_x[:, :args.n_shot].reshape(-1, *episode_x.shape[-3:]).to(args.dev) # [n_class * n_shot, :] 79 | train_target = torch.LongTensor(np.repeat(range(args.n_class), args.n_shot)).to(args.dev) # [n_class * n_shot] 80 | test_input = episode_x[:, args.n_shot:].reshape(-1, *episode_x.shape[-3:]).to(args.dev) # [n_class * n_eval, :] 81 | test_target = torch.LongTensor(np.repeat(range(args.n_class), args.n_eval)).to(args.dev) # [n_class * n_eval] 82 | 83 | # Train learner with metalearner 84 | learner_w_grad.reset_batch_stats() 85 | learner_wo_grad.reset_batch_stats() 86 | learner_w_grad.train() 87 | learner_wo_grad.eval() 88 | cI = train_learner(learner_w_grad, metalearner, train_input, train_target, args) 89 | 90 | learner_wo_grad.transfer_params(learner_w_grad, cI) 91 | output = learner_wo_grad(test_input) 92 | loss = learner_wo_grad.criterion(output, test_target) 93 | acc = accuracy(output, test_target) 94 | 95 | logger.batch_info(loss=loss.item(), acc=acc, phase='eval') 96 | 97 | return logger.batch_info(eps=eps, totaleps=args.episode_val, phase='evaldone') 98 | 99 | 100 | def train_learner(learner_w_grad, metalearner, train_input, train_target, args): 101 | cI = metalearner.metalstm.cI.data 102 | hs = [None] 103 | for _ in range(args.epoch): 104 | for i in range(0, len(train_input), args.batch_size): 105 | x = train_input[i:i+args.batch_size] 106 | y = train_target[i:i+args.batch_size] 107 | 108 | # get the loss/grad 109 | learner_w_grad.copy_flat_params(cI) 110 | output = learner_w_grad(x) 111 | loss = learner_w_grad.criterion(output, y) 112 | acc = accuracy(output, y) 113 | learner_w_grad.zero_grad() 114 | loss.backward() 115 | grad = torch.cat([p.grad.data.view(-1) / args.batch_size for p in learner_w_grad.parameters()], 0) 116 | 117 | # preprocess grad & loss and metalearner forward 118 | grad_prep = preprocess_grad_loss(grad) # [n_learner_params, 2] 119 | loss_prep = preprocess_grad_loss(loss.data.unsqueeze(0)) # [1, 2] 120 | metalearner_input = [loss_prep, grad_prep, grad.unsqueeze(1)] 121 | cI, h = metalearner(metalearner_input, hs[-1]) 122 | hs.append(h) 123 | 124 | #print("training loss: {:8.6f} acc: {:6.3f}, mean grad: {:8.6f}".format(loss, acc, torch.mean(grad))) 125 | 126 | return cI 127 | 128 | 129 | def main(): 130 | 131 | args, unparsed = FLAGS.parse_known_args() 132 | if len(unparsed) != 0: 133 | raise NameError("Argument {} not recognized".format(unparsed)) 134 | 135 | if args.seed is None: 136 | args.seed = random.randint(0, 1e3) 137 | random.seed(args.seed) 138 | np.random.seed(args.seed) 139 | torch.manual_seed(args.seed) 140 | 141 | if args.cpu: 142 | args.dev = torch.device('cpu') 143 | else: 144 | if not torch.cuda.is_available(): 145 | raise RuntimeError("GPU unavailable.") 146 | 147 | torch.backends.cudnn.deterministic = True 148 | torch.backends.cudnn.benchmark = False 149 | args.dev = torch.device('cuda') 150 | 151 | logger = GOATLogger(args) 152 | 153 | # Get data 154 | train_loader, val_loader, test_loader = prepare_data(args) 155 | 156 | # Set up learner, meta-learner 157 | learner_w_grad = Learner(args.image_size, args.bn_eps, args.bn_momentum, args.n_class).to(args.dev) 158 | learner_wo_grad = copy.deepcopy(learner_w_grad) 159 | metalearner = MetaLearner(args.input_size, args.hidden_size, learner_w_grad.get_flat_params().size(0)).to(args.dev) 160 | metalearner.metalstm.init_cI(learner_w_grad.get_flat_params()) 161 | 162 | # Set up loss, optimizer, learning rate scheduler 163 | optim = torch.optim.Adam(metalearner.parameters(), args.lr) 164 | 165 | if args.resume: 166 | logger.loginfo("Initialized from: {}".format(args.resume)) 167 | last_eps, metalearner, optim = resume_ckpt(metalearner, optim, args.resume, args.dev) 168 | 169 | if args.mode == 'test': 170 | _ = meta_test(last_eps, test_loader, learner_w_grad, learner_wo_grad, metalearner, args, logger) 171 | return 172 | 173 | best_acc = 0.0 174 | logger.loginfo("Start training") 175 | # Meta-training 176 | for eps, (episode_x, episode_y) in enumerate(train_loader): 177 | # episode_x.shape = [n_class, n_shot + n_eval, c, h, w] 178 | # episode_y.shape = [n_class, n_shot + n_eval] --> NEVER USED 179 | train_input = episode_x[:, :args.n_shot].reshape(-1, *episode_x.shape[-3:]).to(args.dev) # [n_class * n_shot, :] 180 | train_target = torch.LongTensor(np.repeat(range(args.n_class), args.n_shot)).to(args.dev) # [n_class * n_shot] 181 | test_input = episode_x[:, args.n_shot:].reshape(-1, *episode_x.shape[-3:]).to(args.dev) # [n_class * n_eval, :] 182 | test_target = torch.LongTensor(np.repeat(range(args.n_class), args.n_eval)).to(args.dev) # [n_class * n_eval] 183 | 184 | # Train learner with metalearner 185 | learner_w_grad.reset_batch_stats() 186 | learner_wo_grad.reset_batch_stats() 187 | learner_w_grad.train() 188 | learner_wo_grad.train() 189 | cI = train_learner(learner_w_grad, metalearner, train_input, train_target, args) 190 | 191 | # Train meta-learner with validation loss 192 | learner_wo_grad.transfer_params(learner_w_grad, cI) 193 | output = learner_wo_grad(test_input) 194 | loss = learner_wo_grad.criterion(output, test_target) 195 | acc = accuracy(output, test_target) 196 | 197 | optim.zero_grad() 198 | loss.backward() 199 | nn.utils.clip_grad_norm_(metalearner.parameters(), args.grad_clip) 200 | optim.step() 201 | 202 | logger.batch_info(eps=eps, totaleps=args.episode, loss=loss.item(), acc=acc, phase='train') 203 | 204 | # Meta-validation 205 | if eps % args.val_freq == 0 and eps != 0: 206 | save_ckpt(eps, metalearner, optim, args.save) 207 | acc = meta_test(eps, val_loader, learner_w_grad, learner_wo_grad, metalearner, args, logger) 208 | if acc > best_acc: 209 | best_acc = acc 210 | logger.loginfo("* Best accuracy so far *\n") 211 | 212 | logger.loginfo("Done") 213 | 214 | 215 | if __name__ == '__main__': 216 | main() 217 | --------------------------------------------------------------------------------