├── Dirichlet.lua ├── Gaussian.lua ├── GaussianCriterion.lua ├── KLCriterion.lua ├── Label.lua ├── LatentGMM.lua ├── NormalGamma.lua ├── README.md ├── VAE.lua ├── gaussianGamma.lua ├── gaussianMixture.lua ├── plot.py ├── plot_full.py ├── plot_latent.py ├── plot_recon.py ├── resVAE.lua └── save ├── .DS_Store └── spiral.t7 /Dirichlet.lua: -------------------------------------------------------------------------------- 1 | local Dirichlet, parent = torch.class( 'nn.Dirichlet', 'nn.Module' ) 2 | require 'cephes' 3 | 4 | function Dirichlet:__init(K) 5 | parent.__init(self) 6 | self.K = K 7 | self.a = torch.Tensor(K) 8 | self.Ga = torch.Tensor(K) 9 | self.a0 = torch.Tensor(K) 10 | 11 | self.stats = torch.Tensor(K) 12 | end 13 | 14 | function Dirichlet:parameters() 15 | return {self.a}, {self.Ga} 16 | end 17 | 18 | 19 | function Dirichlet:setPrior(a0) 20 | self.a0:fill(a0) 21 | end 22 | 23 | function Dirichlet:setParameters(a) 24 | self.a:copy(a) 25 | end 26 | 27 | function Dirichlet:updateExpectedStats() 28 | -- Sufficient statistics t() = [logP(k=1), logP(k=2), .. ] 29 | -- = digamma(a[k]) - digamma( sum(a) ) 30 | local a = self.a:clone() 31 | self.stats:zero() 32 | self.stats:add(cephes.digamma(a)):add(-cephes.digamma(a:sum())) 33 | 34 | return self.stats 35 | end 36 | 37 | 38 | function Dirichlet:accGradParameters(input, gradOutput, scale) 39 | -- Accumulate Natural Gradient 40 | self.Ga:add(self.a):add(-1, self.a0):add(-1, input[3][{{},1}]) 41 | 42 | end 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | -------------------------------------------------------------------------------- /Gaussian.lua: -------------------------------------------------------------------------------- 1 | local Gaussian, parent = torch.class('nn.Gaussian', 'nn.Module') 2 | 3 | function Gaussian:__init(K, D, N) 4 | parent.__init(self) 5 | self.K = K 6 | self.D = D 7 | self.N = N 8 | self.stats = torch.Tensor(4, K, D) 9 | self.llh = torch.Tensor(K, N) 10 | end 11 | 12 | function Gaussian:observe(data) 13 | 14 | local D = self.D 15 | 16 | -- data [ NxD ] 17 | self.stats[1][1]:copy( data:sum(1) ) 18 | self.stats[2][1]:copy( torch.cmul(data,data):sum(1) ) 19 | self.stats[3][1]:fill( data:size(1) ) 20 | self.stats[4][1]:fill( data:size(1) ) 21 | 22 | return self.stats 23 | end 24 | 25 | function Gaussian:getLogLikelihood(E_NG, Tx) 26 | local K = self.K 27 | local N = self.N 28 | local D = self.D 29 | 30 | self.llh:zero() -- [K, N] 31 | -- log P(x|k, mu, Sig) = < t(g,m), (x, x2, 1 ,1 ) > 32 | -- t(g,m): [4, K, D] 33 | -- Tx[1]: [N, D] 34 | -- Tx[2]: [N, D] 35 | self.llh:addmm( E_NG[1], Tx[1]:t() ) 36 | :addmm( E_NG[2], Tx[2]:t() ) 37 | :add(E_NG[3]:sum(2):expand(K,N)) 38 | :add(E_NG[4]:sum(2):expand(K,N)) 39 | :add(-0.5*torch.log(2*math.pi)) 40 | 41 | return self.llh 42 | end 43 | 44 | function Gaussian:getMixtureStats(phi, Tx, scale) 45 | local K = self.K 46 | local N = self.N 47 | local D = self.D 48 | -- txz = sum_n [phi x, phi x2 , phi , phi] [KxD] 49 | -- phi [KxN] 50 | -- Tx [NxD] 51 | self.stats:zero() 52 | self.stats[1]:mm(phi, Tx[1]) 53 | self.stats[2]:mm(phi, Tx[2]) 54 | self.stats[3]:copy( phi:sum(2):expand(K,D) ) 55 | self.stats[4]:copy( self.stats[3] ) 56 | 57 | self.stats[1]:mul(scale) 58 | self.stats[2]:mul(scale) 59 | self.stats[3]:mul(scale) 60 | self.stats[4]:mul(scale) 61 | 62 | return self.stats 63 | end 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | -------------------------------------------------------------------------------- /GaussianCriterion.lua: -------------------------------------------------------------------------------- 1 | require 'nn' 2 | --[[ 3 | Taken from https://github.com/y0ast/VAE-Torch 4 | ]]-- 5 | 6 | local GaussianCriterion, parent = torch.class('nn.GaussianCriterion', 'nn.Criterion') 7 | 8 | function GaussianCriterion:__init(scale) 9 | self.scale = scale or 1.0 10 | end 11 | 12 | function GaussianCriterion:updateOutput(input, target) 13 | -- negative LL, so sign is flipped 14 | -- log(sigma) + 0.5 *(2pi)) + 0.5 * (x - mu)^2/sigma^2 15 | -- input[1] = mu 16 | -- input[2] = log(sigma^2) 17 | 18 | local Gelement = torch.mul(input[2],0.5):add(0.5 * math.log(2 * math.pi)) 19 | Gelement:add(torch.add(target,-1,input[1]):pow(2):cdiv(torch.exp(input[2]) + 1e-10):mul(0.5)) 20 | 21 | self.output = torch.sum(Gelement) 22 | 23 | return self.output 24 | end 25 | 26 | function GaussianCriterion:updateGradInput(input, target) 27 | self.gradInput = {} 28 | 29 | -- - (x - mu) / sigma^2 --> (1 / sigma^2 = exp(-log(sigma^2)) ) 30 | self.gradInput[1] = torch.exp(-input[2]):cmul(torch.add(target,-1,input[1])):mul(-1) 31 | 32 | -- 0.5 - 0.5 * (x - mu)^2 / sigma^2 33 | self.gradInput[2] = torch.exp(-input[2]):cmul(torch.add(target,-1,input[1]):pow(2)):mul(-0.5):add(0.5) 34 | 35 | 36 | self.gradInput[1]:mul(self.scale) 37 | self.gradInput[2]:mul(self.scale) 38 | 39 | return self.gradInput 40 | end -------------------------------------------------------------------------------- /KLCriterion.lua: -------------------------------------------------------------------------------- 1 | local KLCriterion, parent = torch.class('nn.KLCriterion', 'nn.Criterion') 2 | 3 | 4 | function KLCriterion:__init(scale) 5 | 6 | self.scale = scale or 1.0 7 | end 8 | 9 | function KLCriterion:updateOutput(input, target) 10 | 11 | local m1 = input[1] 12 | local var1 = input[2] 13 | 14 | local m2 = target[1] or torch.Tensor():resizeAs(m1):zero() 15 | local var2 = target[2] or torch.Tensor():resizeAs(var1):fill(1) 16 | 17 | -- KL = 1/2log(var2/var1) + 1/(2*var2) * (var1 + (mu1 - mu2 )^2) - 1/2 18 | -- KL = 1/2( logvar2 - logvar1 + (var1 + (m1-m2)^2)/var2 - 1 ) 19 | -- KL(m2 = 0 ,var2 = 1) = - 1/2log(var1) + 1/2(var1 + m1^2) - 1/2 20 | 21 | local KLDelements = m1 - m2 22 | KLDelements:pow(2):add(var1):cdiv(var2 + 1e-10) 23 | :add(-1) 24 | :add(-1, torch.log(var1)):add(torch.log(var2)) 25 | :mul(0.5) 26 | 27 | self.output = torch.sum(KLDelements) 28 | 29 | return self.output 30 | end 31 | 32 | function KLCriterion:updateGradInput(input, target) 33 | self.gradInput = {} 34 | 35 | 36 | local m1 = input[1] 37 | local var1 = input[2] 38 | local m2 = target[1] or torch.Tensor():resizeAs(m1):zero() 39 | local var2 = target[2] or torch.Tensor():resizeAs(var1):fill(1) 40 | 41 | -- dKL_dm1 = (m1 - m2)/var2 42 | self.gradInput[1] = m1:clone():add(-1, m2):cdiv(var2 + 1e-10) 43 | 44 | -- dKL_dvar1 = 0.5 * (1/var2 - 1/var1) 45 | self.gradInput[2] = self.gradInput[2] or var1.new() 46 | self.gradInput[2]:resizeAs(var1) 47 | 48 | self.gradInput[2]:copy(var2):pow(-1):add(-1, torch.pow(var1, -1) ):mul(0.5) 49 | 50 | self.gradInput[1]:mul(self.scale) 51 | self.gradInput[2]:mul(self.scale) 52 | 53 | return self.gradInput 54 | end 55 | 56 | -------------------------------------------------------------------------------- /Label.lua: -------------------------------------------------------------------------------- 1 | local Label, parent = torch.class( 'nn.Label', 'nn.Module' ) 2 | 3 | function Label:__init(K, N) 4 | parent.__init(self) 5 | self.K = K 6 | self.N = N 7 | self.phi = torch.Tensor(K, N) 8 | 9 | end 10 | 11 | function Label:reset() 12 | 13 | self.phi:fill(1):div(self.K + 1e-10) 14 | 15 | return self.phi 16 | end 17 | 18 | 19 | function Label:setParameters(llh, E_dir) 20 | local K = self.K 21 | local N = self.N 22 | 23 | -- phi(k,n) = q(z_n = k) = 1/Z exp( energy(z_n = k) ) 24 | -- Z = sum_j energy(z_n = j) 25 | 26 | self.phi:zero() 27 | self.phi:add(llh):add(E_dir:repeatTensor(N,1):t()) 28 | 29 | local max = torch.max(self.phi,1) 30 | self.phi:add(-1, max:expand(K, N)) 31 | self.phi:exp() 32 | self.phi:cdiv(self.phi:sum(1):expand(K,N) + 1e-10) 33 | 34 | return self.phi 35 | end 36 | 37 | function Label:assignLabel() 38 | local _, idx = self.phi:max(1) 39 | return idx:view(self.N) 40 | end -------------------------------------------------------------------------------- /LatentGMM.lua: -------------------------------------------------------------------------------- 1 | require 'nn' 2 | require 'GaussianCriterion' 3 | require 'optim' 4 | require 'nngraph' 5 | require 'KLCriterion' 6 | require 'Dirichlet' 7 | require 'NormalGamma' 8 | require 'Gaussian' 9 | require 'Label' 10 | local nninit = require 'nninit' 11 | 12 | 13 | 14 | 15 | torch.manualSeed(2) 16 | data = torch.load('save/spiral.t7') 17 | 18 | local mnist = require 'mnist' 19 | --data = mnist.traindataset().data:div(255):double() 20 | --data = data:view(data:size(1), data:size(2)*data:size(3)) 21 | 22 | local N = data:size(1) 23 | local Dy = data:size(2) 24 | local Dx = 2 25 | local batch = 100 26 | local batchScale = N/batch 27 | local eta = 0.0001 28 | local eta_latent = 0.1 29 | local optimiser = 'adam' 30 | local latentOptimiser = 'sgd' 31 | local max_Epoch = 10000 32 | local K = 20 33 | local max_iter = 500 34 | 35 | -- ResNet 36 | function resNetBlock(inputSize, hiddenSize ) 37 | local input = - nn.Identity() 38 | local resBranch = input 39 | - nn.Linear(inputSize, hiddenSize):init('weight', nninit.normal, 0,0.001) 40 | :init('bias' , nninit.normal, 0, 0.001) 41 | - nn.Tanh() 42 | - nn.Linear(hiddenSize, inputSize):init('weight', nninit.normal, 0,0.001) 43 | :init('bias' , nninit.normal, 0, 0.001) 44 | local skipBranch = input 45 | - nn.Identity() 46 | local output = {resBranch, skipBranch} 47 | - nn.CAddTable() 48 | return nn.gModule({input}, {output}) 49 | end 50 | 51 | function globalMixing() 52 | local phi = - nn.Identity() -- [K, N ] 53 | local hk = - nn.Identity() -- [K, Dx] 54 | local Jk = - nn.Identity() -- [K, Dx] 55 | 56 | local phiT = phi - nn.Transpose({1,2}) -- [N, K] 57 | local Ehk = {phiT, hk} - nn.MM() -- [N, Dx] 58 | local EJk = {phiT, Jk} - nn.MM() -- [N, Dx] 59 | 60 | return nn.gModule({phi, hk, Jk}, {Ehk, EJk}) 61 | end 62 | 63 | function gaussainMeanfield() 64 | local hy = - nn.Identity() -- [N, Dx] 65 | local Jy = - nn.Identity() -- [N, Dx] 66 | local Ehk = - nn.Identity() 67 | local EJk = - nn.Identity() 68 | 69 | local hx = {hy, Ehk} - nn.CAddTable() -- mu/var 70 | local Jx = {Jy, EJk} - nn.CAddTable() -- -1/2(1/var) 71 | 72 | local var = Jx - nn.MulConstant(-2) 73 | - nn.Power(-1) 74 | local mean = {hx, var} 75 | - nn.CMulTable() 76 | 77 | return nn.gModule({hy, Jy, Ehk, EJk}, {mean, var}) 78 | end 79 | 80 | function createSampler() 81 | -- Sampler 82 | local mean = - nn.Identity() 83 | local var = - nn.Identity() 84 | local rand = - nn.Identity() 85 | local std = var - nn.Power(0.5) 86 | local noise = {std, rand} 87 | - nn.CMulTable() 88 | local x = {mean, noise} 89 | - nn.CAddTable() 90 | 91 | return nn.gModule({mean, var, rand}, {x}) 92 | end 93 | 94 | -- Network 95 | function createNetwork(Dy, Dx) 96 | local hiddenSize = 100 97 | -- Recogniser 98 | local input = - nn.Identity() 99 | local hidden = input 100 | - resNetBlock(Dy, hiddenSize) 101 | 102 | local mean = hidden 103 | - resNetBlock(Dy, hiddenSize) 104 | 105 | local logVar = hidden 106 | - nn.Linear(Dy, hiddenSize) 107 | - nn.Tanh() 108 | - nn.Linear(hiddenSize, Dy):init('bias' , nninit.normal, -5, 0.001) 109 | 110 | 111 | local Jy = logVar 112 | - nn.Exp() -- Var 113 | - nn.Power(-1) -- 1/var 114 | - nn.MulConstant(-0.5) -- - 1/2Var 115 | local hy = {mean, Jy} 116 | - nn.CMulTable() -- mean/(-2var) 117 | - nn.MulConstant(-2) -- mean/var 118 | 119 | local recogniser = nn.gModule( {input}, {hy, Jy}) 120 | 121 | -- Generator 122 | local X_sample = - nn.Identity() 123 | local h = X_sample 124 | - resNetBlock(Dy, hiddenSize) 125 | 126 | local recon_mean = h 127 | - resNetBlock(Dy, hiddenSize) 128 | local recon_logVar = h 129 | - nn.Linear(Dy, hiddenSize) 130 | - nn.Tanh() 131 | - nn.Linear(hiddenSize, Dy) 132 | 133 | local generator = nn.gModule({X_sample}, {recon_mean, recon_logVar}) 134 | 135 | return recogniser, generator 136 | end 137 | 138 | 139 | -- Latent Model 140 | local dir = nn.Dirichlet(K) 141 | local NG = nn.NormalGamma(K,Dx) 142 | local gaussian = nn.Gaussian(K, Dx, batch) 143 | local label = nn.Label(K, batch) 144 | 145 | -- set prior 146 | local m0 = torch.Tensor(Dx):zero() 147 | local l0 = torch.Tensor(Dx):fill(1) 148 | local a0 = torch.Tensor(Dx):fill(1) 149 | local b0 = torch.Tensor(Dx):fill(1) 150 | NG:setPrior(m0, l0, a0, b0) 151 | dir:setPrior(1) 152 | 153 | -- Initialise parameters 154 | local m1 = torch.rand(K, Dx):add(-0.5):mul(2) 155 | local l1 = torch.Tensor(K, Dx):fill(1) 156 | local a1 = torch.Tensor(K, Dx):fill(10) 157 | local b1 = torch.Tensor(K, Dx):fill(1) 158 | local pi1 = torch.randn(K):div(0.1) 159 | NG:setParameters(m1, l1, a1, b1) 160 | dir:setParameters(pi1) 161 | 162 | local latentContainer = nn.Container() 163 | :add(NG) 164 | :add(dir) 165 | local latentParams, gradLatentParams = latentContainer:getParameters() 166 | 167 | 168 | local recogniser, generator = createNetwork(Dy, Dx) 169 | local sampler = createSampler() 170 | local latentGMM = gaussainMeanfield() 171 | local globalMixing = globalMixing() 172 | 173 | local container = nn.Container() 174 | :add(recogniser) 175 | :add(generator) 176 | 177 | 178 | local ReconCrit = nn.GaussianCriterion( batchScale ) 179 | local KLCrit = nn.KLCriterion( batchScale ) 180 | 181 | local params, gradParams = container:getParameters() 182 | 183 | 184 | function feval(param) 185 | 186 | if param ~= params then 187 | params:copy(param) 188 | end 189 | 190 | container:zeroGradParameters() 191 | latentContainer:zeroGradParameters() 192 | 193 | -- Recogniser 194 | local hy, Jy = unpack( recogniser:forward(y) ) 195 | 196 | -- Get global expected stats 197 | local E_NG = NG:updateExpectedStats() 198 | local E_dir = dir:updateExpectedStats() 199 | local mean_x, var_x 200 | 201 | ------ Latent GMM ---- 202 | -- Initialise phi [KxN] 203 | local phi = label:reset() 204 | local Tx, Ehk, EJk 205 | local llh_prev = 0.0 206 | for i=1, max_iter do 207 | 208 | -- From {hy, Jy, phi, E_NG[1], E_NG[2]} -> {hx , Jx} -> {mean_x, var_x}-> {x, x2} 209 | Ehk, EJk = unpack( globalMixing:forward( {phi, E_NG[1], E_NG[2]} ) ) 210 | mean_x, var_x = unpack( latentGMM:forward({hy, Jy, Ehk, EJk}) ) 211 | 212 | Tx = {mean_x, mean_x:clone():pow(2):add(var_x) } --TODO fix this 213 | 214 | -- Get gaussian expected llh 215 | local llh = gaussian:getLogLikelihood(E_NG, Tx) 216 | 217 | -- Update Label parameters 218 | phi = label:setParameters(llh, E_dir) 219 | 220 | local sumllh = llh:sum() 221 | if torch.abs( llh_prev - llh:sum() ) < 1e-7 then 222 | break 223 | end 224 | llh_prev = sumllh 225 | if i == max_iter then 226 | print('max iter reach') 227 | end 228 | 229 | end 230 | 231 | -- Compute Mixture stats 232 | local Txz = gaussian:getMixtureStats(phi, Tx, batchScale) 233 | 234 | -- Do sampling 235 | local rand = torch.randn(var_x:size()):mul(1) 236 | local xs = sampler:forward({mean_x, var_x, rand}) 237 | 238 | ----------------------- 239 | 240 | local recon = generator:forward(xs) 241 | local reconLoss = ReconCrit:forward(recon, y) 242 | 243 | 244 | 245 | local gradRecon = ReconCrit:backward(recon, y) 246 | local gradXs = generator:backward(xs, gradRecon) 247 | local gradMean, gradVar, __ = unpack( sampler:backward({mean_x, var_x, rand}, gradXs ) ) 248 | local gradHy, gradJy, gradEhk, gradEJk = unpack( latentGMM:backward({hy, Jy, Ehk, EJk}, {gradMean, gradVar}) ) 249 | recogniser:backward(y, {gradHy, gradJy}) 250 | 251 | -- Update global parameters 252 | --local gradPHI, gradHK, gradJK = unpack( globalMixing:backward({phi, E_NG[1], E_NG[2]} ,{gradEhk, gradEJk})) 253 | --NG:backward(Txz, {gradHK, gradJK}) -- give NaN for sgd optimiser 254 | NG:backward(Txz) 255 | dir:backward(Txz) 256 | 257 | 258 | local var_k = EJk:clone():mul(-2):pow(-1) 259 | local mean_k = Ehk:clone():cmul(var_k) 260 | local KLLoss = KLCrit:forward({mean_x, var_x}, {mean_k, var_k}) 261 | 262 | local gradMean, gradVar = unpack( KLCrit:backward({mean_x, var_x}, {mean_k, var_k}) ) 263 | local gradHy, gradJy, gradEhk, gradEJk = unpack( latentGMM:backward({hy, Jy, Ehk, EJk}, {gradMean, gradVar}) ) 264 | recogniser:backward(y, {gradHy, gradJy}) 265 | 266 | 267 | 268 | local loss = reconLoss + KLLoss 269 | return loss, gradParams 270 | end 271 | 272 | function Lfeval(params) 273 | return __, gradLatentParams 274 | end 275 | 276 | 277 | 278 | for epoch = 1, max_Epoch do 279 | 280 | local indices = torch.randperm(N):long():split(batch) 281 | 282 | local recon = torch.Tensor():resizeAs(data):zero() 283 | local labels = torch.Tensor():resize(N):zero() 284 | local x_sample = torch.Tensor():resize(N, Dx):zero() 285 | local Loss = 0.0 286 | 287 | for t,v in ipairs(indices) do 288 | xlua.progress(t, #indices) 289 | y = data:index(1,v) 290 | 291 | __, loss = optim[optimiser](feval, params, {learningRate = eta }) 292 | __, __ = optim[latentOptimiser](Lfeval, latentParams, {learningRate = eta_latent}) 293 | 294 | recon[{ { batch*(t-1) + 1, batch*t },{}}] = generator.output[1] 295 | x_sample[{ { batch*(t-1) + 1, batch*t },{}}] = sampler.output 296 | labels[{{ batch*(t-1) + 1, batch*t }}] = label:assignLabel() 297 | 298 | Loss = Loss + loss[1] 299 | end 300 | 301 | print("Epoch: " .. epoch .. " Loss: " .. Loss/N ) 302 | torch.save('save/label.t7', labels) 303 | torch.save('save/recon.t7', recon) 304 | torch.save('save/xs.t7', x_sample) 305 | 306 | local mixing = dir:parameters()[1] 307 | mixing = mixing/mixing:sum() 308 | torch.save('save/mixing.t7', mixing) 309 | -- Latent plot -- 310 | local m, l, a, b = unpack(NG:getBasicParameters()) 311 | local cov = torch.Tensor(K, Dx, Dx) 312 | for k=1,K do 313 | cov[k] = torch.diag( torch.cdiv(b,a)[k] ) 314 | end 315 | torch.save('save/m.t7', m) 316 | torch.save('save/cov.t7', cov) 317 | 318 | 319 | end 320 | 321 | 322 | 323 | 324 | 325 | 326 | 327 | 328 | 329 | 330 | 331 | -------------------------------------------------------------------------------- /NormalGamma.lua: -------------------------------------------------------------------------------- 1 | local NormalGamma, parent = torch.class( 'nn.NormalGamma', 'nn.Module' ) 2 | 3 | function NormalGamma:__init(K, D) 4 | parent.__init(self) 5 | 6 | self.K = K 7 | self.D = D 8 | 9 | self.weight = torch.Tensor(4, K, D) 10 | self.gradWeight = torch.Tensor(4, K, D) 11 | 12 | self.prior = torch.Tensor(4, K, D) 13 | 14 | self.stats = torch.Tensor(4, K, D) 15 | 16 | end 17 | 18 | function NormalGamma:parameters() 19 | return {self.weight}, {self.gradWeight} 20 | end 21 | 22 | function NormalGamma:setPrior(m0, l0, a0, b0) 23 | --[[ 24 | Expect same prior across cluster K 25 | m0 = torch.Tensor(D) 26 | l0 = torch.Tensor(D) 27 | a0 = torch.Tensor(D) 28 | b0 = torch.Tensor(D) 29 | ]]-- 30 | local K = self.K 31 | local D = self.D 32 | 33 | self.prior[1]:copy( torch.cmul( m0, l0):repeatTensor(K,1) ) 34 | self.prior[2]:copy( torch.cmul( l0, m0):cmul(m0):add(2,b0):repeatTensor(K,1) )--2*b0 + l*m0*m0 35 | self.prior[3]:copy( l0:repeatTensor(K,1) ) 36 | self.prior[4]:copy( (2*a0 - 1):repeatTensor(K,1) ) 37 | 38 | return self.prior 39 | end 40 | 41 | function NormalGamma:setParameters(m1, l1, a1, b1) 42 | local K = self.K 43 | local D = self.D 44 | 45 | self.weight[1]:copy( torch.cmul( m1, l1) ) 46 | self.weight[2]:copy( torch.cmul( l1, m1):cmul(m1):add(2,b1) )--2*b0 + l*m0*m0 47 | self.weight[3]:copy( l1 ) 48 | self.weight[4]:copy( (2*a1 - 1) ) 49 | 50 | return self.weight 51 | end 52 | 53 | function NormalGamma:accGradParameters(input, gradOutput, scale) 54 | -- Accumulate Natural gradient 55 | self.gradWeight:add(self.weight):add(-1, self.prior):add(-1, input) 56 | 57 | if gradOutput then 58 | self.gradWeight[1]:add(-1, gradOutput[1]) 59 | self.gradWeight[2]:add(-1, gradOutput[2]) 60 | end 61 | end 62 | 63 | function NormalGamma:getBasicParameters() 64 | -- w1 = ml 65 | -- w2 = 2b + lm^2 66 | -- w3 = l 67 | -- w4 = 2a -1 68 | local l = self.weight[3] --[KxD] 69 | local m = torch.cdiv( self.weight[1], l ) 70 | local a = 0.5 * (self.weight[4] + 1) 71 | local b = 0.5 * (self.weight[2] - torch.cmul( self.weight[1], m) ) 72 | 73 | return {m, l, a, b} 74 | end 75 | 76 | require 'cephes' 77 | function NormalGamma:updateExpectedStats() 78 | --[[ 79 | t(gamma, mu) = [ g * mu, 80 | -0.5 g, 81 | -0.5 g *mu*mu, 82 | 0.5 * log(g) 83 | ] 84 | = [ (a/b) * m, 85 | -0.5 (a/b), 86 | -0.5 (1/l + m*m*a/b), 87 | 0.5 * ( digamma(a) - ln(b) ) 88 | ] 89 | ]]-- 90 | local K = self.K 91 | local D = self.D 92 | local m, l, a, b = unpack( self:getBasicParameters() ) 93 | self.stats[1]:copy(m):cmul(a):cdiv(b + 1e-10) 94 | self.stats[2]:copy(a):cdiv(b + 1e-10):mul(-0.5) 95 | self.stats[3]:copy(m):cmul(m):cmul(a):cdiv(b + 1e-10):add(torch.pow(l,-1)):mul(-0.5) 96 | 97 | local a_flat = a:view(K*D) 98 | a_flat:copy( cephes.digamma(a_flat) ) 99 | self.stats[4]:copy(a):add(-1, torch.log(b)):mul(0.5) 100 | 101 | return self.stats 102 | end 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | Attempt on replicating SVAE with Torch 3 | https://github.com/mattjj/svae 4 | 5 | 6 | 7 | For SVAE Latent GMM model: th LatentGMM.lua 8 | and visualise with: python plot_full.py 9 | 10 | For VAE : th resVAE.lua 11 | visualise with: python plot_latent.py and python plot_recon.py 12 | 13 | For normal variational inference on gaussian Mixture model 14 | : th gaussianMixture.lua 15 | 16 | 17 | Dependencies 18 | 19 | Torch: - nngraph 20 | - cephes 21 | - nninit 22 | 23 | Python: - Torchfile 24 | 25 | -------------------------------------------------------------------------------- /VAE.lua: -------------------------------------------------------------------------------- 1 | require 'nn' 2 | require 'GaussianCriterion' 3 | require 'optim' 4 | require 'nngraph' 5 | require 'KLCriterion' 6 | 7 | torch.manualSeed(1) 8 | data = torch.load('save/spiral.t7') 9 | 10 | local N = data:size(1) 11 | local Dy = data:size(2) 12 | local Dx = 2 13 | local batch = 100 14 | local batchScale = N/batch 15 | local eta = 0.001 16 | local optimiser = 'adam' 17 | local max_Epoch = 500 18 | 19 | 20 | -- Network 21 | function createNetwork(Dy, Dx) 22 | local hiddenSize = 200 23 | -- Recogniser 24 | local input = - nn.View(-1, Dy) 25 | local hidden = input 26 | - nn.Linear(Dy, hiddenSize) 27 | - nn.ReLU(true) 28 | local mean = hidden 29 | - nn.Linear(hiddenSize, Dx) 30 | local logVar = hidden 31 | - nn.Linear(hiddenSize, Dx) 32 | 33 | local recogniser = nn.gModule( {input}, {mean, logVar}) 34 | 35 | -- Sampler 36 | local mean = - nn.Identity() 37 | local std = - nn.Identity() 38 | local rand = - nn.Identity() 39 | local noise = {std, rand} 40 | - nn.CMulTable() 41 | local x = {mean, noise} 42 | - nn.CAddTable() 43 | 44 | local sampler = nn.gModule({mean, std, rand}, {x}) 45 | 46 | -- Generator 47 | local X_sample = - nn.Identity() 48 | local h = X_sample 49 | - nn.Linear(Dx, hiddenSize) 50 | - nn.ReLU(true) 51 | local recon_mean = h 52 | - nn.Linear(hiddenSize, Dy) 53 | local recon_logVar = h 54 | - nn.Linear(hiddenSize, Dy) 55 | 56 | local generator = nn.gModule({X_sample}, {recon_mean, recon_logVar}) 57 | 58 | return recogniser, sampler, generator 59 | end 60 | 61 | local recogniser, sampler, generator = createNetwork(Dy, Dx) 62 | local container = nn.Container() 63 | :add(recogniser) 64 | :add(generator) 65 | 66 | local ReconCrit = nn.GaussianCriterion() 67 | local KLCrit = nn.KLCriterion(1.0) 68 | 69 | 70 | 71 | local params, gradParams = container:getParameters() 72 | 73 | 74 | function feval(x) 75 | 76 | if x ~= params then 77 | params:copy(x) 78 | end 79 | 80 | container:zeroGradParameters() 81 | 82 | local mean, logVar = unpack( recogniser:forward(y) ) 83 | 84 | local std = logVar:clone():mul(0.5):exp() 85 | local rand = torch.randn(std:size()) 86 | local xs = sampler:forward({mean, std, rand}) 87 | 88 | local recon = generator:forward(xs) 89 | local reconLoss = ReconCrit:forward(recon, y) 90 | local gradRecon = ReconCrit:backward(recon, y) 91 | local gradXs = generator:backward(xs, gradRecon) 92 | 93 | local gradMean, gradStd, __ = unpack( sampler:backward({mean, std, rand}, gradXs) ) 94 | local gradLogVar = gradStd:clone():cmul(std):mul(0.5) 95 | recogniser:backward(y, {gradMean, gradLogVar}) 96 | 97 | 98 | 99 | -- Set prior for VAE 100 | local mean_prior = torch.Tensor():resizeAs(mean):zero() 101 | local var_prior = torch.Tensor():resizeAs(std):fill(1) 102 | 103 | local var = std:pow(2) 104 | local KLLoss = KLCrit:forward({mean, var},{mean_prior, var_prior}) 105 | local gradMean, gradVar = unpack( KLCrit:backward({mean, var},{mean_prior, var_prior}) ) 106 | gradLogVar = torch.cmul(gradVar, var) 107 | 108 | recogniser:backward(y, {gradMean, gradLogVar}) 109 | local loss = reconLoss + KLLoss 110 | return loss, gradParams 111 | end 112 | 113 | for epoch = 1, max_Epoch do 114 | 115 | indices = torch.randperm(N):long():split(batch) 116 | 117 | local recon = torch.Tensor():resizeAs(data):zero() 118 | local x_sample = torch.Tensor():resize(N, Dx):zero() 119 | local Loss = 0.0 120 | 121 | for t,v in ipairs(indices) do 122 | xlua.progress(t, #indices) 123 | y = data:index(1,v) 124 | __, loss = optim[optimiser](feval, params, {learningRate = eta }) 125 | recon[{ { batch*(t-1) + 1, batch*t },{}}] = generator.output[1] 126 | x_sample[{ { batch*(t-1) + 1, batch*t },{}}] = sampler.output 127 | Loss = Loss + loss[1] 128 | end 129 | 130 | print("Epoch: " .. epoch .. " Loss: " .. Loss/N ) 131 | 132 | torch.save('save/recon.t7', recon) 133 | torch.save('save/xs.t7', x_sample) 134 | 135 | end 136 | 137 | 138 | 139 | 140 | 141 | 142 | 143 | 144 | 145 | 146 | 147 | 148 | 149 | 150 | 151 | -------------------------------------------------------------------------------- /gaussianGamma.lua: -------------------------------------------------------------------------------- 1 | require 'nn' 2 | require 'NormalGamma' 3 | require 'Gaussian' 4 | -- Normal variational inference in gaussian with normal-gamma prior 5 | 6 | torch.manualSeed(1) 7 | local N = 1000 8 | local D = 2 9 | local K = 1 10 | -- 1.) prepare data and Model 11 | local x = torch.Tensor(N,D):zero() 12 | function generateData() 13 | local xmean = torch.Tensor{{10}, {20}} 14 | local cov = torch.Tensor{{10,-0.8},{-0.8,10}} 15 | local chol = torch.potrf(cov,'L') 16 | for i=1 , N do 17 | local rand = torch.randn(D,1) 18 | x[i]:add(xmean) 19 | x[i]:add(torch.mm(chol,rand)) 20 | end 21 | end 22 | generateData() 23 | 24 | local NG = nn.NormalGamma(K, D) 25 | local gaussian = nn.Gaussian(K, D, N) 26 | 27 | -- 2.) set prior 28 | local m0 = torch.Tensor(D):zero() 29 | local l0 = torch.Tensor(D):fill(1) 30 | local a0 = torch.Tensor(D):fill(1) 31 | local b0 = torch.Tensor(D):fill(1) 32 | local prior = NG:setPrior(m0, l0, a0, b0) 33 | 34 | -- 3.) get statistics from data 35 | local Tx = gaussian:observe(x) 36 | -- 5.) Update variational parameters 37 | NG:backward(Tx) 38 | NG:updateParameters(1.0) 39 | 40 | local m,l, a, b = unpack( NG:getBasicParameters() ) 41 | 42 | 43 | 44 | local x_bar = x:sum(1)/N 45 | print(m) 46 | print(x_bar) 47 | 48 | local cov = torch.Tensor(K, D, D) 49 | for k=1,K do 50 | cov[k] = torch.diag( torch.cdiv(b,a)[k] ) 51 | end 52 | print(cov) 53 | torch.save('save/x.t7', x) 54 | torch.save('save/m.t7', m) 55 | torch.save('save/cov.t7', cov) 56 | 57 | 58 | 59 | 60 | -------------------------------------------------------------------------------- /gaussianMixture.lua: -------------------------------------------------------------------------------- 1 | require 'nn' 2 | require 'Gaussian' 3 | require 'NormalGamma' 4 | require 'Dirichlet' 5 | require 'Label' 6 | 7 | torch.manualSeed(1) 8 | 9 | local N = 1000 10 | local D = 2 11 | local K = 20 12 | 13 | -- 1.) prepare data 14 | 15 | local function generateMixtureOfGaussian(N, xmean, cov, K) 16 | local D = xmean:size(2) 17 | local _x = torch.Tensor(N, D):zero() 18 | local k = torch.ceil( torch.rand(N)*K ) 19 | for i= 1 , N do 20 | local chol = torch.potrf(cov[k[i]] , 'L') 21 | local rand = torch.randn(D,1) 22 | _x[i]:add(xmean[k[i]]) 23 | _x[i]:add(torch.mm(chol,rand)) 24 | end 25 | 26 | return _x 27 | end 28 | 29 | local xmean = 100 * torch.rand(K, D, 1)--torch.Tensor{{10}, {20}} 30 | local cov = 1.0 * torch.Tensor{{1,-0.5},{-0.5,1}}:repeatTensor(K, 1, 1) 31 | 32 | local x = generateMixtureOfGaussian(N, xmean, cov, K) 33 | 34 | 35 | local x = torch.load('save/spiral.t7') 36 | local N = x:size(1) 37 | local D = 2 38 | local K = 5 39 | 40 | -- prepare model 41 | local dir = nn.Dirichlet(K) 42 | local NG = nn.NormalGamma(K,D) 43 | local gaussian = nn.Gaussian(K, D, N) 44 | local label = nn.Label(K, N) 45 | 46 | -- set prior 47 | local m0 = torch.Tensor(D):zero() 48 | local l0 = torch.Tensor(D):fill(1) 49 | local a0 = torch.Tensor(D):fill(1) 50 | local b0 = torch.Tensor(D):fill(1) 51 | NG:setPrior(m0, l0, a0, b0) 52 | dir:setPrior(1000) 53 | 54 | -- Initialise parameters 55 | local m1 = torch.rand(K, D) 56 | local l1 = torch.Tensor(K, D):fill(1) 57 | local a1 = torch.Tensor(K, D):fill(1) 58 | local b1 = torch.Tensor(K, D):fill(1) 59 | local pi1 = torch.rand(K) 60 | NG:setParameters(m1, l1, a1, b1) 61 | dir:setParameters(pi1) 62 | 63 | -- prepare x stats 64 | local Tx = { x, torch.cmul(x,x) } 65 | 66 | for epoch =1, 100 do 67 | NG:zeroGradParameters() 68 | dir:zeroGradParameters() 69 | --1.) Get global expected stats 70 | local E_NG = NG:updateExpectedStats() 71 | local E_dir = dir:updateExpectedStats() 72 | 73 | --2.) Get gaussian expected llh 74 | local llh = gaussian:getLogLikelihood(E_NG, Tx) 75 | 76 | --3.) Update Label parameters 77 | local phi = label:setParameters(llh, E_dir) 78 | 79 | --4.) Compute Mixture stats 80 | local Txz = gaussian:getMixtureStats(phi, Tx, 1.0) 81 | 82 | -- 5.) Update global parameters 83 | NG:backward(Txz) 84 | NG:updateParameters(1.0) 85 | dir:backward(Txz) 86 | dir:updateParameters(1.0) 87 | 88 | -- plot 89 | local m, l, a, b = unpack(NG:getBasicParameters()) 90 | local cov = torch.Tensor(K, D, D) 91 | for k=1,K do 92 | cov[k] = torch.diag( torch.cdiv(b,a)[k] ) 93 | end 94 | print(cov) 95 | torch.save('save/x.t7', x) 96 | torch.save('save/m.t7', m) 97 | torch.save('save/cov.t7', cov) 98 | 99 | end 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | 125 | 126 | 127 | 128 | 129 | 130 | 131 | 132 | 133 | 134 | 135 | 136 | -------------------------------------------------------------------------------- /plot.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torchfile 3 | import matplotlib.pyplot as plt 4 | from matplotlib.patches import Ellipse 5 | 6 | y = torchfile.load('save/x.t7') 7 | 8 | 9 | # set up 10 | plt.ion() 11 | 12 | fig = plt.figure() 13 | ax = fig.add_subplot(111) 14 | 15 | 16 | mean = torchfile.load('save/m.t7') 17 | colors = {} 18 | for j in xrange(len(mean)): 19 | colors[j] = np.random.rand(3,1) 20 | 21 | 22 | plotting = True 23 | count = 0 24 | while plotting: 25 | 26 | ax.cla() 27 | y = torchfile.load('save/x.t7') 28 | ax.scatter(y[:,0], y[:, 1], marker='.' , s = 5 ) 29 | mean = torchfile.load('save/m.t7') 30 | ax.scatter(mean[:,0], mean[:,1], marker='o', s = 10, color='red') 31 | # draw ellipse 32 | cov = torchfile.load('save/cov.t7') 33 | def eigsorted(cov): 34 | vals, vecs = np.linalg.eigh(cov) 35 | order = vals.argsort()[::-1] 36 | return vals[order], vecs[:,order] 37 | 38 | for i in xrange(len(cov)): 39 | vals, vecs = eigsorted(cov[i]) 40 | theta = np.degrees(np.arctan2(*vecs[:,0][::-1])) 41 | # Width and height are "full" widths, not radius 42 | nstd = 2.0 43 | width, height = 2 * nstd * np.sqrt(vals) 44 | ellip = Ellipse(xy= mean[i], width=width, height=height, angle=theta, alpha=0.4, color=colors[i]) 45 | 46 | ax.add_artist(ellip) 47 | 48 | 49 | 50 | plt.draw() 51 | plt.pause(0.05) 52 | count += 1 53 | if count > 100000: 54 | plotting = False 55 | print('Finish Simulation') 56 | 57 | plt.ioff() 58 | plt.show() 59 | -------------------------------------------------------------------------------- /plot_full.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torchfile 3 | import matplotlib.pyplot as plt 4 | from matplotlib.patches import Ellipse 5 | 6 | 7 | 8 | # set up 9 | plt.ion() 10 | 11 | fig = plt.figure() 12 | ax = fig.add_subplot(121) 13 | ax2 = fig.add_subplot(122) 14 | 15 | 16 | 17 | mean = torchfile.load('save/m.t7') 18 | colors = {} 19 | for j in xrange(len(mean)): 20 | colors[j] = np.random.rand(3,1) 21 | 22 | 23 | plotting = True 24 | count = 0 25 | while plotting: 26 | 27 | ax.cla() 28 | x = torchfile.load('save/xs.t7') 29 | ax.scatter(x[:,0], x[:, 1], marker='.' , s = 5 ) 30 | mean = torchfile.load('save/m.t7') 31 | ax.scatter(mean[:,0], mean[:,1], marker='o', s = 10, color='red') 32 | 33 | ax2.cla() 34 | y = torchfile.load('save/spiral.t7') 35 | ax2.scatter(y[:,0], y[:, 1], marker='.' , s = 5 ) 36 | y_recon = torchfile.load('save/recon.t7') 37 | label = torchfile.load('save/label.t7') 38 | mix = torchfile.load('save/mixing.t7') 39 | # draw ellipse 40 | cov = torchfile.load('save/cov.t7') 41 | def eigsorted(cov): 42 | vals, vecs = np.linalg.eigh(cov) 43 | order = vals.argsort()[::-1] 44 | return vals[order], vecs[:,order] 45 | 46 | for i in xrange(len(cov)): 47 | vals, vecs = eigsorted(cov[i]) 48 | theta = np.degrees(np.arctan2(*vecs[:,0][::-1])) 49 | # Width and height are "full" widths, not radius 50 | nstd = 2.0 51 | width, height = 2 * nstd * np.sqrt(vals) 52 | ellip = Ellipse(xy= mean[i], width=width, height=height, angle=theta, alpha=mix[i], color=colors[i]) 53 | 54 | ax.add_artist(ellip) 55 | 56 | # plots latents samples 57 | key = label - 1 == i 58 | # plot reconstruction 59 | ax2.scatter(y_recon[key,0], y_recon[key, 1], marker='.' , s = 5 , color=colors[i]) 60 | 61 | 62 | 63 | plt.draw() 64 | plt.pause(0.05) 65 | count += 1 66 | if count > 100000: 67 | plotting = False 68 | print('Finish Simulation') 69 | 70 | plt.ioff() 71 | plt.show() 72 | -------------------------------------------------------------------------------- /plot_latent.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torchfile 3 | import matplotlib.pyplot as plt 4 | 5 | 6 | # set up 7 | plt.ion() 8 | fig = plt.figure() 9 | ax = fig.add_subplot(111) 10 | 11 | plotting = True 12 | count = 0 13 | while plotting: 14 | 15 | xs = torchfile.load('save/xs.t7') 16 | 17 | ax.cla() 18 | ax.scatter(xs[:,0], xs[:, 1], marker='.' , s = 5 ) 19 | 20 | plt.draw() 21 | plt.pause(0.05) 22 | count += 1 23 | if count > 10000: 24 | plotting = False 25 | print('Finish Simulation') 26 | 27 | plt.ioff() 28 | plt.show() -------------------------------------------------------------------------------- /plot_recon.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torchfile 3 | import matplotlib.pyplot as plt 4 | 5 | y = torchfile.load('save/spiral.t7') 6 | 7 | # set up 8 | plt.ion() 9 | fig = plt.figure() 10 | ax = fig.add_subplot(111) 11 | 12 | plotting = True 13 | count = 0 14 | while plotting: 15 | 16 | y_recon = torchfile.load('save/recon.t7') 17 | 18 | ax.cla() 19 | ax.scatter(y[:,0], y[:, 1], marker='.' , s = 5 ) 20 | ax.scatter(y_recon[:,0], y_recon[:, 1], marker='.' , s = 5, color= 'red' ) 21 | plt.draw() 22 | plt.pause(0.05) 23 | count += 1 24 | if count > 10000: 25 | plotting = False 26 | print('Finish Simulation') 27 | 28 | plt.ioff() 29 | plt.show() -------------------------------------------------------------------------------- /resVAE.lua: -------------------------------------------------------------------------------- 1 | require 'nn' 2 | require 'GaussianCriterion' 3 | require 'optim' 4 | require 'nngraph' 5 | require 'KLCriterion' 6 | local nninit = require 'nninit' 7 | 8 | 9 | torch.manualSeed(1) 10 | data = torch.load('save/spiral.t7') 11 | 12 | local N = data:size(1) 13 | local Dy = data:size(2) 14 | local Dx = 2 15 | local batch = 100 16 | local batchScale = N/batch 17 | local eta = 0.0001 18 | local optimiser = 'adam' 19 | local max_Epoch = 500 20 | 21 | 22 | -- ResNet 23 | function resNetBlock(inputSize, hiddenSize ) 24 | local input = - nn.Identity() 25 | local resBranch = input 26 | - nn.Linear(inputSize, hiddenSize):init('weight', nninit.normal, 0,0.001) 27 | :init('bias' , nninit.normal, 0, 0.001) 28 | - nn.Tanh() 29 | - nn.Linear(hiddenSize, inputSize):init('weight', nninit.normal, 0,0.001) 30 | :init('bias' , nninit.normal, 0, 0.001) 31 | local skipBranch = input 32 | - nn.Identity() 33 | local output = {resBranch, skipBranch} 34 | - nn.CAddTable() 35 | return nn.gModule({input}, {output}) 36 | end 37 | 38 | -- Network 39 | function createNetwork(Dy, Dx) 40 | local hiddenSize = 100 41 | -- Recogniser 42 | local input = - nn.View(-1, Dy) 43 | local hidden = input 44 | - resNetBlock(Dy, hiddenSize) 45 | local mean = hidden 46 | - resNetBlock(Dy, hiddenSize) 47 | local logVar = hidden 48 | - nn.Linear(Dy, Dy):init('bias', nninit.addConstant, -5) 49 | 50 | local recogniser = nn.gModule( {input}, {mean, logVar}) 51 | 52 | -- Sampler 53 | local mean = - nn.Identity() 54 | local std = - nn.Identity() 55 | local rand = - nn.Identity() 56 | local noise = {std, rand} 57 | - nn.CMulTable() 58 | local x = {mean, noise} 59 | - nn.CAddTable() 60 | 61 | local sampler = nn.gModule({mean, std, rand}, {x}) 62 | 63 | -- Generator 64 | local X_sample = - nn.Identity() 65 | local h = X_sample 66 | - resNetBlock(Dy, hiddenSize) 67 | local recon_mean = h 68 | - resNetBlock(Dy, hiddenSize) 69 | local recon_logVar = h 70 | - nn.Linear(Dy, Dy):init('bias', nninit.addConstant, -5) 71 | 72 | 73 | local generator = nn.gModule({X_sample}, {recon_mean, recon_logVar}) 74 | 75 | return recogniser, sampler, generator 76 | end 77 | 78 | 79 | 80 | 81 | local recogniser, sampler, generator = createNetwork(Dy, Dx) 82 | local container = nn.Container() 83 | :add(recogniser) 84 | :add(generator) 85 | 86 | local ReconCrit = nn.GaussianCriterion() 87 | local KLCrit = nn.KLCriterion(1.0) 88 | 89 | 90 | 91 | local params, gradParams = container:getParameters() 92 | 93 | 94 | function feval(x) 95 | 96 | if x ~= params then 97 | params:copy(x) 98 | end 99 | 100 | container:zeroGradParameters() 101 | 102 | local mean, logVar = unpack( recogniser:forward(y) ) 103 | 104 | 105 | local std = logVar:clone():mul(0.5):exp() 106 | local rand = torch.randn(std:size()):mul(1.0) 107 | local xs = sampler:forward({mean, std, rand}) 108 | 109 | local recon = generator:forward(xs) 110 | local reconLoss = ReconCrit:forward(recon, y) 111 | local gradRecon = ReconCrit:backward(recon, y) 112 | local gradXs = generator:backward(xs, gradRecon) 113 | 114 | local gradMean, gradStd, __ = unpack( sampler:backward({mean, std, rand}, gradXs) ) 115 | local gradLogVar = gradStd:clone():cmul(std):mul(0.5) 116 | recogniser:backward(y, {gradMean, gradLogVar}) 117 | 118 | 119 | 120 | -- Set prior for VAE 121 | local mean_prior = torch.Tensor():resizeAs(mean):zero() 122 | local var_prior = torch.Tensor():resizeAs(std):fill(1) 123 | 124 | local var = std:pow(2) 125 | local KLLoss = KLCrit:forward({mean, var},{mean_prior, var_prior}) 126 | local gradMean, gradVar = unpack( KLCrit:backward({mean, var},{mean_prior, var_prior}) ) 127 | gradLogVar = torch.cmul(gradVar, var) 128 | 129 | recogniser:backward(y, {gradMean, gradLogVar}) 130 | local loss = reconLoss + KLLoss 131 | return loss, gradParams 132 | end 133 | 134 | for epoch = 1, max_Epoch do 135 | 136 | indices = torch.randperm(N):long():split(batch) 137 | 138 | local recon = torch.Tensor():resizeAs(data):zero() 139 | local x_sample = torch.Tensor():resize(N, Dx):zero() 140 | local Loss = 0.0 141 | 142 | for t,v in ipairs(indices) do 143 | xlua.progress(t, #indices) 144 | y = data:index(1,v) 145 | __, loss = optim[optimiser](feval, params, {learningRate = eta }) 146 | recon[{ { batch*(t-1) + 1, batch*t },{}}] = generator.output[1] 147 | x_sample[{ { batch*(t-1) + 1, batch*t },{}}] = sampler.output 148 | Loss = Loss + loss[1] 149 | end 150 | 151 | print("Epoch: " .. epoch .. " Loss: " .. Loss/N ) 152 | 153 | torch.save('save/recon.t7', recon) 154 | torch.save('save/xs.t7', x_sample) 155 | 156 | end 157 | 158 | 159 | 160 | 161 | 162 | 163 | 164 | 165 | 166 | 167 | 168 | 169 | 170 | 171 | 172 | -------------------------------------------------------------------------------- /save/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Nat-D/SVAE-Torch/7b53bfb9840fe8f5c515c53f849d8b1c47a7df7b/save/.DS_Store -------------------------------------------------------------------------------- /save/spiral.t7: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Nat-D/SVAE-Torch/7b53bfb9840fe8f5c515c53f849d8b1c47a7df7b/save/spiral.t7 --------------------------------------------------------------------------------