├── README.md ├── main.py └── results ├── ep_300_cps_10.png ├── ep_300_cps_100.png ├── ep_300_cps_300.png └── ep_300_cps_784.png /README.md: -------------------------------------------------------------------------------- 1 | PyTorch implementation of Neural Processes (NP) by Garnelo et al https://arxiv.org/abs/1807.01622 2 | # MNIST image completion 3 | The task is to complete an image given some number \[1;784\] of context points (coordinates) at which we know the greyscale pixel intensity \[0;1]. 4 | 5 | # Results 6 | The first row shows the observed greyscale context points. Unobserved pixels are in blue. 7 | The five rows below show realisations of different samples of the global latent variable `z` given the context points above. Compare with Figure 4 in the [paper](https://arxiv.org/abs/1807.01622). 8 | ##### 10 context points 9 | ![10 context points](results/ep_300_cps_10.png?raw=true "Title") 10 | ##### 100 context points 11 | ![100 context points](results/ep_300_cps_100.png?raw=true "Title") 12 | ##### 300 context points 13 | ![300 context points](results/ep_300_cps_300.png?raw=true "Title") 14 | ##### 784 context points (full image) 15 | ![784 context points](results/ep_300_cps_784.png?raw=true "Title") 16 |
17 | 18 | 19 | 20 | # How to run 21 | `python main.py` produces the results above. The script saves examples of reconstructed images at the end of every epoch in `results/`. 22 | 23 | # Requirements 24 | - Python 3 25 | - PyTorch 0.4.1 or later (tested with 1.0.1) 26 | 27 | 28 | 29 | # Other NP implementations 30 | R + TensorFlow - https://github.com/kasparmartens/NeuralProcesses -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import random 4 | 5 | import torch 6 | import torch.utils.data 7 | from torch import nn, optim 8 | from torch.nn import functional as F 9 | from torchvision import datasets, transforms 10 | from torchvision.utils import save_image 11 | 12 | parser = argparse.ArgumentParser(description='Neural Processes (NP) for MNIST image completion') 13 | parser.add_argument('--batch-size', type=int, default=128, metavar='N', 14 | help='input batch size for training') 15 | parser.add_argument('--epochs', type=int, default=300, metavar='N', 16 | help='number of epochs to train') 17 | parser.add_argument('--no-cuda', action='store_true', default=False, 18 | help='disables CUDA training') 19 | parser.add_argument('--seed', type=int, default=1, metavar='S', 20 | help='random seed (default: 1)') 21 | parser.add_argument('--log-interval', type=int, default=1, metavar='N', 22 | help='how many batches to wait before logging training status') 23 | parser.add_argument('--r_dim', type=int, default=128, metavar='N', 24 | help='dimension of r, the hidden representation of the context points') 25 | parser.add_argument('--z_dim', type=int, default=128, metavar='N', 26 | help='dimension of z, the global latent variable') 27 | args = parser.parse_args() 28 | args.cuda = not args.no_cuda and torch.cuda.is_available() 29 | 30 | torch.manual_seed(args.seed) 31 | random.seed(args.seed) 32 | device = torch.device("cuda" if args.cuda else "cpu") 33 | 34 | kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {} 35 | 36 | train_loader = torch.utils.data.DataLoader( 37 | datasets.MNIST('../data', train=True, download=True, 38 | transform=transforms.ToTensor()), 39 | batch_size=args.batch_size, shuffle=True, **kwargs) 40 | test_loader = torch.utils.data.DataLoader( 41 | datasets.MNIST('../data', train=False, transform=transforms.ToTensor()), 42 | batch_size=args.batch_size, shuffle=True, **kwargs) 43 | 44 | 45 | def get_context_idx(N): 46 | # generate the indeces of the N context points in a flattened image 47 | idx = random.sample(range(0, 784), N) 48 | idx = torch.tensor(idx, device=device) 49 | return idx 50 | 51 | 52 | def generate_grid(h, w): 53 | rows = torch.linspace(0, 1, h, device=device) 54 | cols = torch.linspace(0, 1, w, device=device) 55 | grid = torch.stack([cols.repeat(h, 1).t().contiguous().view(-1), rows.repeat(w)], dim=1) 56 | grid = grid.unsqueeze(0) 57 | return grid 58 | 59 | 60 | def idx_to_y(idx, data): 61 | # get the [0;1] pixel intensity at each index 62 | y = torch.index_select(data, dim=1, index=idx) 63 | return y 64 | 65 | 66 | def idx_to_x(idx, batch_size): 67 | # From flat idx to 2d coordinates of the 28x28 grid. E.g. 35 -> (1, 7) 68 | # Equivalent to np.unravel_index() 69 | x = torch.index_select(x_grid, dim=1, index=idx) 70 | x = x.expand(batch_size, -1, -1) 71 | return x 72 | 73 | 74 | class NP(nn.Module): 75 | def __init__(self, args): 76 | super(NP, self).__init__() 77 | self.r_dim = args.r_dim 78 | self.z_dim = args.z_dim 79 | 80 | self.h_1 = nn.Linear(3, 400) 81 | self.h_2 = nn.Linear(400, 400) 82 | self.h_3 = nn.Linear(400, self.r_dim) 83 | 84 | self.r_to_z_mean = nn.Linear(self.r_dim, self.z_dim) 85 | self.r_to_z_logvar = nn.Linear(self.r_dim, self.z_dim) 86 | 87 | self.g_1 = nn.Linear(self.z_dim + 2, 400) 88 | self.g_2 = nn.Linear(400, 400) 89 | self.g_3 = nn.Linear(400, 400) 90 | self.g_4 = nn.Linear(400, 400) 91 | self.g_5 = nn.Linear(400, 1) 92 | 93 | def h(self, x_y): 94 | x_y = F.relu(self.h_1(x_y)) 95 | x_y = F.relu(self.h_2(x_y)) 96 | x_y = F.relu(self.h_3(x_y)) 97 | return x_y 98 | 99 | def aggregate(self, r): 100 | return torch.mean(r, dim=1) 101 | 102 | def reparameterise(self, z): 103 | mu, logvar = z 104 | std = torch.exp(0.5 * logvar) 105 | eps = torch.randn_like(std) 106 | z_sample = eps.mul(std).add_(mu) 107 | z_sample = z_sample.unsqueeze(1).expand(-1, 784, -1) 108 | return z_sample 109 | 110 | def g(self, z_sample, x_target): 111 | z_x = torch.cat([z_sample, x_target], dim=2) 112 | input = F.relu(self.g_1(z_x)) 113 | input = F.relu(self.g_2(input)) 114 | input = F.relu(self.g_3(input)) 115 | input = F.relu(self.g_4(input)) 116 | y_hat = torch.sigmoid(self.g_5(input)) 117 | return y_hat 118 | 119 | def xy_to_z_params(self, x, y): 120 | x_y = torch.cat([x, y], dim=2) 121 | r_i = self.h(x_y) 122 | r = self.aggregate(r_i) 123 | 124 | mu = self.r_to_z_mean(r) 125 | logvar = self.r_to_z_logvar(r) 126 | 127 | return mu, logvar 128 | 129 | def forward(self, x_context, y_context, x_all=None, y_all=None): 130 | z_context = self.xy_to_z_params(x_context, y_context) # (mu, logvar) of z 131 | if self.training: # loss function will try to keep z_context close to z_all 132 | z_all = self.xy_to_z_params(x_all, y_all) 133 | else: # at test time we don't have the image so we use only the context 134 | z_all = z_context 135 | 136 | z_sample = self.reparameterise(z_all) 137 | 138 | # reconstruct the whole image including the provided context points 139 | x_target = x_grid.expand(y_context.shape[0], -1, -1) 140 | y_hat = self.g(z_sample, x_target) 141 | 142 | return y_hat, z_all, z_context 143 | 144 | 145 | def kl_div_gaussians(mu_q, logvar_q, mu_p, logvar_p): 146 | var_p = torch.exp(logvar_p) 147 | kl_div = (torch.exp(logvar_q) + (mu_q - mu_p) ** 2) / var_p \ 148 | - 1.0 \ 149 | + logvar_p - logvar_q 150 | kl_div = 0.5 * kl_div.sum() 151 | return kl_div 152 | 153 | 154 | def np_loss(y_hat, y, z_all, z_context): 155 | BCE = F.binary_cross_entropy(y_hat, y, reduction="sum") 156 | KLD = kl_div_gaussians(z_all[0], z_all[1], z_context[0], z_context[1]) 157 | return BCE + KLD 158 | 159 | 160 | model = NP(args).to(device) 161 | optimizer = optim.Adam(model.parameters(), lr=1e-3) 162 | x_grid = generate_grid(28, 28) 163 | os.makedirs("results/", exist_ok=True) 164 | 165 | 166 | def train(epoch): 167 | model.train() 168 | train_loss = 0 169 | for batch_idx, (y_all, _) in enumerate(train_loader): 170 | batch_size = y_all.shape[0] 171 | y_all = y_all.to(device).view(batch_size, -1, 1) 172 | 173 | N = random.randint(1, 784) # number of context points 174 | context_idx = get_context_idx(N) 175 | x_context = idx_to_x(context_idx, batch_size) 176 | y_context = idx_to_y(context_idx, y_all) 177 | x_all = x_grid.expand(batch_size, -1, -1) 178 | 179 | optimizer.zero_grad() 180 | y_hat, z_all, z_context = model(x_context, y_context, x_all, y_all) 181 | 182 | loss = np_loss(y_hat, y_all, z_all, z_context) 183 | loss.backward() 184 | train_loss += loss.item() 185 | optimizer.step() 186 | if batch_idx % args.log_interval == 0: 187 | print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( 188 | epoch, batch_idx * len(y_all), len(train_loader.dataset), 189 | 100. * batch_idx / len(train_loader), 190 | loss.item() / len(y_all))) 191 | print('====> Epoch: {} Average loss: {:.4f}'.format( 192 | epoch, train_loss / len(train_loader.dataset))) 193 | 194 | 195 | def test(epoch): 196 | model.eval() 197 | test_loss = 0 198 | with torch.no_grad(): 199 | for i, (y_all, _) in enumerate(test_loader): 200 | y_all = y_all.to(device).view(y_all.shape[0], -1, 1) 201 | batch_size = y_all.shape[0] 202 | 203 | N = 300 204 | context_idx = get_context_idx(N) 205 | x_context = idx_to_x(context_idx, batch_size) 206 | y_context = idx_to_y(context_idx, y_all) 207 | 208 | y_hat, z_all, z_context = model(x_context, y_context) 209 | test_loss += np_loss(y_hat, y_all, z_all, z_context).item() 210 | 211 | if i == 0: # save PNG of reconstructed examples 212 | plot_Ns = [10, 100, 300, 784] 213 | num_examples = min(batch_size, 16) 214 | for N in plot_Ns: 215 | recons = [] 216 | context_idx = get_context_idx(N) 217 | x_context = idx_to_x(context_idx, batch_size) 218 | y_context = idx_to_y(context_idx, y_all) 219 | for d in range(5): 220 | y_hat, _, _ = model(x_context, y_context) 221 | recons.append(y_hat[:num_examples]) 222 | recons = torch.cat(recons).view(-1, 1, 28, 28).expand(-1, 3, -1, -1) 223 | background = torch.tensor([0., 0., 1.], device=device) 224 | background = background.view(1, -1, 1).expand(num_examples, 3, 784).contiguous() 225 | context_pixels = y_all[:num_examples].view(num_examples, 1, -1)[:, :, context_idx] 226 | context_pixels = context_pixels.expand(num_examples, 3, -1) 227 | background[:, :, context_idx] = context_pixels 228 | comparison = torch.cat([background.view(-1, 3, 28, 28), 229 | recons]) 230 | save_image(comparison.cpu(), 231 | 'results/ep_' + str(epoch) + 232 | '_cps_' + str(N) + '.png', nrow=num_examples) 233 | 234 | test_loss /= len(test_loader.dataset) 235 | print('====> Test set loss: {:.4f}'.format(test_loss)) 236 | 237 | 238 | for epoch in range(1, args.epochs + 1): 239 | train(epoch) 240 | test(epoch) 241 | -------------------------------------------------------------------------------- /results/ep_300_cps_10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geniki/neural-processes/85db084b4dc24a65081ab24c58d9494a3ee9e11d/results/ep_300_cps_10.png -------------------------------------------------------------------------------- /results/ep_300_cps_100.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geniki/neural-processes/85db084b4dc24a65081ab24c58d9494a3ee9e11d/results/ep_300_cps_100.png -------------------------------------------------------------------------------- /results/ep_300_cps_300.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geniki/neural-processes/85db084b4dc24a65081ab24c58d9494a3ee9e11d/results/ep_300_cps_300.png -------------------------------------------------------------------------------- /results/ep_300_cps_784.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geniki/neural-processes/85db084b4dc24a65081ab24c58d9494a3ee9e11d/results/ep_300_cps_784.png --------------------------------------------------------------------------------