├── .gitignore ├── __pycache__ └── model.cpython-35.pyc ├── model.py ├── test_vq.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | results/* 2 | data/* 3 | -------------------------------------------------------------------------------- /__pycache__/model.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JACKHAHA363/VQVAE/435175dda881b1d5c37c61487e76c5be5276a577/__pycache__/model.cpython-35.pyc -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Variable 4 | import ipdb 5 | class VectorQuantization(torch.autograd.Function): 6 | @staticmethod 7 | def forward(ctx, x, emb): 8 | """ 9 | x: (bz, D) 10 | emb: (emb_num, D) 11 | output: (bz, D) 12 | """ 13 | dist = row_wise_distance(x, emb) 14 | indices = torch.min(dist, -1)[1] 15 | ctx.indices = indices 16 | ctx.emb_num = emb.size(0) 17 | ctx.bz = x.size(0) 18 | return torch.index_select(emb, 0, indices) 19 | 20 | @staticmethod 21 | def backward(ctx, grad_output): 22 | indices = ctx.indices.view(-1,1) 23 | bz = ctx.bz 24 | emb_num = ctx.emb_num 25 | 26 | # get a one hot index 27 | one_hot_ind = torch.zeros(bz, emb_num) 28 | one_hot_ind.scatter_(1, indices, 1) 29 | one_hot_ind = Variable(one_hot_ind, requires_grad=False) 30 | grad_emb = torch.mm(one_hot_ind.t(), grad_output) 31 | return grad_output, grad_emb 32 | 33 | 34 | def row_wise_distance(m1, m2): 35 | """ 36 | m1: (a,p) m2: (b,p) 37 | result: 38 | dist (a, b), where dist[i,j] = l2_dist(m1[i], m2[j]) 39 | """ 40 | a = m1.size(0) 41 | b = m2.size(0) 42 | mm1 = torch.stack([m1]*b).transpose(0,1) 43 | mm2 = torch.stack([m2]*a) 44 | return torch.sum((mm1-mm2)**2, 2).squeeze() 45 | 46 | 47 | class VQLayer(nn.Module): 48 | def __init__(self, D, K): 49 | super(VQLayer, self).__init__() 50 | self.emb = nn.Embedding(K, D) 51 | self.K = K 52 | self.D = D 53 | 54 | def forward(self, x): 55 | """ 56 | x: (bz, D) 57 | """ 58 | return VectorQuantization.apply(x, self.emb.weight) 59 | 60 | 61 | class VQVae(nn.Module): 62 | def __init__(self, enc, dec, emb_dim, emb_num): 63 | super(VQVae, self).__init__() 64 | self.enc = enc 65 | self.dec = dec 66 | self.vqlayer = VQLayer(D=emb_dim, K=emb_num) 67 | 68 | def forward(self, x): 69 | self.z_e = self.enc(x) 70 | self.z_q = self.vqlayer(self.z_e) 71 | self.x_reconst = self.dec(self.z_q) 72 | return self.x_reconst 73 | 74 | def sample_from_modes(self): 75 | """ 76 | sample from the discrete representation 77 | """ 78 | zq = self.vqlayer.emb.weight 79 | samples = self.dec(zq) 80 | return samples 81 | 82 | 83 | class MLEenc(nn.Module): 84 | def __init__(self, input_dim, emb_dim): 85 | super(MLEenc, self).__init__() 86 | self.fc1 = nn.Linear(input_dim, 400) 87 | self.fc2 = nn.Linear(400, emb_dim) 88 | 89 | self.relu = nn.ReLU() 90 | 91 | def forward(self, x): 92 | h1 = self.relu(self.fc1(x)) 93 | return self.fc2(h1) 94 | 95 | 96 | class MLEdec(nn.Module): 97 | def __init__(self, emb_dim, input_dim): 98 | super(MLEdec, self).__init__() 99 | self.fc1 = nn.Linear(emb_dim, 400) 100 | self.fc2 = nn.Linear(400, input_dim) 101 | 102 | self.relu = nn.ReLU() 103 | self.sigmoid = nn.Sigmoid() 104 | 105 | def forward(self, x): 106 | h = self.relu(self.fc1(x)) 107 | return self.sigmoid(self.fc2(h)) 108 | -------------------------------------------------------------------------------- /test_vq.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.autograd import Variable 4 | from model import VQLayer 5 | import ipdb 6 | K=3 7 | D=2 8 | bz = 5 9 | 10 | 11 | x = Variable(torch.rand(bz, D), requires_grad=True) 12 | vq = VQLayer(D, K) 13 | y = vq(x) 14 | z = torch.sum(y) 15 | z.backward() 16 | 17 | print("embs are", vq.emb.weight.data) 18 | print("quantization", y.data) 19 | 20 | # if emb_i is chosen for k times, then ith row should be all k 21 | print("emb grads", vq.emb.weight.grad.data) 22 | 23 | # should be all zeros 24 | print("x grads", x.grad.data) 25 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import argparse 3 | import torch 4 | import torch.utils.data 5 | from torch import nn, optim 6 | from torch.autograd import Variable 7 | from torch.nn import functional as F 8 | from torchvision import datasets, transforms 9 | from torchvision.utils import save_image 10 | from model import VQLayer, MLEdec, MLEenc, VQVae 11 | import ipdb 12 | 13 | parser = argparse.ArgumentParser(description='VAE MNIST Example') 14 | parser.add_argument('--batch-size', type=int, default=128, metavar='N', 15 | help='input batch size for training (default: 128)') 16 | parser.add_argument('--epochs', type=int, default=10, metavar='N', 17 | help='number of epochs to train (default: 10)') 18 | parser.add_argument('--no-cuda', action='store_true', default=False, 19 | help='enables CUDA training') 20 | parser.add_argument('--seed', type=int, default=1, metavar='S', 21 | help='random seed (default: 1)') 22 | parser.add_argument('--log-interval', type=int, default=10, metavar='N', 23 | help='how many batches to wait before logging training status') 24 | parser.add_argument('--emb-dim', default=500, type=int) 25 | parser.add_argument('--emb-num', default=10, type=int) 26 | parser.add_argument('--beta', default=0.25, type=float) 27 | 28 | args = parser.parse_args() 29 | args.cuda = not args.no_cuda and torch.cuda.is_available() 30 | 31 | 32 | torch.manual_seed(args.seed) 33 | if args.cuda: 34 | torch.cuda.manual_seed(args.seed) 35 | 36 | 37 | kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {} 38 | train_loader = torch.utils.data.DataLoader( 39 | datasets.MNIST('./data', train=True, download=True, 40 | transform=transforms.ToTensor()), 41 | batch_size=args.batch_size, shuffle=True, **kwargs) 42 | test_loader = torch.utils.data.DataLoader( 43 | datasets.MNIST('./data', train=False, transform=transforms.ToTensor()), 44 | batch_size=args.batch_size, shuffle=True, **kwargs) 45 | 46 | 47 | enc = MLEenc(784, args.emb_dim) 48 | dec = MLEdec(args.emb_dim, 784) 49 | vqvae = VQVae(enc, dec, args.emb_dim, args.emb_num) 50 | if args.cuda: 51 | vqvae.cuda() 52 | 53 | optimizer = optim.Adam(vqvae.parameters(), lr=1e-3) 54 | 55 | def get_loss(data, vqvae): 56 | recon_data = vqvae(data) 57 | 58 | # reconst loss 59 | reconst_loss = F.binary_cross_entropy(recon_data, data) 60 | 61 | # cluster assignment loss 62 | detach_z_q = Variable(vqvae.z_q.data, requires_grad=False) 63 | cls_assg_loss = torch.sum((vqvae.z_e - detach_z_q).pow(2)) 64 | cls_assg_loss /= args.batch_size 65 | 66 | # cluster update loss 67 | detach_z_e = Variable(vqvae.z_e.data, requires_grad=False) 68 | z_q = vqvae.vqlayer(detach_z_e) 69 | cls_dist_loss = torch.sum((detach_z_e - z_q).pow(2)) 70 | cls_dist_loss /= args.batch_size 71 | 72 | return reconst_loss, cls_assg_loss, cls_dist_loss 73 | 74 | def train(epoch): 75 | vqvae.train() 76 | train_loss = 0 77 | for batch_idx, (data, _) in enumerate(train_loader): 78 | data = Variable(data) 79 | data = data.view(-1, 784) 80 | if args.cuda: 81 | data = data.cuda() 82 | 83 | # get losses 84 | reconst_loss, cls_assg_loss, cls_dist_loss = get_loss(data, vqvae) 85 | 86 | optimizer.zero_grad() 87 | # get grad for dec and enc 88 | loss = reconst_loss + args.beta * cls_assg_loss 89 | loss.backward() 90 | 91 | # clear the grads in vqlayer because they are not grads for updating emb 92 | vqvae.vqlayer.emb.zero_grad() 93 | # cluster update loss 94 | cls_dist_loss.backward() # get grad in emb 95 | loss += cls_dist_loss 96 | 97 | # all grads good. Update 98 | optimizer.step() 99 | train_loss += loss.data[0] 100 | if batch_idx % args.log_interval == 0: 101 | print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( 102 | epoch, batch_idx * len(data), len(train_loader.dataset), 103 | 100. * batch_idx / len(train_loader), 104 | loss.data[0] / len(data))) 105 | 106 | print('====> Epoch: {} Average loss: {:.4f}'.format( 107 | epoch, train_loss / len(train_loader.dataset))) 108 | 109 | 110 | def test(epoch): 111 | vqvae.eval() 112 | test_loss = 0 113 | for i, (data, _) in enumerate(test_loader): 114 | if args.cuda: 115 | data = data.cuda() 116 | data = Variable(data, volatile=True) 117 | data = data.view(-1, 784) 118 | 119 | reconst_loss, cls_assg_loss, cls_dist_loss = get_loss(data, vqvae) 120 | test_loss += \ 121 | (reconst_loss + args.beta*cls_assg_loss + cls_dist_loss).data[0] 122 | 123 | test_loss /= len(test_loader.dataset) 124 | print('====> Test set loss: {:.4f}'.format(test_loss)) 125 | 126 | 127 | for epoch in range(1, args.epochs + 1): 128 | train(epoch) 129 | test(epoch) 130 | 131 | # sample from each of discrete vector 132 | samples = vqvae.sample_from_modes() 133 | save_image(samples.data.view(args.emb_num, 1, 28, 28), 134 | 'results/sample_' + str(epoch) + '.png') 135 | --------------------------------------------------------------------------------