├── LICENSE.md ├── README.md ├── bmnist ├── README.md ├── aux.lua ├── generate.lua ├── interpolate.lua ├── main.lua └── mnist.lua ├── lang ├── README.md ├── doc │ ├── README_oneb.md │ └── README_snli.md ├── experiments │ ├── README.md │ ├── noise.py │ └── vector.py ├── generate.py ├── models.py ├── preprocess_lm.py ├── run_oneb.py ├── run_snli.py ├── snli_preprocessing.py ├── train.py ├── train_rnnlm.py └── utils.py └── yelp ├── README.md ├── data ├── test0.txt ├── test1.txt ├── train1.txt ├── train2.txt ├── valid1.txt └── valid2.txt ├── models.py ├── train.py └── utils.py /LICENSE.md: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2017, Junbo (Jake) Zhao (NYU), Yoon Kim (Harvard), Kelly Zhang (NYU), Alexander M. Rush (Harvard) and Yann LeCun (NYU) All rights reserved. 4 | 5 | Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: 6 | 7 | Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 8 | 9 | Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 10 | 11 | Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. 12 | 13 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 14 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ARAE 2 | Code for the paper "Adversarially Regularized Autoencoders (ICML 2018)" by Zhao, Kim, Zhang, Rush and LeCun https://arxiv.org/abs/1706.04223 3 | 4 | 5 | ## Disclaimer 6 | Major updates on 06/11/2018: 7 | * WGAN-GP replaced WGAN 8 | * added 1BWord dataset experiment 9 | * added Yelp transfer experiment 10 | * removed unnecessary tricks 11 | * added both RNNLM and ngram-LM evaluation for both forward and reverse PPL. 12 | 13 | ## File structure 14 | * lang: ARAE for language generation, on both 1B word benchmark and SNLI 15 | * yelp: ARAE for language style transfer 16 | * mnist (in Torch): ARAE for discretized MNIST 17 | 18 | 19 | ## Reference 20 | 21 | ``` 22 | @ARTICLE{2017arXiv170604223J, 23 | author = {{Junbo} and {Zhao} and {Kim}, Y. and {Zhang}, K. and {Rush}, A.~M. and 24 | {LeCun}, Y.}, 25 | title = "{Adversarially Regularized Autoencoders for Generating Discrete Structures}", 26 | journal = {ArXiv e-prints}, 27 | archivePrefix = "arXiv", 28 | eprint = {1706.04223}, 29 | primaryClass = "cs.LG", 30 | keywords = {Computer Science - Learning, Computer Science - Computation and Language, Computer Science - Neural and Evolutionary Computing}, 31 | year = 2017, 32 | month = jun, 33 | adsurl = {http://adsabs.harvard.edu/abs/2017arXiv170604223J}, 34 | adsnote = {Provided by the SAO/NASA Astrophysics Data System} 35 | } 36 | ``` 37 | 38 | 39 | 40 | 41 | -------------------------------------------------------------------------------- /bmnist/README.md: -------------------------------------------------------------------------------- 1 | ## binary-MNIST generation 2 | 3 | Torch implementation for binary-MNIST generation experiment. 4 | 5 | ### Requirements 6 | This code runs using torch. It is only tested on GPU. The following torch libraries are required: 7 | nn, cunn, cudnn, optim, image, nngraph, dpnn([https://github.com/Element-Research/dpnn](https://github.com/Element-Research/dpnn)) 8 | 9 | MNSIT t7 dataset can be downloaded here: [https://github.com/torch/tutorials/blob/master/A_datasets/mnist.lua](https://github.com/torch/tutorials/blob/master/A_datasets/mnist.lua). 10 | The dataset should be downloaded and placed under this folder, with the structure 11 | ``` 12 | . 13 | +-- mnist 14 | +-- train_28x28.t7 15 | +-- test_28x28.t7 16 | ``` 17 | 18 | 19 | ### Train 20 | ``` 21 | th main.lua 22 | ``` 23 | 24 | The output model: `$savename/model.t7` 25 | 26 | ### Generation 27 | To generate: 28 | ``` 29 | th generate.lua --imgname output.png --modelpath ./$savename/model.t7 30 | ``` 31 | The output examplar generation: `output.png` 32 | 33 | ### Z-space interpolation 34 | To interpolate: 35 | ``` 36 | th interpolate.lua --imgname int.png --modelpath ./$savename/model.t7 37 | ``` 38 | The output examplar interpolation: `int.png`. 39 | -------------------------------------------------------------------------------- /bmnist/aux.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | auxiliary functions 3 | --]] 4 | 5 | --[[ Break a string "25-30-100" into table {25,30,100} ]]-- 6 | function convert_option(s) 7 | local out = {} 8 | local args = string.split(s, '-') 9 | for _, x in pairs(args) do 10 | x = string.gsub(x, 'n', '-') 11 | local y = tonumber(x) 12 | if y == nil then 13 | error("Parsing arguments: " .. s .. " is not well formed") 14 | end 15 | out[1+#out] = y 16 | end 17 | return out 18 | end 19 | 20 | --[[ Table deepcopy ]]-- 21 | function deepcopy(orig) 22 | local orig_type = type(orig) 23 | local copy 24 | if orig_type == 'table' then 25 | copy = {} 26 | for orig_key, orig_value in next, orig, nil do 27 | copy[deepcopy(orig_key)] = deepcopy(orig_value) 28 | end 29 | setmetatable(copy, deepcopy(getmetatable(orig))) 30 | else -- number, string, boolean, etc 31 | copy = orig 32 | end 33 | return copy 34 | end 35 | 36 | --[[ Checking tensor containing Nan ]]-- 37 | function check_nan(t) 38 | return t:sum()~=t:sum() 39 | end 40 | -------------------------------------------------------------------------------- /bmnist/generate.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Generate file 3 | --]] 4 | 5 | require 'cutorch' 6 | require 'cunn' 7 | require 'cudnn' 8 | require 'nngraph' 9 | require 'dpnn' 10 | paths.dofile("aux.lua") 11 | paths.dofile("mnist.lua") 12 | 13 | -------------- Setting ---------------- 14 | local cmd = torch.CmdLine() 15 | cmd:option('--imgname', 'output.png', 'save image name') 16 | cmd:option('--modelpath', '', 'model path') 17 | cmd:option('--nrow', 10, 'nrow') 18 | local opt_eval = cmd:parse(arg or {}) 19 | 20 | -------------- Model reloading ---------------- 21 | local loaded = torch.load(opt_eval.modelpath) 22 | local model, opt = unpack(loaded) 23 | local AE = model:get(1) 24 | local AE_ = nn.Sequential():add(AE:get(1)):add(AE:get(3)) 25 | local D = model:get(2) 26 | local G = model:get(3) 27 | model:evaluate() 28 | 29 | -------------- Data ---------------- 30 | dataloader = MNISTLoader(opt) 31 | dataloader:cuda() 32 | function binarize(x, res) 33 | local maxval = 0.995 34 | local minval = -0.995 35 | res:resizeAs(x):copy(x:gt((maxval+minval)/2)) 36 | return res 37 | end 38 | local x_binmnist_in = torch.CudaTensor() 39 | 40 | -------------- Generation--------------- 41 | --[[ generate fake samples ]]-- 42 | local noise = torch.CudaTensor(opt.batchSize, opt.noiseDim) 43 | noise:normal() 44 | local fake_hid = G:forward(noise) 45 | local fake_gen = AE_:get(2):forward(fake_hid):float() 46 | local _, fake_gen_max = fake_gen:max(2) 47 | local fake_gen_max = fake_gen_max:mul(2):add(-3) 48 | local irec = image.toDisplayTensor({input=fake_gen_max, nrow=opt_eval.nrow, padding=1}) 49 | 50 | --[[ generate real samples ]]-- 51 | local x_mnist = dataloader:getBatch(opt.batchSize, "test") 52 | x_binmnist_in = binarize(x_mnist, x_binmnist_in) 53 | x_binmnist_in:mul(2):add(-1) 54 | local x_ = AE_:forward(x_binmnist_in) 55 | local _,x_ = x_:max(2) 56 | local x_ = x_:mul(2):add(-3) 57 | local irec0 = image.toDisplayTensor({input=x_, padding=1, nrow=opt_eval.nrow}) 58 | local splitbar = torch.Tensor(1,4,irec0:size(2)):fill(0.5) 59 | 60 | --[[ concatnation ]]-- 61 | local todisp = torch.cat({irec0,splitbar,irec},2) 62 | 63 | --[[ save ]]-- 64 | image.save(opt_eval.imgname, todisp) 65 | -------------------------------------------------------------------------------- /bmnist/interpolate.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Z space interpolation file 3 | --]] 4 | 5 | require 'cutorch' 6 | require 'cunn' 7 | require 'cudnn' 8 | require 'nngraph' 9 | require 'dpnn' 10 | paths.dofile("aux.lua") 11 | paths.dofile("mnist.lua") 12 | 13 | -------------- Setting ---------------- 14 | local cmd = torch.CmdLine() 15 | cmd:option('--imgname', 'output.png', 'save image name') 16 | cmd:option('--modelpath', '', 'model path') 17 | cmd:option('--ninterp', 10, 'number of interpolations') 18 | cmd:option('--batchSize', 10, 'how many images to show') 19 | local opt_eval = cmd:parse(arg or {}) 20 | 21 | -------------- Model reloading ---------------- 22 | local loaded = torch.load(opt_eval.modelpath) 23 | local model, opt = unpack(loaded) 24 | local AE = model:get(1) 25 | local AE_ = nn.Sequential():add(AE:get(1)):add(AE:get(3)) 26 | local D = model:get(2) 27 | local G = model:get(3) 28 | model:evaluate() 29 | 30 | -------------- Data ---------------- 31 | dataloader = MNISTLoader(opt) 32 | dataloader:cuda() 33 | function binarize(x, res) 34 | local maxval = 0.995 35 | local minval = -0.995 36 | res:resizeAs(x):copy(x:gt((maxval+minval)/2)) 37 | return res 38 | end 39 | 40 | -------------- Z-space interpolation --------------- 41 | local noise_v = torch.CudaTensor(opt_eval.ninterp, opt.noiseDim) 42 | local noise_l = torch.CudaTensor(opt.noiseDim) 43 | local noise_r = torch.CudaTensor(opt.noiseDim) 44 | --[[ interpolation on one group of left/right z vectors ]]-- 45 | local function interp() 46 | local line = torch.linspace(0, 1, opt_eval.ninterp) 47 | noise_l:normal() 48 | noise_r:normal() 49 | for i = 1, opt_eval.ninterp do 50 | noise_v:select(1, i):copy(noise_l*line[i] + noise_r*(1-line[i])) 51 | end 52 | local fake_hid = G:forward(noise_v) 53 | local fake_gen = AE_:get(2):forward(fake_hid):float() 54 | local _, fake_gen_max = fake_gen:max(2) 55 | local fake_gen_max = fake_gen_max:mul(2):add(-3) 56 | return fake_gen_max 57 | end 58 | -- calling interplation function 59 | local todisp = {} 60 | for i = 1, opt_eval.batchSize do 61 | local this_todisp = interp() 62 | for j = 1, opt_eval.ninterp do 63 | todisp[#todisp+1] = this_todisp[j] 64 | end -- end for 65 | end -- end for i 66 | -- dumping out 67 | local todisp = image.toDisplayTensor({input=todisp, nrow=opt_eval.ninterp, padding=1}) 68 | image.save(opt_eval.imgname, todisp) 69 | -------------------------------------------------------------------------------- /bmnist/main.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Main file for training binarized-MNIST ARAE. 3 | --]] 4 | 5 | require 'cutorch' 6 | require 'cunn' 7 | require 'cudnn' 8 | require 'nngraph' 9 | require 'optim' 10 | require 'dpnn' 11 | paths.dofile("aux.lua") 12 | paths.dofile("mnist.lua") 13 | 14 | --------------------------------------- 15 | -------------- Setting ---------------- 16 | --------------------------------------- 17 | local cmd = torch.CmdLine() 18 | -- general options 19 | cmd:option('--devid', 1, 'gpu id') 20 | cmd:option('--savename', 4001, 'savename of the save') 21 | 22 | -- training settings 23 | cmd:option('--nEpoches', 100, 'Number of "epoches"') 24 | cmd:option('--nIters', 100, 'Number of "iteration"/"minibatches" in each epoch') 25 | cmd:option('--nItersAE', 1, 'Number of iters trained on AE within each iteration') 26 | cmd:option('--nItersGAN', 1, 'Number of iters trained on GAN within each iteration') 27 | cmd:option('--batchSize', 100, 'Batch size') 28 | cmd:option('--learningRateAE', 0.0005, 'Learning rate for auto-encoder') 29 | cmd:option('--learningRateG', 0.0005, 'Learning rate for generator') 30 | cmd:option('--learningRateD', 5e-05, 'Learning rate for discriminator') 31 | -- model settings 32 | cmd:option('--noiseDim', 32, 'Dimension of noise vector') 33 | cmd:option('--hidDim', 100, 'Dimension of code vector') 34 | cmd:option('--archEnc', '800-400', 'architecture of encoder') 35 | cmd:option('--archDec', '400-800-1000', 'architecture of decoder') 36 | cmd:option('--archG', '64-100-150', 'architecture of G') 37 | cmd:option('--archD', '100-60-20', 'architecture of D') 38 | cmd:option('--noiseAE', 0.4, 'std of the noise added to the code vector') 39 | cmd:option('--noiseAnne', 0.99, 'noise exponential decay factor') 40 | -- gan settings 41 | cmd:option('--nItersD', 10, 'Number of iterations for training WGAN critic') 42 | cmd:option('--clamp', 0.05, 'WGAN critic clamp') 43 | cmd:option('--gan2enc', -0.2, 'Multiplier to the gradient from GAN to enc') 44 | 45 | local opt = cmd:parse(arg or {}) 46 | torch.setnumthreads(4) 47 | torch.setdefaulttensortype('torch.FloatTensor') 48 | cutorch.setDevice(opt.devid) 49 | 50 | -- logging aux 51 | local filename = paths.concat(opt.savename, "log") 52 | os.execute("mkdir -p " .. opt.savename) 53 | print("Saving folder name: " .. opt.savename) 54 | function print_(cont, to_stdout) 55 | local to_stdout = (to_stdout==nil) and true or to_stdout 56 | local str 57 | if type(cont) == "table" then 58 | str = table.concat(cont, '\n') 59 | elseif type(cont) == "string" then 60 | str = cont 61 | end 62 | if to_stdout then 63 | print(str) 64 | end 65 | file = io.open(filename, 'a') 66 | file:write(str .. "\n") 67 | file:close() 68 | end 69 | 70 | --------------------------------------- 71 | -------------- Data ------------------- 72 | --------------------------------------- 73 | dataloader = MNISTLoader(opt) 74 | dataloader:cuda() 75 | function binarize(x, res) 76 | local maxval = 0.995 77 | local minval = -0.995 78 | res:resizeAs(x):copy(x:gt((maxval+minval)/2)) 79 | return res 80 | end 81 | -- buffers for discretetion 82 | local x_binmnist_in = torch.CudaTensor() 83 | local x_binmnist_ou = torch.CudaTensor() 84 | -- data preparation 85 | function prep_batch(set) 86 | local set = set or "train" 87 | local x_mnist = dataloader:getBatch(opt.batchSize, set) 88 | x_binmnist_in = binarize(x_mnist, x_binmnist_in) 89 | x_binmnist_in:mul(2):add(-1) 90 | x_binmnist_ou = binarize(x_mnist, x_binmnist_ou):squeeze() 91 | x_binmnist_ou:add(1) 92 | end 93 | 94 | --------------------------------------- 95 | -------------- Model ------------------ 96 | --------------------------------------- 97 | --[[ AE ]]-- 98 | AE = nn.Sequential() 99 | local arch_enc = convert_option(opt.archEnc) 100 | local Enc = nn.Sequential() 101 | local Dec = nn.Sequential() 102 | AE:add(Enc) 103 | -- encoder 104 | Enc:add(nn.View(784):setNumInputDims(3)) 105 | table.insert(arch_enc, 1, 784) 106 | for i = 1, #arch_enc-1 do 107 | Enc:add(nn.Linear(arch_enc[i],arch_enc[i+1])) 108 | Enc:add(nn.BatchNormalization(arch_enc[i+1])) 109 | Enc:add(cudnn.ReLU()) 110 | end 111 | Enc:add(nn.Linear(arch_enc[#arch_enc], opt.hidDim)) 112 | Enc:add(nn.Normalize(2)) 113 | AE:add(nn.WhiteNoise(0, opt.noiseAE)) 114 | -- decoder 115 | AE:add(Dec) 116 | local arch_dec = convert_option(opt.archDec) 117 | table.insert(arch_dec, 1, opt.hidDim) 118 | for i = 1, #arch_dec-1 do 119 | Dec:add(nn.Linear(arch_dec[i],arch_dec[i+1])) 120 | Dec:add(nn.BatchNormalization(arch_dec[i+1])) 121 | Dec:add(cudnn.ReLU()) 122 | end 123 | Dec:add(nn.Linear(arch_dec[#arch_dec], 784*2)) 124 | Dec:add(nn.View(2,28,28):setNumInputDims(1)) 125 | -- AE with no noise layer 126 | AE_ = nn.Sequential():add(Enc):add(Dec) 127 | -- criterion training AE 128 | criterionAE = cudnn.SpatialCrossEntropyCriterion() 129 | criterionAE:cuda() 130 | 131 | --[[ GAN ]]-- 132 | -- GAN generator 133 | G = nn.Sequential() 134 | local arch_g = convert_option(opt.archG) 135 | table.insert(arch_g, 1, opt.noiseDim) 136 | for i = 1, #arch_g-1 do 137 | G:add(nn.Linear(arch_g[i], arch_g[i+1])) 138 | G:add(nn.BatchNormalization(arch_g[i+1])) 139 | G:add(cudnn.ReLU()) 140 | end 141 | G:add(nn.Linear(arch_g[#arch_g],opt.hidDim)) 142 | G:add(nn.Tanh()) 143 | -- GAN discriminator/critic 144 | D = nn.Sequential() 145 | local arch_d = convert_option(opt.archD) 146 | D:add(nn.Linear(opt.hidDim, arch_d[1])) 147 | D:add(nn.LeakyReLU(0.2)) 148 | for i = 1, #arch_d-1 do 149 | D:add(nn.Linear(arch_d[i], arch_d[i+1])) 150 | D:add(nn.BatchNormalization(arch_d[i+1])) 151 | D:add(nn.LeakyReLU(0.2)) 152 | end 153 | D:add(nn.Linear(arch_d[#arch_d], 1)) 154 | D:add(nn.Mean()) 155 | 156 | --[[ parameter flatterning ]]-- 157 | local model = nn.Sequential():add(AE):add(D):add(G):cuda() 158 | param_ae, gparam_ae = AE:getParameters() 159 | ae_config = {learningRate=opt.learningRateAE, beta1=opt.beta1} 160 | param_d, gparam_d = D:getParameters() 161 | d_config = {learningRate=opt.learningRateD, beta1=opt.beta1} 162 | param_g, gparam_g = G:getParameters() 163 | g_config = {learningRate=opt.learningRateG, beta1=opt.beta1} 164 | 165 | --[[ initialization on the models ]]-- 166 | local function initModel(model, std) 167 | for _, m in pairs(model:listModules()) do 168 | local function setWeights(module, std) 169 | weight = module.weight 170 | bias = module.bias 171 | if weight then weight:randn(weight:size()):mul(std) end 172 | if bias then bias:zero() end 173 | end 174 | setWeights(m, std) 175 | end 176 | end 177 | initModel(D, 0.02) 178 | initModel(G, 0.02) 179 | print(Enc, Dec, D, G) 180 | 181 | --------------------------------------- 182 | -------------- Train ------------------ 183 | --------------------------------------- 184 | -- init buffers 185 | local noise = torch.CudaTensor(opt.batchSize, opt.noiseDim) 186 | local function make_noise() 187 | noise:resize(opt.batchSize, opt.noiseDim):normal() 188 | end 189 | local gan_grad = torch.CudaTensor(1) 190 | local loss_real, loss_fake, loss_D, loss_fakeG, lossAE = 0, 0, 0, 0, 0 191 | --[[training callback functions ]]-- 192 | do 193 | --[[ fevalAE ]]-- 194 | function fevalAE(x) 195 | assert(x == param_ae) 196 | gparam_ae:zero() 197 | local output = AE:forward(x_binmnist_in) 198 | lossAE = criterionAE:forward(output, x_binmnist_ou) 199 | local derr_AE = criterionAE:backward(output, x_binmnist_ou) 200 | AE:backward(x_binmnist_in, derr_AE) 201 | return lossAE, gparam_ae 202 | end 203 | 204 | --[[ fevalD ]]-- 205 | function fevalD(x) 206 | assert(x == param_d) 207 | gparam_d:zero() 208 | x:clamp(-opt.clamp, opt.clamp) 209 | -- on real samples 210 | local real = AE:get(1):forward(x_binmnist_in) 211 | loss_real = D:forward(real)[1] 212 | local dloss_real = gan_grad:fill(1) 213 | D:backward(real, dloss_real) 214 | -- on fake samples 215 | local fake = G:forward(noise) 216 | loss_fake = D:forward(fake)[1] 217 | local dloss_fake = gan_grad:fill(-1) 218 | loss_D = loss_real - loss_fake 219 | D:backward(fake, dloss_fake) 220 | return loss_d, gparam_d 221 | end 222 | 223 | --[[ fevalAE_fromGAN ]]-- 224 | function fevalAE_fromGAN(x) 225 | assert(x == param_ae) 226 | gparam_ae:zero() 227 | -- on real samples 228 | local real = AE:get(1):forward(x_binmnist_in) 229 | local loss_real_ = D:forward(real)[1] 230 | local dloss_real = gan_grad:fill(1) 231 | local dreal = D:updateGradInput(real, dloss_real) 232 | dreal:mul(-math.abs(opt.gan2enc)) 233 | -- fed back to the encoder 234 | AE:get(1):backward(x_binmnist_in, dreal) 235 | return loss_real, gparam_ae 236 | end 237 | 238 | --[[ fevalG ]]-- 239 | function fevalG(x) 240 | assert(x == param_g) 241 | gparam_g:zero() 242 | 243 | noise:normal() 244 | local fake = G:forward(noise) 245 | loss_fakeG = D:forward(fake)[1] 246 | local dloss_fake = gan_grad:fill(1) 247 | local dG = D:updateGradInput(fake, dloss_fake) 248 | G:backward(noise, dG) 249 | return loss_fakeG, gparam_g 250 | end 251 | end 252 | -- training loop 253 | for iEpoch = 1, opt.nEpoches do 254 | local tt = tt or torch.Timer() 255 | cutorch.synchronize() 256 | model:training() 257 | loss_real, loss_fake, loss_D, loss_fakeG, lossAE = 0, 0, 0, 0, 0 258 | for iIter = 1, opt.nIters do 259 | ------ training AE ------ 260 | for iAE = 1, opt.nItersAE do 261 | prep_batch() 262 | optim.adam(fevalAE, param_ae, ae_config) 263 | end -- end for iAE = 1, 264 | 265 | ------ training GAN ------ 266 | for iGAN = 1, opt.nItersGAN do 267 | --- pass on D --- 268 | for iD = 1, opt.nItersD do 269 | prep_batch() 270 | make_noise() 271 | optim.adam(fevalD, param_d, d_config) 272 | --- backproping D into Enc --- 273 | optim.adam(fevalAE_fromGAN, param_ae, ae_config) 274 | end -- end for iD = 1 275 | --- pass on G --- 276 | optim.adam(fevalG, param_g, g_config) 277 | end -- end for iGAN = 1 278 | end -- end for i = 1, nIters 279 | model:evaluate() 280 | -- nan testing 281 | if check_nan(param_d) or check_nan(param_g) then 282 | error("Nan learnt.") 283 | end 284 | -- noise annealing 285 | AE:get(2).std = AE:get(2).std * opt.noiseAnne 286 | -- print message 287 | cutorch.synchronize() 288 | local tim = tt:time()['real'] 289 | local message = string.format("epo: %d, lossD: %.4f, lossAE: %.4f, lossG: %.4f, elaps: %.2e", 290 | iEpoch, -loss_D, lossAE, loss_fakeG, tim) 291 | print_(message) 292 | -- model saving and generation 293 | if iEpoch == opt.nEpoches then 294 | model:clearState() 295 | torch.save(string.format("%s/model.t7", opt.savename), {model, opt}) 296 | end 297 | tt:reset() 298 | end 299 | -------------------------------------------------------------------------------- /bmnist/mnist.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Data loader for MNIST 3 | 4 | MNIST t7 dataset should be downloaded and placed under this folder, 5 | with the strcuture: 6 | . 7 | +-- mnist 8 | +-- train_28x28.t7 9 | +-- test_28x28.t7 10 | --]] 11 | 12 | require "nn" 13 | torch.setdefaulttensortype("torch.FloatTensor") 14 | require "image" 15 | paths.dofile("aux.lua") 16 | 17 | local loader = torch.class("MNISTLoader") 18 | 19 | --[[ constructor ]]-- 20 | function loader:__init(config) 21 | -- config parsing 22 | self.config = deepcopy(config) 23 | self.nC = self.config.nC or 1 24 | self.h = 28 25 | self.w = 28 26 | -- dataset loading 27 | self.train = torch.load("./mnist/train_28x28.th7") 28 | self.test = torch.load("./mnist/test_28x28.th7") 29 | -- preprocessing 30 | self.train.data = self.train.data:float() 31 | self.test.data = self.test.data:float() 32 | self.train.data:div(255/2):add(-1):clamp(-0.995,0.995) 33 | self.test.data:div(255/2):add(-1):clamp(-0.995,0.995) 34 | -- tensor buffers cacheing the output 35 | self.output = torch.Tensor() 36 | self.labels = torch.ByteTensor() 37 | collectgarbage() 38 | end 39 | 40 | --[[ shuffle set ]]-- 41 | function loader:shuffle(set) 42 | local this_set 43 | if set == "train" then 44 | this_set = self.train 45 | elseif set == "labeled" then 46 | this_set = self.labeled 47 | else 48 | error("MnistLoader:shuffle(set): set has to be [train|labeled]") 49 | end 50 | local randperm = torch.randperm(this_set.data:size(1)):long() 51 | this_set.data = this_set.data:index(1, randperm) 52 | this_set.labels = this_set.labels:index(1, randperm) 53 | collectgarbage() 54 | end 55 | 56 | --[[ get one batch ]]-- 57 | function loader:getBatch(batchsize, set) 58 | local this_set 59 | if set == "train" then 60 | this_set = self.train 61 | elseif set == "test" then 62 | this_set = self.test 63 | elseif set == "labeled" then 64 | this_set = self.labeled 65 | end 66 | -- preparing buffers 67 | self.output:resize(batchsize, self.nC, self.h, self.w) 68 | self.labels:resize(batchsize) 69 | -- boostrapping samples 70 | for i = 1, batchsize do 71 | local idx = torch.random(this_set.data:size(1)) 72 | self.output[i]:copy(this_set.data[idx]) 73 | self.labels[i] = this_set.labels[idx] 74 | end 75 | return self.output, self.labels 76 | end 77 | 78 | --[[ type cast ]]-- 79 | function loader:type(typ) 80 | if typ ~= nil then 81 | self.output = self.output:type(typ) 82 | self.labels = self.labels:type(typ) 83 | collectgarbage() 84 | end 85 | return self 86 | end 87 | --[[ Cudaize ]]-- 88 | function loader:cuda() 89 | self:type("torch.CudaTensor") 90 | end 91 | 92 | -------------------------------------------------------------------------------- /lang/README.md: -------------------------------------------------------------------------------- 1 | # ARAE for Language 2 | 3 | ## Requirements 4 | - Python 3.6.3 on Linux 5 | - PyTorch 0.3.1, JSON, Argparse 6 | - KenLM (https://github.com/kpu/kenlm) 7 | 8 | ### KenLM Installation: 9 | - Download stable release and unzip: http://kheafield.com/code/kenlm.tar.gz 10 | - Need Boost >= 1.42.0 and bjam 11 | - Ubuntu: `sudo apt-get install libboost-all-dev` 12 | - Mac: `brew install boost; brew install bjam` 13 | - Run *within* kenlm directory: 14 | ```bash 15 | mkdir -p build 16 | cd build 17 | cmake .. 18 | make -j 4 19 | ``` 20 | - `pip install https://github.com/kpu/kenlm/archive/master.zip` 21 | - For more information on KenLM see: https://github.com/kpu/kenlm and http://kheafield.com/code/kenlm/ 22 | 23 | ## Train and Pretrain models 24 | * [SNLI](doc/README_snli.md) 25 | * [1BWord benchmark](doc/README_oneb.md) 26 | 27 | ### Your Customized Datasets 28 | If you would like to train a text ARAE on another dataset, simply 29 | 1) Create a data directory with a `train.txt` and `test.txt` files with line delimited sentences. 30 | 2) Run training command with the `--data_path` argument pointing to that data directory. 31 | 32 | ## Evaluation with RNNLM 33 | 34 | To evaluate the reverse PPL with an RNNLM, first preprocess the data with the generated/real text files, e.g. 35 | 36 | ``` 37 | python preprocess_lm.py --trainfile generated-data.txt --valfile real-val.txt --testfile real-test.txt 38 | ``` 39 | 40 | To train the model 41 | 42 | ``` 43 | python train_rnnlm.py --train_file lm-data-train.hdf5 --val_file lm-data-val.hdf5 --checkpoint_path lm-model.ptb 44 | ``` 45 | 46 | To evaluate on test 47 | 48 | ``` 49 | python train_rnnlm.py --trainfile lm-data-train.hdf5 --val_file lm-data-test.hdf5 --train_from lm-model.ptb --test 1 50 | ``` 51 | 52 | -------------------------------------------------------------------------------- /lang/doc/README_oneb.md: -------------------------------------------------------------------------------- 1 | ## Pretrained version 2 | 3 | 1) Download and unzip from [here](https://drive.google.com/file/d/1h4GlTP1iVbQQQfZkSGoNtbcfCp1D2gzB/view?usp=sharing) 4 | 5 | 2) Run: 6 | 7 | ``` 8 | python generate.py --load_path ./oneb_pretrained -steps 8 9 | ``` 10 | 11 | 12 | ## Example generations 13 | 14 | ``` 15 | the store collapsed after two days , but the situation is turning out as this play . 16 | you meet i with myself pushing a . the best solution to that type and mental illness . 17 | the french worker pulled a gas leak at the group , whose eyes shot found . 18 | quite recently held as " one bit . " 19 | the church of manchester , a known member of merseyside parish , claims crime . 20 | the couple will have to flood three wards with the third estate , of course . 21 | it was not giving " hope " against felt and was modest . 22 | the plane was a public marketing . 23 | there are the unfortunately consequences . 24 | the market to keep consumers below the ranks of around 300 , " she never said ." 25 | ``` 26 | 27 | ## Examplar interpolations 28 | 29 | ``` 30 | it said colorado with its men who . 31 | it said last may of " may remember . " 32 | " it sounds when a lot of refugees may say the word . 33 | it me what i said , what we told the yankees on new tests . 34 | the three said it 's about " who in the pub of america , " he said . 35 | the three said teachers about the other 90 minutes with new signs of problems , he said . 36 | why if anyone could be allowed to enter the nation 's most comfortable university of its worst ? 37 | why pay millions of the nurses to marry with a job , and don 't think of public life . 38 | 39 | the exhibition followed the set panel to carry its staff to finish their lives about their work hours . 40 | the exhibition followed the set panel with one drink ; half per cent said they were prepared to retire . 41 | the publisher said it believes the offences may be worth in , over eight years older . 42 | the shares slid then traded highs , the euro with $ to settle at $ . 43 | to be guided 's own offense is attempting to the problem over your professional league . 44 | to be handled with paul 's length and constantly is the master sergeant . 45 | he was able to convert his professional 's outfit with these words . 46 | he said a formula could be : jeremy clarkson , 18 , always . 47 | 48 | the bills signed him the candidate and rep. pat . 49 | he then filed the bankruptcy and found friday 5 . 50 | he estimates 100 other the 33 million died in 2007 . 51 | 500 feet far near the bottom a & p . 52 | oil ceo will have the milwaukee brewers ' 53 | oil eventually will continue the last session above 10,000 . 54 | the vast majority will have died last week , the youngest day . 55 | authorities have obviously blamed a policeman holding a pair of the rocks . 56 | 57 | he still , it matters . 58 | he still looks like it , " ali said . 59 | he failed but , she said two errors . 60 | such , short term " fundamentally broken old . " 61 | such a key idea is strictly , " he asked . 62 | he finds this : they sell and destroy the ford body , which they do . 63 | he threw it from now the taliban but virtually so without command and military equipment . 64 | in the opposite mix it is derived from a modern television , and it caught in all worlds . 65 | 66 | if banks do not walk out , the process was able to drive without running , he said . 67 | to work , they must roll up the distance to get through with something nice walking . 68 | to make " the big cat " games , people finally get over to all profit . 69 | but had not been seriously the problem - finishing duke and with boards . 70 | but , and the us company boss , " csi " between january . 71 | but i was in near the la times during an rather than in 2007 . 72 | but , 's grant the korean peninsula back across . 73 | in 1977 , 's christian parish appealed the letter into ."'" 74 | ``` 75 | 76 | ## Train 77 | 78 | ``` 79 | python run_oneb.py --data_path ./data/oneb --no_earlystopping 80 | ``` 81 | -------------------------------------------------------------------------------- /lang/doc/README_snli.md: -------------------------------------------------------------------------------- 1 | ## Preparation 2 | 3 | ### SNLI Data Preparation 4 | - Download dataset and unzip: 5 | ``` 6 | mkdir data; cd data 7 | wget https://nlp.stanford.edu/projects/snli/snli_1.0.zip 8 | unzip snli_1.0.zip 9 | cd ..; python snli_preprocessing.py --in_path data/snli_1.0 --out_path data/snli_lm 10 | ``` 11 | 12 | 13 | ## Pretrained version 14 | 15 | 1) Download and unzip from [here](https://drive.google.com/file/d/1h66T8UdFuNWWjvmLLcYC9bHExpNv8NR4/view?usp=sharing) 16 | 17 | 2) Run: 18 | 19 | ``` 20 | python generate.py --load_path ./snli_pretrained 21 | ``` 22 | 23 | 24 | ## Example generations 25 | 26 | ``` 27 | young boy bowling to the death . 28 | a woman wearing sunglasses for a woman to talk . 29 | a man is laying down enjoying a mural of the leaning against his eyes . 30 | animals are reading from old . 31 | the women are n't admiring anything . 32 | the man is watching the beach on his phone . 33 | a woman is eating . 34 | a shirtless girl carrying a backpack . 35 | the couple is holding on the beach . 36 | two women face off a taxi . 37 | ``` 38 | 39 | ## Examplar interpolations 40 | 41 | ``` 42 | the weather does winter . 43 | the two humans repairing . 44 | two people ride a bus . 45 | two people standing inside a firetruck . 46 | two men sitting down on a stage posing in a room . 47 | 48 | a man has two things up . 49 | a woman has three green . 50 | a woman takes a smoke outside . 51 | a woman takes a bath outside . 52 | a woman takes a bath with her hands . 53 | 54 | an outdoor food cart . 55 | a kid outside is fixing . 56 | a man is holding a painting . 57 | a man is on a scooter 58 | a dog on a leash is in a car . 59 | 60 | a man on a hill 61 | the dog on the shore 62 | the dog is on the stairs . 63 | the car is advertising . 64 | the car is made surgery . 65 | 66 | a man in blue glasses has paint on the wall of lunch . 67 | a boy is playing with toys on the train tracks . 68 | a boy is looking at clothes on display . 69 | a child waits to hold on a shovel to stop . 70 | a child smiles to someone not by a river 71 | ``` 72 | 73 | ## Train 74 | 75 | ``` 76 | python run_snli.py --data_path ./data/snli_lm --no_earlystopping 77 | ``` 78 | -------------------------------------------------------------------------------- /lang/experiments/README.md: -------------------------------------------------------------------------------- 1 | 2 | # Walk 3 | 4 | Computes a random vector and then generates many sentences with a small noise 5 | perturbation. It prints new sentences and highlights how they differ from the original. 6 | 7 | 8 | ## Examples 9 | 10 | 11 | __A man goes in front of a skateboard .__ 12 | 13 | A __snowboarder crashes__ in a __board at the hospital__ . 14 | 15 | A man goes in a __cabin__ . 16 | 17 | A man __pours into the mud__ . 18 | 19 | A man goes in __midair__ . 20 | 21 | A man goes __into the mud__ . 22 | 23 | A man goes in front of a __telescope__ . 24 | 25 | A man goes in __their driveway__ . 26 | 27 | A man goes in . 28 | 29 | A man goes in front of __the__ skateboard . 30 | 31 | A __snowboarder crashes__ in front of a __bed__ . 32 | 33 | A man goes in __the mud__ . 34 | 35 | A man __pours__ in front of a __lake__ . 36 | 37 | ``` 38 | Workers clean a very colorful street . 39 | Workers *enjoy* a very colorful . 40 | *Women enjoy* a very *quiet* . 41 | Workers clean a *long white area* . 42 | Workers *enjoy* a very colorful *hallway* . 43 | Workers *carry* a very colorful *and empty* . 44 | Workers clean a *long white structure* . 45 | Workers *enjoy* a very colorful *narrow* . 46 | Workers clean a very *tall structure* . 47 | *Women enjoy* a *green sunny* . 48 | Workers clean a *long white balloon* . 49 | *Women enjoy* a *green outdoor* . 50 | *Women enjoy* a *green sunny weather* . 51 | Workers *enjoy* a very colorful *environment* . 52 | Workers *enjoy* a *green construction and* . 53 | *Women enjoy* a *green outdoor staircase* . 54 | Workers *enjoy* a very colorful *walkway* . 55 | *Women enjoy* a very *green hallway* . 56 | Workers *enjoy* a very colorful *city* . 57 | ``` 58 | 59 | # vector.py 60 | 61 | An experiment to try to change the content of a sentence based my modifying the 62 | latent variable z. The alteration is done by sampling sentences from a fixed 63 | z, then updating towards the mean of the desired feature. We do this repeatedly 64 | until the argmax generation has the correct feature. 65 | 66 | 67 | ## Requirements 68 | 69 | * spacy 70 | 71 | ## Usage 72 | 73 | 1) Generate 1M sample sentences with vectors and features 74 | 75 | > python -m experiments.vector gen --dump features.pkl --load_path maxlen15 76 | 77 | 2) Attempt to alter new sentences based on the mean vector of `VERB_standingssss` sentences. 78 | 79 | > python -m experiments.vector alter --dump features.pkl --alter VERB_standing --load_path maxlen15 80 | 81 | 82 | ## Examples 83 | 84 | 85 | 86 | ``` 87 | Sent 0 : The man is wearing a scarf 88 | Sent 1 : The man is wearing a hat 89 | Sent 2 : The man is smoking 90 | Sent 3 : man is seated 91 | Sent 4 : man is standing on roof . 92 | 93 | 94 | Sent 0 : A man rode on vacation . 95 | Sent 1 : A man is riding . 96 | Sent 2 : A man is stand on . 97 | Sent 3 : A man is standing on fire . 98 | 99 | 100 | Sent 0 : A brown dog is inside of the house . 101 | Sent 1 : A white bird is sitting at . 102 | Sent 2 : A white boy is working in water . 103 | Sent 3 : A white man is sitting and is in a pool . 104 | Sent 4 : A young dog is standing in a water . 105 | 106 | 107 | Sent 0 : A man in a cowboy hat with sand near a mirror . 108 | Sent 1 : A man in a hand hat with face sits in a piece . 109 | Sent 2 : A man in a hat standing in hand as a piece of works . 110 | 111 | 112 | Sent 0 : A girl has a beach on blanket . 113 | Sent 1 : There is a boys sleeping on . 114 | Sent 2 : There is a standing girl on the shoreline . 115 | 116 | 117 | Sent 0 : A people are sitting on a close trail . 118 | Sent 1 : Boys are sitting with a tank top . 119 | Sent 2 : people are standing up a big . 120 | 121 | 122 | Sent 0 : There is the road . 123 | Sent 1 : There is a . 124 | Sent 2 : There is a man . 125 | Sent 3 : There is a man in a street . 126 | Sent 4 : There is a man in a street . 127 | Sent 5 : There is a man in a street . 128 | Sent 6 : There is a man standing in a street . 129 | 130 | 131 | Sent 0 : A young girls jump fast and a wheelchair of an outside . 132 | Sent 1 : A football teams stand next to a brown building and a street . 133 | Sent 2 : Several girls are standing up and a beautiful of a street . 134 | 135 | 136 | Sent 0 : Two men working on an outdoor bus beside a orange house and orange . 137 | Sent 1 : Two men sitting and conversing through a tall grass playing together . 138 | Sent 2 : Two men standing outside near a dead and man standing around wearing vests . 139 | 140 | 141 | Sent 0 : A rock plastic glass and biting the branch of it . 142 | Sent 1 : A very young boy hanging out of the camera and tall . 143 | Sent 2 : A very sad boy standing out of the tall grass . 144 | 145 | 146 | Sent 0 : A lady playing hopscotch on her hands . 147 | Sent 1 : A girl playing hopscotch on the sidewalk holding . 148 | Sent 2 : A women is sitting behind the window . 149 | Sent 3 : A woman standing inside of the grass . 150 | 151 | 152 | Sent 0 : The woman is smiling . 153 | Sent 1 : The woman is smiling . 154 | Sent 2 : The woman is standing . 155 | 156 | 157 | Sent 0 : Boy in orange on a sand beach . 158 | Sent 1 : Boy on holding a sand beach in pool . 159 | Sent 2 : Boy on holding a snow ball in is standing . 160 | 161 | 162 | Sent 0 : A man rides the motorcycle or fire coming across the water . 163 | Sent 1 : The man is in winter gear as the traffic lights water from the crowd . 164 | Sent 2 : The man is behind the river or trying to make water into the distance . 165 | Sent 3 : The man is standing behind through the side of or preparing the water . 166 | 167 | 168 | Sent 0 : Two dogs fight 169 | Sent 1 : Two men walk out 170 | Sent 2 : Two men are out 171 | Sent 3 : Two men are walking around a pool . 172 | Sent 4 : Two men are walking a fountain . 173 | Sent 5 : Several men are walking by a fountain . 174 | Sent 6 : Several men are standing by a fountain . 175 | 176 | 177 | Sent 0 : Several dogs play a . 178 | Sent 1 : There are two sitting . 179 | Sent 2 : There are two sitting in a working . 180 | Sent 3 : There are men standing near a . 181 | 182 | 183 | Sent 0 : A boy without his mouth climbs . 184 | Sent 1 : A boy wears sunglasses with some white sunglasses . 185 | Sent 2 : A boy wall is without his orange . 186 | Sent 3 : A young boy is standing under an orange . 187 | 188 | 189 | Sent 0 : five people 's sliding down a boat . 190 | Sent 1 : five people kicking the ball at a soccer game . 191 | Sent 2 : five people kicking the ball at a soccer game . 192 | Sent 3 : five people looking at the ball through a play . 193 | Sent 4 : five people dressed as the child play . 194 | Sent 5 : five people standing next to the play . 195 | 196 | 197 | Sent 0 : The are blue . 198 | Sent 1 : The white bike are at the court . 199 | Sent 2 : The white are standing on the sidewalk . 200 | 201 | 202 | Sent 0 : Men are sitting at a table . 203 | Sent 1 : Men are sitting on . 204 | Sent 2 : Men are standing on a bench . 205 | 206 | 207 | Sent 0 : A group of men passing them are going a . 208 | Sent 1 : A men of men are walking along a street . 209 | Sent 2 : A man dressed as two are taking pictures to a run . 210 | Sent 3 : A man dressed as three men are over to a river . 211 | Sent 4 : Several men standing next to two bikes are near a line . 212 | 213 | 214 | Sent 0 : A crowd is breaking up on some kind of playing . 215 | Sent 1 : A crowd is standing around on the balcony and making on . 216 | 217 | 218 | Sent 0 : a person hangs on a wall 219 | Sent 1 : A person hangs on a wall on cliff . 220 | Sent 2 : A person sits on a wall on it . 221 | Sent 3 : A person is standing on a rope on roof . 222 | 223 | 224 | Sent 0 : The kid a bucket . 225 | Sent 1 : The kid is nearby . 226 | Sent 2 : The kid is nearby . 227 | Sent 3 : The kid is standing . 228 | 229 | 230 | Sent 0 : A woman is a man in a beach next . 231 | Sent 1 : A man is talking in a field with a glass . 232 | Sent 2 : A man is standing in a beach with a garbage . 233 | 234 | 235 | Sent 0 : The young girl sit along on two wooden , enjoying a wrestling . 236 | Sent 1 : The young girl sit along side on a wet beach bowling . 237 | Sent 2 : Women standing , are sitting on two hanging around a young woman . 238 | 239 | 240 | Sent 0 : There are men sitting on a ledge . 241 | Sent 1 : There are men sitting on a ledge . 242 | Sent 2 : There are man standing on a ledge . 243 | 244 | 245 | Sent 0 : a blond girl with a painting camera outside 246 | Sent 1 : a women with a purse smoking and reading through her . 247 | Sent 2 : A woman standing with a smoking and a vehicle holding road . 248 | 249 | 250 | Sent 0 : There are a sad woman on a floor . 251 | Sent 1 : There are a sad woman on stairs . 252 | Sent 2 : There are standing on a phone with a standing . 253 | 254 | 255 | Sent 0 : There are two cars at a track . 256 | Sent 1 : There are two people down a river . 257 | Sent 2 : There is two people walking for a sport . 258 | Sent 3 : There is two people walking for a big . 259 | Sent 4 : There is two people walking on a stream . 260 | Sent 5 : people is walking on a stream . 261 | Sent 6 : people is walking on a stream . 262 | Sent 7 : People are standing by a parade . 263 | 264 | 265 | Sent 0 : Young women are sitting and crossing a street performer . 266 | Sent 1 : The colorful ladies are walking toward a couple street . 267 | Sent 2 : The naked standing boy are sitting at a street construction . 268 | 269 | 270 | Sent 0 : There is some men walking walking and looks . 271 | Sent 1 : There is a woman standing walking and looking out . 272 | ``` 273 | -------------------------------------------------------------------------------- /lang/experiments/noise.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from models import load_models, generate 3 | import torch 4 | import difflib 5 | import numpy.linalg 6 | 7 | ENDC = '\033[0m' 8 | BOLD = '\033[1m' 9 | 10 | 11 | def main(args): 12 | noise = torch.ones(100, model_args['z_size']) 13 | 14 | for k in range(10): 15 | noise[0].normal_() 16 | for i in range(1, 100): 17 | noise[i].normal_() 18 | noise[i] = noise[i] / (10 * numpy.linalg.norm(noise[i])) 19 | noise[i] += noise[0] 20 | sents = gen(noise) 21 | print(sents[0]) 22 | seen = set() 23 | seen.add(sents[0]) 24 | for i in range(40): 25 | if sents[i] not in seen: 26 | seen.add(sents[i]) 27 | a = sents[0].split() 28 | b = sents[i].split() 29 | sm = difflib.SequenceMatcher(a=a, b=b) 30 | 31 | for tag, i1, i2, j1, j2 in sm.get_opcodes(): 32 | if tag == "equal": 33 | print(" ".join(b[j1:j2]), end=" ") 34 | if tag == "replace": 35 | print(BOLD + " ".join(b[j1:j2]) + ENDC, end=" ") 36 | # print("*" + " ".join(b[j1:j2]) + "*", end=" ") 37 | print() 38 | print() 39 | 40 | def gen(vec): 41 | "Generate argmax sentence from vector." 42 | return generate(autoencoder, gan_gen, z=vec, 43 | vocab=idx2word, sample=False, 44 | maxlen=model_args['maxlen']) 45 | 46 | 47 | if __name__ == "__main__": 48 | parser = argparse.ArgumentParser(description='PyTorch experiment') 49 | parser.add_argument('--load_path', type=str, 50 | help='directory to load models from') 51 | args = parser.parse_args() 52 | model_args, idx2word, autoencoder, gan_gen, gan_disc \ 53 | = load_models(args.load_path) 54 | 55 | main(args) 56 | -------------------------------------------------------------------------------- /lang/experiments/vector.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pickle 3 | from collections import defaultdict 4 | import spacy 5 | from spacy.symbols import nsubj, VERB 6 | from models import load_models, generate 7 | import argparse 8 | import torch 9 | 10 | nlp = spacy.load("en") 11 | 12 | 13 | def get_subj_verb(sent): 14 | "Given a parsed sentence, find subject, verb, and subject modifiers." 15 | sub = set() 16 | verbs = set() 17 | mod = set() 18 | for i, possible_subject in enumerate(sent): 19 | if possible_subject.dep == nsubj and possible_subject.head.pos == VERB: 20 | if possible_subject.head.head.pos == VERB: 21 | verbs.add(str(possible_subject.head.head)) 22 | else: 23 | verbs.add(str(possible_subject.head)) 24 | sub.add(str(possible_subject)) 25 | c = list(possible_subject.children) 26 | for w in c: 27 | mod.add(str(w)) 28 | if not c: 29 | mod.add(str(sent[i-1]) if i != 0 else "NONE") 30 | return verbs, sub, mod 31 | 32 | 33 | def featurize(sent): 34 | "Given a sentence construct a feature rep" 35 | verb, sub, mod = get_subj_verb(sent) 36 | d = {} 37 | 38 | def add(d, pre, ls): 39 | for l in ls: 40 | d[pre + "_" + l] = 1 41 | add(d, "VERB", list(verb)[:1]) 42 | add(d, "MOD", list(mod)) 43 | add(d, "NOUN", list(sub)[:1]) 44 | return d 45 | 46 | 47 | def gen(vec): 48 | "Generate argmax sentence from vector." 49 | return generate(autoencoder, gan_gen, z=torch.FloatTensor(vec).view(1, -1), 50 | vocab=idx2word, sample=False, 51 | maxlen=model_args['maxlen']) 52 | 53 | 54 | def gen_samples(vec): 55 | "Generate sample sentences from vector." 56 | sentences = [] 57 | sentences = generate(autoencoder, gan_gen, z=torch.FloatTensor(vec) 58 | .view(1, -1).expand(20, vec.shape[0]), 59 | vocab=idx2word, sample=True, 60 | maxlen=model_args['maxlen'])[0] 61 | return sentences 62 | 63 | 64 | def switch(vec, mat, rev, f1, f2): 65 | "Update vec away from feature1 and towards feature2." 66 | means = [] 67 | m2 = np.mean(mat[list(rev[f2])], axis=0) 68 | for f in f1: 69 | if list(rev[f]): 70 | means.append(np.mean(mat[list(rev[f])], axis=0)) 71 | m1 = np.mean(means) if f1 else np.zeros(m2.shape) 72 | 73 | val = vec + (m2 - m1) 74 | return val, vec - m1 75 | 76 | 77 | def alter(args): 78 | sents, features, rev, mat = pickle.load(open(args.dump, "br")) 79 | mat = mat.numpy() 80 | 81 | # Find examples to alter toward new feat. 82 | new_feat = args.alter 83 | 84 | pre = new_feat.split("_")[0] 85 | word = new_feat.split("_")[1] 86 | for i in range(args.nsent): 87 | vec = mat[i] 88 | 89 | for j in range(10): 90 | sent = gen(vec)[0] 91 | f = featurize(nlp(sent)) 92 | print("Sent ", j, ": \t ", sent, "\t") 93 | if word in sent: 94 | break 95 | 96 | # Compute the feature distribution associated with this point. 97 | samples = gen_samples(vec) 98 | feats = [f] * 50 99 | for s in samples: 100 | feats.append(featurize(nlp(s))) 101 | 102 | mod = [] 103 | for feat in feats: 104 | for feature in feat: 105 | if feature.startswith(pre): 106 | mod.append(feature) 107 | 108 | # Try to updated the vector towards new_feat 109 | update, temp = switch(vec, mat, rev, mod, new_feat) 110 | if j == 0: 111 | orig = temp 112 | 113 | # Interpolate with original. 114 | vec = 0.2 * orig + 0.8 * update 115 | 116 | print() 117 | print() 118 | 119 | 120 | def dump_samples(args): 121 | "Construct a large number of samples with features and dump to file." 122 | all_features = [] 123 | all_sents = [] 124 | 125 | batches = args.nbatches 126 | batch = args.batch_size 127 | samples = 1 128 | total = batches * batch * samples 129 | all_zs = torch.FloatTensor(total, model_args['z_size']) 130 | rev = defaultdict(set) 131 | 132 | for j in range(batches): 133 | print("%d / %d batches " % (j, batches)) 134 | noise = torch.ones(batch, model_args['z_size']) 135 | noise.normal_() 136 | noise = noise.view(batch, 1, model_args['z_size'])\ 137 | .expand(batch, samples, 138 | model_args['z_size']).contiguous()\ 139 | .view(batch*samples, 140 | model_args['z_size']) 141 | sentences = generate(autoencoder, gan_gen, z=noise, 142 | vocab=idx2word, sample=True, 143 | maxlen=model_args['maxlen']) 144 | 145 | for i in range(batch * samples): 146 | k = len(all_features) 147 | nlp_sent = nlp(sentences[i]) 148 | feats = featurize(nlp_sent) 149 | all_sents.append(sentences[i]) 150 | all_features.append(feats) 151 | for f in feats: 152 | rev[f].add(k) 153 | all_zs[k] = noise[i] 154 | pickle.dump((all_sents, all_features, rev, all_zs), open(args.dump, "bw")) 155 | 156 | 157 | def main(args): 158 | if args.mode == 'gen': 159 | dump_samples(args) 160 | elif args.mode == 'alter': 161 | alter(args) 162 | 163 | 164 | if __name__ == "__main__": 165 | parser = argparse.ArgumentParser(description='PyTorch experiment') 166 | parser.add_argument('mode', default='gen', 167 | help='choices [gen, alter]') 168 | parser.add_argument('--load_path', type=str, 169 | help='directory to load models from') 170 | 171 | parser.add_argument('--dump', type=str, default="features.pkl", 172 | help='path to sample dump') 173 | parser.add_argument('--nbatches', type=int, default=1000) 174 | parser.add_argument('--batch_size', type=int, default=1000) 175 | parser.add_argument('--alter', type=str, default="") 176 | parser.add_argument('--nsent', type=int, default=100) 177 | args = parser.parse_args() 178 | model_args, idx2word, autoencoder, gan_gen, gan_disc \ 179 | = load_models(args.load_path) 180 | 181 | main(args) 182 | -------------------------------------------------------------------------------- /lang/generate.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | import json 4 | import os 5 | import random 6 | 7 | import torch 8 | from torch.autograd import Variable 9 | 10 | from models import Seq2Seq, MLP_D, MLP_G, generate 11 | 12 | parser = argparse.ArgumentParser(description='PyTorch ARAE for Text Eval') 13 | parser.add_argument('--load_path', type=str, required=True, 14 | help='directory to load models from') 15 | parser.add_argument('--temp', type=float, default=1, 16 | help='softmax temperature (lower --> more discrete)') 17 | parser.add_argument('--ngenerations', type=int, default=10, 18 | help='Number of sentences to generate') 19 | parser.add_argument('--ninterpolations', type=int, default=5, 20 | help='Number z-space sentence interpolation examples') 21 | parser.add_argument('--steps', type=int, default=5, 22 | help='Number of steps in each interpolation') 23 | parser.add_argument('--outf', type=str, default='./generated.txt', 24 | help='filename and path to write to') 25 | parser.add_argument('--noprint', action='store_true', 26 | help='prevents examples from printing') 27 | parser.add_argument('--sample', action='store_true', 28 | help='sample when decoding for generation') 29 | parser.add_argument('--seed', type=int, default=1111, 30 | help='random seed') 31 | args = parser.parse_args() 32 | 33 | 34 | def interpolate(ae, gg, z1, z2, vocab, 35 | steps=5, sample=None, maxlen=None): 36 | """ 37 | Interpolating in z space 38 | Assumes that type(z1) == type(z2) 39 | """ 40 | if type(z1) == Variable: 41 | noise1 = z1 42 | noise2 = z2 43 | elif type(z1) == torch.FloatTensor or type(z1) == torch.cuda.FloatTensor: 44 | noise1 = Variable(z1, volatile=True) 45 | noise2 = Variable(z2, volatile=True) 46 | elif type(z1) == np.ndarray: 47 | noise1 = Variable(torch.from_numpy(z1).float(), volatile=True) 48 | noise2 = Variable(torch.from_numpy(z2).float(), volatile=True) 49 | else: 50 | raise ValueError("Unsupported input type (noise): {}".format(type(z1))) 51 | 52 | # interpolation weights 53 | lambdas = [x*1.0/(steps-1) for x in range(steps)] 54 | 55 | gens = [] 56 | for L in lambdas: 57 | gens.append(generate(ae, gg, (1-L)*noise1 + L*noise2, 58 | vocab, sample, maxlen)) 59 | 60 | interpolations = [] 61 | for i in range(len(gens[0])): 62 | interpolations.append([s[i] for s in gens]) 63 | return interpolations 64 | 65 | 66 | def load_models(load_path): 67 | model_args = json.load(open(os.path.join(load_path, 'options.json'), 'r')) 68 | vars(args).update(model_args) 69 | autoencoder = Seq2Seq(emsize=args.emsize, 70 | nhidden=args.nhidden, 71 | ntokens=args.ntokens, 72 | nlayers=args.nlayers, 73 | noise_r=args.noise_r, 74 | hidden_init=args.hidden_init, 75 | dropout=args.dropout, 76 | gpu=args.cuda) 77 | gan_gen = MLP_G(ninput=args.z_size, noutput=args.nhidden, layers=args.arch_g) 78 | gan_disc = MLP_D(ninput=args.nhidden, noutput=1, layers=args.arch_d) 79 | 80 | autoencoder = autoencoder.cuda() 81 | gan_gen = gan_gen.cuda() 82 | gan_disc = gan_disc.cuda() 83 | 84 | word2idx = json.load(open(os.path.join(args.save, 'vocab.json'), 'r')) 85 | idx2word = {v: k for k, v in word2idx.items()} 86 | 87 | print('Loading models from {}'.format(args.save)) 88 | loaded = torch.load(os.path.join(args.save, "model.pt")) 89 | autoencoder.load_state_dict(loaded.get('ae')) 90 | gan_gen.load_state_dict(loaded.get('gan_g')) 91 | gan_disc.load_state_dict(loaded.get('gan_d')) 92 | return model_args, idx2word, autoencoder, gan_gen, gan_disc 93 | 94 | 95 | def main(args): 96 | # Set the random seed manually for reproducibility. 97 | random.seed(args.seed) 98 | np.random.seed(args.seed) 99 | torch.manual_seed(args.seed) 100 | if torch.cuda.is_available(): 101 | torch.cuda.manual_seed(args.seed) 102 | else: 103 | print("Note that our pre-trained models require CUDA to evaluate.") 104 | 105 | model_args, idx2word, autoencoder, gan_gen, gan_disc \ 106 | = load_models(args.load_path) 107 | 108 | if args.ngenerations > 0: 109 | noise = torch.ones(args.ngenerations, model_args['z_size']) 110 | noise = noise.normal_().cuda() 111 | sentences = generate(autoencoder, gan_gen, z=noise, 112 | vocab=idx2word, sample=args.sample, 113 | maxlen=model_args['maxlen']) 114 | 115 | if not args.noprint: 116 | print("\nSentence generations:\n") 117 | for sent in sentences: 118 | print(sent) 119 | with open(args.outf, "w") as f: 120 | f.write("Sentence generations:\n\n") 121 | for sent in sentences: 122 | f.write(sent+"\n") 123 | 124 | if args.ninterpolations > 0: 125 | noise1 = torch.ones(args.ninterpolations, model_args['z_size']) 126 | noise1 = noise1.normal_().cuda() 127 | noise2 = torch.ones(args.ninterpolations, model_args['z_size']) 128 | noise2 = noise2.normal_().cuda() 129 | interps = interpolate(autoencoder, gan_gen, 130 | z1=noise1, 131 | z2=noise2, 132 | vocab=idx2word, 133 | steps=args.steps, 134 | sample=args.sample, 135 | maxlen=model_args['maxlen']) 136 | 137 | if not args.noprint: 138 | print("\nSentence interpolations:\n") 139 | for interp in interps: 140 | for sent in interp: 141 | print(sent) 142 | print("") 143 | with open(args.outf, "a") as f: 144 | f.write("\nSentence interpolations:\n\n") 145 | for interp in interps: 146 | for sent in interp: 147 | f.write(sent+"\n") 148 | f.write('\n') 149 | 150 | main(args) 151 | -------------------------------------------------------------------------------- /lang/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence 6 | 7 | from utils import to_gpu 8 | import json 9 | import os 10 | import numpy as np 11 | 12 | 13 | class MLP_D(nn.Module): 14 | def __init__(self, ninput, noutput, layers, 15 | activation=nn.LeakyReLU(0.2), gpu=True): 16 | super(MLP_D, self).__init__() 17 | self.ninput = ninput 18 | self.noutput = noutput 19 | 20 | layer_sizes = [ninput] + [int(x) for x in layers.split('-')] 21 | self.layers = [] 22 | 23 | for i in range(len(layer_sizes)-1): 24 | layer = nn.Linear(layer_sizes[i], layer_sizes[i+1]) 25 | self.layers.append(layer) 26 | self.add_module("layer"+str(i+1), layer) 27 | 28 | # No batch normalization after first layer 29 | if i != 0: 30 | bn = nn.BatchNorm1d(layer_sizes[i+1], eps=1e-05, momentum=0.1) 31 | self.layers.append(bn) 32 | self.add_module("bn"+str(i+1), bn) 33 | 34 | self.layers.append(activation) 35 | self.add_module("activation"+str(i+1), activation) 36 | 37 | layer = nn.Linear(layer_sizes[-1], noutput) 38 | self.layers.append(layer) 39 | self.add_module("layer"+str(len(self.layers)), layer) 40 | 41 | self.init_weights() 42 | 43 | def forward(self, x): 44 | for i, layer in enumerate(self.layers): 45 | x = layer(x) 46 | x = torch.mean(x) 47 | return x 48 | 49 | def init_weights(self): 50 | init_std = 0.02 51 | for layer in self.layers: 52 | try: 53 | layer.weight.data.normal_(0, init_std) 54 | layer.bias.data.fill_(0) 55 | except: 56 | pass 57 | 58 | 59 | class MLP_G(nn.Module): 60 | def __init__(self, ninput, noutput, layers, 61 | activation=nn.ReLU(), gpu=True): 62 | super(MLP_G, self).__init__() 63 | self.ninput = ninput 64 | self.noutput = noutput 65 | 66 | layer_sizes = [ninput] + [int(x) for x in layers.split('-')] 67 | self.layers = [] 68 | 69 | for i in range(len(layer_sizes)-1): 70 | layer = nn.Linear(layer_sizes[i], layer_sizes[i+1]) 71 | self.layers.append(layer) 72 | self.add_module("layer"+str(i+1), layer) 73 | 74 | bn = nn.BatchNorm1d(layer_sizes[i+1], eps=1e-05, momentum=0.1) 75 | self.layers.append(bn) 76 | self.add_module("bn"+str(i+1), bn) 77 | 78 | self.layers.append(activation) 79 | self.add_module("activation"+str(i+1), activation) 80 | 81 | layer = nn.Linear(layer_sizes[-1], noutput) 82 | self.layers.append(layer) 83 | self.add_module("layer"+str(len(self.layers)), layer) 84 | 85 | self.init_weights() 86 | 87 | def forward(self, x): 88 | for i, layer in enumerate(self.layers): 89 | x = layer(x) 90 | return x 91 | 92 | def init_weights(self): 93 | init_std = 0.02 94 | for layer in self.layers: 95 | try: 96 | layer.weight.data.normal_(0, init_std) 97 | layer.bias.data.fill_(0) 98 | except: 99 | pass 100 | 101 | 102 | class Seq2Seq(nn.Module): 103 | def __init__(self, emsize, nhidden, ntokens, nlayers, noise_r=0.2, 104 | hidden_init=False, dropout=0, gpu=True): 105 | super(Seq2Seq, self).__init__() 106 | self.nhidden = nhidden 107 | self.emsize = emsize 108 | self.ntokens = ntokens 109 | self.nlayers = nlayers 110 | self.noise_r = noise_r 111 | self.hidden_init = hidden_init 112 | self.dropout = dropout 113 | self.gpu = gpu 114 | 115 | self.start_symbols = to_gpu(gpu, Variable(torch.ones(10, 1).long())) 116 | 117 | # Vocabulary embedding 118 | self.embedding = nn.Embedding(ntokens, emsize) 119 | self.embedding_decoder = nn.Embedding(ntokens, emsize) 120 | 121 | # RNN Encoder and Decoder 122 | self.encoder = nn.LSTM(input_size=emsize, 123 | hidden_size=nhidden, 124 | num_layers=nlayers, 125 | dropout=dropout, 126 | batch_first=True) 127 | 128 | decoder_input_size = emsize+nhidden 129 | self.decoder = nn.LSTM(input_size=decoder_input_size, 130 | hidden_size=nhidden, 131 | num_layers=1, 132 | dropout=dropout, 133 | batch_first=True) 134 | 135 | # Initialize Linear Transformation 136 | self.linear = nn.Linear(nhidden, ntokens) 137 | 138 | self.init_weights() 139 | 140 | def init_weights(self): 141 | initrange = 0.1 142 | 143 | # Initialize Vocabulary Matrix Weight 144 | self.embedding.weight.data.uniform_(-initrange, initrange) 145 | self.embedding_decoder.weight.data.uniform_(-initrange, initrange) 146 | 147 | # Initialize Encoder and Decoder Weights 148 | for p in self.encoder.parameters(): 149 | p.data.uniform_(-initrange, initrange) 150 | for p in self.decoder.parameters(): 151 | p.data.uniform_(-initrange, initrange) 152 | 153 | # Initialize Linear Weight 154 | self.linear.weight.data.uniform_(-initrange, initrange) 155 | self.linear.bias.data.fill_(0) 156 | 157 | def init_hidden(self, bsz): 158 | zeros1 = Variable(torch.zeros(1, bsz, self.nhidden)) 159 | zeros2 = Variable(torch.zeros(1, bsz, self.nhidden)) 160 | return (to_gpu(self.gpu, zeros1), to_gpu(self.gpu, zeros2)) 161 | 162 | def init_state(self, bsz): 163 | zeros = Variable(torch.zeros(1, bsz, self.nhidden)) 164 | return to_gpu(self.gpu, zeros) 165 | 166 | def store_grad_norm(self, grad): 167 | norm = torch.norm(grad, 2, 1) 168 | self.grad_norm = norm.detach().data.mean() 169 | return grad 170 | 171 | def forward(self, indices, lengths, noise, encode_only=False): 172 | batch_size, maxlen = indices.size() 173 | 174 | hidden = self.encode(indices, lengths, noise) 175 | 176 | if encode_only: 177 | return hidden 178 | 179 | if hidden.requires_grad: 180 | hidden.register_hook(self.store_grad_norm) 181 | 182 | decoded = self.decode(hidden, batch_size, maxlen, 183 | indices=indices, lengths=lengths) 184 | 185 | return decoded 186 | 187 | def encode(self, indices, lengths, noise): 188 | embeddings = self.embedding(indices) 189 | packed_embeddings = pack_padded_sequence(input=embeddings, 190 | lengths=lengths, 191 | batch_first=True) 192 | 193 | packed_output, state = self.encoder(packed_embeddings) 194 | hidden = state[0][-1] 195 | hidden = hidden / torch.norm(hidden, p=2, dim=1, keepdim=True) 196 | 197 | if noise and self.noise_r > 0: 198 | gauss_noise = torch.normal(means=torch.zeros(hidden.size()), 199 | std=self.noise_r) 200 | hidden = hidden + Variable(gauss_noise.cuda()) 201 | 202 | return hidden 203 | 204 | def decode(self, hidden, batch_size, maxlen, indices=None, lengths=None): 205 | # batch x hidden 206 | all_hidden = hidden.unsqueeze(1).repeat(1, maxlen, 1) 207 | 208 | if self.hidden_init: 209 | # initialize decoder hidden state to encoder output 210 | state = (hidden.unsqueeze(0), self.init_state(batch_size)) 211 | else: 212 | state = self.init_hidden(batch_size) 213 | 214 | embeddings = self.embedding_decoder(indices) 215 | augmented_embeddings = torch.cat([embeddings, all_hidden], 2) 216 | packed_embeddings = pack_padded_sequence(input=augmented_embeddings, 217 | lengths=lengths, 218 | batch_first=True) 219 | 220 | packed_output, state = self.decoder(packed_embeddings, state) 221 | output, lengths = pad_packed_sequence(packed_output, batch_first=True) 222 | 223 | # reshape to batch_size*maxlen x nhidden before linear over vocab 224 | decoded = self.linear(output.contiguous().view(-1, self.nhidden)) 225 | decoded = decoded.view(batch_size, maxlen, self.ntokens) 226 | 227 | return decoded 228 | 229 | def generate(self, hidden, maxlen, sample=True, temp=1.0): 230 | """Generate through decoder; no backprop""" 231 | 232 | batch_size = hidden.size(0) 233 | 234 | if self.hidden_init: 235 | # initialize decoder hidden state to encoder output 236 | state = (hidden.unsqueeze(0), self.init_state(batch_size)) 237 | else: 238 | state = self.init_hidden(batch_size) 239 | 240 | # 241 | self.start_symbols.data.resize_(batch_size, 1) 242 | self.start_symbols.data.fill_(1) 243 | 244 | embedding = self.embedding_decoder(self.start_symbols) 245 | inputs = torch.cat([embedding, hidden.unsqueeze(1)], 2) 246 | 247 | # unroll 248 | all_indices = [] 249 | for i in range(maxlen): 250 | output, state = self.decoder(inputs, state) 251 | overvocab = self.linear(output.squeeze(1)) 252 | if not sample: 253 | vals, indices = torch.max(overvocab, 1) 254 | else: 255 | probs = F.softmax(overvocab / temp, dim=-1) 256 | indices = torch.multinomial(probs, 1) 257 | indices = indices.unsqueeze(1) 258 | all_indices.append(indices) 259 | 260 | embedding = self.embedding_decoder(indices) 261 | inputs = torch.cat([embedding, hidden.unsqueeze(1)], 2) 262 | 263 | max_indices = torch.cat(all_indices, 1) 264 | return max_indices 265 | 266 | def noise_anneal(self, fac): 267 | self.noise_r *= fac 268 | 269 | 270 | def generate(autoencoder, gan_gen, z, vocab, sample, maxlen): 271 | """ 272 | Assume noise is batch_size x z_size 273 | """ 274 | if type(z) == Variable: 275 | noise = z 276 | elif type(z) == torch.FloatTensor or type(z) == torch.cuda.FloatTensor: 277 | noise = Variable(z, volatile=True) 278 | elif type(z) == np.ndarray: 279 | noise = Variable(torch.from_numpy(z).float(), volatile=True) 280 | else: 281 | raise ValueError("Unsupported input type (noise): {}".format(type(z))) 282 | 283 | gan_gen.eval() 284 | autoencoder.eval() 285 | 286 | # generate from random noise 287 | fake_hidden = gan_gen(noise) 288 | max_indices = autoencoder.generate(hidden=fake_hidden, 289 | maxlen=maxlen, 290 | sample=sample) 291 | 292 | max_indices = max_indices.data.cpu().numpy() 293 | sentences = [] 294 | for idx in max_indices: 295 | # generated sentence 296 | words = [vocab[x] for x in idx] 297 | # truncate sentences to first occurrence of 298 | truncated_sent = [] 299 | for w in words: 300 | if w != '': 301 | truncated_sent.append(w) 302 | else: 303 | break 304 | sent = " ".join(truncated_sent) 305 | sentences.append(sent) 306 | 307 | return sentences 308 | -------------------------------------------------------------------------------- /lang/preprocess_lm.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | """Create the data for the LSTM. 5 | """ 6 | 7 | import os 8 | import sys 9 | import argparse 10 | import numpy as np 11 | import h5py 12 | import itertools 13 | from collections import defaultdict 14 | 15 | class Indexer: 16 | def __init__(self, symbols = ["","","",""]): 17 | self.vocab = defaultdict(int) 18 | self.PAD = symbols[0] 19 | self.UNK = symbols[1] 20 | self.BOS = symbols[2] 21 | self.EOS = symbols[3] 22 | self.d = {self.PAD: 0, self.UNK: 1, self.BOS: 2, self.EOS: 3} 23 | self.idx2word = {} 24 | 25 | def add_w(self, ws): 26 | for w in ws: 27 | if w not in self.d: 28 | self.d[w] = len(self.d) 29 | 30 | def convert(self, w): 31 | return self.d[w] if w in self.d else self.d[self.UNK] 32 | 33 | def convert_sequence(self, ls): 34 | return [self.convert(l) for l in ls] 35 | 36 | def write(self, outfile): 37 | out = open(outfile, "w") 38 | items = [(v, k) for k, v in self.d.items()] 39 | items.sort() 40 | for v, k in items: 41 | out.write(" ".join([k, str(v)]) + "\n") 42 | out.close() 43 | 44 | def prune_vocab(self, k, cnt = False): 45 | vocab_list = [(word, count) for word, count in self.vocab.items()] 46 | if cnt: 47 | self.pruned_vocab = {pair[0]:pair[1] for pair in vocab_list if pair[1] > k} 48 | else: 49 | vocab_list.sort(key = lambda x: x[1], reverse=True) 50 | k = min(k, len(vocab_list)) 51 | self.pruned_vocab = {pair[0]:pair[1] for pair in vocab_list[:k]} 52 | for word in self.pruned_vocab: 53 | if word not in self.d: 54 | self.d[word] = len(self.d) 55 | for word, idx in self.d.items(): 56 | self.idx2word[idx] = word 57 | 58 | def load_vocab(self, vocab_file): 59 | self.d = {} 60 | for line in open(vocab_file, 'r'): 61 | v, k = line.strip().split() 62 | self.d[v] = int(k) 63 | for word, idx in self.d.items(): 64 | self.idx2word[idx] = word 65 | 66 | def pad(ls, length, symbol): 67 | if len(ls) >= length: 68 | return ls[:length] 69 | return ls + [symbol] * (length -len(ls)) 70 | 71 | def get_data(args): 72 | indexer = Indexer(["","","",""]) 73 | 74 | def make_vocab(textfile, seqlength, train=1): 75 | num_sents = 0 76 | for sent in open(textfile, 'r'): 77 | sent = sent.strip().split() 78 | if len(sent) > seqlength or len(sent) < 1: 79 | continue 80 | num_sents += 1 81 | if train == 1: 82 | for word in sent: 83 | indexer.vocab[word] += 1 84 | return num_sents 85 | 86 | def convert(textfile, batchsize, seqlength, outfile, num_sents, max_sent_l=0,shuffle=0): 87 | newseqlength = seqlength + 2 #add 2 for EOS and BOS 88 | sents = np.zeros((num_sents, newseqlength), dtype=int) 89 | sent_lengths = np.zeros((num_sents,), dtype=int) 90 | dropped = 0 91 | sent_id = 0 92 | for sent in open(textfile, 'r'): 93 | sent = [indexer.BOS] + sent.strip().split() + [indexer.EOS] 94 | max_sent_l = max(len(sent), max_sent_l) 95 | if len(sent) > seqlength + 2 or len(sent) < 3: 96 | dropped += 1 97 | continue 98 | sent_pad = pad(sent, newseqlength, indexer.PAD) 99 | sents[sent_id] = np.array(indexer.convert_sequence(sent_pad), dtype=int) 100 | sent_lengths[sent_id] = (sents[sent_id] != 0).sum() 101 | sent_id += 1 102 | if sent_id % 100000 == 0: 103 | print("{}/{} sentences processed".format(sent_id, num_sents)) 104 | print(sent_id, num_sents) 105 | if shuffle == 1: 106 | rand_idx = np.random.permutation(sent_id) 107 | sents = sents[rand_idx] 108 | sent_lengths = sent_lengths[rand_idx] 109 | 110 | #break up batches based on source lengths 111 | sent_lengths = sent_lengths[:sent_id] 112 | sent_sort = np.argsort(sent_lengths) 113 | sents = sents[sent_sort] 114 | sent_l = sent_lengths[sent_sort] 115 | curr_l = 1 116 | l_location = [] #idx where sent length changes 117 | 118 | for j,i in enumerate(sent_sort): 119 | if sent_lengths[i] > curr_l: 120 | curr_l = sent_lengths[i] 121 | l_location.append(j) 122 | l_location.append(len(sents)) 123 | #get batch sizes 124 | curr_idx = 0 125 | batch_idx = [0] 126 | nonzeros = [] 127 | batch_l = [] 128 | batch_w = [] 129 | for i in range(len(l_location)-1): 130 | while curr_idx < l_location[i+1]: 131 | curr_idx = min(curr_idx + batchsize, l_location[i+1]) 132 | batch_idx.append(curr_idx) 133 | for i in range(len(batch_idx)-1): 134 | batch_l.append(batch_idx[i+1] - batch_idx[i]) 135 | batch_w.append(sent_l[batch_idx[i]]) 136 | 137 | # Write output 138 | f = h5py.File(outfile, "w") 139 | 140 | f["source"] = sents 141 | f["batch_l"] = np.array(batch_l, dtype=int) 142 | f["source_l"] = np.array(batch_w, dtype=int) 143 | f["sents_l"] = np.array(sent_l, dtype = int) 144 | f["batch_idx"] = np.array(batch_idx[:-1], dtype=int) 145 | f["vocab_size"] = np.array([len(indexer.d)]) 146 | print("Saved {} sentences (dropped {} due to length/unk filter)".format( 147 | len(f["source"]), dropped)) 148 | f.close() 149 | return max_sent_l 150 | 151 | print("First pass through data to get vocab...") 152 | num_sents_train = make_vocab(args.trainfile, args.seqlength) 153 | print("Number of sentences in training: {}".format(num_sents_train)) 154 | num_sents_valid = make_vocab(args.valfile, args.seqlength, 0) 155 | print("Number of sentences in valid: {}".format(num_sents_valid)) 156 | num_sents_test = make_vocab(args.testfile, args.seqlength, 0) 157 | print("Number of sentences in test: {}".format(num_sents_test)) 158 | if args.vocabminfreq >= 0: 159 | indexer.prune_vocab(args.vocabminfreq, True) 160 | else: 161 | indexer.prune_vocab(args.vocabsize, False) 162 | if args.vocabfile != '': 163 | print('Loading pre-specified source vocab from ' + args.vocabfile) 164 | indexer.load_vocab(args.vocabfile) 165 | indexer.write(args.outputfile + ".dict") 166 | print("Vocab size: Original = {}, Pruned = {}".format(len(indexer.vocab), 167 | len(indexer.d))) 168 | max_sent_l = 0 169 | max_sent_l = convert(args.valfile, args.batchsize, args.seqlength, 170 | args.outputfile + "-val.hdf5", num_sents_valid, 171 | max_sent_l, args.shuffle) 172 | max_sent_l = convert(args.testfile, args.batchsize, args.seqlength, 173 | args.outputfile + "-test.hdf5", num_sents_test, 174 | max_sent_l, args.shuffle) 175 | max_sent_l = convert(args.trainfile, args.batchsize, args.seqlength, 176 | args.outputfile + "-train.hdf5", num_sents_train, 177 | max_sent_l, args.shuffle) 178 | print("Max sent length (before dropping): {}".format(max_sent_l)) 179 | 180 | def main(arguments): 181 | parser = argparse.ArgumentParser( 182 | description=__doc__, 183 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 184 | parser.add_argument('--vocabsize', help="Size of source vocabulary, constructed " 185 | "by taking the top X most frequent words. " 186 | " Rest are replaced with special UNK tokens.", 187 | type=int, default=70000) 188 | parser.add_argument('--vocabminfreq', help="Minimum frequency for vocab, if using frequency cutoff", 189 | type=int, default=-1) 190 | parser.add_argument('--trainfile', help="Path to training data.", required=True) 191 | parser.add_argument('--valfile', help="Path validation data.", required=True) 192 | parser.add_argument('--testfile', help="Path to test data.", required=True) 193 | parser.add_argument('--batchsize', help="Size of each minibatch.", type=int, default=32) 194 | parser.add_argument('--seqlength', help="Maximum source sequence length. Sequences longer " 195 | "than this are dropped.", type=int, default=200) 196 | parser.add_argument('--outputfile', help="Prefix of the output file names. ", type=str) 197 | parser.add_argument('--vocabfile', help="If working with a preset vocab, " 198 | "then including this will ignore srcvocabsize and use the" 199 | "vocab provided here.", 200 | type = str, default='') 201 | parser.add_argument('--shuffle', help="If = 1, shuffle sentences before sorting (based on " 202 | "source length).", 203 | type = int, default = 1) 204 | 205 | args = parser.parse_args(arguments) 206 | get_data(args) 207 | 208 | if __name__ == '__main__': 209 | sys.exit(main(sys.argv[1:])) 210 | -------------------------------------------------------------------------------- /lang/run_oneb.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | parser = argparse.ArgumentParser(description='ARAE for 1Bword dataset') 4 | # Path Arguments 5 | parser.add_argument('--data_path', type=str, required=True, 6 | help='location of the data corpus') 7 | parser.add_argument('--kenlm_path', type=str, default='./kenlm', 8 | help='path to kenlm directory') 9 | parser.add_argument('--save', type=str, default='oneb_example', 10 | help='output directory name') 11 | 12 | # Data Processing Arguments 13 | parser.add_argument('--maxlen', type=int, default=20, 14 | help='maximum length') 15 | parser.add_argument('--vocab_size', type=int, default=30000, 16 | help='cut vocabulary down to this size ' 17 | '(most frequently seen words in train)') 18 | parser.add_argument('--lowercase', dest='lowercase', action='store_true', 19 | help='lowercase all text') 20 | parser.add_argument('--no-lowercase', dest='lowercase', action='store_true', 21 | help='not lowercase all text') 22 | parser.set_defaults(lowercase=True) 23 | 24 | # Model Arguments 25 | parser.add_argument('--emsize', type=int, default=500, 26 | help='size of word embeddings') 27 | parser.add_argument('--nhidden', type=int, default=500, 28 | help='number of hidden units per layer') 29 | parser.add_argument('--nlayers', type=int, default=1, 30 | help='number of layers') 31 | parser.add_argument('--noise_r', type=float, default=0.1, 32 | help='stdev of noise for autoencoder (regularizer)') 33 | parser.add_argument('--noise_anneal', type=float, default=0.9995, 34 | help='anneal noise_r exponentially by this' 35 | 'every 100 iterations') 36 | parser.add_argument('--hidden_init', action='store_true', 37 | help="initialize decoder hidden state with encoder's") 38 | parser.add_argument('--arch_g', type=str, default='500-500', 39 | help='generator architecture (MLP)') 40 | parser.add_argument('--arch_d', type=str, default='500-500', 41 | help='critic/discriminator architecture (MLP)') 42 | parser.add_argument('--z_size', type=int, default=100, 43 | help='dimension of random noise z to feed into generator') 44 | parser.add_argument('--dropout', type=float, default=0.0, 45 | help='dropout applied to layers (0 = no dropout)') 46 | 47 | # Training Arguments 48 | parser.add_argument('--epochs', type=int, default=15, 49 | help='maximum number of epochs') 50 | parser.add_argument('--min_epochs', type=int, default=12, 51 | help="minimum number of epochs to train for") 52 | parser.add_argument('--no_earlystopping', action='store_true', 53 | help="won't use KenLM for early stopping") 54 | parser.add_argument('--patience', type=int, default=2, 55 | help="number of language model evaluations without ppl " 56 | "improvement to wait before early stopping") 57 | parser.add_argument('--batch_size', type=int, default=128, metavar='N', 58 | help='batch size') 59 | parser.add_argument('--niters_ae', type=int, default=1, 60 | help='number of autoencoder iterations in training') 61 | parser.add_argument('--niters_gan_d', type=int, default=5, 62 | help='number of discriminator iterations in training') 63 | parser.add_argument('--niters_gan_g', type=int, default=1, 64 | help='number of generator iterations in training') 65 | parser.add_argument('--niters_gan_ae', type=int, default=1, 66 | help='number of gan-into-ae iterations in training') 67 | parser.add_argument('--niters_gan_schedule', type=str, default='', 68 | help='epoch counts to increase number of GAN training ' 69 | ' iterations (increment by 1 each time)') 70 | parser.add_argument('--lr_ae', type=float, default=1, 71 | help='autoencoder learning rate') 72 | parser.add_argument('--lr_gan_g', type=float, default=1e-04, 73 | help='generator learning rate') 74 | parser.add_argument('--lr_gan_d', type=float, default=1e-04, 75 | help='critic/discriminator learning rate') 76 | parser.add_argument('--beta1', type=float, default=0.5, 77 | help='beta1 for adam. default=0.5') 78 | parser.add_argument('--clip', type=float, default=1, 79 | help='gradient clipping, max norm') 80 | parser.add_argument('--gan_clamp', type=float, default=0.01, 81 | help='WGAN clamp') 82 | parser.add_argument('--gan_gp_lambda', type=float, default=10, 83 | help='WGAN GP penalty lambda') 84 | parser.add_argument('--grad_lambda', type=float, default=1, 85 | help='WGAN into AE lambda') 86 | 87 | # Evaluation Arguments 88 | parser.add_argument('--sample', action='store_true', 89 | help='sample when decoding for generation') 90 | parser.add_argument('--N', type=int, default=5, 91 | help='N-gram order for training n-gram language model') 92 | parser.add_argument('--log_interval', type=int, default=200, 93 | help='interval to log autoencoder training results') 94 | 95 | # Other 96 | parser.add_argument('--seed', type=int, default=1111, 97 | help='random seed') 98 | 99 | args = parser.parse_args() 100 | print(vars(args)) 101 | 102 | exec(open("train.py").read()) 103 | -------------------------------------------------------------------------------- /lang/run_snli.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | parser = argparse.ArgumentParser(description='ARAE for SNLI') 4 | # Path Arguments 5 | parser.add_argument('--data_path', type=str, required=True, 6 | help='location of the data corpus') 7 | parser.add_argument('--kenlm_path', type=str, default='./kenlm', 8 | help='path to kenlm directory') 9 | parser.add_argument('--save', type=str, default='snli_example', 10 | help='output directory name') 11 | 12 | # Data Processing Arguments 13 | parser.add_argument('--maxlen', type=int, default=15, 14 | help='maximum length') 15 | parser.add_argument('--vocab_size', type=int, default=11000, 16 | help='cut vocabulary down to this size ' 17 | '(most frequently seen words in train)') 18 | parser.add_argument('--lowercase', dest='lowercase', action='store_true', 19 | help='lowercase all text') 20 | parser.add_argument('--no-lowercase', dest='lowercase', action='store_true', 21 | help='not lowercase all text') 22 | parser.set_defaults(lowercase=True) 23 | 24 | # Model Arguments 25 | parser.add_argument('--emsize', type=int, default=300, 26 | help='size of word embeddings') 27 | parser.add_argument('--nhidden', type=int, default=300, 28 | help='number of hidden units per layer') 29 | parser.add_argument('--nlayers', type=int, default=1, 30 | help='number of layers') 31 | parser.add_argument('--noise_r', type=float, default=0.05, 32 | help='stdev of noise for autoencoder (regularizer)') 33 | parser.add_argument('--noise_anneal', type=float, default=0.9995, 34 | help='anneal noise_r exponentially by this' 35 | 'every 100 iterations') 36 | parser.add_argument('--hidden_init', action='store_true', 37 | help="initialize decoder hidden state with encoder's") 38 | parser.add_argument('--arch_g', type=str, default='300-300', 39 | help='generator architecture (MLP)') 40 | parser.add_argument('--arch_d', type=str, default='300-300', 41 | help='critic/discriminator architecture (MLP)') 42 | parser.add_argument('--z_size', type=int, default=100, 43 | help='dimension of random noise z to feed into generator') 44 | parser.add_argument('--dropout', type=float, default=0.0, 45 | help='dropout applied to layers (0 = no dropout)') 46 | 47 | # Training Arguments 48 | parser.add_argument('--epochs', type=int, default=15, 49 | help='maximum number of epochs') 50 | parser.add_argument('--min_epochs', type=int, default=12, 51 | help="minimum number of epochs to train for") 52 | parser.add_argument('--no_earlystopping', action='store_true', 53 | help="won't use KenLM for early stopping") 54 | parser.add_argument('--patience', type=int, default=2, 55 | help="number of language model evaluations without ppl " 56 | "improvement to wait before early stopping") 57 | parser.add_argument('--batch_size', type=int, default=64, metavar='N', 58 | help='batch size') 59 | parser.add_argument('--niters_ae', type=int, default=1, 60 | help='number of autoencoder iterations in training') 61 | parser.add_argument('--niters_gan_d', type=int, default=5, 62 | help='number of discriminator iterations in training') 63 | parser.add_argument('--niters_gan_g', type=int, default=1, 64 | help='number of generator iterations in training') 65 | parser.add_argument('--niters_gan_ae', type=int, default=1, 66 | help='number of gan-into-ae iterations in training') 67 | parser.add_argument('--niters_gan_schedule', type=str, default='', 68 | help='epoch counts to increase number of GAN training ' 69 | ' iterations (increment by 1 each time)') 70 | parser.add_argument('--lr_ae', type=float, default=1, 71 | help='autoencoder learning rate') 72 | parser.add_argument('--lr_gan_g', type=float, default=1e-04, 73 | help='generator learning rate') 74 | parser.add_argument('--lr_gan_d', type=float, default=1e-04, 75 | help='critic/discriminator learning rate') 76 | parser.add_argument('--beta1', type=float, default=0.5, 77 | help='beta1 for adam. default=0.5') 78 | parser.add_argument('--clip', type=float, default=1, 79 | help='gradient clipping, max norm') 80 | parser.add_argument('--gan_clamp', type=float, default=0.01, 81 | help='WGAN clamp') 82 | parser.add_argument('--gan_gp_lambda', type=float, default=1, 83 | help='WGAN GP penalty lambda') 84 | parser.add_argument('--grad_lambda', type=float, default=0.1, 85 | help='WGAN into AE lambda') 86 | 87 | # Evaluation Arguments 88 | parser.add_argument('--sample', action='store_true', 89 | help='sample when decoding for generation') 90 | parser.add_argument('--N', type=int, default=5, 91 | help='N-gram order for training n-gram language model') 92 | parser.add_argument('--log_interval', type=int, default=200, 93 | help='interval to log autoencoder training results') 94 | 95 | # Other 96 | parser.add_argument('--seed', type=int, default=1111, 97 | help='random seed') 98 | 99 | args = parser.parse_args() 100 | print(vars(args)) 101 | 102 | 103 | exec(open("train.py").read()) 104 | -------------------------------------------------------------------------------- /lang/snli_preprocessing.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import codecs 4 | import argparse 5 | 6 | """ 7 | Transforms SNLI data into lines of text files 8 | (data format required for ARAE model). 9 | Gets rid of repeated premise sentences. 10 | """ 11 | 12 | 13 | def transform_data(in_path): 14 | print("Loading", in_path) 15 | 16 | premises = [] 17 | hypotheses = [] 18 | 19 | last_premise = None 20 | with codecs.open(in_path, encoding='utf-8') as f: 21 | for line in f: 22 | loaded_example = json.loads(line) 23 | 24 | # load premise 25 | raw_premise = loaded_example['sentence1_binary_parse'].split(" ") 26 | premise_words = [] 27 | # loop through words of premise binary parse 28 | for word in raw_premise: 29 | # don't add parse brackets 30 | if word != "(" and word != ")": 31 | premise_words.append(word) 32 | premise = " ".join(premise_words) 33 | 34 | # load hypothesis 35 | raw_hypothesis = \ 36 | loaded_example['sentence2_binary_parse'].split(" ") 37 | hypothesis_words = [] 38 | for word in raw_hypothesis: 39 | if word != "(" and word != ")": 40 | hypothesis_words.append(word) 41 | hypothesis = " ".join(hypothesis_words) 42 | 43 | # make sure to not repeat premiess 44 | if premise != last_premise: 45 | premises.append(premise) 46 | hypotheses.append(hypothesis) 47 | 48 | last_premise = premise 49 | 50 | return premises, hypotheses 51 | 52 | 53 | def write_sentences(write_path, premises, hypotheses, append=False): 54 | print("Writing to {}\n".format(write_path)) 55 | if append: 56 | with open(write_path, "a") as f: 57 | for p in premises: 58 | f.write(p) 59 | f.write("\n") 60 | for h in hypotheses: 61 | f.write(h) 62 | f.write('\n') 63 | else: 64 | with open(write_path, "w") as f: 65 | for p in premises: 66 | f.write(p) 67 | f.write("\n") 68 | for h in hypotheses: 69 | f.write(h) 70 | f.write('\n') 71 | 72 | 73 | if __name__ == "__main__": 74 | parser = argparse.ArgumentParser() 75 | parser.add_argument('--in_path', type=str, default="../Data/snli_1.0", 76 | help='path to snli data') 77 | parser.add_argument('--out_path', type=str, default="../Data/snli_lm", 78 | help='path to write snli language modeling data to') 79 | args = parser.parse_args() 80 | 81 | # make out-path directory if it doesn't exist 82 | if not os.path.exists(args.out_path): 83 | os.makedirs(args.out_path) 84 | print("Creating directory "+args.out_path) 85 | 86 | # process and write test.txt and train.txt files 87 | premises, hypotheses = \ 88 | transform_data(os.path.join(args.in_path, "snli_1.0_test.jsonl")) 89 | write_sentences(write_path=os.path.join(args.out_path, "test.txt"), 90 | premises=premises, hypotheses=hypotheses) 91 | 92 | premises, hypotheses = \ 93 | transform_data(os.path.join(args.in_path, "snli_1.0_train.jsonl")) 94 | write_sentences(write_path=os.path.join(args.out_path, "train.txt"), 95 | premises=premises, hypotheses=hypotheses) 96 | 97 | premises, hypotheses = \ 98 | transform_data(os.path.join(args.in_path, "snli_1.0_dev.jsonl")) 99 | write_sentences(write_path=os.path.join(args.out_path, "train.txt"), 100 | premises=premises, hypotheses=hypotheses, append=True) 101 | -------------------------------------------------------------------------------- /lang/train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import time 4 | import math 5 | import numpy as np 6 | import random 7 | import sys 8 | import shutil 9 | import json 10 | import string 11 | 12 | import torch 13 | import torch.nn as nn 14 | import torch.optim as optim 15 | import torch.nn.functional as F 16 | from torch.autograd import Variable 17 | 18 | from utils import to_gpu, Corpus, batchify, train_ngram_lm, get_ppl, create_exp_dir 19 | from models import Seq2Seq, MLP_D, MLP_G 20 | 21 | # Set the random seed manually for reproducibility. 22 | random.seed(args.seed) 23 | np.random.seed(args.seed) 24 | torch.manual_seed(args.seed) 25 | torch.cuda.manual_seed(args.seed) 26 | 27 | ############################################################################### 28 | # Load data 29 | ############################################################################### 30 | # create corpus 31 | corpus = Corpus(args.data_path, 32 | maxlen=args.maxlen, 33 | vocab_size=args.vocab_size, 34 | lowercase=args.lowercase) 35 | 36 | # save arguments 37 | ntokens = len(corpus.dictionary.word2idx) 38 | print("Vocabulary Size: {}".format(ntokens)) 39 | args.ntokens = ntokens 40 | 41 | # exp dir 42 | create_exp_dir(os.path.join(args.save), ['train.py', 'models.py', 'utils.py'], 43 | dict=corpus.dictionary.word2idx, options=args) 44 | 45 | def logging(str, to_stdout=True): 46 | with open(os.path.join(args.save, 'log.txt'), 'a') as f: 47 | f.write(str + '\n') 48 | if to_stdout: 49 | print(str) 50 | logging(str(vars(args))) 51 | 52 | eval_batch_size = 10 53 | test_data = batchify(corpus.test, eval_batch_size, shuffle=False) 54 | train_data = batchify(corpus.train, args.batch_size, shuffle=True) 55 | 56 | print("Loaded data!") 57 | 58 | ############################################################################### 59 | # Build the models 60 | ############################################################################### 61 | autoencoder = Seq2Seq(emsize=args.emsize, 62 | nhidden=args.nhidden, 63 | ntokens=args.ntokens, 64 | nlayers=args.nlayers, 65 | noise_r=args.noise_r, 66 | hidden_init=args.hidden_init, 67 | dropout=args.dropout) 68 | gan_gen = MLP_G(ninput=args.z_size, noutput=args.nhidden, layers=args.arch_g) 69 | gan_disc = MLP_D(ninput=args.nhidden, noutput=1, layers=args.arch_d) 70 | 71 | print(autoencoder) 72 | print(gan_gen) 73 | print(gan_disc) 74 | 75 | optimizer_ae = optim.SGD(autoencoder.parameters(), lr=args.lr_ae) 76 | optimizer_gan_g = optim.Adam(gan_gen.parameters(), 77 | lr=args.lr_gan_g, 78 | betas=(args.beta1, 0.999)) 79 | optimizer_gan_d = optim.Adam(gan_disc.parameters(), 80 | lr=args.lr_gan_d, 81 | betas=(args.beta1, 0.999)) 82 | autoencoder = autoencoder.cuda() 83 | gan_gen = gan_gen.cuda() 84 | gan_disc = gan_disc.cuda() 85 | 86 | # global vars 87 | one = torch.Tensor(1).fill_(1).cuda() 88 | mone = one * -1 89 | 90 | ############################################################################### 91 | # Training code 92 | ############################################################################### 93 | def save_model(): 94 | print("Saving models to {}".format(args.save)) 95 | torch.save({ 96 | "ae": autoencoder.state_dict(), 97 | "gan_g": gan_gen.state_dict(), 98 | "gan_d": gan_disc.state_dict() 99 | }, 100 | os.path.join(args.save, "model.pt")) 101 | 102 | def load_models(): 103 | model_args = json.load(open(os.path.join(args.save, 'options.json'), 'r')) 104 | word2idx = json.load(open(os.path.join(args.save, 'vocab.json'), 'r')) 105 | idx2word = {v: k for k, v in word2idx.items()} 106 | 107 | print('Loading models from {}'.format(args.save)) 108 | loaded = torch.load(os.path.join(args.save, "model.pt")) 109 | autoencoder.load_state_dict(loaded.get('ae')) 110 | gan_gen.load_state_dict(loaded.get('gan_g')) 111 | gan_disc.load_state_dict(loaded.get('gan_d')) 112 | return model_args, idx2word, autoencoder, gan_gen, gan_disc 113 | 114 | def evaluate_autoencoder(data_source, epoch): 115 | # Turn on evaluation mode which disables dropout. 116 | autoencoder.eval() 117 | total_loss = 0 118 | ntokens = len(corpus.dictionary.word2idx) 119 | all_accuracies = 0 120 | bcnt = 0 121 | for i, batch in enumerate(data_source): 122 | source, target, lengths = batch 123 | source = Variable(source.cuda(), volatile=True) 124 | target = Variable(target.cuda(), volatile=True) 125 | 126 | mask = target.gt(0) 127 | masked_target = target.masked_select(mask) 128 | # examples x ntokens 129 | output_mask = mask.unsqueeze(1).expand(mask.size(0), ntokens) 130 | 131 | # output: batch x seq_len x ntokens 132 | output = autoencoder(source, lengths, noise=True) 133 | flattened_output = output.view(-1, ntokens) 134 | 135 | masked_output = \ 136 | flattened_output.masked_select(output_mask).view(-1, ntokens) 137 | total_loss += F.cross_entropy(masked_output, masked_target).data 138 | 139 | # accuracy 140 | max_vals, max_indices = torch.max(masked_output, 1) 141 | all_accuracies += \ 142 | torch.mean(max_indices.eq(masked_target).float()).data[0] 143 | bcnt += 1 144 | 145 | aeoutf = os.path.join(args.save, "autoencoder.txt") 146 | with open(aeoutf, "a") as f: 147 | max_values, max_indices = torch.max(output, 2) 148 | max_indices = \ 149 | max_indices.view(output.size(0), -1).data.cpu().numpy() 150 | target = target.view(output.size(0), -1).data.cpu().numpy() 151 | for t, idx in zip(target, max_indices): 152 | # real sentence 153 | chars = " ".join([corpus.dictionary.idx2word[x] for x in t]) 154 | f.write(chars + '\n') 155 | # autoencoder output sentence 156 | chars = " ".join([corpus.dictionary.idx2word[x] for x in idx]) 157 | f.write(chars + '\n'*2) 158 | 159 | return total_loss[0] / len(data_source), all_accuracies/bcnt 160 | 161 | 162 | def gen_fixed_noise(noise, to_save): 163 | gan_gen.eval() 164 | autoencoder.eval() 165 | 166 | fake_hidden = gan_gen(noise) 167 | max_indices = autoencoder.generate(fake_hidden, args.maxlen, sample=args.sample) 168 | 169 | with open(to_save, "w") as f: 170 | max_indices = max_indices.data.cpu().numpy() 171 | for idx in max_indices: 172 | # generated sentence 173 | words = [corpus.dictionary.idx2word[x] for x in idx] 174 | # truncate sentences to first occurrence of 175 | truncated_sent = [] 176 | for w in words: 177 | if w != '': 178 | truncated_sent.append(w) 179 | else: 180 | break 181 | chars = " ".join(truncated_sent) 182 | f.write(chars + '\n') 183 | 184 | 185 | def train_lm(data_path): 186 | save_path = os.path.join("/tmp", ''.join(random.choice( 187 | string.ascii_uppercase + string.digits) for _ in range(6))) 188 | 189 | indices = [] 190 | noise = Variable(torch.ones(100, args.z_size).cuda()) 191 | for i in range(1000): 192 | noise.data.normal_(0, 1) 193 | fake_hidden = gan_gen(noise) 194 | max_indices = autoencoder.generate(fake_hidden, args.maxlen, sample=args.sample) 195 | indices.append(max_indices.data.cpu().numpy()) 196 | indices = np.concatenate(indices, axis=0) 197 | 198 | with open(save_path, "w") as f: 199 | # laplacian smoothing 200 | for word in corpus.dictionary.word2idx.keys(): 201 | f.write(word+'\n') 202 | for idx in indices: 203 | words = [corpus.dictionary.idx2word[x] for x in idx] 204 | # truncate sentences to first occurrence of 205 | truncated_sent = [] 206 | for w in words: 207 | if w != '': 208 | truncated_sent.append(w) 209 | else: 210 | break 211 | chars = " ".join(truncated_sent) 212 | f.write(chars+'\n') 213 | # reverse ppl 214 | try: 215 | rev_lm = train_ngram_lm(kenlm_path=args.kenlm_path, 216 | data_path=save_path, 217 | output_path=save_path+".arpa", 218 | N=args.N) 219 | with open(os.path.join(args.data_path, 'test.txt'), 'r') as f: 220 | lines = f.readlines() 221 | if args.lowercase: 222 | lines = list(map(lambda x: x.lower(), lines)) 223 | sentences = [l.replace('\n', '') for l in lines] 224 | rev_ppl = get_ppl(rev_lm, sentences) 225 | except: 226 | print("reverse ppl error: it maybe the generated files aren't valid to obtain an LM") 227 | rev_ppl = 1e15 228 | # forward ppl 229 | for_lm = train_ngram_lm(kenlm_path=args.kenlm_path, 230 | data_path=os.path.join(args.data_path, 'train.txt'), 231 | output_path=save_path+".arpa", 232 | N=args.N) 233 | with open(save_path, 'r') as f: 234 | lines = f.readlines() 235 | sentences = [l.replace('\n', '') for l in lines] 236 | for_ppl = get_ppl(for_lm, sentences) 237 | return rev_ppl, for_ppl 238 | 239 | 240 | def train_ae(epoch, batch, total_loss_ae, start_time, i): 241 | autoencoder.train() 242 | optimizer_ae.zero_grad() 243 | 244 | source, target, lengths = batch 245 | source = Variable(source.cuda()) 246 | target = Variable(target.cuda()) 247 | output = autoencoder(source, lengths, noise=True) 248 | 249 | mask = target.gt(0) 250 | masked_target = target.masked_select(mask) 251 | output_mask = mask.unsqueeze(1).expand(mask.size(0), ntokens) 252 | flat_output = output.view(-1, ntokens) 253 | masked_output = flat_output.masked_select(output_mask).view(-1, ntokens) 254 | loss = F.cross_entropy(masked_output, masked_target) 255 | loss.backward() 256 | torch.nn.utils.clip_grad_norm(autoencoder.parameters(), args.clip) 257 | optimizer_ae.step() 258 | 259 | total_loss_ae += loss.data[0] 260 | if i % args.log_interval == 0: 261 | probs = F.softmax(masked_output, dim=-1) 262 | max_vals, max_indices = torch.max(probs, 1) 263 | accuracy = torch.mean(max_indices.eq(masked_target).float()).data[0] 264 | cur_loss = total_loss_ae / args.log_interval 265 | elapsed = time.time() - start_time 266 | logging('| epoch {:3d} | {:5d}/{:5d} batches | ms/batch {:5.2f} | ' 267 | 'loss {:5.2f} | ppl {:8.2f} | acc {:8.2f}'.format( 268 | epoch, i, len(train_data), 269 | elapsed * 1000 / args.log_interval, 270 | cur_loss, math.exp(cur_loss), accuracy)) 271 | total_loss_ae = 0 272 | start_time = time.time() 273 | return total_loss_ae, start_time 274 | 275 | 276 | def train_gan_g(): 277 | gan_gen.train() 278 | optimizer_gan_g.zero_grad() 279 | 280 | z = Variable(torch.Tensor(args.batch_size, args.z_size).normal_(0, 1).cuda()) 281 | fake_hidden = gan_gen(z) 282 | errG = gan_disc(fake_hidden) 283 | errG.backward(one) 284 | optimizer_gan_g.step() 285 | 286 | return errG 287 | 288 | 289 | def grad_hook(grad): 290 | #gan_norm = torch.norm(grad, p=2, dim=1).detach().data.mean() 291 | #print(gan_norm, autoencoder.grad_norm) 292 | return grad * args.grad_lambda 293 | 294 | 295 | ''' Steal from https://github.com/caogang/wgan-gp/blob/master/gan_cifar10.py ''' 296 | def calc_gradient_penalty(netD, real_data, fake_data): 297 | bsz = real_data.size(0) 298 | alpha = torch.rand(bsz, 1) 299 | alpha = alpha.expand(bsz, real_data.size(1)) # only works for 2D XXX 300 | alpha = alpha.cuda() 301 | interpolates = alpha * real_data + ((1 - alpha) * fake_data) 302 | interpolates = Variable(interpolates, requires_grad=True) 303 | disc_interpolates = netD(interpolates) 304 | 305 | gradients = torch.autograd.grad(outputs=disc_interpolates, inputs=interpolates, 306 | grad_outputs=torch.ones(disc_interpolates.size()).cuda(), 307 | create_graph=True, retain_graph=True, only_inputs=True)[0] 308 | gradients = gradients.view(gradients.size(0), -1) 309 | 310 | gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() * args.gan_gp_lambda 311 | return gradient_penalty 312 | 313 | 314 | def train_gan_d(batch): 315 | gan_disc.train() 316 | optimizer_gan_d.zero_grad() 317 | 318 | # + samples 319 | source, target, lengths = batch 320 | source = Variable(source.cuda()) 321 | target = Variable(target.cuda()) 322 | real_hidden = autoencoder(source, lengths, noise=False, encode_only=True) 323 | errD_real = gan_disc(real_hidden.detach()) 324 | errD_real.backward(one) 325 | 326 | # - samples 327 | z = Variable(torch.Tensor(args.batch_size, args.z_size).normal_(0, 1).cuda()) 328 | fake_hidden = gan_gen(z) 329 | errD_fake = gan_disc(fake_hidden.detach()) 330 | errD_fake.backward(mone) 331 | 332 | # gradient penalty 333 | gradient_penalty = calc_gradient_penalty(gan_disc, real_hidden.data, fake_hidden.data) 334 | gradient_penalty.backward() 335 | 336 | optimizer_gan_d.step() 337 | return -(errD_real - errD_fake), errD_real, errD_fake 338 | 339 | 340 | def train_gan_d_into_ae(batch): 341 | autoencoder.train() 342 | optimizer_ae.zero_grad() 343 | 344 | source, target, lengths = batch 345 | source = Variable(source.cuda()) 346 | target = Variable(target.cuda()) 347 | real_hidden = autoencoder(source, lengths, noise=False, encode_only=True) 348 | real_hidden.register_hook(grad_hook) 349 | errD_real = gan_disc(real_hidden) 350 | errD_real.backward(mone) 351 | torch.nn.utils.clip_grad_norm(autoencoder.parameters(), args.clip) 352 | 353 | optimizer_ae.step() 354 | return errD_real 355 | 356 | 357 | def train(): 358 | logging("Training") 359 | train_data = batchify(corpus.train, args.batch_size, shuffle=True) 360 | 361 | # gan: preparation 362 | if args.niters_gan_schedule != "": 363 | gan_schedule = [int(x) for x in args.niters_gan_schedule.split("-")] 364 | else: 365 | gan_schedule = [] 366 | niter_gan = 1 367 | fixed_noise = Variable(torch.ones(args.batch_size, args.z_size).normal_(0, 1).cuda()) 368 | 369 | best_rev_ppl = None 370 | impatience = 0 371 | for epoch in range(1, args.epochs+1): 372 | # update gan training schedule 373 | if epoch in gan_schedule: 374 | niter_gan += 1 375 | logging("GAN training loop schedule: {}".format(niter_gan)) 376 | 377 | total_loss_ae = 0 378 | epoch_start_time = time.time() 379 | start_time = time.time() 380 | niter = 0 381 | niter_g = 1 382 | 383 | while niter < len(train_data): 384 | # train ae 385 | for i in range(args.niters_ae): 386 | if niter >= len(train_data): 387 | break # end of epoch 388 | total_loss_ae, start_time = train_ae(epoch, train_data[niter], 389 | total_loss_ae, start_time, niter) 390 | niter += 1 391 | # train gan 392 | for k in range(niter_gan): 393 | for i in range(args.niters_gan_d): 394 | errD, errD_real, errD_fake = train_gan_d( 395 | train_data[random.randint(0, len(train_data)-1)]) 396 | for i in range(args.niters_gan_ae): 397 | train_gan_d_into_ae(train_data[random.randint(0, len(train_data)-1)]) 398 | for i in range(args.niters_gan_g): 399 | errG = train_gan_g() 400 | 401 | niter_g += 1 402 | if niter_g % 100 == 0: 403 | autoencoder.noise_anneal(args.noise_anneal) 404 | logging('[{}/{}][{}/{}] Loss_D: {:.8f} (Loss_D_real: {:.8f} ' 405 | 'Loss_D_fake: {:.8f}) Loss_G: {:.8f}'.format( 406 | epoch, args.epochs, niter, len(train_data), 407 | errD.data[0], errD_real.data[0], 408 | errD_fake.data[0], errG.data[0])) 409 | # eval 410 | test_loss, accuracy = evaluate_autoencoder(test_data, epoch) 411 | logging('| end of epoch {:3d} | time: {:5.2f}s | test loss {:5.2f} | ' 412 | 'test ppl {:5.2f} | acc {:3.3f}'.format(epoch, 413 | (time.time() - epoch_start_time), test_loss, 414 | math.exp(test_loss), accuracy)) 415 | gen_fixed_noise(fixed_noise, os.path.join(args.save, 416 | "{:03d}_examplar_gen".format(epoch))) 417 | 418 | # eval with rev_ppl and for_ppl 419 | rev_ppl, for_ppl = train_lm(args.data_path) 420 | logging("Epoch {:03d}, Reverse perplexity {}".format(epoch, rev_ppl)) 421 | logging("Epoch {:03d}, Forward perplexity {}".format(epoch, for_ppl)) 422 | if best_rev_ppl is None or rev_ppl < best_rev_ppl: 423 | impatience = 0 424 | best_rev_ppl = rev_ppl 425 | logging("New saving model: epoch {:03d}.".format(epoch)) 426 | save_model() 427 | else: 428 | if not args.no_earlystopping and epoch >= args.min_epochs: 429 | impatience += 1 430 | if impatience > args.patience: 431 | logging("Ending training") 432 | sys.exit() 433 | 434 | train() 435 | -------------------------------------------------------------------------------- /lang/train_rnnlm.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import sys 4 | import os 5 | 6 | import argparse 7 | import json 8 | import random 9 | import shutil 10 | import copy 11 | 12 | import torch 13 | from torch import cuda 14 | import torch.nn as nn 15 | from torch.autograd import Variable 16 | from torch.nn.parameter import Parameter 17 | 18 | import torch.nn.functional as F 19 | import numpy as np 20 | import h5py 21 | import time 22 | import logging 23 | 24 | parser = argparse.ArgumentParser() 25 | 26 | # Input data 27 | parser.add_argument('--train_file', default='') 28 | parser.add_argument('--val_file', default='') 29 | parser.add_argument('--train_from', default='') 30 | 31 | # Model options 32 | parser.add_argument('--word_dim', default=300, type=int) 33 | parser.add_argument('--h_dim', default=300, type=int) 34 | parser.add_argument('--num_layers', default=1, type=int) 35 | parser.add_argument('--dropout', default=0.2, type=float) 36 | 37 | # Optimization options 38 | parser.add_argument('--checkpoint_path', default='baseline.pt') 39 | parser.add_argument('--num_epochs', default=15, type=int) 40 | parser.add_argument('--lr', default=1, type=float) 41 | parser.add_argument('--max_grad_norm', default=5, type=float) 42 | parser.add_argument('--test', default=, type=int) 43 | parser.add_argument('--gpu', default=2, type=int) 44 | parser.add_argument('--seed', default=3435, type=int) 45 | parser.add_argument('--print_every', type=int, default=500) 46 | 47 | class Dataset(object): 48 | def __init__(self, h5_file): 49 | data = h5py.File(h5_file, 'r') 50 | self.sents = self._convert(data['source']).long() 51 | self.sent_lengths = self._convert(data['source_l']).long() 52 | self.batch_size = self._convert(data['batch_l']).long() 53 | self.batch_idx = self._convert(data['batch_idx']).long() 54 | self.vocab_size = data['vocab_size'][0] 55 | self.num_batches = self.batch_idx.size(0) 56 | 57 | def _convert(self, x): 58 | return torch.from_numpy(np.asarray(x)) 59 | 60 | def __len__(self): 61 | return self.num_batches 62 | 63 | def __getitem__(self, idx): 64 | assert(idx < self.num_batches and idx >= 0) 65 | start_idx = self.batch_idx[idx] 66 | end_idx = start_idx + self.batch_size[idx] 67 | length = self.sent_lengths[idx] 68 | sents = self.sents[start_idx:end_idx] 69 | batch_size = self.batch_size[idx] 70 | data_batch = [Variable(sents[:, :length]), length-1, batch_size] 71 | return data_batch 72 | 73 | class RNNLM(nn.Module): 74 | def __init__(self, vocab_size=10000, 75 | word_dim=300, 76 | h_dim=300, 77 | num_layers=1, 78 | dropout=0): 79 | super(RNNLM, self).__init__() 80 | self.h_dim = h_dim 81 | self.num_layers = num_layers 82 | self.word_vecs = nn.Embedding(vocab_size, word_dim) 83 | self.dropout = nn.Dropout(dropout) 84 | self.rnn = nn.LSTM(word_dim, h_dim, num_layers = num_layers, 85 | dropout = dropout, batch_first = True) 86 | self.vocab_linear = nn.Sequential(nn.Dropout(dropout), 87 | nn.Linear(h_dim, vocab_size), 88 | nn.LogSoftmax(dim=-1)) 89 | def forward(self, sent): 90 | word_vecs = self.dropout(self.word_vecs(sent[:, :-1])) #last token is 91 | h, _ = self.rnn(word_vecs) 92 | preds = self.vocab_linear(h) 93 | return preds 94 | 95 | def main(args): 96 | np.random.seed(args.seed) 97 | torch.manual_seed(args.seed) 98 | train_data = Dataset(args.train_file) 99 | val_data = Dataset(args.val_file) 100 | train_sents = train_data.batch_size.sum() 101 | vocab_size = int(train_data.vocab_size) 102 | 103 | print('Train data: %d batches' % len(train_data)) 104 | print('Val data: %d batches' % len(val_data)) 105 | print('Word vocab size: %d' % vocab_size) 106 | cuda.set_device(args.gpu) 107 | 108 | if args.train_from == '': 109 | model = RNNLM(vocab_size = vocab_size, 110 | word_dim = args.word_dim, 111 | h_dim = args.h_dim, 112 | num_layers = args.num_layers, 113 | dropout = args.dropout) 114 | for param in model.parameters(): 115 | param.data.uniform_(-0.1, 0.1) 116 | else: 117 | print('loading model from ' + args.train_from) 118 | checkpoint = torch.load(args.train_from) 119 | model = checkpoint['model'] 120 | print("model architecture") 121 | print(model) 122 | 123 | optimizer = torch.optim.SGD(model.parameters(), lr=args.lr) 124 | criterion = nn.NLLLoss() 125 | model.train() 126 | 127 | if args.gpu >= 0: 128 | model.cuda() 129 | criterion.cuda() 130 | 131 | best_val_ppl = 1e5 132 | epoch = 0 133 | if args.test == 1: 134 | print('Evaluating on test') 135 | eval(val_data, model, criterion) 136 | exit() 137 | while epoch < args.num_epochs: 138 | start_time = time.time() 139 | epoch += 1 140 | print('Starting epoch %d' % epoch) 141 | train_nll = 0. 142 | num_sents = 0 143 | num_words = 0 144 | b = 0 145 | 146 | for i in np.random.permutation(len(train_data)): 147 | sents, length, batch_size = train_data[i] 148 | if args.gpu >= 0: 149 | sents = sents.cuda() 150 | b += 1 151 | optimizer.zero_grad() 152 | preds = model(sents) 153 | nll = sum([criterion(preds[:, l], sents[:, l+1]) for l in range(length)]) 154 | train_nll += nll.data[0]*batch_size 155 | nll.backward() 156 | if args.max_grad_norm > 0: 157 | torch.nn.utils.clip_grad_norm(model.parameters(), args.max_grad_norm) 158 | optimizer.step() 159 | 160 | num_sents += batch_size 161 | num_words += batch_size * length 162 | 163 | if b % args.print_every == 0: 164 | param_norm = sum([p.norm()**2 for p in model.parameters()]).data[0]**0.5 165 | print('Epoch: %d, Batch: %d/%d, LR: %.4f, TrainPPL: %.2f, |Param|: %.4f, BestValPerf: %.2f, Throughput: %.2f examples/sec' % 166 | (epoch, b, len(train_data), args.lr, np.exp(train_nll / num_words), 167 | param_norm, best_val_ppl, num_sents / (time.time() - start_time))) 168 | print('--------------------------------') 169 | print('Checking validation perf...') 170 | val_ppl = eval(val_data, model, criterion) 171 | if val_ppl < best_val_ppl: 172 | best_val_ppl = val_ppl 173 | checkpoint = { 174 | 'args': args.__dict__, 175 | 'model': model, 176 | 'optimizer': optimizer 177 | } 178 | print('Saving checkpoint to %s' % args.checkpoint_path) 179 | torch.save(checkpoint, args.checkpoint_path) 180 | 181 | def eval(data, model, criterion): 182 | model.eval() 183 | num_sents = 0 184 | num_words = 0 185 | total_nll = 0. 186 | for i in range(len(data)): 187 | sents, length, batch_size = data[i] 188 | num_words += batch_size*length 189 | num_sents += batch_size 190 | if args.gpu >= 0: 191 | sents = sents.cuda() 192 | preds = model.forward(sents) 193 | nll = sum([criterion(preds[:, l], sents[:, l+1]) for l in range(length)]) 194 | total_nll += nll.data[0]*batch_size 195 | ppl = np.exp(total_nll / num_words) 196 | print('PPL: %.4f' % (ppl)) 197 | model.train() 198 | return ppl 199 | 200 | if __name__ == '__main__': 201 | args = parser.parse_args() 202 | main(args) 203 | -------------------------------------------------------------------------------- /lang/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | import random 5 | import shutil 6 | import json 7 | import math 8 | 9 | def load_kenlm(): 10 | global kenlm 11 | import kenlm 12 | 13 | 14 | def to_gpu(gpu, var): 15 | if gpu: 16 | return var.cuda() 17 | return var 18 | 19 | 20 | class Dictionary(object): 21 | def __init__(self): 22 | self.word2idx = {} 23 | self.idx2word = {} 24 | self.word2idx[''] = 0 25 | self.word2idx[''] = 1 26 | self.word2idx[''] = 2 27 | self.word2idx[''] = 3 28 | self.wordcounts = {} 29 | 30 | # to track word counts 31 | def add_word(self, word): 32 | if word not in self.wordcounts: 33 | self.wordcounts[word] = 1 34 | else: 35 | self.wordcounts[word] += 1 36 | 37 | # prune vocab based on count k cutoff or most frequently seen k words 38 | def prune_vocab(self, k=5, cnt=False): 39 | # get all words and their respective counts 40 | vocab_list = [(word, count) for word, count in self.wordcounts.items()] 41 | if cnt: 42 | # prune by count 43 | self.pruned_vocab = \ 44 | {pair[0]: pair[1] for pair in vocab_list if pair[1] > k} 45 | else: 46 | # prune by most frequently seen words 47 | vocab_list.sort(key=lambda x: (x[1], x[0]), reverse=True) 48 | k = min(k, len(vocab_list)) 49 | self.pruned_vocab = [pair[0] for pair in vocab_list[:k]] 50 | # sort to make vocabulary determistic 51 | self.pruned_vocab.sort() 52 | 53 | # add all chosen words to new vocabulary/dict 54 | for word in self.pruned_vocab: 55 | if word not in self.word2idx: 56 | self.word2idx[word] = len(self.word2idx) 57 | print("original vocab {}; pruned to {}". 58 | format(len(self.wordcounts), len(self.word2idx))) 59 | self.idx2word = {v: k for k, v in self.word2idx.items()} 60 | 61 | def __len__(self): 62 | return len(self.word2idx) 63 | 64 | 65 | class Corpus(object): 66 | def __init__(self, path, maxlen, vocab_size=11000, lowercase=False): 67 | self.dictionary = Dictionary() 68 | self.maxlen = maxlen 69 | self.lowercase = lowercase 70 | self.vocab_size = vocab_size 71 | self.train_path = os.path.join(path, 'train.txt') 72 | self.test_path = os.path.join(path, 'test.txt') 73 | 74 | # make the vocabulary from training set 75 | self.make_vocab() 76 | 77 | self.train = self.tokenize(self.train_path) 78 | self.test = self.tokenize(self.test_path) 79 | 80 | def make_vocab(self): 81 | assert os.path.exists(self.train_path) 82 | # Add words to the dictionary 83 | with open(self.train_path, 'r') as f: 84 | for line in f: 85 | if self.lowercase: 86 | # -1 to get rid of \n character 87 | words = line[:-1].lower().split(" ") 88 | else: 89 | words = line[:-1].split(" ") 90 | for word in words: 91 | self.dictionary.add_word(word) 92 | 93 | # prune the vocabulary 94 | self.dictionary.prune_vocab(k=self.vocab_size, cnt=False) 95 | 96 | def tokenize(self, path): 97 | """Tokenizes a text file.""" 98 | dropped = 0 99 | with open(path, 'r') as f: 100 | linecount = 0 101 | lines = [] 102 | for line in f: 103 | linecount += 1 104 | if self.lowercase: 105 | words = line[:-1].lower().strip().split(" ") 106 | else: 107 | words = line[:-1].strip().split(" ") 108 | if len(words) > self.maxlen: 109 | dropped += 1 110 | continue 111 | words = [''] + words 112 | words += [''] 113 | # vectorize 114 | vocab = self.dictionary.word2idx 115 | unk_idx = vocab[''] 116 | indices = [vocab[w] if w in vocab else unk_idx for w in words] 117 | lines.append(indices) 118 | 119 | print("Number of sentences dropped from {}: {} out of {} total". 120 | format(path, dropped, linecount)) 121 | return lines 122 | 123 | 124 | def batchify(data, bsz, shuffle=False, gpu=False): 125 | if shuffle: 126 | random.shuffle(data) 127 | nbatch = len(data) // bsz 128 | batches = [] 129 | 130 | for i in range(nbatch): 131 | # Pad batches to maximum sequence length in batch 132 | batch = data[i*bsz:(i+1)*bsz] 133 | # subtract 1 from lengths b/c includes BOTH starts & end symbols 134 | lengths = [len(x)-1 for x in batch] 135 | # sort items by length (decreasing) 136 | batch, lengths = length_sort(batch, lengths) 137 | 138 | # source has no end symbol 139 | source = [x[:-1] for x in batch] 140 | # target has no start symbol 141 | target = [x[1:] for x in batch] 142 | 143 | # find length to pad to 144 | maxlen = max(lengths) 145 | for x, y in zip(source, target): 146 | zeros = (maxlen-len(x))*[0] 147 | x += zeros 148 | y += zeros 149 | 150 | source = torch.LongTensor(np.array(source)) 151 | target = torch.LongTensor(np.array(target)).view(-1) 152 | 153 | if gpu: 154 | source = source.cuda() 155 | target = target.cuda() 156 | 157 | batches.append((source, target, lengths)) 158 | 159 | return batches 160 | 161 | 162 | def length_sort(items, lengths, descending=True): 163 | """In order to use pytorch variable length sequence package""" 164 | items = list(zip(items, lengths)) 165 | items.sort(key=lambda x: x[1], reverse=True) 166 | items, lengths = zip(*items) 167 | return list(items), list(lengths) 168 | 169 | 170 | def train_ngram_lm(kenlm_path, data_path, output_path, N): 171 | """ 172 | Trains a modified Kneser-Ney n-gram KenLM from a text file. 173 | Creates a .arpa file to store n-grams. 174 | """ 175 | # create .arpa file of n-grams 176 | curdir = os.path.abspath(os.path.curdir) 177 | # 178 | command = "bin/lmplz -o "+str(N)+" <"+os.path.join(curdir, data_path) + \ 179 | " >"+os.path.join(curdir, output_path) 180 | os.system("cd "+os.path.join(kenlm_path, 'build')+" && "+command) 181 | 182 | load_kenlm() 183 | # create language model 184 | assert(output_path) # captured by try..except block outside 185 | model = kenlm.Model(output_path) 186 | 187 | return model 188 | 189 | 190 | def get_ppl(lm, sentences): 191 | """ 192 | Assume sentences is a list of strings (space delimited sentences) 193 | """ 194 | total_nll = 0 195 | total_wc = 0 196 | for sent in sentences: 197 | words = sent.strip().split() 198 | nll = np.sum([- math.log(math.pow(10.0, score)) for score, _, _ in lm.full_scores(sent, bos=True, eos=False)]) 199 | word_count = len(words) 200 | total_wc += word_count 201 | total_nll += nll 202 | ppl = np.exp(total_nll / total_wc) 203 | return ppl 204 | 205 | 206 | def create_exp_dir(path, scripts_to_save=None, dict=None, options=None): 207 | if not os.path.exists(path): 208 | os.mkdir(path) 209 | else: 210 | shutil.rmtree(path) 211 | os.mkdir(path) 212 | 213 | print('Experiment dir : {}'.format(path)) 214 | if scripts_to_save is not None: 215 | if not os.path.exists(os.path.join(path, 'scripts')): 216 | os.mkdir(os.path.join(path, 'scripts')) 217 | for script in scripts_to_save: 218 | dst_file = os.path.join(path, 'scripts', os.path.basename(script)) 219 | shutil.copyfile(script, dst_file) 220 | 221 | # dump the dictionary 222 | if dict is not None: 223 | with open(os.path.join(path, 'vocab.json'), 'w') as f: 224 | json.dump(dict, f) 225 | 226 | # dump the args 227 | if options is not None: 228 | with open(os.path.join(path, 'options.json'), 'w') as f: 229 | json.dump(vars(options), f) 230 | -------------------------------------------------------------------------------- /yelp/README.md: -------------------------------------------------------------------------------- 1 | # ARAE for language style transfer on Yelp 2 | 3 | python train.py --data_path ./data 4 | 5 | ## Requirements 6 | - Python 3.6.3 on Linux 7 | - PyTorch 0.3.1, JSON, Argparse 8 | - KenLM (https://github.com/kpu/kenlm) 9 | 10 | -------------------------------------------------------------------------------- /yelp/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence 6 | 7 | from utils import to_gpu 8 | import json 9 | import os 10 | import numpy as np 11 | 12 | 13 | class MLP_Classify(nn.Module): 14 | def __init__(self, ninput, noutput, layers, 15 | activation=nn.ReLU(), gpu=False): 16 | super(MLP_Classify, self).__init__() 17 | self.ninput = ninput 18 | self.noutput = noutput 19 | 20 | layer_sizes = [ninput] + [int(x) for x in layers.split('-')] 21 | self.layers = [] 22 | 23 | for i in range(len(layer_sizes)-1): 24 | layer = nn.Linear(layer_sizes[i], layer_sizes[i+1]) 25 | self.layers.append(layer) 26 | self.add_module("layer"+str(i+1), layer) 27 | 28 | # No batch normalization in first layer 29 | if i != 0: 30 | bn = nn.BatchNorm1d(layer_sizes[i+1]) 31 | self.layers.append(bn) 32 | self.add_module("bn"+str(i+1), bn) 33 | 34 | self.layers.append(activation) 35 | self.add_module("activation"+str(i+1), activation) 36 | 37 | layer = nn.Linear(layer_sizes[-1], noutput) 38 | self.layers.append(layer) 39 | self.add_module("layer"+str(len(self.layers)), layer) 40 | 41 | self.init_weights() 42 | 43 | def forward(self, x): 44 | for i, layer in enumerate(self.layers): 45 | x = layer(x) 46 | x = F.sigmoid(x) 47 | return x 48 | 49 | def init_weights(self): 50 | init_std = 0.02 51 | for layer in self.layers: 52 | try: 53 | layer.weight.data.normal_(0, init_std) 54 | layer.bias.data.fill_(0) 55 | except: 56 | pass 57 | 58 | 59 | class Seq2Seq2Decoder(nn.Module): 60 | def __init__(self, emsize, nhidden, ntokens, nlayers, noise_r=0.2, 61 | share_decoder_emb=False, hidden_init=False, dropout=0, gpu=False): 62 | super(Seq2Seq2Decoder, self).__init__() 63 | self.nhidden = nhidden 64 | self.emsize = emsize 65 | self.ntokens = ntokens 66 | self.nlayers = nlayers 67 | self.noise_r = noise_r 68 | self.hidden_init = hidden_init 69 | self.dropout = dropout 70 | self.gpu = gpu 71 | 72 | self.start_symbols = to_gpu(gpu, Variable(torch.ones(10, 1).long())) 73 | 74 | # Vocabulary embedding 75 | self.embedding = nn.Embedding(ntokens, emsize) 76 | self.embedding_decoder1 = nn.Embedding(ntokens, emsize) 77 | self.embedding_decoder2 = nn.Embedding(ntokens, emsize) 78 | 79 | # RNN Encoder and Decoder 80 | self.encoder = nn.LSTM(input_size=emsize, 81 | hidden_size=nhidden, 82 | num_layers=nlayers, 83 | dropout=dropout, 84 | batch_first=True) 85 | 86 | decoder_input_size = emsize+nhidden 87 | self.decoder1 = nn.LSTM(input_size=decoder_input_size, 88 | hidden_size=nhidden, 89 | num_layers=1, 90 | dropout=dropout, 91 | batch_first=True) 92 | self.decoder2 = nn.LSTM(input_size=decoder_input_size, 93 | hidden_size=nhidden, 94 | num_layers=1, 95 | dropout=dropout, 96 | batch_first=True) 97 | 98 | # Initialize Linear Transformation 99 | self.linear = nn.Linear(nhidden, ntokens) 100 | 101 | self.init_weights() 102 | 103 | if share_decoder_emb: 104 | self.embedding_decoder2.weight = self.embedding_decoder1.weight 105 | 106 | def init_weights(self): 107 | initrange = 0.1 108 | 109 | # Initialize Vocabulary Matrix Weight 110 | self.embedding.weight.data.uniform_(-initrange, initrange) 111 | self.embedding_decoder1.weight.data.uniform_(-initrange, initrange) 112 | self.embedding_decoder2.weight.data.uniform_(-initrange, initrange) 113 | 114 | # Initialize Encoder and Decoder Weights 115 | for p in self.encoder.parameters(): 116 | p.data.uniform_(-initrange, initrange) 117 | for p in self.decoder1.parameters(): 118 | p.data.uniform_(-initrange, initrange) 119 | for p in self.decoder2.parameters(): 120 | p.data.uniform_(-initrange, initrange) 121 | 122 | # Initialize Linear Weight 123 | self.linear.weight.data.uniform_(-initrange, initrange) 124 | self.linear.bias.data.fill_(0) 125 | 126 | def init_hidden(self, bsz): 127 | zeros1 = Variable(torch.zeros(self.nlayers, bsz, self.nhidden)) 128 | zeros2 = Variable(torch.zeros(self.nlayers, bsz, self.nhidden)) 129 | return (to_gpu(self.gpu, zeros1), to_gpu(self.gpu, zeros2)) 130 | 131 | def init_state(self, bsz): 132 | zeros = Variable(torch.zeros(self.nlayers, bsz, self.nhidden)) 133 | return to_gpu(self.gpu, zeros) 134 | 135 | def store_grad_norm(self, grad): 136 | norm = torch.norm(grad, 2, 1) 137 | self.grad_norm = norm.detach().data.mean() 138 | return grad 139 | 140 | def forward(self, whichdecoder, indices, lengths, noise=False, encode_only=False): 141 | batch_size, maxlen = indices.size() 142 | 143 | hidden = self.encode(indices, lengths, noise) 144 | 145 | if hidden.requires_grad: 146 | hidden.register_hook(self.store_grad_norm) 147 | 148 | if encode_only: 149 | return hidden 150 | 151 | decoded = self.decode(whichdecoder, hidden, batch_size, maxlen, 152 | indices=indices, lengths=lengths) 153 | 154 | return decoded 155 | 156 | def encode(self, indices, lengths, noise): 157 | embeddings = self.embedding(indices) 158 | packed_embeddings = pack_padded_sequence(input=embeddings, 159 | lengths=lengths, 160 | batch_first=True) 161 | 162 | # Encode 163 | packed_output, state = self.encoder(packed_embeddings) 164 | hidden, cell = state 165 | 166 | # batch_size x nhidden 167 | hidden = hidden[-1] # get hidden state of last layer of encoder 168 | 169 | # normalize to unit ball (l2 norm of 1) - p=2, dim=1 170 | norms = torch.norm(hidden, 2, 1) 171 | 172 | # For older versions of PyTorch use: 173 | hidden = torch.div(hidden, norms.unsqueeze(1).expand_as(hidden)) 174 | # For newest version of PyTorch (as of 8/25) use this: 175 | # hidden = torch.div(hidden, norms.unsqueeze(1).expand_as(hidden)) 176 | 177 | if noise and self.noise_r > 0: 178 | gauss_noise = torch.normal(means=torch.zeros(hidden.size()), 179 | std=self.noise_r) 180 | hidden = hidden + to_gpu(self.gpu, Variable(gauss_noise)) 181 | 182 | return hidden 183 | 184 | def decode(self, whichdecoder, hidden, batch_size, maxlen, indices=None, lengths=None): 185 | # batch x hidden 186 | all_hidden = hidden.unsqueeze(1).repeat(1, maxlen, 1) 187 | 188 | if self.hidden_init: 189 | # initialize decoder hidden state to encoder output 190 | state = (hidden.unsqueeze(0), self.init_state(batch_size)) 191 | else: 192 | state = self.init_hidden(batch_size) 193 | 194 | if whichdecoder == 1: 195 | embeddings = self.embedding_decoder1(indices) 196 | else: 197 | embeddings = self.embedding_decoder2(indices) 198 | 199 | augmented_embeddings = torch.cat([embeddings, all_hidden], 2) 200 | packed_embeddings = pack_padded_sequence(input=augmented_embeddings, 201 | lengths=lengths, 202 | batch_first=True) 203 | 204 | if whichdecoder == 1: 205 | packed_output, state = self.decoder1(packed_embeddings, state) 206 | else: 207 | packed_output, state = self.decoder2(packed_embeddings, state) 208 | output, lengths = pad_packed_sequence(packed_output, batch_first=True) 209 | 210 | # reshape to batch_size*maxlen x nhidden before linear over vocab 211 | decoded = self.linear(output.contiguous().view(-1, self.nhidden)) 212 | decoded = decoded.view(batch_size, maxlen, self.ntokens) 213 | 214 | return decoded 215 | 216 | def generate(self, whichdecoder, hidden, maxlen, sample=False, temp=1.0): 217 | """Generate through decoder; no backprop""" 218 | 219 | batch_size = hidden.size(0) 220 | 221 | if self.hidden_init: 222 | # initialize decoder hidden state to encoder output 223 | state = (hidden.unsqueeze(0), self.init_state(batch_size)) 224 | else: 225 | state = self.init_hidden(batch_size) 226 | 227 | # 228 | self.start_symbols.data.resize_(batch_size, 1) 229 | self.start_symbols.data.fill_(1) 230 | self.start_symbols = to_gpu(self.gpu, self.start_symbols) 231 | 232 | if whichdecoder == 1: 233 | embedding = self.embedding_decoder1(self.start_symbols) 234 | else: 235 | embedding = self.embedding_decoder2(self.start_symbols) 236 | 237 | inputs = torch.cat([embedding, hidden.unsqueeze(1)], 2) 238 | 239 | # unroll 240 | all_indices = [] 241 | for i in range(maxlen): 242 | if whichdecoder == 1: 243 | output, state = self.decoder1(inputs, state) 244 | else: 245 | output, state = self.decoder2(inputs, state) 246 | overvocab = self.linear(output.squeeze(1)) 247 | 248 | if not sample: 249 | vals, indices = torch.max(overvocab, 1) 250 | indices = indices.unsqueeze(1) 251 | else: 252 | assert 1 == 0 253 | # sampling 254 | probs = F.softmax(overvocab/temp) 255 | indices = torch.multinomial(probs, 1) 256 | 257 | all_indices.append(indices) 258 | 259 | if whichdecoder == 1: 260 | embedding = self.embedding_decoder1(indices) 261 | else: 262 | embedding = self.embedding_decoder2(indices) 263 | inputs = torch.cat([embedding, hidden.unsqueeze(1)], 2) 264 | 265 | max_indices = torch.cat(all_indices, 1) 266 | 267 | return max_indices 268 | 269 | 270 | class MLP_D(nn.Module): 271 | def __init__(self, ninput, noutput, layers, 272 | activation=nn.LeakyReLU(0.2), gpu=False): 273 | super(MLP_D, self).__init__() 274 | self.ninput = ninput 275 | self.noutput = noutput 276 | 277 | layer_sizes = [ninput] + [int(x) for x in layers.split('-')] 278 | self.layers = [] 279 | 280 | for i in range(len(layer_sizes)-1): 281 | layer = nn.Linear(layer_sizes[i], layer_sizes[i+1]) 282 | self.layers.append(layer) 283 | self.add_module("layer"+str(i+1), layer) 284 | 285 | # No batch normalization after first layer 286 | if i != 0: 287 | bn = nn.BatchNorm1d(layer_sizes[i+1], eps=1e-05, momentum=0.1) 288 | self.layers.append(bn) 289 | self.add_module("bn"+str(i+1), bn) 290 | 291 | self.layers.append(activation) 292 | self.add_module("activation"+str(i+1), activation) 293 | 294 | layer = nn.Linear(layer_sizes[-1], noutput) 295 | self.layers.append(layer) 296 | self.add_module("layer"+str(len(self.layers)), layer) 297 | 298 | self.init_weights() 299 | 300 | def forward(self, x): 301 | for i, layer in enumerate(self.layers): 302 | x = layer(x) 303 | x = torch.mean(x) 304 | return x 305 | 306 | def init_weights(self): 307 | init_std = 0.02 308 | for layer in self.layers: 309 | try: 310 | layer.weight.data.normal_(0, init_std) 311 | layer.bias.data.fill_(0) 312 | except: 313 | pass 314 | 315 | 316 | class MLP_G(nn.Module): 317 | def __init__(self, ninput, noutput, layers, 318 | activation=nn.ReLU(), gpu=False): 319 | super(MLP_G, self).__init__() 320 | self.ninput = ninput 321 | self.noutput = noutput 322 | 323 | layer_sizes = [ninput] + [int(x) for x in layers.split('-')] 324 | self.layers = [] 325 | 326 | for i in range(len(layer_sizes)-1): 327 | layer = nn.Linear(layer_sizes[i], layer_sizes[i+1]) 328 | self.layers.append(layer) 329 | self.add_module("layer"+str(i+1), layer) 330 | 331 | bn = nn.BatchNorm1d(layer_sizes[i+1], eps=1e-05, momentum=0.1) 332 | self.layers.append(bn) 333 | self.add_module("bn"+str(i+1), bn) 334 | 335 | self.layers.append(activation) 336 | self.add_module("activation"+str(i+1), activation) 337 | 338 | layer = nn.Linear(layer_sizes[-1], noutput) 339 | self.layers.append(layer) 340 | self.add_module("layer"+str(len(self.layers)), layer) 341 | 342 | self.init_weights() 343 | 344 | def forward(self, x): 345 | for i, layer in enumerate(self.layers): 346 | x = layer(x) 347 | return x 348 | 349 | def init_weights(self): 350 | init_std = 0.02 351 | for layer in self.layers: 352 | try: 353 | layer.weight.data.normal_(0, init_std) 354 | layer.bias.data.fill_(0) 355 | except: 356 | pass 357 | 358 | 359 | class Seq2Seq(nn.Module): 360 | def __init__(self, emsize, nhidden, ntokens, nlayers, noise_r=0.2, 361 | hidden_init=False, dropout=0, gpu=False): 362 | super(Seq2Seq, self).__init__() 363 | self.nhidden = nhidden 364 | self.emsize = emsize 365 | self.ntokens = ntokens 366 | self.nlayers = nlayers 367 | self.noise_r = noise_r 368 | self.hidden_init = hidden_init 369 | self.dropout = dropout 370 | self.gpu = gpu 371 | 372 | self.start_symbols = to_gpu(gpu, Variable(torch.ones(10, 1).long())) 373 | 374 | # Vocabulary embedding 375 | self.embedding = nn.Embedding(ntokens, emsize) 376 | self.embedding_decoder = nn.Embedding(ntokens, emsize) 377 | 378 | # RNN Encoder and Decoder 379 | self.encoder = nn.LSTM(input_size=emsize, 380 | hidden_size=nhidden, 381 | num_layers=nlayers, 382 | dropout=dropout, 383 | batch_first=True) 384 | 385 | decoder_input_size = emsize+nhidden 386 | self.decoder = nn.LSTM(input_size=decoder_input_size, 387 | hidden_size=nhidden, 388 | num_layers=1, 389 | dropout=dropout, 390 | batch_first=True) 391 | 392 | # Initialize Linear Transformation 393 | self.linear = nn.Linear(nhidden, ntokens) 394 | 395 | self.init_weights() 396 | 397 | def init_weights(self): 398 | initrange = 0.1 399 | 400 | # Initialize Vocabulary Matrix Weight 401 | self.embedding.weight.data.uniform_(-initrange, initrange) 402 | self.embedding_decoder.weight.data.uniform_(-initrange, initrange) 403 | 404 | # Initialize Encoder and Decoder Weights 405 | for p in self.encoder.parameters(): 406 | p.data.uniform_(-initrange, initrange) 407 | for p in self.decoder.parameters(): 408 | p.data.uniform_(-initrange, initrange) 409 | 410 | # Initialize Linear Weight 411 | self.linear.weight.data.uniform_(-initrange, initrange) 412 | self.linear.bias.data.fill_(0) 413 | 414 | def init_hidden(self, bsz): 415 | zeros1 = Variable(torch.zeros(self.nlayers, bsz, self.nhidden)) 416 | zeros2 = Variable(torch.zeros(self.nlayers, bsz, self.nhidden)) 417 | return (to_gpu(self.gpu, zeros1), to_gpu(self.gpu, zeros2)) 418 | 419 | def init_state(self, bsz): 420 | zeros = Variable(torch.zeros(self.nlayers, bsz, self.nhidden)) 421 | return to_gpu(self.gpu, zeros) 422 | 423 | def store_grad_norm(self, grad): 424 | norm = torch.norm(grad, 2, 1) 425 | self.grad_norm = norm.detach().data.mean() 426 | return grad 427 | 428 | def forward(self, indices, lengths, noise, encode_only=False): 429 | batch_size, maxlen = indices.size() 430 | 431 | hidden = self.encode(indices, lengths, noise) 432 | 433 | if encode_only: 434 | return hidden 435 | 436 | if hidden.requires_grad: 437 | hidden.register_hook(self.store_grad_norm) 438 | 439 | decoded = self.decode(hidden, batch_size, maxlen, 440 | indices=indices, lengths=lengths) 441 | 442 | return decoded 443 | 444 | def encode(self, indices, lengths, noise): 445 | embeddings = self.embedding(indices) 446 | packed_embeddings = pack_padded_sequence(input=embeddings, 447 | lengths=lengths, 448 | batch_first=True) 449 | 450 | # Encode 451 | packed_output, state = self.encoder(packed_embeddings) 452 | 453 | hidden, cell = state 454 | # batch_size x nhidden 455 | hidden = hidden[-1] # get hidden state of last layer of encoder 456 | 457 | # normalize to unit ball (l2 norm of 1) - p=2, dim=1 458 | norms = torch.norm(hidden, 2, 1) 459 | 460 | # For older versions of PyTorch use: 461 | hidden = torch.div(hidden, norms.expand_as(hidden)) 462 | # For newest version of PyTorch (as of 8/25) use this: 463 | # hidden = torch.div(hidden, norms.unsqueeze(1).expand_as(hidden)) 464 | 465 | if noise and self.noise_r > 0: 466 | gauss_noise = torch.normal(means=torch.zeros(hidden.size()), 467 | std=self.noise_r) 468 | hidden = hidden + to_gpu(self.gpu, Variable(gauss_noise)) 469 | 470 | return hidden 471 | 472 | def decode(self, hidden, batch_size, maxlen, indices=None, lengths=None): 473 | # batch x hidden 474 | all_hidden = hidden.unsqueeze(1).repeat(1, maxlen, 1) 475 | 476 | if self.hidden_init: 477 | # initialize decoder hidden state to encoder output 478 | state = (hidden.unsqueeze(0), self.init_state(batch_size)) 479 | else: 480 | state = self.init_hidden(batch_size) 481 | 482 | embeddings = self.embedding_decoder(indices) 483 | augmented_embeddings = torch.cat([embeddings, all_hidden], 2) 484 | packed_embeddings = pack_padded_sequence(input=augmented_embeddings, 485 | lengths=lengths, 486 | batch_first=True) 487 | 488 | packed_output, state = self.decoder(packed_embeddings, state) 489 | output, lengths = pad_packed_sequence(packed_output, batch_first=True) 490 | 491 | # reshape to batch_size*maxlen x nhidden before linear over vocab 492 | decoded = self.linear(output.contiguous().view(-1, self.nhidden)) 493 | decoded = decoded.view(batch_size, maxlen, self.ntokens) 494 | 495 | return decoded 496 | 497 | def generate(self, hidden, maxlen, sample=False, temp=1.0): 498 | """Generate through decoder; no backprop""" 499 | 500 | batch_size = hidden.size(0) 501 | 502 | if self.hidden_init: 503 | # initialize decoder hidden state to encoder output 504 | state = (hidden.unsqueeze(0), self.init_state(batch_size)) 505 | else: 506 | state = self.init_hidden(batch_size) 507 | 508 | # 509 | self.start_symbols.data.resize_(batch_size, 1) 510 | self.start_symbols.data.fill_(1) 511 | 512 | embedding = self.embedding_decoder(self.start_symbols) 513 | inputs = torch.cat([embedding, hidden.unsqueeze(1)], 2) 514 | 515 | # unroll 516 | all_indices = [] 517 | for i in range(maxlen): 518 | output, state = self.decoder(inputs, state) 519 | overvocab = self.linear(output.squeeze(1)) 520 | 521 | if not sample: 522 | vals, indices = torch.max(overvocab, 1) 523 | else: 524 | # sampling 525 | probs = F.softmax(overvocab/temp) 526 | indices = torch.multinomial(probs, 1) 527 | 528 | all_indices.append(indices) 529 | 530 | embedding = self.embedding_decoder(indices) 531 | inputs = torch.cat([embedding, hidden.unsqueeze(1)], 2) 532 | 533 | max_indices = torch.cat(all_indices, 1) 534 | 535 | return max_indices 536 | 537 | 538 | def load_models(load_path, epoch, twodecoders=False): 539 | model_args = json.load(open("{}/args.json".format(load_path), "r")) 540 | word2idx = json.load(open("{}/vocab.json".format(load_path), "r")) 541 | idx2word = {v: k for k, v in word2idx.items()} 542 | 543 | if not twodecoders: 544 | autoencoder = Seq2Seq(emsize=model_args['emsize'], 545 | nhidden=model_args['nhidden'], 546 | ntokens=model_args['ntokens'], 547 | nlayers=model_args['nlayers'], 548 | hidden_init=model_args['hidden_init']) 549 | else: 550 | autoencoder = Seq2Seq2Decoder(emsize=model_args['emsize'], 551 | nhidden=model_args['nhidden'], 552 | ntokens=model_args['ntokens'], 553 | nlayers=model_args['nlayers'], 554 | hidden_init=model_args['hidden_init']) 555 | 556 | gan_gen = MLP_G(ninput=model_args['z_size'], 557 | noutput=model_args['nhidden'], 558 | layers=model_args['arch_g']) 559 | gan_disc = MLP_D(ninput=model_args['nhidden'], 560 | noutput=1, 561 | layers=model_args['arch_d']) 562 | 563 | print('Loading models from'+load_path) 564 | ae_path = os.path.join(load_path, "autoencoder_model_{}.pt".format(epoch)) 565 | gen_path = os.path.join(load_path, "gan_gen_model_{}.pt".format(epoch)) 566 | disc_path = os.path.join(load_path, "gan_disc_model_{}.pt".format(epoch)) 567 | 568 | autoencoder.load_state_dict(torch.load(ae_path)) 569 | gan_gen.load_state_dict(torch.load(gen_path)) 570 | gan_disc.load_state_dict(torch.load(disc_path)) 571 | return model_args, idx2word, autoencoder, gan_gen, gan_disc 572 | 573 | 574 | def generate(autoencoder, gan_gen, z, vocab, sample, maxlen): 575 | """ 576 | Assume noise is batch_size x z_size 577 | """ 578 | if type(z) == Variable: 579 | noise = z 580 | elif type(z) == torch.FloatTensor or type(z) == torch.cuda.FloatTensor: 581 | noise = Variable(z, volatile=True) 582 | elif type(z) == np.ndarray: 583 | noise = Variable(torch.from_numpy(z).float(), volatile=True) 584 | else: 585 | raise ValueError("Unsupported input type (noise): {}".format(type(z))) 586 | 587 | gan_gen.eval() 588 | autoencoder.eval() 589 | 590 | # generate from random noise 591 | fake_hidden = gan_gen(noise) 592 | max_indices = autoencoder.generate(hidden=fake_hidden, 593 | maxlen=maxlen, 594 | sample=sample) 595 | 596 | max_indices = max_indices.data.cpu().numpy() 597 | sentences = [] 598 | for idx in max_indices: 599 | # generated sentence 600 | words = [vocab[x] for x in idx] 601 | # truncate sentences to first occurrence of 602 | truncated_sent = [] 603 | for w in words: 604 | if w != '': 605 | truncated_sent.append(w) 606 | else: 607 | break 608 | sent = " ".join(truncated_sent) 609 | sentences.append(sent) 610 | 611 | return sentences 612 | -------------------------------------------------------------------------------- /yelp/train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import time 4 | import math 5 | import numpy as np 6 | import random 7 | import sys 8 | import json 9 | 10 | import torch 11 | import torch.nn as nn 12 | import torch.optim as optim 13 | import torch.nn.functional as F 14 | from torch.autograd import Variable 15 | 16 | from utils import to_gpu, Corpus, batchify 17 | from models import Seq2Seq2Decoder, Seq2Seq, MLP_D, MLP_G, MLP_Classify 18 | import shutil 19 | 20 | parser = argparse.ArgumentParser(description='ARAE for Yelp transfer') 21 | # Path Arguments 22 | parser.add_argument('--data_path', type=str, required=True, 23 | help='location of the data corpus') 24 | parser.add_argument('--outf', type=str, default='yelp_example', 25 | help='output directory name') 26 | parser.add_argument('--load_vocab', type=str, default="", 27 | help='path to load vocabulary from') 28 | 29 | # Data Processing Arguments 30 | parser.add_argument('--vocab_size', type=int, default=30000, 31 | help='cut vocabulary down to this size ' 32 | '(most frequently seen words in train)') 33 | parser.add_argument('--maxlen', type=int, default=25, 34 | help='maximum sentence length') 35 | parser.add_argument('--lowercase', dest='lowercase', action='store_true', 36 | help='lowercase all text') 37 | parser.add_argument('--no-lowercase', dest='lowercase', action='store_true', 38 | help='not lowercase all text') 39 | parser.set_defaults(lowercase=True) 40 | 41 | # Model Arguments 42 | parser.add_argument('--emsize', type=int, default=128, 43 | help='size of word embeddings') 44 | parser.add_argument('--nhidden', type=int, default=128, 45 | help='number of hidden units per layer') 46 | parser.add_argument('--nlayers', type=int, default=1, 47 | help='number of layers') 48 | parser.add_argument('--noise_r', type=float, default=0.1, 49 | help='stdev of noise for autoencoder (regularizer)') 50 | parser.add_argument('--noise_anneal', type=float, default=0.9995, 51 | help='anneal noise_r exponentially by this' 52 | 'every 100 iterations') 53 | parser.add_argument('--hidden_init', action='store_true', 54 | help="initialize decoder hidden state with encoder's") 55 | parser.add_argument('--arch_g', type=str, default='128-128', 56 | help='generator architecture (MLP)') 57 | parser.add_argument('--arch_d', type=str, default='128-128', 58 | help='critic/discriminator architecture (MLP)') 59 | parser.add_argument('--arch_classify', type=str, default='128-128', 60 | help='classifier architecture') 61 | parser.add_argument('--z_size', type=int, default=32, 62 | help='dimension of random noise z to feed into generator') 63 | parser.add_argument('--temp', type=float, default=1, 64 | help='softmax temperature (lower --> more discrete)') 65 | parser.add_argument('--dropout', type=float, default=0.0, 66 | help='dropout applied to layers (0 = no dropout)') 67 | 68 | # Training Arguments 69 | parser.add_argument('--epochs', type=int, default=25, 70 | help='maximum number of epochs') 71 | parser.add_argument('--batch_size', type=int, default=64, metavar='N', 72 | help='batch size') 73 | parser.add_argument('--niters_ae', type=int, default=1, 74 | help='number of autoencoder iterations in training') 75 | parser.add_argument('--niters_gan_d', type=int, default=5, 76 | help='number of discriminator iterations in training') 77 | parser.add_argument('--niters_gan_g', type=int, default=1, 78 | help='number of generator iterations in training') 79 | parser.add_argument('--niters_gan_ae', type=int, default=1, 80 | help='number of gan-into-ae iterations in training') 81 | parser.add_argument('--niters_gan_schedule', type=str, default='', 82 | help='epoch counts to increase number of GAN training ' 83 | ' iterations (increment by 1 each time)') 84 | parser.add_argument('--lr_ae', type=float, default=1, 85 | help='autoencoder learning rate') 86 | parser.add_argument('--lr_gan_g', type=float, default=1e-04, 87 | help='generator learning rate') 88 | parser.add_argument('--lr_gan_d', type=float, default=1e-04, 89 | help='critic/discriminator learning rate') 90 | parser.add_argument('--lr_classify', type=float, default=1e-04, 91 | help='classifier learning rate') 92 | parser.add_argument('--beta1', type=float, default=0.5, 93 | help='beta1 for adam. default=0.5') 94 | parser.add_argument('--clip', type=float, default=1, 95 | help='gradient clipping, max norm') 96 | parser.add_argument('--gan_gp_lambda', type=float, default=0.1, 97 | help='WGAN GP penalty lambda') 98 | parser.add_argument('--grad_lambda', type=float, default=0.01, 99 | help='WGAN into AE lambda') 100 | parser.add_argument('--lambda_class', type=float, default=1, 101 | help='lambda on classifier') 102 | 103 | # Evaluation Arguments 104 | parser.add_argument('--sample', action='store_true', 105 | help='sample when decoding for generation') 106 | parser.add_argument('--log_interval', type=int, default=200, 107 | help='interval to log autoencoder training results') 108 | 109 | # Other 110 | parser.add_argument('--seed', type=int, default=1111, 111 | help='random seed') 112 | parser.add_argument('--cuda', dest='cuda', action='store_true', 113 | help='use CUDA') 114 | parser.add_argument('--no-cuda', dest='cuda', action='store_true', 115 | help='not using CUDA') 116 | parser.set_defaults(cuda=True) 117 | parser.add_argument('--device_id', type=str, default='0') 118 | 119 | args = parser.parse_args() 120 | print(vars(args)) 121 | 122 | os.environ['CUDA_VISIBLE_DEVICES'] = args.device_id 123 | 124 | # make output directory if it doesn't already exist 125 | if os.path.isdir(args.outf): 126 | shutil.rmtree(args.outf) 127 | os.makedirs(args.outf) 128 | 129 | # Set the random seed manually for reproducibility. 130 | random.seed(args.seed) 131 | np.random.seed(args.seed) 132 | torch.manual_seed(args.seed) 133 | if torch.cuda.is_available(): 134 | if not args.cuda: 135 | print("WARNING: You have a CUDA device, " 136 | "so you should probably run with --cuda") 137 | else: 138 | torch.cuda.manual_seed(args.seed) 139 | 140 | ############################################################################### 141 | # Load data 142 | ############################################################################### 143 | 144 | label_ids = {"pos": 1, "neg": 0} 145 | id2label = {1:"pos", 0:"neg"} 146 | 147 | # (Path to textfile, Name, Use4Vocab) 148 | datafiles = [(os.path.join(args.data_path, "valid1.txt"), "valid1", False), 149 | (os.path.join(args.data_path, "valid2.txt"), "valid2", False), 150 | (os.path.join(args.data_path, "train1.txt"), "train1", True), 151 | (os.path.join(args.data_path, "train2.txt"), "train2", True)] 152 | vocabdict = None 153 | if args.load_vocab != "": 154 | vocabdict = json.load(args.vocab) 155 | vocabdict = {k: int(v) for k, v in vocabdict.items()} 156 | corpus = Corpus(datafiles, 157 | maxlen=args.maxlen, 158 | vocab_size=args.vocab_size, 159 | lowercase=args.lowercase, 160 | vocab=vocabdict) 161 | 162 | # dumping vocabulary 163 | with open('{}/vocab.json'.format(args.outf), 'w') as f: 164 | json.dump(corpus.dictionary.word2idx, f) 165 | 166 | # save arguments 167 | ntokens = len(corpus.dictionary.word2idx) 168 | print("Vocabulary Size: {}".format(ntokens)) 169 | args.ntokens = ntokens 170 | with open('{}/args.json'.format(args.outf), 'w') as f: 171 | json.dump(vars(args), f) 172 | with open("{}/log.txt".format(args.outf), 'w') as f: 173 | f.write(str(vars(args))) 174 | f.write("\n\n") 175 | 176 | eval_batch_size = 100 177 | test1_data = batchify(corpus.data['valid1'], eval_batch_size, shuffle=False) 178 | test2_data = batchify(corpus.data['valid2'], eval_batch_size, shuffle=False) 179 | train1_data = batchify(corpus.data['train1'], args.batch_size, shuffle=True) 180 | train2_data = batchify(corpus.data['train2'], args.batch_size, shuffle=True) 181 | 182 | print("Loaded data!") 183 | 184 | ############################################################################### 185 | # Build the models 186 | ############################################################################### 187 | 188 | ntokens = len(corpus.dictionary.word2idx) 189 | autoencoder = Seq2Seq2Decoder(emsize=args.emsize, 190 | nhidden=args.nhidden, 191 | ntokens=ntokens, 192 | nlayers=args.nlayers, 193 | noise_r=args.noise_r, 194 | hidden_init=args.hidden_init, 195 | dropout=args.dropout, 196 | gpu=args.cuda) 197 | 198 | gan_gen = MLP_G(ninput=args.z_size, noutput=args.nhidden, layers=args.arch_g) 199 | gan_disc = MLP_D(ninput=args.nhidden, noutput=1, layers=args.arch_d) 200 | classifier = MLP_Classify(ninput=args.nhidden, noutput=1, layers=args.arch_classify) 201 | g_factor = None 202 | 203 | print(autoencoder) 204 | print(gan_gen) 205 | print(gan_disc) 206 | print(classifier) 207 | 208 | optimizer_ae = optim.SGD(autoencoder.parameters(), lr=args.lr_ae) 209 | optimizer_gan_g = optim.Adam(gan_gen.parameters(), 210 | lr=args.lr_gan_g, 211 | betas=(args.beta1, 0.999)) 212 | optimizer_gan_d = optim.Adam(gan_disc.parameters(), 213 | lr=args.lr_gan_d, 214 | betas=(args.beta1, 0.999)) 215 | #### classify 216 | optimizer_classify = optim.Adam(classifier.parameters(), 217 | lr=args.lr_classify, 218 | betas=(args.beta1, 0.999)) 219 | 220 | criterion_ce = nn.CrossEntropyLoss() 221 | 222 | if args.cuda: 223 | autoencoder = autoencoder.cuda() 224 | gan_gen = gan_gen.cuda() 225 | gan_disc = gan_disc.cuda() 226 | classifier = classifier.cuda() 227 | criterion_ce = criterion_ce.cuda() 228 | 229 | ############################################################################### 230 | # Training code 231 | ############################################################################### 232 | 233 | 234 | def save_model(): 235 | print("Saving models") 236 | with open('{}/autoencoder_model.pt'.format(args.outf), 'wb') as f: 237 | torch.save(autoencoder.state_dict(), f) 238 | with open('{}/gan_gen_model.pt'.format(args.outf), 'wb') as f: 239 | torch.save(gan_gen.state_dict(), f) 240 | with open('{}/gan_disc_model.pt'.format(args.outf), 'wb') as f: 241 | torch.save(gan_disc.state_dict(), f) 242 | 243 | 244 | def train_classifier(whichclass, batch): 245 | classifier.train() 246 | classifier.zero_grad() 247 | 248 | source, target, lengths = batch 249 | source = to_gpu(args.cuda, Variable(source)) 250 | labels = to_gpu(args.cuda, Variable(torch.zeros(source.size(0)).fill_(whichclass-1))) 251 | 252 | # Train 253 | code = autoencoder(0, source, lengths, noise=False, encode_only=True).detach() 254 | scores = classifier(code) 255 | classify_loss = F.binary_cross_entropy(scores.squeeze(1), labels) 256 | classify_loss.backward() 257 | optimizer_classify.step() 258 | classify_loss = classify_loss.cpu().data[0] 259 | 260 | pred = scores.data.round().squeeze(1) 261 | accuracy = pred.eq(labels.data).float().mean() 262 | 263 | return classify_loss, accuracy 264 | 265 | 266 | def grad_hook_cla(grad): 267 | return grad * args.lambda_class 268 | 269 | 270 | def classifier_regularize(whichclass, batch): 271 | autoencoder.train() 272 | autoencoder.zero_grad() 273 | 274 | source, target, lengths = batch 275 | source = to_gpu(args.cuda, Variable(source)) 276 | target = to_gpu(args.cuda, Variable(target)) 277 | flippedclass = abs(2-whichclass) 278 | labels = to_gpu(args.cuda, Variable(torch.zeros(source.size(0)).fill_(flippedclass))) 279 | 280 | # Train 281 | code = autoencoder(0, source, lengths, noise=False, encode_only=True) 282 | code.register_hook(grad_hook_cla) 283 | scores = classifier(code) 284 | classify_reg_loss = F.binary_cross_entropy(scores.squeeze(1), labels) 285 | classify_reg_loss.backward() 286 | 287 | torch.nn.utils.clip_grad_norm(autoencoder.parameters(), args.clip) 288 | optimizer_ae.step() 289 | 290 | return classify_reg_loss 291 | 292 | 293 | def evaluate_autoencoder(whichdecoder, data_source, epoch): 294 | # Turn on evaluation mode which disables dropout. 295 | autoencoder.eval() 296 | total_loss = 0 297 | ntokens = len(corpus.dictionary.word2idx) 298 | all_accuracies = 0 299 | bcnt = 0 300 | for i, batch in enumerate(data_source): 301 | source, target, lengths = batch 302 | source = to_gpu(args.cuda, Variable(source, volatile=True)) 303 | target = to_gpu(args.cuda, Variable(target, volatile=True)) 304 | 305 | mask = target.gt(0) 306 | masked_target = target.masked_select(mask) 307 | # examples x ntokens 308 | output_mask = mask.unsqueeze(1).expand(mask.size(0), ntokens) 309 | 310 | hidden = autoencoder(0, source, lengths, noise=False, encode_only=True) 311 | 312 | # output: batch x seq_len x ntokens 313 | if whichdecoder == 1: 314 | output = autoencoder(1, source, lengths, noise=False) 315 | flattened_output = output.view(-1, ntokens) 316 | masked_output = \ 317 | flattened_output.masked_select(output_mask).view(-1, ntokens) 318 | # accuracy 319 | max_vals1, max_indices1 = torch.max(masked_output, 1) 320 | all_accuracies += \ 321 | torch.mean(max_indices1.eq(masked_target).float()).data[0] 322 | 323 | max_values1, max_indices1 = torch.max(output, 2) 324 | max_indices2 = autoencoder.generate(2, hidden, maxlen=50) 325 | else: 326 | output = autoencoder(2, source, lengths, noise=False) 327 | flattened_output = output.view(-1, ntokens) 328 | masked_output = \ 329 | flattened_output.masked_select(output_mask).view(-1, ntokens) 330 | # accuracy 331 | max_vals2, max_indices2 = torch.max(masked_output, 1) 332 | all_accuracies += \ 333 | torch.mean(max_indices2.eq(masked_target).float()).data[0] 334 | 335 | max_values2, max_indices2 = torch.max(output, 2) 336 | max_indices1 = autoencoder.generate(1, hidden, maxlen=50) 337 | 338 | total_loss += criterion_ce(masked_output/args.temp, masked_target).data 339 | bcnt += 1 340 | 341 | aeoutf_from = "{}/{}_output_decoder_{}_from.txt".format(args.outf, epoch, whichdecoder) 342 | aeoutf_tran = "{}/{}_output_decoder_{}_tran.txt".format(args.outf, epoch, whichdecoder) 343 | with open(aeoutf_from, 'w') as f_from, open(aeoutf_tran,'w') as f_trans: 344 | max_indices1 = \ 345 | max_indices1.view(output.size(0), -1).data.cpu().numpy() 346 | max_indices2 = \ 347 | max_indices2.view(output.size(0), -1).data.cpu().numpy() 348 | target = target.view(output.size(0), -1).data.cpu().numpy() 349 | tran_indices = max_indices2 if whichdecoder == 1 else max_indices1 350 | for t, tran_idx in zip(target, tran_indices): 351 | # real sentence 352 | chars = " ".join([corpus.dictionary.idx2word[x] for x in t]) 353 | f_from.write(chars) 354 | f_from.write("\n") 355 | # transfer sentence 356 | chars = " ".join([corpus.dictionary.idx2word[x] for x in tran_idx]) 357 | f_trans.write(chars) 358 | f_trans.write("\n") 359 | 360 | return total_loss[0] / len(data_source), all_accuracies/bcnt 361 | 362 | 363 | def evaluate_generator(whichdecoder, noise, epoch): 364 | gan_gen.eval() 365 | autoencoder.eval() 366 | 367 | # generate from fixed random noise 368 | fake_hidden = gan_gen(noise) 369 | max_indices = \ 370 | autoencoder.generate(whichdecoder, fake_hidden, maxlen=50, sample=args.sample) 371 | 372 | with open("%s/%s_generated%d.txt" % (args.outf, epoch, whichdecoder), "w") as f: 373 | max_indices = max_indices.data.cpu().numpy() 374 | for idx in max_indices: 375 | # generated sentence 376 | words = [corpus.dictionary.idx2word[x] for x in idx] 377 | # truncate sentences to first occurrence of 378 | truncated_sent = [] 379 | for w in words: 380 | if w != '': 381 | truncated_sent.append(w) 382 | else: 383 | break 384 | chars = " ".join(truncated_sent) 385 | f.write(chars) 386 | f.write("\n") 387 | 388 | 389 | def train_ae(whichdecoder, batch, total_loss_ae, start_time, i): 390 | autoencoder.train() 391 | optimizer_ae.zero_grad() 392 | 393 | source, target, lengths = batch 394 | source = to_gpu(args.cuda, Variable(source)) 395 | target = to_gpu(args.cuda, Variable(target)) 396 | 397 | mask = target.gt(0) 398 | masked_target = target.masked_select(mask) 399 | output_mask = mask.unsqueeze(1).expand(mask.size(0), ntokens) 400 | output = autoencoder(whichdecoder, source, lengths, noise=True) 401 | flat_output = output.view(-1, ntokens) 402 | masked_output = flat_output.masked_select(output_mask).view(-1, ntokens) 403 | loss = criterion_ce(masked_output/args.temp, masked_target) 404 | loss.backward() 405 | 406 | # `clip_grad_norm` to prevent exploding gradient in RNNs / LSTMs 407 | torch.nn.utils.clip_grad_norm(autoencoder.parameters(), args.clip) 408 | optimizer_ae.step() 409 | 410 | total_loss_ae += loss.data 411 | 412 | accuracy = None 413 | if i % args.log_interval == 0 and i > 0: 414 | probs = F.softmax(masked_output, dim=-1) 415 | max_vals, max_indices = torch.max(probs, 1) 416 | accuracy = torch.mean(max_indices.eq(masked_target).float()).data[0] 417 | cur_loss = total_loss_ae[0] / args.log_interval 418 | elapsed = time.time() - start_time 419 | print('| epoch {:3d} | {:5d}/{:5d} batches | ms/batch {:5.2f} | ' 420 | 'loss {:5.2f} | ppl {:8.2f} | acc {:8.2f}' 421 | .format(epoch, i, len(train1_data), 422 | elapsed * 1000 / args.log_interval, 423 | cur_loss, math.exp(cur_loss), accuracy)) 424 | 425 | with open("{}/log.txt".format(args.outf), 'a') as f: 426 | f.write('| epoch {:3d} | {:5d}/{:5d} batches | ms/batch {:5.2f} | ' 427 | 'loss {:5.2f} | ppl {:8.2f} | acc {:8.2f}\n'. 428 | format(epoch, i, len(train1_data), 429 | elapsed * 1000 / args.log_interval, 430 | cur_loss, math.exp(cur_loss), accuracy)) 431 | 432 | total_loss_ae = 0 433 | start_time = time.time() 434 | 435 | return total_loss_ae, start_time 436 | 437 | 438 | def train_gan_g(): 439 | gan_gen.train() 440 | gan_gen.zero_grad() 441 | 442 | noise = to_gpu(args.cuda, 443 | Variable(torch.ones(args.batch_size, args.z_size))) 444 | noise.data.normal_(0, 1) 445 | fake_hidden = gan_gen(noise) 446 | errG = gan_disc(fake_hidden) 447 | errG.backward(one) 448 | optimizer_gan_g.step() 449 | 450 | return errG 451 | 452 | 453 | def grad_hook(grad): 454 | return grad * args.grad_lambda 455 | 456 | 457 | ''' Steal from https://github.com/caogang/wgan-gp/blob/master/gan_cifar10.py ''' 458 | def calc_gradient_penalty(netD, real_data, fake_data): 459 | bsz = real_data.size(0) 460 | alpha = torch.rand(bsz, 1) 461 | alpha = alpha.expand(bsz, real_data.size(1)) # only works for 2D XXX 462 | alpha = alpha.cuda() 463 | interpolates = alpha * real_data + ((1 - alpha) * fake_data) 464 | interpolates = Variable(interpolates, requires_grad=True) 465 | disc_interpolates = netD(interpolates) 466 | 467 | gradients = torch.autograd.grad(outputs=disc_interpolates, inputs=interpolates, 468 | grad_outputs=torch.ones(disc_interpolates.size()).cuda(), 469 | create_graph=True, retain_graph=True, only_inputs=True)[0] 470 | gradients = gradients.view(gradients.size(0), -1) 471 | 472 | gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() * args.gan_gp_lambda 473 | return gradient_penalty 474 | 475 | 476 | def train_gan_d(whichdecoder, batch): 477 | gan_disc.train() 478 | optimizer_gan_d.zero_grad() 479 | 480 | # positive samples ---------------------------- 481 | # generate real codes 482 | source, target, lengths = batch 483 | source = to_gpu(args.cuda, Variable(source)) 484 | target = to_gpu(args.cuda, Variable(target)) 485 | 486 | # batch_size x nhidden 487 | real_hidden = autoencoder(whichdecoder, source, lengths, noise=False, encode_only=True) 488 | 489 | # loss / backprop 490 | errD_real = gan_disc(real_hidden) 491 | errD_real.backward(one) 492 | 493 | # negative samples ---------------------------- 494 | # generate fake codes 495 | noise = to_gpu(args.cuda, 496 | Variable(torch.ones(args.batch_size, args.z_size))) 497 | noise.data.normal_(0, 1) 498 | 499 | # loss / backprop 500 | fake_hidden = gan_gen(noise) 501 | errD_fake = gan_disc(fake_hidden.detach()) 502 | errD_fake.backward(mone) 503 | 504 | # gradient penalty 505 | gradient_penalty = calc_gradient_penalty(gan_disc, real_hidden.data, fake_hidden.data) 506 | gradient_penalty.backward() 507 | 508 | optimizer_gan_d.step() 509 | errD = -(errD_real - errD_fake) 510 | 511 | return errD, errD_real, errD_fake 512 | 513 | 514 | def train_gan_d_into_ae(whichdecoder, batch): 515 | autoencoder.train() 516 | optimizer_ae.zero_grad() 517 | 518 | source, target, lengths = batch 519 | source = to_gpu(args.cuda, Variable(source)) 520 | target = to_gpu(args.cuda, Variable(target)) 521 | real_hidden = autoencoder(whichdecoder, source, lengths, noise=False, encode_only=True) 522 | real_hidden.register_hook(grad_hook) 523 | errD_real = gan_disc(real_hidden) 524 | errD_real.backward(mone) 525 | torch.nn.utils.clip_grad_norm(autoencoder.parameters(), args.clip) 526 | 527 | optimizer_ae.step() 528 | 529 | return errD_real 530 | 531 | 532 | print("Training...") 533 | with open("{}/log.txt".format(args.outf), 'a') as f: 534 | f.write('Training...\n') 535 | 536 | # schedule of increasing GAN training loops 537 | if args.niters_gan_schedule != "": 538 | gan_schedule = [int(x) for x in args.niters_gan_schedule.split("-")] 539 | else: 540 | gan_schedule = [] 541 | niter_gan = 1 542 | 543 | fixed_noise = to_gpu(args.cuda, 544 | Variable(torch.ones(args.batch_size, args.z_size))) 545 | fixed_noise.data.normal_(0, 1) 546 | one = to_gpu(args.cuda, torch.FloatTensor([1])) 547 | mone = one * -1 548 | 549 | for epoch in range(1, args.epochs+1): 550 | # update gan training schedule 551 | if epoch in gan_schedule: 552 | niter_gan += 1 553 | print("GAN training loop schedule increased to {}".format(niter_gan)) 554 | with open("{}/log.txt".format(args.outf), 'a') as f: 555 | f.write("GAN training loop schedule increased to {}\n". 556 | format(niter_gan)) 557 | 558 | total_loss_ae1 = 0 559 | total_loss_ae2 = 0 560 | classify_loss = 0 561 | epoch_start_time = time.time() 562 | start_time = time.time() 563 | niter = 0 564 | niter_global = 1 565 | 566 | # loop through all batches in training data 567 | while niter < len(train1_data) and niter < len(train2_data): 568 | 569 | # train autoencoder ---------------------------- 570 | for i in range(args.niters_ae): 571 | if niter == len(train1_data): 572 | break # end of epoch 573 | total_loss_ae1, start_time = \ 574 | train_ae(1, train1_data[niter], total_loss_ae1, start_time, niter) 575 | total_loss_ae2, _ = \ 576 | train_ae(2, train2_data[niter], total_loss_ae2, start_time, niter) 577 | 578 | # train classifier ---------------------------- 579 | classify_loss1, classify_acc1 = train_classifier(1, train1_data[niter]) 580 | classify_loss2, classify_acc2 = train_classifier(2, train2_data[niter]) 581 | classify_loss = (classify_loss1 + classify_loss2) / 2 582 | classify_acc = (classify_acc1 + classify_acc2) / 2 583 | # reverse to autoencoder 584 | classifier_regularize(1, train1_data[niter]) 585 | classifier_regularize(2, train2_data[niter]) 586 | 587 | niter += 1 588 | 589 | # train gan ---------------------------------- 590 | for k in range(niter_gan): 591 | 592 | # train discriminator/critic 593 | for i in range(args.niters_gan_d): 594 | # feed a seen sample within this epoch; good for early training 595 | if i % 2 == 0: 596 | batch = train1_data[random.randint(0, len(train1_data)-1)] 597 | whichdecoder = 1 598 | else: 599 | batch = train2_data[random.randint(0, len(train2_data)-1)] 600 | whichdecoder = 2 601 | errD, errD_real, errD_fake = train_gan_d(whichdecoder, batch) 602 | 603 | # train generator 604 | for i in range(args.niters_gan_g): 605 | errG = train_gan_g() 606 | 607 | # train autoencoder from d 608 | for i in range(args.niters_gan_ae): 609 | if i % 2 == 0: 610 | batch = train1_data[random.randint(0, len(train1_data)-1)] 611 | whichdecoder = 1 612 | else: 613 | batch = train2_data[random.randint(0, len(train2_data)-1)] 614 | whichdecoder = 2 615 | errD_ = train_gan_d_into_ae(whichdecoder, batch) 616 | 617 | niter_global += 1 618 | if niter_global % 100 == 0: 619 | print('[%d/%d][%d/%d] Loss_D: %.4f (Loss_D_real: %.4f ' 620 | 'Loss_D_fake: %.4f) Loss_G: %.4f' 621 | % (epoch, args.epochs, niter, len(train1_data), 622 | errD.data[0], errD_real.data[0], 623 | errD_fake.data[0], errG.data[0])) 624 | print("Classify loss: {:5.2f} | Classify accuracy: {:3.3f}\n".format( 625 | classify_loss, classify_acc)) 626 | with open("{}/log.txt".format(args.outf), 'a') as f: 627 | f.write('[%d/%d][%d/%d] Loss_D: %.4f (Loss_D_real: %.4f ' 628 | 'Loss_D_fake: %.4f) Loss_G: %.4f\n' 629 | % (epoch, args.epochs, niter, len(train1_data), 630 | errD.data[0], errD_real.data[0], 631 | errD_fake.data[0], errG.data[0])) 632 | f.write("Classify loss: {:5.2f} | Classify accuracy: {:3.3f}\n".format( 633 | classify_loss, classify_acc)) 634 | 635 | # exponentially decaying noise on autoencoder 636 | autoencoder.noise_r = \ 637 | autoencoder.noise_r*args.noise_anneal 638 | 639 | 640 | # end of epoch ---------------------------- 641 | # evaluation 642 | test_loss, accuracy = evaluate_autoencoder(1, test1_data[:1000], epoch) 643 | print('-' * 89) 644 | print('| end of epoch {:3d} | time: {:5.2f}s | test loss {:5.2f} | ' 645 | 'test ppl {:5.2f} | acc {:3.3f}'. 646 | format(epoch, (time.time() - epoch_start_time), 647 | test_loss, math.exp(test_loss), accuracy)) 648 | print('-' * 89) 649 | with open("{}/log.txt".format(args.outf), 'a') as f: 650 | f.write('-' * 89) 651 | f.write('\n| end of epoch {:3d} | time: {:5.2f}s | test loss {:5.2f} |' 652 | ' test ppl {:5.2f} | acc {:3.3f}\n'. 653 | format(epoch, (time.time() - epoch_start_time), 654 | test_loss, math.exp(test_loss), accuracy)) 655 | f.write('-' * 89) 656 | f.write('\n') 657 | 658 | test_loss, accuracy = evaluate_autoencoder(2, test2_data[:1000], epoch) 659 | print('-' * 89) 660 | print('| end of epoch {:3d} | time: {:5.2f}s | test loss {:5.2f} | ' 661 | 'test ppl {:5.2f} | acc {:3.3f}'. 662 | format(epoch, (time.time() - epoch_start_time), 663 | test_loss, math.exp(test_loss), accuracy)) 664 | print('-' * 89) 665 | with open("{}/log.txt".format(args.outf), 'a') as f: 666 | f.write('-' * 89) 667 | f.write('\n| end of epoch {:3d} | time: {:5.2f}s | test loss {:5.2f} |' 668 | ' test ppl {:5.2f} | acc {:3.3f}\n'. 669 | format(epoch, (time.time() - epoch_start_time), 670 | test_loss, math.exp(test_loss), accuracy)) 671 | f.write('-' * 89) 672 | f.write('\n') 673 | 674 | evaluate_generator(1, fixed_noise, "end_of_epoch_{}".format(epoch)) 675 | evaluate_generator(2, fixed_noise, "end_of_epoch_{}".format(epoch)) 676 | 677 | # shuffle between epochs 678 | train1_data = batchify(corpus.data['train1'], args.batch_size, shuffle=True) 679 | train2_data = batchify(corpus.data['train2'], args.batch_size, shuffle=True) 680 | 681 | 682 | test_loss, accuracy = evaluate_autoencoder(1, test1_data, epoch+1) 683 | print('-' * 89) 684 | print('| end of epoch {:3d} | time: {:5.2f}s | test loss {:5.2f} | ' 685 | 'test ppl {:5.2f} | acc {:3.3f}'. 686 | format(epoch, (time.time() - epoch_start_time), 687 | test_loss, math.exp(test_loss), accuracy)) 688 | print('-' * 89) 689 | with open("{}/log.txt".format(args.outf), 'a') as f: 690 | f.write('-' * 89) 691 | f.write('\n| end of epoch {:3d} | time: {:5.2f}s | test loss {:5.2f} |' 692 | ' test ppl {:5.2f} | acc {:3.3f}\n'. 693 | format(epoch, (time.time() - epoch_start_time), 694 | test_loss, math.exp(test_loss), accuracy)) 695 | f.write('-' * 89) 696 | f.write('\n') 697 | 698 | test_loss, accuracy = evaluate_autoencoder(2, test2_data, epoch+1) 699 | print('-' * 89) 700 | print('| end of epoch {:3d} | time: {:5.2f}s | test loss {:5.2f} | ' 701 | 'test ppl {:5.2f} | acc {:3.3f}'. 702 | format(epoch, (time.time() - epoch_start_time), 703 | test_loss, math.exp(test_loss), accuracy)) 704 | print('-' * 89) 705 | with open("{}/log.txt".format(args.outf), 'a') as f: 706 | f.write('-' * 89) 707 | f.write('\n| end of epoch {:3d} | time: {:5.2f}s | test loss {:5.2f} |' 708 | ' test ppl {:5.2f} | acc {:3.3f}\n'. 709 | format(epoch, (time.time() - epoch_start_time), 710 | test_loss, math.exp(test_loss), accuracy)) 711 | f.write('-' * 89) 712 | f.write('\n') 713 | -------------------------------------------------------------------------------- /yelp/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | import random 5 | 6 | PAD_WORD="" 7 | EOS_WORD="" 8 | BOS_WORD="" 9 | UNK="" 10 | 11 | def load_kenlm(): 12 | global kenlm 13 | import kenlm 14 | 15 | 16 | def to_gpu(gpu, var): 17 | if gpu: 18 | return var.cuda() 19 | return var 20 | 21 | 22 | class Dictionary(object): 23 | def __init__(self, word2idx=None): 24 | if word2idx is None: 25 | self.word2idx = {} 26 | self.idx2word = {} 27 | self.word2idx[PAD_WORD] = 0 28 | self.word2idx[BOS_WORD] = 1 29 | self.word2idx[EOS_WORD] = 2 30 | self.word2idx[UNK] = 3 31 | self.wordcounts = {} 32 | else: 33 | self.word2idx = word2idx 34 | self.idx2word = {v: k for k, v in word2idx.items()} 35 | 36 | # to track word counts 37 | def add_word(self, word): 38 | if word not in self.wordcounts: 39 | self.wordcounts[word] = 1 40 | else: 41 | self.wordcounts[word] += 1 42 | 43 | # prune vocab based on count k cutoff or most frequently seen k words 44 | def prune_vocab(self, k=5, cnt=False): 45 | # get all words and their respective counts 46 | vocab_list = [(word, count) for word, count in self.wordcounts.items()] 47 | if cnt: 48 | # prune by count 49 | self.pruned_vocab = \ 50 | {pair[0]: pair[1] for pair in vocab_list if pair[1] > k} 51 | else: 52 | # prune by most frequently seen words 53 | vocab_list.sort(key=lambda x: (x[1], x[0]), reverse=True) 54 | k = min(k, len(vocab_list)) 55 | self.pruned_vocab = [pair[0] for pair in vocab_list[:k]] 56 | # sort to make vocabulary determistic 57 | self.pruned_vocab.sort() 58 | 59 | # add all chosen words to new vocabulary/dict 60 | for word in self.pruned_vocab: 61 | if word not in self.word2idx: 62 | self.word2idx[word] = len(self.word2idx) 63 | print("Original vocab {}; Pruned to {}". 64 | format(len(self.wordcounts), len(self.word2idx))) 65 | self.idx2word = {v: k for k, v in self.word2idx.items()} 66 | 67 | def __len__(self): 68 | return len(self.word2idx) 69 | 70 | 71 | class Corpus(object): 72 | def __init__(self, datafiles, maxlen, vocab_size=11000, lowercase=False, vocab=None, debug=False): 73 | self.dictionary = Dictionary(vocab) 74 | self.maxlen = maxlen 75 | self.lowercase = lowercase 76 | self.vocab_size = vocab_size 77 | self.datafiles = datafiles 78 | self.forvocab = [] 79 | self.data = {} 80 | 81 | if vocab is None: 82 | for path, name, fvocab in datafiles: 83 | if fvocab or debug: 84 | self.forvocab.append(path) 85 | self.make_vocab() 86 | 87 | for path, name, _ in datafiles: 88 | self.data[name] = self.tokenize(path) 89 | 90 | 91 | def make_vocab(self): 92 | for path in self.forvocab: 93 | assert os.path.exists(path) 94 | # Add words to the dictionary 95 | with open(path, 'r') as f: 96 | for line in f: 97 | L = line.lower() if self.lowercase else line 98 | words = L.strip().split(" ") 99 | for word in words: 100 | self.dictionary.add_word(word) 101 | 102 | # prune the vocabulary 103 | self.dictionary.prune_vocab(k=self.vocab_size, cnt=False) 104 | 105 | def tokenize(self, path): 106 | """Tokenizes a text file.""" 107 | dropped = 0 108 | with open(path, 'r') as f: 109 | linecount = 0 110 | lines = [] 111 | for line in f: 112 | linecount += 1 113 | L = line.lower() if self.lowercase else line 114 | words = L.strip().split(" ") 115 | if self.maxlen > 0 and len(words) > self.maxlen: 116 | dropped += 1 117 | continue 118 | words = [BOS_WORD] + words + [EOS_WORD] 119 | # vectorize 120 | vocab = self.dictionary.word2idx 121 | unk_idx = vocab[UNK] 122 | indices = [vocab[w] if w in vocab else unk_idx for w in words] 123 | lines.append(indices) 124 | 125 | print("Number of sentences dropped from {}: {} out of {} total". 126 | format(path, dropped, linecount)) 127 | return lines 128 | 129 | 130 | def batchify(data, bsz, shuffle=False, gpu=False): 131 | if shuffle: 132 | random.shuffle(data) 133 | 134 | nbatch = len(data) // bsz 135 | batches = [] 136 | 137 | for i in range(nbatch): 138 | # Pad batches to maximum sequence length in batch 139 | batch = data[i*bsz:(i+1)*bsz] 140 | 141 | # subtract 1 from lengths b/c includes BOTH starts & end symbols 142 | words = batch 143 | lengths = [len(x)-1 for x in words] 144 | 145 | # sort items by length (decreasing) 146 | batch, lengths = length_sort(batch, lengths) 147 | words = batch 148 | 149 | # source has no end symbol 150 | source = [x[:-1] for x in words] 151 | # target has no start symbol 152 | target = [x[1:] for x in words] 153 | 154 | # find length to pad to 155 | maxlen = max(lengths) 156 | for x, y in zip(source, target): 157 | zeros = (maxlen-len(x))*[0] 158 | x += zeros 159 | y += zeros 160 | 161 | source = torch.LongTensor(np.array(source)) 162 | target = torch.LongTensor(np.array(target)).view(-1) 163 | 164 | batches.append((source, target, lengths)) 165 | print('{} batches'.format(len(batches))) 166 | return batches 167 | 168 | 169 | def length_sort(items, lengths, descending=True): 170 | """In order to use pytorch variable length sequence package""" 171 | items = list(zip(items, lengths)) 172 | items.sort(key=lambda x: x[1], reverse=True) 173 | items, lengths = zip(*items) 174 | return list(items), list(lengths) 175 | 176 | 177 | def truncate(words): 178 | # truncate sentences to first occurrence of 179 | truncated_sent = [] 180 | for w in words: 181 | if w != EOS_WORD: 182 | truncated_sent.append(w) 183 | else: 184 | break 185 | sent = " ".join(truncated_sent) 186 | return sent 187 | 188 | 189 | def train_ngram_lm(kenlm_path, data_path, output_path, N): 190 | """ 191 | Trains a modified Kneser-Ney n-gram KenLM from a text file. 192 | Creates a .arpa file to store n-grams. 193 | """ 194 | # create .arpa file of n-grams 195 | curdir = os.path.abspath(os.path.curdir) 196 | 197 | command = "bin/lmplz -o "+str(N)+" <"+os.path.join(curdir, data_path) + \ 198 | " >"+os.path.join(curdir, output_path) 199 | os.system("cd "+os.path.join(kenlm_path, 'build')+" && "+command) 200 | 201 | load_kenlm() 202 | # create language model 203 | model = kenlm.Model(output_path) 204 | 205 | return model 206 | 207 | 208 | def get_ppl(lm, sentences): 209 | """ 210 | Assume sentences is a list of strings (space delimited sentences) 211 | """ 212 | total_nll = 0 213 | total_wc = 0 214 | for sent in sentences: 215 | words = sent.strip().split() 216 | score = lm.score(sent, bos=True, eos=False) 217 | word_count = len(words) 218 | total_wc += word_count 219 | total_nll += score 220 | ppl = 10**-(total_nll/total_wc) 221 | return ppl 222 | --------------------------------------------------------------------------------