├── LICENSE.md ├── README.md ├── data.py ├── layer_norm.py ├── layer_norm_lstm.py ├── main.py ├── meta_optimizer.py ├── model.py └── utils.py /LICENSE.md: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Ilya Kostrikov 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Intro 2 | 3 | PyTorch implementation of [Learning to learn by gradient descent by gradient descent](https://arxiv.org/abs/1606.04474). 4 | 5 | ## Run 6 | 7 | ```bash 8 | python main.py 9 | ``` 10 | 11 | ### TODO 12 | - [x] Initial implementation 13 | - [x] Toy data 14 | - [x] LSTM updates 15 | - [ ] Refactor, find a better way to organize the modules 16 | - [ ] Compare with standard optimizers 17 | - [x] Real data 18 | - [ ] More difficult models 19 | -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def get_batch(batch_size): 5 | x = torch.randn(batch_size, 10) 6 | x = x - 2 * x.pow(2) 7 | y = x.sum(1) 8 | return x, y 9 | -------------------------------------------------------------------------------- /layer_norm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | 6 | 7 | class LayerNorm1D(nn.Module): 8 | 9 | def __init__(self, num_outputs, eps=1e-5, affine=True): 10 | super(LayerNorm1D, self).__init__() 11 | self.eps = eps 12 | self.weight = nn.Parameter(torch.ones(1, num_outputs)) 13 | self.bias = nn.Parameter(torch.zeros(1, num_outputs)) 14 | 15 | def forward(self, inputs): 16 | input_mean = inputs.mean(1,keepdim=True).expand_as(inputs) 17 | input_std = inputs.std(1,keepdim=True).expand_as(inputs) 18 | x = (inputs - input_mean) / (input_std + self.eps) 19 | return x * self.weight.expand_as(x) + self.bias.expand_as(x) 20 | -------------------------------------------------------------------------------- /layer_norm_lstm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | from layer_norm import LayerNorm1D 6 | 7 | 8 | class LayerNormLSTMCell(nn.Module): 9 | 10 | def __init__(self, num_inputs, num_hidden, forget_gate_bias=-1): 11 | super(LayerNormLSTMCell, self).__init__() 12 | 13 | self.forget_gate_bias = forget_gate_bias 14 | self.num_hidden = num_hidden 15 | self.fc_i2h = nn.Linear(num_inputs, 4 * num_hidden) 16 | self.fc_h2h = nn.Linear(num_hidden, 4 * num_hidden) 17 | 18 | self.ln_i2h = LayerNorm1D(4 * num_hidden) 19 | self.ln_h2h = LayerNorm1D(4 * num_hidden) 20 | 21 | self.ln_h2o = LayerNorm1D(num_hidden) 22 | 23 | def forward(self, inputs, state): 24 | hx, cx = state 25 | i2h = self.fc_i2h(inputs) 26 | h2h = self.fc_h2h(hx) 27 | x = self.ln_i2h(i2h) + self.ln_h2h(h2h) 28 | gates = x.split(self.num_hidden, 1) 29 | 30 | in_gate = F.sigmoid(gates[0]) 31 | forget_gate = F.sigmoid(gates[1] + self.forget_gate_bias) 32 | out_gate = F.sigmoid(gates[2]) 33 | in_transform = F.tanh(gates[3]) 34 | 35 | cx = forget_gate * cx + in_gate * in_transform 36 | hx = out_gate * F.tanh(self.ln_h2o(cx)) 37 | return hx, cx 38 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import operator 3 | import sys 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | import torch.optim as optim 9 | from data import get_batch 10 | from meta_optimizer import MetaModel, MetaOptimizer, FastMetaOptimizer 11 | from model import Model 12 | from torch.autograd import Variable 13 | from torchvision import datasets, transforms 14 | 15 | parser = argparse.ArgumentParser(description='PyTorch REINFORCE example') 16 | parser.add_argument('--batch_size', type=int, default=32, metavar='N', 17 | help='batch size (default: 32)') 18 | parser.add_argument('--optimizer_steps', type=int, default=100, metavar='N', 19 | help='number of meta optimizer steps (default: 100)') 20 | parser.add_argument('--truncated_bptt_step', type=int, default=20, metavar='N', 21 | help='step at which it truncates bptt (default: 20)') 22 | parser.add_argument('--updates_per_epoch', type=int, default=10, metavar='N', 23 | help='updates per epoch (default: 100)') 24 | parser.add_argument('--max_epoch', type=int, default=10000, metavar='N', 25 | help='number of epoch (default: 10000)') 26 | parser.add_argument('--hidden_size', type=int, default=10, metavar='N', 27 | help='hidden size of the meta optimizer (default: 10)') 28 | parser.add_argument('--num_layers', type=int, default=2, metavar='N', 29 | help='number of LSTM layers (default: 2)') 30 | parser.add_argument('--no-cuda', action='store_true', default=False, 31 | help='enables CUDA training') 32 | args = parser.parse_args() 33 | args.cuda = not args.no_cuda and torch.cuda.is_available() 34 | 35 | assert args.optimizer_steps % args.truncated_bptt_step == 0 36 | 37 | kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {} 38 | train_loader = torch.utils.data.DataLoader( 39 | datasets.MNIST('../data', train=True, download=True, 40 | transform=transforms.Compose([ 41 | transforms.ToTensor(), 42 | transforms.Normalize((0.1307,), (0.3081,)) 43 | ])), 44 | batch_size=args.batch_size, shuffle=True, **kwargs) 45 | test_loader = torch.utils.data.DataLoader( 46 | datasets.MNIST('../data', train=False, transform=transforms.Compose([ 47 | transforms.ToTensor(), 48 | transforms.Normalize((0.1307,), (0.3081,)) 49 | ])), 50 | batch_size=args.batch_size, shuffle=True, **kwargs) 51 | 52 | def main(): 53 | # Create a meta optimizer that wraps a model into a meta model 54 | # to keep track of the meta updates. 55 | meta_model = Model() 56 | if args.cuda: 57 | meta_model.cuda() 58 | 59 | meta_optimizer = FastMetaOptimizer(MetaModel(meta_model), args.num_layers, args.hidden_size) 60 | if args.cuda: 61 | meta_optimizer.cuda() 62 | 63 | optimizer = optim.Adam(meta_optimizer.parameters(), lr=1e-3) 64 | 65 | for epoch in range(args.max_epoch): 66 | decrease_in_loss = 0.0 67 | final_loss = 0.0 68 | train_iter = iter(train_loader) 69 | for i in range(args.updates_per_epoch): 70 | 71 | # Sample a new model 72 | model = Model() 73 | if args.cuda: 74 | model.cuda() 75 | 76 | x, y = next(train_iter) 77 | if args.cuda: 78 | x, y = x.cuda(), y.cuda() 79 | x, y = Variable(x), Variable(y) 80 | 81 | # Compute initial loss of the model 82 | f_x = model(x) 83 | initial_loss = F.nll_loss(f_x, y) 84 | 85 | for k in range(args.optimizer_steps // args.truncated_bptt_step): 86 | # Keep states for truncated BPTT 87 | meta_optimizer.reset_lstm( 88 | keep_states=k > 0, model=model, use_cuda=args.cuda) 89 | 90 | loss_sum = 0 91 | prev_loss = torch.zeros(1) 92 | if args.cuda: 93 | prev_loss = prev_loss.cuda() 94 | for j in range(args.truncated_bptt_step): 95 | x, y = next(train_iter) 96 | if args.cuda: 97 | x, y = x.cuda(), y.cuda() 98 | x, y = Variable(x), Variable(y) 99 | 100 | # First we need to compute the gradients of the model 101 | f_x = model(x) 102 | loss = F.nll_loss(f_x, y) 103 | model.zero_grad() 104 | loss.backward() 105 | 106 | # Perfom a meta update using gradients from model 107 | # and return the current meta model saved in the optimizer 108 | meta_model = meta_optimizer.meta_update(model, loss.data) 109 | 110 | # Compute a loss for a step the meta optimizer 111 | f_x = meta_model(x) 112 | loss = F.nll_loss(f_x, y) 113 | 114 | loss_sum += (loss - Variable(prev_loss)) 115 | 116 | prev_loss = loss.data 117 | 118 | # Update the parameters of the meta optimizer 119 | meta_optimizer.zero_grad() 120 | loss_sum.backward() 121 | for param in meta_optimizer.parameters(): 122 | param.grad.data.clamp_(-1, 1) 123 | optimizer.step() 124 | 125 | # Compute relative decrease in the loss function w.r.t initial 126 | # value 127 | decrease_in_loss += loss.data[0] / initial_loss.data[0] 128 | final_loss += loss.data[0] 129 | 130 | print("Epoch: {}, final loss {}, average final/initial loss ratio: {}".format(epoch, final_loss / args.updates_per_epoch, 131 | decrease_in_loss / args.updates_per_epoch)) 132 | 133 | if __name__ == "__main__": 134 | main() 135 | -------------------------------------------------------------------------------- /meta_optimizer.py: -------------------------------------------------------------------------------- 1 | from functools import reduce 2 | from operator import mul 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import torch.optim as optim 8 | from torch.autograd import Variable 9 | import math 10 | from utils import preprocess_gradients 11 | from layer_norm_lstm import LayerNormLSTMCell 12 | from layer_norm import LayerNorm1D 13 | 14 | class MetaOptimizer(nn.Module): 15 | 16 | def __init__(self, model, num_layers, hidden_size): 17 | super(MetaOptimizer, self).__init__() 18 | self.meta_model = model 19 | 20 | self.hidden_size = hidden_size 21 | 22 | self.linear1 = nn.Linear(3, hidden_size) 23 | self.ln1 = LayerNorm1D(hidden_size) 24 | 25 | self.lstms = [] 26 | for i in range(num_layers): 27 | self.lstms.append(LayerNormLSTMCell(hidden_size, hidden_size)) 28 | 29 | self.linear2 = nn.Linear(hidden_size, 1) 30 | self.linear2.weight.data.mul_(0.1) 31 | self.linear2.bias.data.fill_(0.0) 32 | 33 | def cuda(self): 34 | super(MetaOptimizer, self).cuda() 35 | for i in range(len(self.lstms)): 36 | self.lstms[i].cuda() 37 | 38 | def reset_lstm(self, keep_states=False, model=None, use_cuda=False): 39 | self.meta_model.reset() 40 | self.meta_model.copy_params_from(model) 41 | 42 | if keep_states: 43 | for i in range(len(self.lstms)): 44 | self.hx[i] = Variable(self.hx[i].data) 45 | self.cx[i] = Variable(self.cx[i].data) 46 | else: 47 | self.hx = [] 48 | self.cx = [] 49 | for i in range(len(self.lstms)): 50 | self.hx.append(Variable(torch.zeros(1, self.hidden_size))) 51 | self.cx.append(Variable(torch.zeros(1, self.hidden_size))) 52 | if use_cuda: 53 | self.hx[i], self.cx[i] = self.hx[i].cuda(), self.cx[i].cuda() 54 | 55 | def forward(self, x): 56 | # Gradients preprocessing 57 | x = F.tanh(self.ln1(self.linear1(x))) 58 | 59 | for i in range(len(self.lstms)): 60 | if x.size(0) != self.hx[i].size(0): 61 | self.hx[i] = self.hx[i].expand(x.size(0), self.hx[i].size(1)) 62 | self.cx[i] = self.cx[i].expand(x.size(0), self.cx[i].size(1)) 63 | 64 | self.hx[i], self.cx[i] = self.lstms[i](x, (self.hx[i], self.cx[i])) 65 | x = self.hx[i] 66 | 67 | x = self.linear2(x) 68 | return x.squeeze() 69 | 70 | def meta_update(self, model_with_grads, loss): 71 | # First we need to create a flat version of parameters and gradients 72 | grads = [] 73 | 74 | for module in model_with_grads.children(): 75 | grads.append(module._parameters['weight'].grad.data.view(-1)) 76 | grads.append(module._parameters['bias'].grad.data.view(-1)) 77 | 78 | flat_params = self.meta_model.get_flat_params() 79 | flat_grads = preprocess_gradients(torch.cat(grads)) 80 | 81 | inputs = Variable(torch.cat((flat_grads, flat_params.data), 1)) 82 | 83 | # Meta update itself 84 | flat_params = flat_params + self(inputs) 85 | 86 | self.meta_model.set_flat_params(flat_params) 87 | 88 | # Finally, copy values from the meta model to the normal one. 89 | self.meta_model.copy_params_to(model_with_grads) 90 | return self.meta_model.model 91 | 92 | class FastMetaOptimizer(nn.Module): 93 | 94 | def __init__(self, model, num_layers, hidden_size): 95 | super(FastMetaOptimizer, self).__init__() 96 | self.meta_model = model 97 | 98 | self.linear1 = nn.Linear(6, 2) 99 | self.linear1.bias.data[0] = 1 100 | 101 | def forward(self, x): 102 | # Gradients preprocessing 103 | x = F.sigmoid(self.linear1(x)) 104 | return x.split(1, 1) 105 | 106 | def reset_lstm(self, keep_states=False, model=None, use_cuda=False): 107 | self.meta_model.reset() 108 | self.meta_model.copy_params_from(model) 109 | 110 | if keep_states: 111 | self.f = Variable(self.f.data) 112 | self.i = Variable(self.i.data) 113 | else: 114 | self.f = Variable(torch.zeros(1, 1)) 115 | self.i = Variable(torch.zeros(1, 1)) 116 | if use_cuda: 117 | self.f = self.f.cuda() 118 | self.i = self.i.cuda() 119 | 120 | def meta_update(self, model_with_grads, loss): 121 | # First we need to create a flat version of parameters and gradients 122 | grads = [] 123 | 124 | for module in model_with_grads.children(): 125 | grads.append(module._parameters['weight'].grad.data.view(-1).unsqueeze(-1)) 126 | grads.append(module._parameters['bias'].grad.data.view(-1).unsqueeze(-1)) 127 | 128 | flat_params = self.meta_model.get_flat_params().unsqueeze(-1) 129 | flat_grads = torch.cat(grads) 130 | 131 | self.i = self.i.expand(flat_params.size(0), 1) 132 | self.f = self.f.expand(flat_params.size(0), 1) 133 | 134 | loss = loss.expand_as(flat_grads) 135 | inputs = Variable(torch.cat((preprocess_gradients(flat_grads), flat_params.data, loss), 1)) 136 | inputs = torch.cat((inputs, self.f, self.i), 1) 137 | self.f, self.i = self(inputs) 138 | 139 | # Meta update itself 140 | flat_params = self.f * flat_params - self.i * Variable(flat_grads) 141 | flat_params = flat_params.view(-1) 142 | 143 | self.meta_model.set_flat_params(flat_params) 144 | 145 | # Finally, copy values from the meta model to the normal one. 146 | self.meta_model.copy_params_to(model_with_grads) 147 | return self.meta_model.model 148 | 149 | # A helper class that keeps track of meta updates 150 | # It's done by replacing parameters with variables and applying updates to 151 | # them. 152 | 153 | 154 | class MetaModel: 155 | 156 | def __init__(self, model): 157 | self.model = model 158 | 159 | def reset(self): 160 | for module in self.model.children(): 161 | module._parameters['weight'] = Variable( 162 | module._parameters['weight'].data) 163 | module._parameters['bias'] = Variable( 164 | module._parameters['bias'].data) 165 | 166 | def get_flat_params(self): 167 | params = [] 168 | 169 | for module in self.model.children(): 170 | params.append(module._parameters['weight'].view(-1)) 171 | params.append(module._parameters['bias'].view(-1)) 172 | 173 | return torch.cat(params) 174 | 175 | def set_flat_params(self, flat_params): 176 | # Restore original shapes 177 | offset = 0 178 | for i, module in enumerate(self.model.children()): 179 | weight_shape = module._parameters['weight'].size() 180 | bias_shape = module._parameters['bias'].size() 181 | 182 | weight_flat_size = reduce(mul, weight_shape, 1) 183 | bias_flat_size = reduce(mul, bias_shape, 1) 184 | 185 | module._parameters['weight'] = flat_params[ 186 | offset:offset + weight_flat_size].view(*weight_shape) 187 | module._parameters['bias'] = flat_params[ 188 | offset + weight_flat_size:offset + weight_flat_size + bias_flat_size].view(*bias_shape) 189 | 190 | offset += weight_flat_size + bias_flat_size 191 | 192 | def copy_params_from(self, model): 193 | for modelA, modelB in zip(self.model.parameters(), model.parameters()): 194 | modelA.data.copy_(modelB.data) 195 | 196 | def copy_params_to(self, model): 197 | for modelA, modelB in zip(self.model.parameters(), model.parameters()): 198 | modelB.data.copy_(modelA.data) 199 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.optim as optim 5 | from torch.autograd import Variable 6 | 7 | 8 | class Model(nn.Module): 9 | 10 | def __init__(self): 11 | super(Model, self).__init__() 12 | self.linear1 = nn.Linear(28 * 28, 32) 13 | self.linear2 = nn.Linear(32, 10) 14 | 15 | def forward(self, inputs): 16 | x = inputs.view(-1, 28 * 28) 17 | x = F.relu(self.linear1(x)) 18 | x = self.linear2(x) 19 | return F.log_softmax(x) 20 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | 4 | def preprocess_gradients(x): 5 | p = 10 6 | eps = 1e-6 7 | indicator = (x.abs() > math.exp(-p)).float() 8 | x1 = (x.abs() + eps).log() / p * indicator - (1 - indicator) 9 | x2 = x.sign() * indicator + math.exp(p) * x * (1 - indicator) 10 | 11 | return torch.cat((x1, x2), 1) 12 | --------------------------------------------------------------------------------