├── layers ├── .DS_Store ├── EMaskedClassNLLCriterion.lua ├── MaskedClassNLLCriterion.lua ├── NCEMaskedLoss.lua ├── Embedding.lua ├── NCE2.lua ├── NCE0.lua ├── NCE1.lua └── NCE.lua ├── utils ├── .DS_Store ├── linkedtable.lua ├── wordembedding.lua ├── xqueue.lua ├── alias_method.lua └── shortcut.lua ├── conllx_scripts ├── test.in ├── test.out ├── run_1st.sh ├── run_1st-t5.sh ├── run_shuffle1.sh ├── eval.lua ├── run_shuffle.sh ├── shuffle.lua └── conllx_eval.lua ├── scripts ├── dgedge.lua ├── run_msr_sort.sh ├── run_msr_sort_bid.sh ├── run_apw_sort2.sh ├── run_dep.sh ├── WE_norm.lua ├── run_msr_test.sh ├── run_dep_sort.sh ├── run_conllx_sort.sh ├── run_dep_sort20.sh ├── run_conllx_rnd_nosort.sh ├── run_conllx5_sort.sh ├── run_msr_test_bid.sh ├── run_conllx_sort_bid.sh ├── run_conllx_nnum_sort.sh ├── run.sh ├── dgvertex.lua ├── run_msr.sh ├── run_msr_bid.sh ├── test_sort_large.lua ├── run_apw_sort_all.sh ├── run_apw_sort.sh ├── run_dep_bid.sh ├── run_apw_sort_bid.sh ├── testsetdeptree2hdf5.lua ├── sort_large_hdf5.lua ├── WE_txt2torch.lua ├── sort_large_hdf5_bid.lua ├── conllxutils.lua ├── conllx2hdf5.lua ├── sorthdf5bid.lua ├── deptree2hdf5.lua ├── depgraph.lua ├── words2hdf5.lua └── sorthdf5.lua ├── dataset ├── LM_DatasetGPU.lua ├── testTreeLM_Dataset.lua ├── LM_Dataset0.lua ├── NCEDataGenerator.lua ├── NCEDataGenerator_lu.lua ├── LM_Dataset.lua ├── TreeLM_Dataset.lua ├── TreeLM_NCE_Dataset.lua ├── BidTreeLM_Dataset.lua └── BidTreeLM_NCE_Dataset.lua ├── init.lua ├── msr_scripts ├── bestof5.pl └── score.pl ├── LICENSE ├── nnets ├── basic.lua ├── MLP.lua ├── GPULSTMLM.lua └── LSTMLM.lua ├── experiments ├── depparse │ ├── treelstm_we_train.sh │ ├── ldtreelstm_train.sh │ ├── treelstm_we_rerank_test.sh │ ├── ldtreelstm_rerank_test.sh │ ├── treelstm_we_rerank_valid.sh │ ├── ldtreelstm_rerank_valid.sh │ └── gpu_lock.py └── msr │ ├── treelstm_h400.sh │ ├── ldtreelstm_h400.sh │ └── gpu_lock.py ├── OLD_VERSION.md ├── gpu_lock.py ├── train_mlp.lua ├── main.lua ├── rerank.lua ├── README.md └── main_nce.lua /layers/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XingxingZhang/td-treelstm/HEAD/layers/.DS_Store -------------------------------------------------------------------------------- /utils/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XingxingZhang/td-treelstm/HEAD/utils/.DS_Store -------------------------------------------------------------------------------- /conllx_scripts/test.in: -------------------------------------------------------------------------------- 1 | 1 1 1 1 1 2 | 1 1 1 1 1 3 | 4 | 2 2 2 2 2 5 | 2 2 2 6 | 2 2 7 | 8 | 3 3 9 | 3 3 3 10 | 3 11 | 3 12 | 13 | -------------------------------------------------------------------------------- /conllx_scripts/test.out: -------------------------------------------------------------------------------- 1 | 2 2 2 2 2 2 | 2 2 2 3 | 2 2 4 | 5 | 3 3 6 | 3 3 3 7 | 3 8 | 3 9 | 10 | 1 1 1 1 1 11 | 1 1 1 1 1 12 | 13 | -------------------------------------------------------------------------------- /scripts/dgedge.lua: -------------------------------------------------------------------------------- 1 | 2 | local DGEdge = torch.class('DGEdge') 3 | function DGEdge:__init(u, v, name) 4 | self.u = u 5 | self.v = v 6 | self.name = name 7 | end 8 | -------------------------------------------------------------------------------- /scripts/run_msr_sort.sh: -------------------------------------------------------------------------------- 1 | 2 | dataset=/disk/scratch1/xingxing.zhang/xTreeRNN/dataset/msr/msr.dep.100.h5 3 | th sort_large_hdf5.lua --dataset $dataset --sort -1 --batchSize 64 4 | -------------------------------------------------------------------------------- /scripts/run_msr_sort_bid.sh: -------------------------------------------------------------------------------- 1 | 2 | dataset=/disk/scratch/XingxingZhang/treelstm/dataset/msr/msr.dep.100.bid.h5 3 | th sort_large_hdf5_bid.lua --dataset $dataset --sort -1 --batchSize 64 --bidirectional 4 | 5 | -------------------------------------------------------------------------------- /utils/linkedtable.lua: -------------------------------------------------------------------------------- 1 | 2 | local LTable = torch.class('LinkedTable') 3 | 4 | function LTable:__init() 5 | self.table = {} 6 | self.keys = {} 7 | end 8 | 9 | function LTable:put(key, value) 10 | end 11 | 12 | 13 | -------------------------------------------------------------------------------- /conllx_scripts/run_1st.sh: -------------------------------------------------------------------------------- 1 | 2 | th eval.lua --sysFile /disk/scratch/s1270921/depparse/iornn-depparse-1st/tools/mstparser-2/experiment/dev-1-best-mst1storder.conll --goldFile /disk/scratch/s1270921/depparse/iornn-depparse-1st/data/valid 3 | -------------------------------------------------------------------------------- /scripts/run_apw_sort2.sh: -------------------------------------------------------------------------------- 1 | 2 | dataset=/disk/scratch1/xingxing.zhang/xTreeRNN/dataset/apw/ap.dep.h5 3 | dataset=/disk/scratch1/xingxing.zhang/xTreeRNN/dataset/apw/ap.dep.h5 4 | th sort_large_hdf5_bid.lua --dataset $dataset --sort -1 --batchSize 64 5 | 6 | 7 | -------------------------------------------------------------------------------- /conllx_scripts/run_1st-t5.sh: -------------------------------------------------------------------------------- 1 | 2 | sysFile=/disk/scratch/s1270921/depparse/iornn-depparse-1st-t5/tools/mstparser-2/experiment/dev-mst1storder.conll 3 | goldFile=/disk/scratch/s1270921/depparse/iornn-depparse-1st-t5/data/valid 4 | th eval.lua --sysFile $sysFile --goldFile $goldFile 5 | -------------------------------------------------------------------------------- /dataset/LM_DatasetGPU.lua: -------------------------------------------------------------------------------- 1 | 2 | require 'LM_Dataset' 3 | 4 | local LM_DatasetGPU, parent = torch.class('LM_DatasetGPU', 'LM_Dataset') 5 | 6 | function LM_DatasetGPU:__init(datasetPath, preFetchCount) 7 | parent.__init(self, datasetPath) 8 | preFetchCount = preFetchCount or 100 9 | end 10 | -------------------------------------------------------------------------------- /init.lua: -------------------------------------------------------------------------------- 1 | 2 | function importdir(dpath) 3 | package.path = dpath .. "/?.lua;" .. package.path 4 | end 5 | 6 | -- load all local packages 7 | for dir in paths.iterdirs(".") do 8 | -- print(dir) 9 | importdir(dir) 10 | end 11 | 12 | require 'torch' 13 | require 'nn' 14 | require 'nngraph' 15 | require 'optim' -------------------------------------------------------------------------------- /conllx_scripts/run_shuffle1.sh: -------------------------------------------------------------------------------- 1 | 2 | seed=1 3 | infile=/afs/inf.ed.ac.uk/group/project/img2txt/xTreeLSTM/xtreelstm/dataset/depparse/train.autopos 4 | outfile=/afs/inf.ed.ac.uk/group/project/img2txt/xTreeLSTM/xtreelstm/dataset/depparse/train.autopos.s$seed 5 | th shuffle.lua --inFile $infile --outFile $outfile --seed $seed 6 | 7 | -------------------------------------------------------------------------------- /scripts/run_dep.sh: -------------------------------------------------------------------------------- 1 | 2 | source ~/.profile 3 | 4 | train=../../dataset/penn_wsj.dep.train 5 | valid=../../dataset/penn_wsj.dep.valid 6 | test=../../dataset/penn_wsj.dep.test 7 | dataset=../../dataset/penn_wsj.dep.h5 8 | th deptree2hdf5.lua --train $train --valid $valid --test $test --dataset $dataset --freq 5 --keepFreq --maxLen 150 --ignoreCase 9 | 10 | -------------------------------------------------------------------------------- /scripts/WE_norm.lua: -------------------------------------------------------------------------------- 1 | 2 | local function normalize(infile, outfile) 3 | local embed, word2idx, idx2word = unpack( torch.load(infile) ) 4 | local N = embed:size(1) 5 | for i = 1, N do 6 | local norm = embed[i]:norm() 7 | embed[i]:div(norm) 8 | end 9 | 10 | torch.save(outfile, {embed, word2idx, idx2word}) 11 | end 12 | 13 | normalize(arg[1], arg[2]) 14 | -------------------------------------------------------------------------------- /scripts/run_msr_test.sh: -------------------------------------------------------------------------------- 1 | 2 | vocab=/disk/scratch1/xingxing.zhang/xTreeRNN/dataset/msr/msr.dep.100.sort20.vocab.t7 3 | test=/afs/inf.ed.ac.uk/group/project/img2txt/deptree_rnnlm/data/msr_sent_compl/question.dep 4 | testdataset=/disk/scratch1/xingxing.zhang/xTreeRNN/dataset/msr/msr.dep.100.question.h5 5 | th testsetdeptree2hdf5.lua --vocab $vocab --test $test --testdataset $testdataset 6 | -------------------------------------------------------------------------------- /scripts/run_dep_sort.sh: -------------------------------------------------------------------------------- 1 | 2 | source ~/.profile 3 | 4 | train=../../dataset/penn_wsj.dep.train 5 | valid=../../dataset/penn_wsj.dep.valid 6 | test=../../dataset/penn_wsj.dep.test 7 | dataset=../../dataset/penn_wsj.dep.h5 8 | th deptree2hdf5.lua --train $train --valid $valid --test $test --dataset $dataset --freq 5 --keepFreq --maxLen 150 --ignoreCase --sort -1 --batchSize 64 9 | 10 | -------------------------------------------------------------------------------- /scripts/run_conllx_sort.sh: -------------------------------------------------------------------------------- 1 | 2 | source ~/.profile 3 | 4 | train=../../dataset/depparse/train.autopos 5 | valid=../../dataset/depparse/valid.autopos 6 | test=../../dataset/depparse/test.autopos 7 | dataset=../../dataset/penn_wsj.conllx.h5 8 | th conllx2hdf5.lua --train $train --valid $valid --test $test --dataset $dataset --freq 1 --keepFreq --maxLen 150 --sort -1 --batchSize 64 9 | 10 | -------------------------------------------------------------------------------- /scripts/run_dep_sort20.sh: -------------------------------------------------------------------------------- 1 | 2 | source ~/.profile 3 | 4 | train=../../dataset/penn_wsj.dep.train 5 | valid=../../dataset/penn_wsj.dep.valid 6 | test=../../dataset/penn_wsj.dep.test 7 | dataset=../../dataset/penn_wsj.dep.h5 8 | th deptree2hdf5.lua --train $train --valid $valid --test $test --dataset $dataset --freq 5 --keepFreq --maxLen 150 --ignoreCase --sort 20 --batchSize 64 9 | 10 | -------------------------------------------------------------------------------- /scripts/run_conllx_rnd_nosort.sh: -------------------------------------------------------------------------------- 1 | 2 | source ~/.profile 3 | 4 | train=../../dataset/depparse/train.autopos.s1 5 | valid=../../dataset/depparse/valid.autopos 6 | test=../../dataset/depparse/test.autopos 7 | dataset=../../dataset/penn_wsj.conllx.s1.h5 8 | th conllx2hdf5.lua --train $train --valid $valid --test $test --dataset $dataset --freq 1 --keepFreq --maxLen 150 --sort 0 --batchSize 64 9 | 10 | -------------------------------------------------------------------------------- /scripts/run_conllx5_sort.sh: -------------------------------------------------------------------------------- 1 | 2 | source ~/.profile 3 | 4 | train=../../dataset/depparse/train.autopos 5 | valid=../../dataset/depparse/valid.autopos 6 | test=../../dataset/depparse/test.autopos 7 | dataset=../../dataset/penn_wsj.conllx5.h5 8 | th conllx2hdf5.lua --train $train --valid $valid --test $test --dataset $dataset --freq 5 --keepFreq --ignoreCase --maxLen 150 --sort -1 --batchSize 64 9 | 10 | -------------------------------------------------------------------------------- /scripts/run_msr_test_bid.sh: -------------------------------------------------------------------------------- 1 | 2 | vocab=/disk/scratch/XingxingZhang/treelstm/dataset/msr/msr.dep.100.bid.sort20.vocab.t7 3 | test=/afs/inf.ed.ac.uk/group/project/img2txt/deptree_rnnlm/data/msr_sent_compl/question.dep 4 | testdataset=/disk/scratch/XingxingZhang/treelstm/dataset/msr/msr.dep.100.bid.question.h5 5 | th testsetdeptree2hdf5.lua --vocab $vocab --test $test --testdataset $testdataset --bidirectional 6 | 7 | -------------------------------------------------------------------------------- /scripts/run_conllx_sort_bid.sh: -------------------------------------------------------------------------------- 1 | 2 | source ~/.profile 3 | 4 | train=../../dataset/depparse/train.autopos 5 | valid=../../dataset/depparse/valid.autopos 6 | test=../../dataset/depparse/test.autopos 7 | dataset=../../dataset/penn_wsj.conllx.bid.h5 8 | th conllx2hdf5.lua --train $train --valid $valid --test $test --dataset $dataset --freq 1 --keepFreq --maxLen 150 --sort -1 --batchSize 64 --bidirectional 9 | 10 | -------------------------------------------------------------------------------- /scripts/run_conllx_nnum_sort.sh: -------------------------------------------------------------------------------- 1 | 2 | source ~/.profile 3 | 4 | train=../../dataset/depparse/train.autopos 5 | valid=../../dataset/depparse/valid.autopos 6 | test=../../dataset/depparse/test.autopos 7 | dataset=../../dataset/penn_wsj.conllx.nnum.h5 8 | th conllx2hdf5.lua --train $train --valid $valid --test $test --dataset $dataset --freq 1 --keepFreq --maxLen 150 --sort -1 --batchSize 64 --normalizeNumber 9 | 10 | -------------------------------------------------------------------------------- /dataset/testTreeLM_Dataset.lua: -------------------------------------------------------------------------------- 1 | 2 | require 'TreeLM_Dataset' 3 | 4 | local function main() 5 | local treelmData = TreeLM_Dataset('/disk/scratch1/xingxing.zhang/xTreeRNN/dataset/msr/msr.dep.100.h5') 6 | local label, batchSize = 'train', 64 7 | for x, y in treelmData:createBatch(label, batchSize) do 8 | print('x = ') 9 | print(x) 10 | print('y = ') 11 | print(y) 12 | end 13 | end 14 | 15 | main() 16 | -------------------------------------------------------------------------------- /msr_scripts/bestof5.pl: -------------------------------------------------------------------------------- 1 | #!/usr/bin/perl -w 2 | 3 | $count = 0; 4 | $best = ""; 5 | $bestlp = -1000000; 6 | 7 | while (<>) { 8 | s/\s+[\r\n]+$//o; 9 | @fields = split(/\t/); 10 | if ($fields[1] > $bestlp) { 11 | $bestlp = $fields[1]; 12 | $best = $fields[0]; 13 | } 14 | 15 | if ((++$count % 5) == 0) { 16 | print "$best\n"; 17 | $bestlp = -1000000; 18 | } 19 | } 20 | -------------------------------------------------------------------------------- /scripts/run.sh: -------------------------------------------------------------------------------- 1 | 2 | source ~/.profile 3 | 4 | train=../../dataset/penn_wsj.train 5 | valid=../../dataset/penn_wsj.valid 6 | test=../../dataset/penn_wsj.test 7 | dataset=../../dataset/penn_wsj.h5 8 | # th words2hdf5.lua --train $train --valid $valid --test $test --dataset $dataset --freq 0 --ignorecase --keepfreq --maxlen 1000 9 | th words2hdf5.lua --train $train --valid $valid --test $test --dataset $dataset --freq 0 --keepfreq --maxlen 100 10 | 11 | -------------------------------------------------------------------------------- /scripts/dgvertex.lua: -------------------------------------------------------------------------------- 1 | 2 | local DGVertex = torch.class('DGVertex') 3 | function DGVertex:__init() 4 | self.v = -1 5 | self.tok = nil 6 | self.adjList = {} 7 | self.leftChildren = {} 8 | self.rightChildren = {} 9 | self.dependencyVertex = {} 10 | self.bfsID = -1 11 | self.action = -1 12 | 13 | -- this is for bidirectional model 14 | self.leftCxtPos = 0 15 | end 16 | 17 | function DGVertex:isEmpty() 18 | return self.v == -1 and self.tok == nil 19 | end 20 | -------------------------------------------------------------------------------- /scripts/run_msr.sh: -------------------------------------------------------------------------------- 1 | 2 | source ~/.profile 3 | 4 | datadir=/afs/inf.ed.ac.uk/group/project/img2txt/deptree_rnnlm/data/msr_sent_compl 5 | train=$datadir/msr.dep.100.train 6 | valid=$datadir/msr.dep.100.valid 7 | test=$datadir/msr.dep.100.test 8 | dataset=/disk/scratch1/xingxing.zhang/xTreeRNN/dataset/msr/msr.dep.100.h5 9 | th deptree2hdf5.lua --train $train --valid $valid --test $test --dataset $dataset --freq 5 --keepFreq --maxLen 100 --ignoreCase --sort 20 --batchSize 64 10 | 11 | -------------------------------------------------------------------------------- /scripts/run_msr_bid.sh: -------------------------------------------------------------------------------- 1 | 2 | source ~/.profile 3 | 4 | datadir=/afs/inf.ed.ac.uk/group/project/img2txt/deptree_rnnlm/data/msr_sent_compl 5 | train=$datadir/msr.dep.100.train 6 | valid=$datadir/msr.dep.100.valid 7 | test=$datadir/msr.dep.100.test 8 | dataset=/disk/scratch/XingxingZhang/treelstm/dataset/msr/msr.dep.100.bid.h5 9 | th deptree2hdf5.lua --train $train --valid $valid --test $test --dataset $dataset --freq 5 --keepFreq --maxLen 100 --ignoreCase --sort 20 --batchSize 64 --bidirectional 10 | 11 | -------------------------------------------------------------------------------- /scripts/test_sort_large.lua: -------------------------------------------------------------------------------- 1 | 2 | require 'hdf5' 3 | require 'torch' 4 | include '../utils/shortcut.lua' 5 | require 'deptreeutils' 6 | 7 | require 'sorthdf5' 8 | 9 | local opts = {} 10 | opts.dataset = '/disk/scratch1/xingxing.zhang/xTreeRNN/dataset/msr/msr.dep.100.h5' 11 | opts.sort = 20 12 | opts.batchSize = 64 13 | local h5sortFile = '/disk/scratch1/xingxing.zhang/xTreeRNN/dataset/msr/msr.dep.100.sort20.h5' 14 | local h5sorter = SortHDF5(opts.dataset, h5sortFile) 15 | h5sorter:sortHDF5(opts.sort, opts.batchSize) -------------------------------------------------------------------------------- /conllx_scripts/eval.lua: -------------------------------------------------------------------------------- 1 | 2 | local conllx_eval = require('conllx_eval') 3 | 4 | local function getOpts() 5 | local cmd = torch.CmdLine() 6 | cmd:text('====== Evaluation Script for Dependency Parser ======') 7 | cmd:option('--sysFile', '', 'system output') 8 | cmd:option('--goldFile', '', 'gold standard') 9 | 10 | return cmd:parse(arg) 11 | end 12 | 13 | local function main() 14 | local opts = getOpts() 15 | local conllx_eval = require('conllx_eval') 16 | conllx_eval.eval(opts.sysFile, opts.goldFile) 17 | end 18 | 19 | main() 20 | -------------------------------------------------------------------------------- /scripts/run_apw_sort_all.sh: -------------------------------------------------------------------------------- 1 | 2 | source ~/.profile 3 | 4 | train=/disk/scratch1/xingxing.zhang/xTreeRNN/dataset/apw/ap.dep.train 5 | valid=/disk/scratch1/xingxing.zhang/xTreeRNN/dataset/apw/ap.dep.valid 6 | test=/disk/scratch1/xingxing.zhang/xTreeRNN/dataset/apw/ap.dep.test 7 | dataset=/disk/scratch1/xingxing.zhang/xTreeRNN/dataset/apw/ap.dep.h5 8 | th deptree2hdf5.lua --train $train --valid $valid --test $test --dataset $dataset --freq 20 --keepFreq --maxLen 100 --ignoreCase --sort 0 --batchSize 64 9 | 10 | # due to a strange bug; we must sort the h5 file with an individual program 11 | th sort_large_hdf5_bid.lua --dataset $dataset --sort -1 --batchSize 64 12 | 13 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright 2015-Present University of Edinburgh (author: Xingxing Zhang) 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | 15 | -------------------------------------------------------------------------------- /scripts/run_apw_sort.sh: -------------------------------------------------------------------------------- 1 | 2 | source ~/.profile 3 | 4 | train=../../dataset/penn_wsj.dep.train 5 | valid=../../dataset/penn_wsj.dep.valid 6 | test=../../dataset/penn_wsj.dep.test 7 | dataset=../../dataset/penn_wsj.dep.h5 8 | 9 | train=/disk/scratch1/xingxing.zhang/xTreeRNN/dataset/apw/ap.dep.train 10 | valid=/disk/scratch1/xingxing.zhang/xTreeRNN/dataset/apw/ap.dep.valid 11 | test=/disk/scratch1/xingxing.zhang/xTreeRNN/dataset/apw/ap.dep.test 12 | dataset=/disk/scratch1/xingxing.zhang/xTreeRNN/dataset/apw/ap.dep.h5 13 | th deptree2hdf5.lua --train $train --valid $valid --test $test --dataset $dataset --freq 20 --keepFreq --maxLen 100 --ignoreCase --sort 0 --batchSize 64 14 | 15 | -------------------------------------------------------------------------------- /scripts/run_dep_bid.sh: -------------------------------------------------------------------------------- 1 | 2 | source ~/.profile 3 | 4 | datadir=/afs/inf.ed.ac.uk/group/project/img2txt/deptree_rnnlm/data/msr_sent_compl 5 | train=$datadir/msr.dep.100.train 6 | valid=$datadir/msr.dep.100.valid 7 | test=$datadir/msr.dep.100.test 8 | dataset=/disk/scratch/XingxingZhang/treelstm/dataset/msr/msr.dep.100.bid.h5 9 | 10 | train=../../dataset/penn_wsj.dep.train 11 | valid=../../dataset/penn_wsj.dep.valid 12 | test=../../dataset/penn_wsj.dep.test 13 | dataset=../../dataset/penn_wsj.dep.bid.h5 14 | th deptree2hdf5.lua --train $train --valid $valid --test $test --dataset $dataset --freq 5 --keepFreq --maxLen 150 --ignoreCase --sort -1 --batchSize 64 --bidirectional 15 | 16 | 17 | -------------------------------------------------------------------------------- /scripts/run_apw_sort_bid.sh: -------------------------------------------------------------------------------- 1 | 2 | source ~/.profile 3 | 4 | train=/disk/scratch1/xingxing.zhang/xTreeRNN/dataset/apw/ap.dep.train 5 | valid=/disk/scratch1/xingxing.zhang/xTreeRNN/dataset/apw/ap.dep.valid 6 | test=/disk/scratch1/xingxing.zhang/xTreeRNN/dataset/apw/ap.dep.test 7 | dataset=/disk/scratch1/xingxing.zhang/xTreeRNN/dataset/apw/ap.dep.bid.h5 8 | th deptree2hdf5.lua --train $train --valid $valid --test $test --dataset $dataset --freq 20 --keepFreq --maxLen 100 --ignoreCase --sort 0 --batchSize 64 --bidirectional 9 | 10 | # due to a strange bug; we must sort the h5 file with an individual program 11 | th sort_large_hdf5_bid.lua --dataset $dataset --sort -1 --batchSize 64 --bidirectional 12 | 13 | -------------------------------------------------------------------------------- /conllx_scripts/run_shuffle.sh: -------------------------------------------------------------------------------- 1 | 2 | seed=1 3 | infile=/disk/scratch/s1270921/depparse/iornn-depparse-1st-t5-1/data/train.autopos 4 | outfile=/disk/scratch/s1270921/depparse/iornn-depparse-1st-t5-1/data/train.autopos.s$seed 5 | th shuffle.lua --inFile $infile --outFile $outfile --seed $seed 6 | 7 | seed=2 8 | infile=/disk/scratch/s1270921/depparse/iornn-depparse-1st-t5-2/data/train.autopos 9 | outfile=/disk/scratch/s1270921/depparse/iornn-depparse-1st-t5-2/data/train.autopos.s$seed 10 | th shuffle.lua --inFile $infile --outFile $outfile --seed $seed 11 | 12 | seed=3 13 | infile=/disk/scratch/s1270921/depparse/iornn-depparse-1st-t5-3/data/train.autopos 14 | outfile=/disk/scratch/s1270921/depparse/iornn-depparse-1st-t5-3/data/train.autopos.s$seed 15 | th shuffle.lua --inFile $infile --outFile $outfile --seed $seed 16 | 17 | -------------------------------------------------------------------------------- /nnets/basic.lua: -------------------------------------------------------------------------------- 1 | 2 | require 'torch' 3 | 4 | local BModel = torch.class('BModel') 5 | 6 | function BModel:__init() 7 | self.name = 'Basic Model, name needed!' 8 | end 9 | 10 | function BModel:save(modelPath, saveOpts) 11 | if not modelPath:ends('.t7') then 12 | modelPath = modelPath .. '.t7' 13 | end 14 | 15 | if self.params:type() == 'torch.CudaTensor' then 16 | torch.save(modelPath, self.params:float()) 17 | else 18 | torch.save(modelPath, self.params) 19 | end 20 | 21 | if saveOpts then 22 | local optPath = modelPath:sub(1, -4) .. '.state.t7' 23 | torch.save(optPath, self.opts) 24 | end 25 | end 26 | 27 | function BModel:load(modelPath) 28 | self.params:copy( torch.load(modelPath) ) 29 | end 30 | 31 | function BModel:setModel(params) 32 | self.params:copy(params) 33 | end 34 | 35 | function BModel:getModel(outModel) 36 | return outModel:copy(self.params) 37 | end 38 | 39 | function BModel:print(msg) 40 | if msg == nil then 41 | xprint('the model is [%s]\n', self.name) 42 | else 43 | xprintln('[%s] %s', self.name, msg) 44 | end 45 | end 46 | 47 | 48 | -------------------------------------------------------------------------------- /experiments/depparse/treelstm_we_train.sh: -------------------------------------------------------------------------------- 1 | 2 | ID=`./gpu_lock.py --id-to-hog 1` 3 | echo $ID 4 | if [ $ID -eq -1 ]; then 5 | echo "no gpu is free" 6 | exit 7 | fi 8 | ./gpu_lock.py 9 | 10 | curdir=`pwd` 11 | codedir=/afs/inf.ed.ac.uk/group/project/img2txt/xTreeLSTM/xtreelstm/td-treelstm-release 12 | dataset=/disk/scratch/XingxingZhang/treelstm/dataset/depparse/dataset/penn_wsj.conllx.sort.h5 13 | wembed=/disk/scratch/XingxingZhang/treelstm/dataset/res/glove/glove.6B.100d.t7 14 | lr=1.0 15 | label=.w200 16 | model=model_$lr$label.t7 17 | log=$lr$label.txt 18 | 19 | cd $codedir 20 | 21 | CUDA_VISIBLE_DEVICES=$ID th $codedir/main_nce.lua \ 22 | --model TreeLSTM \ 23 | --dataset $dataset \ 24 | --wordEmbedding $wembed \ 25 | --useGPU \ 26 | --nin 100 \ 27 | --nhid 200 \ 28 | --nlayers 2 \ 29 | --lr $lr \ 30 | --batchSize 64 \ 31 | --maxEpoch 50 \ 32 | --save $curdir/$model \ 33 | --gradClip 5 \ 34 | --optimMethod SGD \ 35 | --patience 1 \ 36 | --dropout 0.2 \ 37 | | tee $curdir/$log 38 | 39 | cd $curdir 40 | 41 | ./gpu_lock.py --free $ID 42 | ./gpu_lock.py 43 | 44 | -------------------------------------------------------------------------------- /experiments/depparse/ldtreelstm_train.sh: -------------------------------------------------------------------------------- 1 | 2 | ID=`./gpu_lock.py --id-to-hog 1` 3 | echo $ID 4 | if [ $ID -eq -1 ]; then 5 | echo "no gpu is free" 6 | exit 7 | fi 8 | ./gpu_lock.py 9 | 10 | curdir=`pwd` 11 | codedir=/afs/inf.ed.ac.uk/group/project/img2txt/xTreeLSTM/xtreelstm/td-treelstm-release 12 | dataset=/disk/scratch/XingxingZhang/treelstm/dataset/depparse/dataset/penn_wsj.conllx.bid.sort.h5 13 | wembed=/disk/scratch/XingxingZhang/treelstm/dataset/res/glove/glove.6B.100d.t7 14 | lr=1.0 15 | label=.200.lc200 16 | model=model_$lr$label.t7 17 | log=$lr$label.txt 18 | 19 | cd $codedir 20 | 21 | CUDA_VISIBLE_DEVICES=$ID th $codedir/main_bid.lua \ 22 | --model BiTreeLSTM \ 23 | --dataset $dataset \ 24 | --useGPU \ 25 | --nin 100 \ 26 | --nhid 200 \ 27 | --nlayers 2 \ 28 | --lr $lr \ 29 | --batchSize 64 \ 30 | --maxEpoch 50 \ 31 | --save $curdir/$model \ 32 | --gradClip 5 \ 33 | --optimMethod SGD \ 34 | --patience 1 \ 35 | --nlclayers 1 \ 36 | --nlchid 200 \ 37 | --dropout 0.2 \ 38 | | tee $curdir/$log 39 | 40 | cd $curdir 41 | 42 | ./gpu_lock.py --free $ID 43 | ./gpu_lock.py 44 | 45 | -------------------------------------------------------------------------------- /OLD_VERSION.md: -------------------------------------------------------------------------------- 1 | 2 | ### Using Old Version Torch 3 | Torch is under activate development. Unfortunately, some APIs in later versions are not compatible with these in the earlier versions, which is terrible!!! In order to run the code, 4 | you need to go back to the version around 2015-07-22 by running the following commands: 5 | ``` 6 | # create a directory for torch source code 7 | mkdir src 8 | cd src 9 | git clone https://github.com/torch/torch7.git 10 | git clone https://github.com/torch/nn.git 11 | git clone https://github.com/torch/cutorch.git 12 | git clone https://github.com/torch/nngraph.git 13 | git clone https://github.com/torch/cunn.git 14 | 15 | cd torch7/ 16 | git checkout 80a545e 17 | luarocks make rocks/torch-scm-1.rockspec 18 | cd .. 19 | 20 | cd nn 21 | git checkout c503fb8 22 | luarocks make rocks/nn-scm-1.rockspec 23 | cd .. 24 | 25 | cd cutorch 26 | git checkout 2eddb66 27 | luarocks make rocks/cutorch-scm-1.rockspec 28 | cd .. 29 | 30 | cd cunn 31 | git checkout 4f66456 32 | luarocks make rocks/cunn-scm-1.rockspec 33 | cd .. 34 | 35 | cd nngraph 36 | git checkout 1c43c98 37 | luarocks make nngraph-scm-1.rockspec 38 | cd .. 39 | ``` 40 | 41 | [README.md](README.md) -------------------------------------------------------------------------------- /scripts/testsetdeptree2hdf5.lua: -------------------------------------------------------------------------------- 1 | 2 | require 'hdf5' 3 | require 'torch' 4 | include '../utils/shortcut.lua' 5 | require 'deptreeutils' 6 | 7 | local function getOpts() 8 | local cmd = torch.CmdLine() 9 | cmd:text('== convert testset dependency trees to hdf5 format ==') 10 | cmd:text() 11 | cmd:text('Options:') 12 | cmd:option('--vocab', '', 'vocabulary file (.t7)') 13 | cmd:option('--test', '', 'test text file (plain text)') 14 | cmd:option('--testdataset', '', 'the output file of testset (.h5)') 15 | 16 | cmd:option('--bidirectional', false, 'create bidirectional model') 17 | 18 | return cmd:parse(arg) 19 | end 20 | 21 | local function main() 22 | local opts = getOpts() 23 | print(opts) 24 | 25 | print('load vocab ...') 26 | local vocab = torch.load(opts.vocab) 27 | print('load vocab done!') 28 | local h5out = hdf5.open(opts.testdataset, 'w') 29 | if opts.bidirectional then 30 | DepTreeUtils.deptree2hdf5Bidirectional(opts.test, h5out, 'test', vocab, 123456789) 31 | else 32 | DepTreeUtils.deptree2hdf5(opts.test, h5out, 'test', vocab, 123456789) 33 | end 34 | print('create testset done!') 35 | h5out:close() 36 | end 37 | 38 | main() 39 | -------------------------------------------------------------------------------- /utils/wordembedding.lua: -------------------------------------------------------------------------------- 1 | 2 | require 'shortcut' 3 | 4 | local WordEmbed = torch.class('WordEmbedding') 5 | 6 | -- input should be torch7 file 7 | function WordEmbed:__init(embedFile) 8 | self.embed, self.word2idx, self.idx2word = unpack( torch.load(embedFile) ) 9 | xprintln('load embedding done!') 10 | self.lowerCase = true 11 | for _, word in ipairs(self.idx2word) do 12 | if word ~= word:lower() then 13 | self.lowerCase = false 14 | end 15 | end 16 | print('lower case: ') 17 | print(self.lowerCase) 18 | end 19 | 20 | function WordEmbed:releaseMemory() 21 | self.embed = nil 22 | self.word2idx = nil 23 | self.idx2word = nil 24 | collectgarbage() 25 | end 26 | 27 | function WordEmbed:initMat(mat, vocab) 28 | assert(mat:size(2) == self.embed:size(2)) 29 | local idx2word = vocab.idx2word 30 | local nvocab = #idx2word 31 | local cnt = 0 32 | for wid = 1, nvocab do 33 | local word = idx2word[wid] 34 | word = self.lowerCase and word:lower() or word 35 | local wid_ = self.word2idx[word] 36 | if wid_ ~= nil then 37 | mat[wid] = self.embed[wid_] 38 | cnt = cnt + 1 39 | end 40 | end 41 | print(string.format('word embedding coverage: %d / %d = %f', cnt, nvocab, cnt / nvocab)) 42 | end 43 | 44 | -------------------------------------------------------------------------------- /layers/EMaskedClassNLLCriterion.lua: -------------------------------------------------------------------------------- 1 | 2 | --[[ 3 | -- when y contains zeros 4 | --]] 5 | 6 | local EMaskedClassNLLCriterion, parent = torch.class('EMaskedClassNLLCriterion', 'nn.Module') 7 | 8 | function EMaskedClassNLLCriterion:__init() 9 | parent.__init(self) 10 | end 11 | 12 | function EMaskedClassNLLCriterion:updateOutput(input_) 13 | local input, target, div = unpack(input_) 14 | if input:dim() == 2 then 15 | local nll = 0 16 | local n = target:size(1) 17 | for i = 1, n do 18 | if target[i] ~= 0 then 19 | nll = nll - input[i][target[i]] 20 | end 21 | end 22 | self.output = nll / div 23 | return self.output 24 | else 25 | error('input must be matrix! Note only batch mode is supported!') 26 | end 27 | end 28 | 29 | function EMaskedClassNLLCriterion:updateGradInput(input_) 30 | local input, target, div = unpack(input_) 31 | -- print('self.gradInput', self.gradInput) 32 | self.gradInput:resizeAs(input) 33 | self.gradInput:zero() 34 | local er = -1 / div 35 | if input:dim() == 2 then 36 | local n = target:size(1) 37 | for i = 1, n do 38 | if target[i] ~= 0 then 39 | self.gradInput[i][target[i]] = er 40 | end 41 | end 42 | return self.gradInput 43 | else 44 | error('input must be matrix! Note only batch mode is supported!') 45 | end 46 | end 47 | 48 | -------------------------------------------------------------------------------- /layers/MaskedClassNLLCriterion.lua: -------------------------------------------------------------------------------- 1 | 2 | --[[ 3 | -- when y contains zeros 4 | --]] 5 | 6 | local MaskedClassNLLCriterion, parent = torch.class('MaskedClassNLLCriterion', 'nn.Module') 7 | 8 | function MaskedClassNLLCriterion:__init() 9 | parent.__init(self) 10 | end 11 | 12 | function MaskedClassNLLCriterion:updateOutput(input_) 13 | local input, target = unpack(input_) 14 | if input:dim() == 2 then 15 | local nll = 0 16 | local n = target:size(1) 17 | for i = 1, n do 18 | if target[i] ~= 0 then 19 | nll = nll - input[i][target[i]] 20 | end 21 | end 22 | self.output = nll / target:size(1) 23 | return self.output 24 | else 25 | error('input must be matrix! Note only batch mode is supported!') 26 | end 27 | end 28 | 29 | function MaskedClassNLLCriterion:updateGradInput(input_) 30 | local input, target = unpack(input_) 31 | -- print('self.gradInput', self.gradInput) 32 | self.gradInput:resizeAs(input) 33 | self.gradInput:zero() 34 | local er = -1 / target:size(1) 35 | if input:dim() == 2 then 36 | local n = target:size(1) 37 | for i = 1, n do 38 | if target[i] ~= 0 then 39 | self.gradInput[i][target[i]] = er 40 | end 41 | end 42 | return self.gradInput 43 | else 44 | error('input must be matrix! Note only batch mode is supported!') 45 | end 46 | end 47 | 48 | -------------------------------------------------------------------------------- /utils/xqueue.lua: -------------------------------------------------------------------------------- 1 | 2 | local XQueue = torch.class('XQueue') 3 | 4 | function XQueue:__init(maxn) 5 | self.MAXN = (maxn or 4294967294) + 1 -- 2^32 - 1, max unsigned int32 in other languages, should be big enough 6 | self.front = 0 7 | self.rear = 0 8 | self.Q = {} 9 | end 10 | 11 | function XQueue:push(v) 12 | local nextRear = self:nextPos(self.rear) 13 | if nextRear == self.front then error('queue is full!!!') end 14 | self.rear = nextRear 15 | self.Q[nextRear] = v 16 | end 17 | 18 | function XQueue:pop() 19 | self.front = self:nextPos(self.front) 20 | local rval = self.Q[self.front] 21 | self.Q[self.front] = nil 22 | return rval 23 | end 24 | 25 | function XQueue:top() 26 | return self.Q[self:nextPos(self.front)] 27 | end 28 | 29 | function XQueue:isEmpty() 30 | return self.front == self.rear 31 | end 32 | 33 | function XQueue:nextPos(x) 34 | local pos = x + 1 35 | return pos == self.MAXN and 0 or pos 36 | end 37 | 38 | function XQueue:isFull() 39 | return self:nextPos(self.rear) == self.front 40 | end 41 | 42 | function XQueue:printAll() 43 | if not self:isEmpty() then 44 | print '==queue elements==' 45 | local i = self:nextPos(self.front) 46 | while true do 47 | print(self.Q[i]) 48 | if i == self.rear then break end 49 | i = self:nextPos(i) 50 | end 51 | print '==queue elements end==' 52 | end 53 | end 54 | 55 | 56 | 57 | -------------------------------------------------------------------------------- /experiments/depparse/treelstm_we_rerank_test.sh: -------------------------------------------------------------------------------- 1 | 2 | ID=`./gpu_lock.py --id-to-hog 1` 3 | echo $ID 4 | if [ $ID -eq -1 ]; then 5 | echo "no gpu is free" 6 | exit 7 | fi 8 | ./gpu_lock.py 9 | 10 | curdir=`pwd` 11 | codedir=/afs/inf.ed.ac.uk/group/project/img2txt/xTreeLSTM/xtreelstm/td-treelstm-release 12 | 13 | vocab=/disk/scratch/XingxingZhang/treelstm/dataset/depparse/dataset/penn_wsj.conllx.sort.vocab.t7 14 | basefile=/disk/scratch/XingxingZhang/treelstm/experiments/ptb_depparse/rerank_data/test-20-best-mst2ndorder.conll.g 15 | model=/disk/scratch/XingxingZhang/treelstm/experiments/ptb_depparse/2layer_w_wo_we/model_1.0.w200.t7 16 | label=.w200.2l.test 17 | log=log$label.txt 18 | scorefile=out$label.txt 19 | 20 | goldfile=/disk/scratch/XingxingZhang/treelstm/experiments/ptb_depparse/rerank_data/test 21 | basescorefile=/disk/scratch/XingxingZhang/treelstm/experiments/ptb_depparse/rerank_data/test-20-best-mst2ndorder.conll.mstscores 22 | 23 | cd $codedir 24 | 25 | CUDA_VISIBLE_DEVICES=$ID th $codedir/dep_rerank.lua \ 26 | --vocab $vocab \ 27 | --baseFile $basefile \ 28 | --modelPath $model \ 29 | --useGPU \ 30 | --batchSize 64 \ 31 | --scoreFile $curdir/$scorefile \ 32 | --baseScoreFile $basescorefile \ 33 | --goldFile $goldfile \ 34 | --k 4 \ 35 | --standard stanford \ 36 | | tee $curdir/$log 37 | 38 | cd $curdir 39 | 40 | ./gpu_lock.py --free $ID 41 | ./gpu_lock.py 42 | 43 | -------------------------------------------------------------------------------- /experiments/depparse/ldtreelstm_rerank_test.sh: -------------------------------------------------------------------------------- 1 | 2 | ID=`./gpu_lock.py --id-to-hog 1` 3 | echo $ID 4 | if [ $ID -eq -1 ]; then 5 | echo "no gpu is free" 6 | exit 7 | fi 8 | ./gpu_lock.py 9 | 10 | curdir=`pwd` 11 | codedir=/afs/inf.ed.ac.uk/group/project/img2txt/xTreeLSTM/xtreelstm/td-treelstm-release 12 | 13 | vocab=/disk/scratch/XingxingZhang/treelstm/dataset/depparse/dataset/penn_wsj.conllx.bid.sort.vocab.t7 14 | basefile=/disk/scratch/XingxingZhang/treelstm/experiments/ptb_depparse/rerank_data/test-20-best-mst2ndorder.conll.g 15 | model=/disk/scratch/XingxingZhang/treelstm/experiments/ptb_depparse/bid_flush_best/model_1.0.200.lc200.t7 16 | label=.200.2l.lc200.test 17 | log=log$label.txt 18 | scorefile=out$label.txt 19 | 20 | goldfile=/disk/scratch/XingxingZhang/treelstm/experiments/ptb_depparse/rerank_data/test 21 | basescorefile=/disk/scratch/XingxingZhang/treelstm/experiments/ptb_depparse/rerank_data/test-20-best-mst2ndorder.conll.mstscores 22 | 23 | cd $codedir 24 | 25 | CUDA_VISIBLE_DEVICES=$ID th $codedir/dep_rerank.lua \ 26 | --vocab $vocab \ 27 | --baseFile $basefile \ 28 | --modelPath $model \ 29 | --useGPU \ 30 | --batchSize 64 \ 31 | --scoreFile $curdir/$scorefile \ 32 | --baseScoreFile $basescorefile \ 33 | --goldFile $goldfile \ 34 | --k 4 \ 35 | --standard stanford \ 36 | | tee $curdir/$log 37 | 38 | cd $curdir 39 | 40 | ./gpu_lock.py --free $ID 41 | ./gpu_lock.py 42 | 43 | -------------------------------------------------------------------------------- /experiments/depparse/treelstm_we_rerank_valid.sh: -------------------------------------------------------------------------------- 1 | 2 | ID=`./gpu_lock.py --id-to-hog 1` 3 | echo $ID 4 | if [ $ID -eq -1 ]; then 5 | echo "no gpu is free" 6 | exit 7 | fi 8 | ./gpu_lock.py 9 | 10 | curdir=`pwd` 11 | codedir=/afs/inf.ed.ac.uk/group/project/img2txt/xTreeLSTM/xtreelstm/td-treelstm-release 12 | 13 | vocab=/disk/scratch/XingxingZhang/treelstm/dataset/depparse/dataset/penn_wsj.conllx.sort.vocab.t7 14 | basefile=/disk/scratch/XingxingZhang/treelstm/experiments/ptb_depparse/rerank_data/dev-20-best-mst2ndorder.conll.g 15 | model=/disk/scratch/XingxingZhang/treelstm/experiments/ptb_depparse/2layer_w_wo_we/model_1.0.w200.t7 16 | label=.w200.2l 17 | log=log$label.txt 18 | scorefile=out$label.txt 19 | 20 | goldfile=/disk/scratch/XingxingZhang/treelstm/experiments/ptb_depparse/rerank_data/valid 21 | basescorefile=/disk/scratch/XingxingZhang/treelstm/experiments/ptb_depparse/rerank_data/dev-20-best-mst2ndorder.conll.mstscores 22 | 23 | cd $codedir 24 | 25 | CUDA_VISIBLE_DEVICES=$ID th $codedir/dep_rerank.lua \ 26 | --vocab $vocab \ 27 | --baseFile $basefile \ 28 | --modelPath $model \ 29 | --useGPU \ 30 | --batchSize 64 \ 31 | --scoreFile $curdir/$scorefile \ 32 | --baseScoreFile $basescorefile \ 33 | --goldFile $goldfile \ 34 | --k 20 \ 35 | --searchk \ 36 | --standard stanford \ 37 | | tee $curdir/$log 38 | 39 | cd $curdir 40 | 41 | ./gpu_lock.py --free $ID 42 | ./gpu_lock.py 43 | 44 | -------------------------------------------------------------------------------- /layers/NCEMaskedLoss.lua: -------------------------------------------------------------------------------- 1 | 2 | --[[ 3 | This is used to compute the loss function of NCE 4 | --]] 5 | 6 | local NCEMaskedLoss, parent = torch.class('NCEMaskedLoss', 'nn.Module') 7 | 8 | function NCEMaskedLoss:__init() 9 | parent.__init(self) 10 | end 11 | 12 | function NCEMaskedLoss:updateOutput(input_) 13 | local input, mask, div = unpack(input_) 14 | self.output = -( input:view(-1) * mask:view(-1) ) / div 15 | 16 | return self.output 17 | end 18 | 19 | function NCEMaskedLoss:updateGradInput(input_) 20 | local input, mask, div = unpack(input_) 21 | 22 | --[[ 23 | print('size of input') 24 | print(input:size()) 25 | 26 | print('size of self.gradInput') 27 | print(self.gradInput:size()) 28 | --]] 29 | 30 | self.gradInput:resizeAs(input) 31 | self.gradInput:zero() 32 | -- print(self.gradInput) 33 | 34 | local mask_ 35 | if mask:dim() == 2 then 36 | mask_ = mask:view(mask:size(1) * mask:size(2), 1) 37 | elseif mask:dim() == 1 then 38 | mask_ = mask:view(mask:size(1), 1) 39 | else 40 | error('mask must be matrix or vector!') 41 | end 42 | 43 | --[[ 44 | print('mask_ size') 45 | print(mask_:size()) 46 | print(mask_:type()) 47 | --]] 48 | 49 | self.gradInput:copy(-mask_ / div) 50 | 51 | --[[ 52 | print('size gradInput') 53 | print(self.gradInput:size()) 54 | print(self.gradInput[{ {-10, -1} }]) 55 | --]] 56 | 57 | return {self.gradInput} 58 | end 59 | 60 | -------------------------------------------------------------------------------- /experiments/msr/treelstm_h400.sh: -------------------------------------------------------------------------------- 1 | 2 | ID=`./gpu_lock.py --id-to-hog 0` 3 | echo $ID 4 | if [ $ID -eq -1 ]; then 5 | echo "this gpu is not free" 6 | exit 7 | fi 8 | ./gpu_lock.py 9 | 10 | curdir=`pwd` 11 | codedir=/afs/inf.ed.ac.uk/group/project/img2txt/xTreeLSTM/xtreelstm/td-treelstm-release 12 | dataset=/disk/scratch/XingxingZhang/treelstm/dataset/msr/msr.dep.100.sort20.h5 13 | lr=1.0 14 | label=.tree.400 15 | model=model_$lr$label.t7 16 | log=$lr$label.txt 17 | 18 | cd $codedir 19 | 20 | CUDA_VISIBLE_DEVICES=$ID th $codedir/main_nce.lua \ 21 | --model TreeLSTMNCE \ 22 | --dataset $dataset \ 23 | --useGPU \ 24 | --nin 200 \ 25 | --nhid 400 \ 26 | --nlayers 1 \ 27 | --lr $lr \ 28 | --batchSize 64 \ 29 | --validBatchSize 16 \ 30 | --maxEpoch 50 \ 31 | --save $curdir/$model \ 32 | --gradClip 5 \ 33 | --optimMethod SGD \ 34 | --patience 1 \ 35 | --nneg 20 \ 36 | --power 0.75 \ 37 | --lnZ 9 \ 38 | --learnZ \ 39 | --savePerEpoch \ 40 | --saveBeforeLrDiv \ 41 | --seqLen 101 \ 42 | | tee $curdir/$log 43 | 44 | ./gpu_lock.py --free $ID 45 | ./gpu_lock.py 46 | 47 | testfile=/disk/scratch/XingxingZhang/treelstm/dataset/msr/msr.dep.100.question.h5 48 | outfile=out.txt 49 | rlog=rerank.$log 50 | 51 | time th $codedir/rerank.lua --modelPath $curdir/$model \ 52 | --testFile $testfile \ 53 | --outFile $curdir/$outfile | tee $curdir/$rlog 54 | 55 | cd $curdir 56 | 57 | -------------------------------------------------------------------------------- /experiments/depparse/ldtreelstm_rerank_valid.sh: -------------------------------------------------------------------------------- 1 | 2 | ID=`./gpu_lock.py --id-to-hog 1` 3 | echo $ID 4 | if [ $ID -eq -1 ]; then 5 | echo "no gpu is free" 6 | exit 7 | fi 8 | ./gpu_lock.py 9 | 10 | curdir=`pwd` 11 | codedir=/afs/inf.ed.ac.uk/group/project/img2txt/xTreeLSTM/xtreelstm/td-treelstm-release 12 | 13 | vocab=/disk/scratch/XingxingZhang/treelstm/dataset/depparse/dataset/penn_wsj.conllx.bid.sort.vocab.t7 14 | basefile=/disk/scratch/XingxingZhang/treelstm/experiments/ptb_depparse/rerank_data/dev-20-best-mst2ndorder.conll.g 15 | model=/disk/scratch/XingxingZhang/treelstm/experiments/ptb_depparse/bid_flush_best/model_1.0.200.lc200.t7 16 | label=.200.2l.lc200 17 | log=log$label.txt 18 | scorefile=out$label.txt 19 | 20 | goldfile=/disk/scratch/XingxingZhang/treelstm/experiments/ptb_depparse/rerank_data/valid 21 | basescorefile=/disk/scratch/XingxingZhang/treelstm/experiments/ptb_depparse/rerank_data/dev-20-best-mst2ndorder.conll.mstscores 22 | 23 | cd $codedir 24 | 25 | CUDA_VISIBLE_DEVICES=$ID th $codedir/dep_rerank.lua \ 26 | --vocab $vocab \ 27 | --baseFile $basefile \ 28 | --modelPath $model \ 29 | --useGPU \ 30 | --batchSize 64 \ 31 | --scoreFile $curdir/$scorefile \ 32 | --baseScoreFile $basescorefile \ 33 | --goldFile $goldfile \ 34 | --k 20 \ 35 | --searchk \ 36 | --standard stanford \ 37 | | tee $curdir/$log 38 | 39 | cd $curdir 40 | 41 | ./gpu_lock.py --free $ID 42 | ./gpu_lock.py 43 | 44 | -------------------------------------------------------------------------------- /msr_scripts/score.pl: -------------------------------------------------------------------------------- 1 | #!/usr/bin/perl -w 2 | 3 | # scoring script for Holmes dataset 4 | 5 | die "usage: score file reference\n" unless ($#ARGV==1); 6 | $infile = $ARGV[0]; 7 | $reffile = $ARGV[1]; 8 | 9 | open(FIN, $infile) or die "unable to open $infile\n"; 10 | open(RIN, $reffile) or die "unable to open $reffile\n"; 11 | 12 | @hyp = ; 13 | @ref = ; 14 | 15 | die "mismatched number of hypotheses and references\n" unless ($#hyp==$#ref); 16 | $nlines = $#hyp+1; 17 | $ndev = int($nlines/2); 18 | $ntest = $nlines - $ndev; 19 | 20 | sub normalize { 21 | my $str = shift; 22 | $str =~ s/[\r\n]//og; 23 | $str =~ s/^\s+//o; 24 | $str =~ s/\s+$//o; 25 | $str =~ s/\s+/ /og; 26 | return $str; 27 | } 28 | 29 | $correct = 0; 30 | $correct_dev = 0; 31 | $correct_test = 0; 32 | $tot = 0; 33 | for ($i=0; $i<=$#ref; $i++) { 34 | $ref[$i] = normalize($ref[$i]); 35 | $hyp[$i] = normalize($hyp[$i]); 36 | print "$ref[$i]\n\t=> $hyp[$i]\n********\n" unless ($ref[$i] eq $hyp[$i]); 37 | $correct++ if ($ref[$i] eq $hyp[$i]); 38 | $correct_dev++ if ($ref[$i] eq $hyp[$i] and $i<$ndev); 39 | $correct_test++ if ($ref[$i] eq $hyp[$i] and $i>=$ndev); 40 | $tot++; 41 | } 42 | 43 | print "$correct of $tot correct\n"; 44 | 45 | $ave = 100 * $correct / $tot; 46 | $dave = 100 * $correct_dev / $ndev; 47 | $tave = 100 * $correct_test / $ntest; 48 | 49 | print "Overall average: $ave%\n"; 50 | print "dev: $dave%\ntest: $tave%\n"; 51 | -------------------------------------------------------------------------------- /experiments/msr/ldtreelstm_h400.sh: -------------------------------------------------------------------------------- 1 | 2 | ID=`./gpu_lock.py --id-to-hog 0` 3 | echo $ID 4 | if [ $ID -eq -1 ]; then 5 | echo "no gpu is free" 6 | exit 7 | fi 8 | ./gpu_lock.py 9 | 10 | curdir=`pwd` 11 | codedir=/afs/inf.ed.ac.uk/group/project/img2txt/xTreeLSTM/xtreelstm/td-treelstm-release 12 | dataset=/disk/scratch/XingxingZhang/treelstm/dataset/msr/msr.dep.100.bid.sort20.h5 13 | lr=1.0 14 | label=.ldtree.400 15 | model=model_$lr$label.t7 16 | log=$lr$label.txt 17 | 18 | cd $codedir 19 | 20 | CUDA_VISIBLE_DEVICES=$ID th $codedir/main_bid.lua \ 21 | --model BiTreeLSTMNCE \ 22 | --dataset $dataset \ 23 | --useGPU \ 24 | --nin 200 \ 25 | --nhid 400 \ 26 | --nlayers 1 \ 27 | --lr $lr \ 28 | --batchSize 64 \ 29 | --validBatchSize 16 \ 30 | --maxEpoch 50 \ 31 | --save $curdir/$model \ 32 | --gradClip 5 \ 33 | --optimMethod SGD \ 34 | --patience 1 \ 35 | --nneg 20 \ 36 | --power 0.75 \ 37 | --lnZ 9 \ 38 | --learnZ \ 39 | --savePerEpoch \ 40 | --saveBeforeLrDiv \ 41 | --seqLen 101 \ 42 | --nlclayers 1 \ 43 | --nlchid 400 \ 44 | | tee $curdir/$log 45 | 46 | ./gpu_lock.py --free $ID 47 | ./gpu_lock.py 48 | 49 | testfile=/disk/scratch/XingxingZhang/treelstm/dataset/msr/msr.dep.100.bid.question.h5 50 | outfile=out.txt 51 | rlog=rerank.$log 52 | 53 | time th $codedir/rerank.lua --modelPath $curdir/$model \ 54 | --testFile $testfile \ 55 | --outFile $curdir/$outfile | tee $curdir/$rlog 56 | 57 | cd $curdir 58 | 59 | -------------------------------------------------------------------------------- /scripts/sort_large_hdf5.lua: -------------------------------------------------------------------------------- 1 | 2 | require 'hdf5' 3 | require 'torch' 4 | include '../utils/shortcut.lua' 5 | require 'deptreeutils' 6 | require 'sorthdf5' 7 | 8 | local function getOpts() 9 | local cmd = torch.CmdLine() 10 | cmd:text('== convert dependency trees to hdf5 format ==') 11 | cmd:text() 12 | cmd:text('Options:') 13 | cmd:option('--dataset', '', 'the resulting dataset (.h5)') 14 | cmd:option('--sort', 0, '0: no sorting of the training data; -1: sort training data by their length; k (k > 0): sort the consecutive k batches by their length') 15 | cmd:option('--batchSize', 64, 'batch size when --sort > 0 or --sort == -1') 16 | 17 | return cmd:parse(arg) 18 | end 19 | 20 | local function main() 21 | local opts = getOpts() 22 | local dataPrefix = opts.dataset:sub(1, -4) 23 | local vocabPath = dataPrefix .. '.vocab.t7' 24 | if opts.sort ~= 0 then 25 | assert(opts.sort == -1 or opts.sort > 0, 'valid values [0, -1, > 0]') 26 | print '========begin to sort dataset========' 27 | local h5sorter = nil 28 | local mid = opts.sort == -1 and 'sort' or string.format('sort%d', opts.sort) 29 | local h5sortFile = opts.dataset:sub(1, -4) .. string.format('.%s.h5', mid) 30 | local h5sortVocabFile = opts.dataset:sub(1, -4) .. string.format('.%s.vocab.t7', mid) 31 | local cmd = string.format('cp %s %s', vocabPath, h5sortVocabFile) 32 | print(cmd) 33 | os.execute(cmd) 34 | h5sorter = SortHDF5(opts.dataset, h5sortFile) 35 | h5sorter:sortHDF5(opts.sort, opts.batchSize) 36 | printf('save dataset to %s\n', h5sortFile) 37 | end 38 | end 39 | 40 | main() 41 | -------------------------------------------------------------------------------- /scripts/WE_txt2torch.lua: -------------------------------------------------------------------------------- 1 | 2 | include '../utils/shortcut.lua' 3 | 4 | local function getShape(embedPath) 5 | local nrow , ncol = 0, 0 6 | local fin = io.open(embedPath, 'r') 7 | while true do 8 | local line = fin:read() 9 | if line == nil then break end 10 | if nrow == 0 then 11 | local fields = line:splitc(' \t') 12 | ncol = #fields 13 | end 14 | nrow = nrow + 1 15 | end 16 | fin:close() 17 | xprintln('word embedding file: nrow = %d, ncol = %d', nrow, ncol) 18 | 19 | return nrow, ncol 20 | end 21 | 22 | local function toTorch7(txtFile, torchFile) 23 | local nrow, ncol = getShape(txtFile) 24 | local embed = torch.FloatTensor(nrow, ncol-1) 25 | print('get embedding done!') 26 | local word2idx = {} 27 | local idx2word = {} 28 | local fin = io.open(txtFile, 'r') 29 | local cnt = 0 30 | while true do 31 | local line = fin:read() 32 | if line == nil then break end 33 | cnt = cnt + 1 34 | local fields = line:splitc(' \t') 35 | assert(#fields == ncol) 36 | table.insert(idx2word, fields[1]) 37 | local vec = {} 38 | for i = 2, #fields do 39 | vec[#vec + 1]= tonumber(fields[i]) 40 | end 41 | embed[cnt] = torch.FloatTensor(vec) 42 | if cnt % 10000 == 0 then print(cnt) end 43 | end 44 | xprintln('totaly lines %d\n', cnt) 45 | for idx, word in ipairs(idx2word) do 46 | word2idx[word] = idx 47 | end 48 | fin:close() 49 | xprintln('begin to save ...') 50 | torch.save(torchFile, {embed, word2idx, idx2word}) 51 | xprintln('save done!') 52 | end 53 | 54 | local function main() 55 | toTorch7(arg[1], arg[2]) 56 | end 57 | 58 | main() 59 | -------------------------------------------------------------------------------- /dataset/LM_Dataset0.lua: -------------------------------------------------------------------------------- 1 | 2 | include '../utils/shortcut.lua' 3 | 4 | require 'hdf5' 5 | 6 | local LM_Dataset = {} 7 | 8 | function LM_Dataset.toBatch(sents, eos) 9 | local maxn = 0 10 | for _, sent in ipairs(sents) do 11 | if sent:size(1) > maxn then 12 | maxn = sent:size(1) 13 | end 14 | end 15 | maxn = maxn + 1 16 | local batchSize = #sents 17 | -- for x, in default x contains EOS tokens 18 | local x = torch.ones(maxn, batchSize):type('torch.IntTensor') 19 | x:mul(eos) 20 | local y = torch.zeros(maxn, batchSize):type('torch.IntTensor') 21 | for i = 1, batchSize do 22 | local senlen = sents[i]:size(1) 23 | x[{ {2, senlen + 1}, i }] = sents[i] 24 | y[{ {1, senlen}, i }] = sents[i] 25 | y[{ senlen + 1, i }] = eos 26 | end 27 | 28 | return x, y 29 | end 30 | 31 | function LM_Dataset.createBatch(h5in, label, batchSize, eos) 32 | -- local h5in = hdf5.open(h5InFile, 'r') 33 | local x_data = h5in:read(string.format('/%s/x_data', label)) 34 | local index = h5in:read(string.format('/%s/index', label)) 35 | local N = index:dataspaceSize()[1] 36 | 37 | local istart = 1 38 | 39 | return function() 40 | if istart <= N then 41 | local iend = math.min(istart + batchSize - 1, N) 42 | local sents = {} 43 | for i = istart, iend do 44 | local idx = index:partial({i, i}, {1, 2}) 45 | local start, len = idx[1][1], idx[1][2] 46 | local sent = x_data:partial({start, start + len - 1}) 47 | table.insert(sents, sent) 48 | end 49 | 50 | istart = iend + 1 51 | 52 | return LM_Dataset.toBatch(sents, eos) 53 | else 54 | h5in:close() 55 | end 56 | end 57 | end 58 | 59 | return LM_Dataset 60 | -------------------------------------------------------------------------------- /conllx_scripts/shuffle.lua: -------------------------------------------------------------------------------- 1 | 2 | include '../utils/shortcut.lua' 3 | 4 | local function conllxLineIterator(infile) 5 | local fin = io.open(infile) 6 | local bufs = {} 7 | 8 | return function() 9 | while true do 10 | local line = fin:read() 11 | if line == nil then 12 | fin:close() 13 | break 14 | end 15 | line = line:trim() 16 | if line:len() == 0 then 17 | local rlines = {} 18 | for i, buf in ipairs(bufs) do 19 | rlines[i] = buf 20 | end 21 | table.clear(bufs) 22 | 23 | return rlines 24 | else 25 | table.insert(bufs, line) 26 | end 27 | end 28 | end 29 | 30 | end 31 | 32 | local function shuffle(inFile, outFile) 33 | local inIter = conllxLineIterator(inFile) 34 | local depTrees = {} 35 | for lines in inIter do 36 | depTrees[#depTrees + 1] = lines 37 | end 38 | local newIdxs = torch.randperm(#depTrees) 39 | local fout = io.open(outFile, 'w') 40 | for i = 1, newIdxs:size(1) do 41 | local idx = newIdxs[i] 42 | local lines = depTrees[idx] 43 | for _, line in ipairs(lines) do 44 | fout:write(string.format('%s\n', line)) 45 | end 46 | fout:write('\n') 47 | end 48 | fout:close() 49 | end 50 | 51 | local function getOpts() 52 | local cmd = torch.CmdLine() 53 | cmd:text('====== shuffle training set (CoNLL-X format) ======') 54 | cmd:option('--inFile', '', 'system output') 55 | cmd:option('--outFile', '', 'gold standard') 56 | cmd:option('--seed', 123, 'random seed for shuffle') 57 | 58 | return cmd:parse(arg) 59 | end 60 | 61 | local function main() 62 | local opts = getOpts() 63 | print(opts) 64 | torch.manualSeed(opts.seed) 65 | shuffle(opts.inFile, opts.outFile) 66 | end 67 | 68 | main() 69 | 70 | 71 | -------------------------------------------------------------------------------- /layers/Embedding.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Copyright 2014 Google Inc. All Rights Reserved. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | ]]-- 16 | 17 | local Embedding, parent = torch.class('Embedding', 'nn.Module') 18 | 19 | function Embedding:__init(inputSize, outputSize) 20 | parent.__init(self) 21 | self.outputSize = outputSize 22 | self.weight = torch.Tensor(inputSize, outputSize) 23 | self.gradWeight = torch.Tensor(inputSize, outputSize) 24 | end 25 | 26 | function Embedding:updateOutput(input) 27 | self.output:resize(input:size(1), self.outputSize) 28 | for i = 1, input:size(1) do 29 | self.output[i]:copy(self.weight[input[i]]) 30 | end 31 | return self.output 32 | end 33 | 34 | function Embedding:updateGradInput(input, gradOutput) 35 | if self.gradInput then 36 | self.gradInput:resize(input:size()) 37 | return self.gradInput 38 | end 39 | end 40 | 41 | function Embedding:accGradParameters(input, gradOutput, scale) 42 | scale = scale or 1 43 | if scale == 0 then 44 | self.gradWeight:zero() 45 | end 46 | for i = 1, input:size(1) do 47 | local word = input[i] 48 | self.gradWeight[word]:add(gradOutput[i]) 49 | end 50 | end 51 | 52 | -- we do not need to accumulate parameters when sharing 53 | Embedding.sharedAccUpdateGradParameters = Embedding.accUpdateGradParameters 54 | -------------------------------------------------------------------------------- /utils/alias_method.lua: -------------------------------------------------------------------------------- 1 | 2 | local AliasMethod = torch.class('AliasMethod') 3 | 4 | function AliasMethod:__init(probs) 5 | local function initArray(N, val) 6 | local arr = {} 7 | for i = 1, N do 8 | arr[i] = val 9 | end 10 | return arr 11 | end 12 | 13 | local N = #probs 14 | local probTable = initArray(N, 0) 15 | local aliasTable = initArray(N, 1) 16 | 17 | local smaller, larger = {}, {} 18 | for i, p in ipairs(probs) do 19 | probTable[i] = N * p 20 | if probTable[i] < 1.0 then 21 | smaller[#smaller + 1] = i 22 | else 23 | larger[#larger + 1] = i 24 | end 25 | end 26 | 27 | local smallerSize, largerSize = #smaller, #larger 28 | 29 | while smallerSize > 0 and largerSize > 0 do 30 | local small = smaller[smallerSize] 31 | smallerSize = smallerSize - 1 32 | local large = larger[largerSize] 33 | largerSize = largerSize - 1 34 | 35 | aliasTable[small] = large 36 | probTable[large] = probTable[large] - (1.0 - probTable[small]) 37 | if probTable[large] < 1.0 then 38 | smallerSize = smallerSize + 1 39 | smaller[smallerSize] = large 40 | else 41 | largerSize = largerSize + 1 42 | larger[largerSize + 1] = large 43 | end 44 | end 45 | 46 | self.probTable = torch.DoubleTensor(probTable) 47 | self.aliasTable = torch.LongTensor(aliasTable) 48 | self.size = N 49 | end 50 | 51 | function AliasMethod:drawBatch(N) 52 | local rndIdxs = (torch.DoubleTensor(N):uniform(0, 1) * self.size + 1):long() 53 | local probs = self.probTable:index(1, rndIdxs) 54 | local coins = torch.DoubleTensor(N):uniform(0, 1) 55 | local rndOut = torch.LongTensor(N) 56 | local bl = torch.lt(coins, probs) 57 | rndOut[bl] = rndIdxs[bl] 58 | local nbl = (-bl + 1) 59 | rndOut[nbl] = self.aliasTable:index(1, rndIdxs[nbl]) 60 | 61 | return rndOut 62 | end 63 | 64 | -------------------------------------------------------------------------------- /scripts/sort_large_hdf5_bid.lua: -------------------------------------------------------------------------------- 1 | 2 | require 'hdf5' 3 | require 'torch' 4 | include '../utils/shortcut.lua' 5 | require 'deptreeutils' 6 | require 'sorthdf5' 7 | 8 | local function getOpts() 9 | local cmd = torch.CmdLine() 10 | cmd:text('== convert dependency trees to hdf5 format ==') 11 | cmd:text() 12 | cmd:text('Options:') 13 | cmd:option('--dataset', '', 'the resulting dataset (.h5)') 14 | cmd:option('--sort', 0, '0: no sorting of the training data; -1: sort training data by their length; k (k > 0): sort the consecutive k batches by their length') 15 | cmd:option('--batchSize', 64, 'batch size when --sort > 0 or --sort == -1') 16 | 17 | cmd:option('--bidirectional', false, 'create bidirectional model') 18 | 19 | return cmd:parse(arg) 20 | end 21 | 22 | local function main() 23 | local opts = getOpts() 24 | local dataPrefix = opts.dataset:sub(1, -4) 25 | local vocabPath = dataPrefix .. '.vocab.t7' 26 | if opts.sort ~= 0 then 27 | assert(opts.sort == -1 or opts.sort > 0, 'valid values [0, -1, > 0]') 28 | print '========begin to sort dataset========' 29 | local h5sorter = nil 30 | local mid = opts.sort == -1 and 'sort' or string.format('sort%d', opts.sort) 31 | local h5sortFile = opts.dataset:sub(1, -4) .. string.format('.%s.h5', mid) 32 | local h5sortVocabFile = opts.dataset:sub(1, -4) .. string.format('.%s.vocab.t7', mid) 33 | local cmd = string.format('cp %s %s', vocabPath, h5sortVocabFile) 34 | print(cmd) 35 | os.execute(cmd) 36 | -- h5sorter = SortHDF5(opts.dataset, h5sortFile) 37 | if opts.bidirectional then 38 | require 'sorthdf5bid' 39 | h5sorter = SortHDF5Bidirectional(opts.dataset, h5sortFile) 40 | else 41 | h5sorter = SortHDF5(opts.dataset, h5sortFile) 42 | end 43 | h5sorter:sortHDF5(opts.sort, opts.batchSize) 44 | printf('save dataset to %s\n', h5sortFile) 45 | end 46 | end 47 | 48 | main() 49 | -------------------------------------------------------------------------------- /dataset/NCEDataGenerator.lua: -------------------------------------------------------------------------------- 1 | 2 | require 'alias_method' 3 | 4 | local NCEDataGenerator = torch.class('NCEDataGenerator') 5 | 6 | function NCEDataGenerator:__init(vocab, nneg, power, normalizeUNK) 7 | power = power or 1.0 8 | if normalizeUNK == nil then normalizeUNK = true end 9 | self.nneg = nneg 10 | 11 | self.unigramProbs = self:initUnigramProbs(vocab, power, normalizeUNK) 12 | self.size = self.unigramProbs:size(1) 13 | print('probs sum') 14 | print(self.unigramProbs:sum()) 15 | print(self.unigramProbs:size(1)) 16 | self.aliasMethod = AliasMethod(self.unigramProbs:totable()) 17 | end 18 | 19 | function NCEDataGenerator:initUnigramProbs(vocab, power, normalizeUNK) 20 | print('power', power) 21 | print('normalizeUNK', normalizeUNK) 22 | 23 | local freqs = vocab.freqs 24 | local uniqUNK = vocab.uniqUNK 25 | local unkID = vocab.UNK 26 | local word2idx = vocab.word2idx 27 | local vocabSize = vocab.nvocab 28 | 29 | if normalizeUNK then freqs[unkID] = math.ceil( freqs[unkID] / uniqUNK ) end 30 | 31 | local ifreqs = torch.LongTensor(freqs) 32 | local pfreqs = ifreqs:double():pow(power) 33 | 34 | local uprobs = pfreqs:div( pfreqs:sum() ) 35 | while uprobs:sum() ~= 1 do 36 | uprobs = pfreqs:div( pfreqs:sum() ) 37 | end 38 | 39 | return uprobs 40 | end 41 | 42 | function NCEDataGenerator:getYNegProbs(y, useGPU) 43 | local probs = self.unigramProbs 44 | local nneg = self.nneg 45 | 46 | assert(y:dim() == 2) 47 | 48 | local y_neg = self.aliasMethod:drawBatch(y:size(1) * y:size(2) * nneg) 49 | 50 | local y_ = y:reshape(y:size(1) * y:size(2)) 51 | -- I think I should do this earlier 52 | -- y_[y_:eq(0)] = 1 53 | local y_prob = probs:index(1, y_):reshape(y:size(1), y:size(2)) 54 | local y_neg_prob = probs:index(1, y_neg):reshape(y:size(1), y:size(2), nneg) 55 | y_neg = y_neg:reshape(y:size(1), y:size(2), nneg) 56 | 57 | if useGPU then 58 | return y_neg:cuda(), y_prob:cuda(), y_neg_prob:cuda() 59 | else 60 | return y_neg, y_prob:float(), y_neg_prob:float() 61 | end 62 | end 63 | -------------------------------------------------------------------------------- /conllx_scripts/conllx_eval.lua: -------------------------------------------------------------------------------- 1 | 2 | include '../utils/shortcut.lua' 3 | 4 | local CoNLLXEval = {} 5 | 6 | function CoNLLXEval.conllxLineIterator(infile) 7 | local fin = io.open(infile) 8 | local bufs = {} 9 | 10 | return function() 11 | while true do 12 | local line = fin:read() 13 | if line == nil then 14 | fin:close() 15 | break 16 | end 17 | line = line:trim() 18 | if line:len() == 0 then 19 | local rlines = {} 20 | for i, buf in ipairs(bufs) do 21 | rlines[i] = buf 22 | end 23 | table.clear(bufs) 24 | 25 | return rlines 26 | else 27 | table.insert(bufs, line) 28 | end 29 | end 30 | end 31 | 32 | end 33 | 34 | function CoNLLXEval.eval(sysFile, goldFile) 35 | local punctTags ={ "``", "''", ".", ",", ":"} 36 | local punctTagSet = {} 37 | for _, pt in ipairs(punctTags) do 38 | punctTagSet[pt] = true 39 | end 40 | -- print(punctTagSet) 41 | 42 | local sysIter = CoNLLXEval.conllxLineIterator(sysFile) 43 | local goldIter = CoNLLXEval.conllxLineIterator(goldFile) 44 | local sen_cnt = 0 45 | local total, noPunctTotal = 0, 0 46 | local nUA, noPunctNUA = 0, 0 47 | local nLA, noPunctNLA = 0, 0 48 | for sysLines in sysIter do 49 | local goldLines = goldIter() 50 | assert(#sysLines == #goldLines, 'the sys sentence and the gold sentence should contain the same number of words') 51 | for i = 1, #sysLines do 52 | local sfields = sysLines[i]:splitc('\t ') 53 | local gfields = goldLines[i]:splitc('\t ') 54 | local sAid, gAid = tonumber(sfields[7]), tonumber(gfields[7]) 55 | local sDep, gDep = sfields[8], gfields[8] 56 | if sAid == gAid then 57 | nUA = nUA + 1 58 | if sDep == gDep then nLA = nLA + 1 end 59 | end 60 | 61 | total = total + 1 62 | 63 | local gtag = gfields[5] 64 | if not punctTagSet[gtag] then 65 | noPunctTotal = noPunctTotal + 1 66 | if sAid == gAid then 67 | noPunctNUA = noPunctNUA + 1 68 | if sDep == gDep then noPunctNLA = noPunctNLA + 1 end 69 | end 70 | end 71 | end 72 | 73 | sen_cnt = sen_cnt + 1 74 | end 75 | 76 | xprintln('totally %d sentences', sen_cnt) 77 | local LAS, UAS = nLA / total * 100, nUA / total * 100 78 | local noPunctLAS, noPunctUAS = noPunctNLA / noPunctTotal * 100, noPunctNUA / noPunctTotal * 100 79 | 80 | --[[ 81 | xprintln('==no punct==') 82 | xprintln('LAS = %.2f, UAS = %.2f', noPunctLAS, noPunctUAS) 83 | xprintln('==with punct==') 84 | xprintln('LAS = %.2f, UAS = %.2f', LAS, UAS) 85 | --]] 86 | 87 | return LAS, UAS, noPunctLAS, noPunctUAS 88 | end 89 | 90 | return CoNLLXEval -------------------------------------------------------------------------------- /dataset/NCEDataGenerator_lu.lua: -------------------------------------------------------------------------------- 1 | 2 | local NCEDataGenerator = torch.class('NCEDataGenerator') 3 | 4 | function NCEDataGenerator:__init(vocab, nneg, power, normalizeUNK, tableSize) 5 | power = power or 1.0 6 | if normalizeUNK == nil then normalizeUNK = true end 7 | self.nneg = nneg 8 | self.tableSize = tableSize or 1e8 9 | 10 | self.unigramProbs, self.unigramBins = self:initUnigramProbs(vocab, self.tableSize, power, normalizeUNK) 11 | print('probs sum') 12 | print(self.unigramProbs:sum()) 13 | print(self.unigramProbs:size(1)) 14 | print('bins last') 15 | print(self.unigramBins[#self.unigramBins]) 16 | print(#self.unigramBins) 17 | end 18 | 19 | function NCEDataGenerator:initUnigramProbs(vocab, tableSize, power, normalizeUNK) 20 | print('power', power) 21 | print('normalizeUNK', normalizeUNK) 22 | 23 | local freqs = vocab.freqs 24 | local uniqUNK = vocab.uniqUNK 25 | local unkID = vocab.UNK 26 | local word2idx = vocab.word2idx 27 | local vocabSize = vocab.nvocab 28 | 29 | if normalizeUNK then freqs[unkID] = math.ceil( freqs[unkID] / uniqUNK ) end 30 | 31 | local ifreqs = torch.LongTensor(freqs) 32 | local pfreqs = ifreqs:double():pow(power) 33 | 34 | local total = pfreqs:sum() 35 | local acc, i = pfreqs[1], 1 36 | local thres = acc / total 37 | local tableBins = {} 38 | local maxBinSize = 1e5 39 | assert(tableSize % maxBinSize == 0) 40 | local bins = torch.IntTensor(tableSize) 41 | local offset = 0 42 | for a = 1, tableSize do 43 | if a / tableSize > thres then 44 | i = i + 1 45 | if i > vocabSize then i = vocabSize end 46 | acc = acc + pfreqs[i] 47 | thres = acc / total 48 | end 49 | tableBins[a - offset] = i 50 | if a % maxBinSize == 0 then 51 | bins[{ {offset + 1, a} }] = torch.IntTensor(tableBins) 52 | offset = offset + maxBinSize 53 | end 54 | end 55 | 56 | local uprobs = pfreqs:div( pfreqs:sum() ) 57 | while uprobs:sum() ~= 1 do 58 | uprobs = pfreqs:div( pfreqs:sum() ) 59 | end 60 | 61 | return uprobs, bins 62 | end 63 | 64 | function NCEDataGenerator:getYNegProbs(y, useGPU) 65 | local probs = self.unigramProbs 66 | local bins = self.unigramBins 67 | local nneg = self.nneg 68 | 69 | assert(y:dim() == 2) 70 | local rnds = (torch.DoubleTensor(y:size(1) * y:size(2) * nneg):uniform(0, 1) * self.tableSize + 1):long() 71 | local y_neg = bins:index(1, rnds):long() 72 | 73 | local y_ = y:reshape(y:size(1) * y:size(2)) 74 | y_[y_:eq(0)] = 1 75 | local y_prob = probs:index(1, y_):reshape(y:size(1), y:size(2)) 76 | local y_neg_prob = probs:index(1, y_neg):reshape(y:size(1), y:size(2), nneg) 77 | y_neg = y_neg:reshape(y:size(1), y:size(2), nneg) 78 | 79 | if useGPU then 80 | return y_neg:cuda(), y_prob:cuda(), y_neg_prob:cuda() 81 | else 82 | return y_neg, y_prob:float(), y_neg_prob:float() 83 | end 84 | end 85 | -------------------------------------------------------------------------------- /dataset/LM_Dataset.lua: -------------------------------------------------------------------------------- 1 | 2 | require 'torch' 3 | require 'hdf5' 4 | require 'shortcut' 5 | 6 | local LM_Dataset = torch.class('LM_Dataset') 7 | 8 | function LM_Dataset:__init(datasetPath) 9 | self.vocab = torch.load(datasetPath:sub(1, -3) .. 'vocab.t7') 10 | xprintln('load vocab done!') 11 | self.h5in = hdf5.open(datasetPath, 'r') 12 | 13 | local function getLength(label) 14 | local index = self.h5in:read(string.format('/%s/index', label)) 15 | return index:dataspaceSize()[1] 16 | end 17 | self.trainSize = getLength('train') 18 | self.validSize = getLength('valid') 19 | self.testSize = getLength('test') 20 | xprintln('train size %d, valid size %d, test size %d', self.trainSize, self.validSize, self.testSize) 21 | -- print(table.keys(self.vocab)) 22 | xprintln('vocab size %d', self.vocab.nvocab) 23 | self.eos = self.vocab.word2idx['###eos###'] 24 | xprintln('EOS id %d', self.eos) 25 | end 26 | 27 | function LM_Dataset:getVocabSize() 28 | return self.vocab.nvocab 29 | end 30 | 31 | function LM_Dataset:getTrainSize() 32 | return self.trainSize 33 | end 34 | 35 | function LM_Dataset:getValidSize() 36 | return self.validSize 37 | end 38 | 39 | function LM_Dataset:getTestSize() 40 | return self.testSize 41 | end 42 | 43 | function LM_Dataset:toBatch(sents, eos, bs) 44 | local maxn = 0 45 | for _, sent in ipairs(sents) do 46 | if sent:size(1) > maxn then 47 | maxn = sent:size(1) 48 | end 49 | end 50 | maxn = maxn + 1 51 | local nsent = #sents 52 | -- for x, in default x contains EOS tokens 53 | local x = torch.ones(maxn, bs):type('torch.IntTensor') 54 | -- local x = torch.ones(maxn, batchSize) 55 | x:mul(eos) 56 | local y = torch.zeros(maxn, bs):type('torch.IntTensor') 57 | -- local y = torch.zeros(maxn, batchSize) 58 | for i = 1, nsent do 59 | local senlen = sents[i]:size(1) 60 | x[{ {2, senlen + 1}, i }] = sents[i] 61 | y[{ {1, senlen}, i }] = sents[i] 62 | y[{ senlen + 1, i }] = eos 63 | end 64 | 65 | return x, y 66 | end 67 | 68 | function LM_Dataset:createBatch(label, batchSize) 69 | local h5in = self.h5in 70 | local x_data = h5in:read(string.format('/%s/x_data', label)) 71 | local index = h5in:read(string.format('/%s/index', label)) 72 | local N = index:dataspaceSize()[1] 73 | local eos = self.eos 74 | 75 | local istart = 1 76 | 77 | return function() 78 | if istart <= N then 79 | local iend = math.min(istart + batchSize - 1, N) 80 | local sents = {} 81 | for i = istart, iend do 82 | local idx = index:partial({i, i}, {1, 2}) 83 | local start, len = idx[1][1], idx[1][2] 84 | local sent = x_data:partial({start, start + len - 1}) 85 | table.insert(sents, sent) 86 | end 87 | 88 | istart = iend + 1 89 | 90 | return self:toBatch(sents, eos, batchSize) 91 | end 92 | end 93 | end 94 | 95 | 96 | function LM_Dataset:close() 97 | self.h5in:close() 98 | end 99 | -------------------------------------------------------------------------------- /nnets/MLP.lua: -------------------------------------------------------------------------------- 1 | 2 | require 'basic' 3 | require 'shortcut' 4 | require 'optim' 5 | 6 | -- local nninit = require 'nninit' 7 | 8 | local mlp = torch.class('MLP', 'BModel') 9 | 10 | local function transferData(useGPU, data) 11 | if useGPU then 12 | return data:cuda() 13 | else 14 | return data 15 | end 16 | end 17 | 18 | function mlp:__init(opts) 19 | self.opts = opts 20 | self.model = nn.Sequential() 21 | local nhids = opts.snhids:splitc(',') 22 | opts.nhids = {} 23 | for _, snhid in ipairs(nhids) do 24 | table.insert(opts.nhids, tonumber(snhid)) 25 | end 26 | 27 | self.model = nn.Sequential() 28 | local nlayers = #opts.nhids 29 | for i = 2, nlayers do 30 | self.model:add(nn.Linear(opts.nhids[i-1], opts.nhids[i])) 31 | if i ~= nlayers then 32 | if opts.activ == 'tanh' then 33 | self.model:add(nn.Tanh()) 34 | elseif opts.activ == 'relu' then 35 | self.model:add(nn.ReLU()) 36 | else 37 | error(opts.activ .. ' not supported!') 38 | end 39 | 40 | if opts.dropout > 0 then 41 | self.model:add(nn.Dropout(opts.dropout)) 42 | end 43 | end 44 | end 45 | self.model:add(nn.LogSoftMax()) 46 | print(self.model) 47 | self.criterion = nn.ClassNLLCriterion() 48 | 49 | if opts.useGPU then 50 | self.model = self.model:cuda() 51 | self.criterion = self.criterion:cuda() 52 | end 53 | 54 | self.params, self.grads = self.model:getParameters() 55 | printf('#param %d\n', self.params:size(1)) 56 | 57 | if opts.optimMethod == 'AdaGrad' then 58 | self.optimMethod = optim.adagrad 59 | elseif opts.optimMethod == 'SGD' then 60 | self.optimMethod = optim.sgd 61 | end 62 | 63 | end 64 | 65 | function mlp:trainBatch(x, y, sgd_params) 66 | self.model:training() 67 | if self.opts.useGPU then 68 | x = x:cuda() 69 | y = y:cuda() 70 | end 71 | 72 | local feval = function(newParam) 73 | if self.params ~= newParam then 74 | self.params:copy(newParam) 75 | end 76 | 77 | self.grads:zero() 78 | local output = self.model:forward(x) 79 | local loss = self.criterion:forward(output, y) 80 | local df = self.criterion:backward(output, y) 81 | self.model:backward(x, df) 82 | 83 | return loss, self.grads 84 | end 85 | 86 | local _, loss_ = self.optimMethod(feval, self.params, sgd_params) 87 | 88 | return loss_[1] 89 | end 90 | 91 | function mlp:validBatch(x, y) 92 | if self.opts.useGPU then 93 | y = y:cuda() 94 | end 95 | 96 | local yPred = self:predictBatch(x) 97 | local maxv, maxi = yPred:max(2) 98 | return torch.sum( torch.eq(maxi, y) ), x:size(1) 99 | end 100 | 101 | function mlp:predictBatch(x) 102 | self.model:evaluate() 103 | if self.opts.useGPU then 104 | x = x:cuda() 105 | end 106 | 107 | return torch.exp(self.model:forward(x)) 108 | end 109 | 110 | -------------------------------------------------------------------------------- /scripts/conllxutils.lua: -------------------------------------------------------------------------------- 1 | 2 | local CoNLLXUtils = torch.class('CoNLLXUtils') 3 | 4 | include '../utils/shortcut.lua' 5 | 6 | CoNLLXUtils.ROOT_MARK = '###root###' 7 | 8 | function CoNLLXUtils.normalizeNumber(str) 9 | local function match(str, pat) 10 | local istart, iend = str:find(pat) 11 | return istart ~= nil and iend ~= nil and iend - istart + 1 == str:len() 12 | end 13 | 14 | if match(str, '%d+') then 15 | return '' 16 | elseif match(str, '%d+%.%d+') then 17 | return '' 18 | elseif match(str, '%d[%d,]+') then 19 | return '' 20 | else 21 | return str 22 | end 23 | end 24 | 25 | function CoNLLXUtils.conllxLines2dwords(lines, normalize) 26 | local words = {} 27 | for _, line in ipairs(lines) do 28 | local fields = line:splitc(' \t') 29 | assert(#fields == 10, 'MUST have ten columns') 30 | if normalize then 31 | words[tonumber(fields[1])] = CoNLLXUtils.normalizeNumber(fields[2]) 32 | else 33 | words[tonumber(fields[1])] = fields[2] 34 | end 35 | end 36 | 37 | local dwords = {} 38 | for _, line in ipairs(lines) do 39 | local fields = line:splitc(' \t') 40 | local p2 = tonumber(fields[1]) 41 | local p1 = tonumber(fields[7]) 42 | local w2 = words[p2] 43 | local w1 = p1 == 0 and CoNLLXUtils.ROOT_MARK or words[p1] 44 | table.insert(dwords, {rel = fields[8], w1 = w1, 45 | p1 = p1, w2 = w2, p2 = p2}) 46 | end 47 | 48 | -- print('dwords') 49 | -- print(dwords) 50 | 51 | return dwords 52 | end 53 | 54 | function CoNLLXUtils.conllxIterator(infile, normalize) 55 | if normalize then 56 | print('Note Normalize Number') 57 | end 58 | 59 | local fin = io.open(infile) 60 | local bufs = {} 61 | 62 | return function() 63 | 64 | while true do 65 | local line = fin:read() 66 | if line == nil then 67 | fin:close() 68 | break 69 | end 70 | line = line:trim() 71 | if line:len() == 0 then 72 | local dwords = CoNLLXUtils.conllxLines2dwords(bufs, normalize) 73 | table.clear(bufs) 74 | 75 | return dwords 76 | else 77 | table.insert(bufs, line) 78 | end 79 | end 80 | 81 | end 82 | end 83 | 84 | function CoNLLXUtils.conllxLineIterator(infile) 85 | local fin = io.open(infile) 86 | local bufs = {} 87 | 88 | return function() 89 | while true do 90 | local line = fin:read() 91 | if line == nil then 92 | fin:close() 93 | break 94 | end 95 | line = line:trim() 96 | if line:len() == 0 then 97 | local rlines = {} 98 | for i, buf in ipairs(bufs) do 99 | rlines[i] = buf 100 | end 101 | table.clear(bufs) 102 | 103 | return rlines 104 | else 105 | table.insert(bufs, line) 106 | end 107 | end 108 | end 109 | 110 | end 111 | -------------------------------------------------------------------------------- /dataset/TreeLM_Dataset.lua: -------------------------------------------------------------------------------- 1 | 2 | require 'torch' 3 | require 'hdf5' 4 | require 'shortcut' 5 | 6 | local TreeLM_Dataset = torch.class('TreeLM_Dataset') 7 | 8 | function TreeLM_Dataset:__init(datasetPath) 9 | self.vocab = torch.load(datasetPath:sub(1, -3) .. 'vocab.t7') 10 | xprintln('load vocab done!') 11 | self.h5in = hdf5.open(datasetPath, 'r') 12 | 13 | local function getLength(label) 14 | local index = self.h5in:read(string.format('/%s/index', label)) 15 | return index:dataspaceSize()[1] 16 | end 17 | self.trainSize = getLength('train') 18 | self.validSize = getLength('valid') 19 | self.testSize = getLength('test') 20 | xprintln('train size %d, valid size %d, test size %d', self.trainSize, self.validSize, self.testSize) 21 | xprintln('vocab size %d', self.vocab.nvocab) 22 | self.UNK = self.vocab.UNK 23 | xprintln('unknown word token is %d', self.UNK) 24 | end 25 | 26 | function TreeLM_Dataset:getVocabSize() 27 | return self.vocab.nvocab 28 | end 29 | 30 | function TreeLM_Dataset:getTrainSize() 31 | return self.trainSize 32 | end 33 | 34 | function TreeLM_Dataset:getValidSize() 35 | return self.validSize 36 | end 37 | 38 | function TreeLM_Dataset:getTestSize() 39 | return self.testSize 40 | end 41 | 42 | function TreeLM_Dataset:toBatch(xs, ys, batchSize) 43 | local dtype = 'torch.LongTensor' 44 | local maxn = 0 45 | for _, y_ in ipairs(ys) do 46 | if y_:size(1) > maxn then 47 | maxn = y_:size(1) 48 | end 49 | end 50 | local x = torch.ones(maxn, batchSize, 4):type(dtype) 51 | x:mul(self.UNK) 52 | x[{ {}, {}, 4 }] = torch.linspace(2, maxn + 1, maxn):resize(maxn, 1):expand(maxn, batchSize) 53 | local nsent = #ys 54 | local y = torch.zeros(maxn, batchSize):type(dtype) 55 | for i = 1, nsent do 56 | local sx, sy = xs[i], ys[i] 57 | x[{ {1, sx:size(1)}, i, {} }] = sx 58 | y[{ {1, sy:size(1)}, i }] = sy 59 | end 60 | 61 | return x, y 62 | end 63 | 64 | function TreeLM_Dataset:createBatch(label, batchSize) 65 | local h5in = self.h5in 66 | local x_data = h5in:read(string.format('/%s/x_data', label)) 67 | local y_data = h5in:read(string.format('/%s/y_data', label)) 68 | local index = h5in:read(string.format('/%s/index', label)) 69 | local N = index:dataspaceSize()[1] 70 | 71 | local istart = 1 72 | 73 | return function() 74 | if istart <= N then 75 | local iend = math.min(istart + batchSize - 1, N) 76 | local xs = {} 77 | local ys = {} 78 | for i = istart, iend do 79 | local idx = index:partial({i, i}, {1, 2}) 80 | local start, len = idx[1][1], idx[1][2] 81 | local x = x_data:partial({start, start + len - 1}, {1, 4}) 82 | local y = y_data:partial({start, start + len - 1}) 83 | table.insert(xs, x) 84 | table.insert(ys, y) 85 | end 86 | 87 | istart = iend + 1 88 | 89 | return self:toBatch(xs, ys, batchSize) 90 | end 91 | end 92 | end 93 | 94 | function TreeLM_Dataset:close() 95 | self.h5in:close() 96 | end 97 | -------------------------------------------------------------------------------- /dataset/TreeLM_NCE_Dataset.lua: -------------------------------------------------------------------------------- 1 | 2 | require 'torch' 3 | require 'hdf5' 4 | require 'shortcut' 5 | require 'NCEDataGenerator' 6 | 7 | local TreeLM_NCE_Dataset = torch.class('TreeLM_NCE_Dataset') 8 | 9 | function TreeLM_NCE_Dataset:__init(datasetPath, nneg, power, normalizeUNK) 10 | self.vocab = torch.load(datasetPath:sub(1, -3) .. 'vocab.t7') 11 | xprintln('load vocab done!') 12 | self.h5in = hdf5.open(datasetPath, 'r') 13 | 14 | local function getLength(label) 15 | local index = self.h5in:read(string.format('/%s/index', label)) 16 | return index:dataspaceSize()[1] 17 | end 18 | self.trainSize = getLength('train') 19 | self.validSize = getLength('valid') 20 | self.testSize = getLength('test') 21 | xprintln('train size %d, valid size %d, test size %d', self.trainSize, self.validSize, self.testSize) 22 | xprintln('vocab size %d', self.vocab.nvocab) 23 | self.UNK = self.vocab.UNK 24 | xprintln('unknown word token is %d', self.UNK) 25 | 26 | self.ncedata = NCEDataGenerator(self.vocab, nneg, power, normalizeUNK) 27 | end 28 | 29 | function TreeLM_NCE_Dataset:getVocabSize() 30 | return self.vocab.nvocab 31 | end 32 | 33 | function TreeLM_NCE_Dataset:getTrainSize() 34 | return self.trainSize 35 | end 36 | 37 | function TreeLM_NCE_Dataset:getValidSize() 38 | return self.validSize 39 | end 40 | 41 | function TreeLM_NCE_Dataset:getTestSize() 42 | return self.testSize 43 | end 44 | 45 | function TreeLM_NCE_Dataset:toBatch(xs, ys, batchSize) 46 | local dtype = 'torch.LongTensor' 47 | local maxn = 0 48 | for _, y_ in ipairs(ys) do 49 | if y_:size(1) > maxn then 50 | maxn = y_:size(1) 51 | end 52 | end 53 | local x = torch.ones(maxn, batchSize, 4):type(dtype) 54 | x:mul(self.UNK) 55 | x[{ {}, {}, 4 }] = torch.linspace(2, maxn + 1, maxn):resize(maxn, 1):expand(maxn, batchSize) 56 | local nsent = #ys 57 | local y = torch.zeros(maxn, batchSize):type(dtype) 58 | for i = 1, nsent do 59 | local sx, sy = xs[i], ys[i] 60 | x[{ {1, sx:size(1)}, i, {} }] = sx 61 | y[{ {1, sy:size(1)}, i }] = sy 62 | end 63 | 64 | return x, y 65 | end 66 | 67 | function TreeLM_NCE_Dataset:createBatch(label, batchSize, useNCE) 68 | local h5in = self.h5in 69 | local x_data = h5in:read(string.format('/%s/x_data', label)) 70 | local y_data = h5in:read(string.format('/%s/y_data', label)) 71 | local index = h5in:read(string.format('/%s/index', label)) 72 | local N = index:dataspaceSize()[1] 73 | 74 | local istart = 1 75 | 76 | return function() 77 | if istart <= N then 78 | local iend = math.min(istart + batchSize - 1, N) 79 | local xs = {} 80 | local ys = {} 81 | for i = istart, iend do 82 | local idx = index:partial({i, i}, {1, 2}) 83 | local start, len = idx[1][1], idx[1][2] 84 | local x = x_data:partial({start, start + len - 1}, {1, 4}) 85 | local y = y_data:partial({start, start + len - 1}) 86 | table.insert(xs, x) 87 | table.insert(ys, y) 88 | end 89 | 90 | istart = iend + 1 91 | 92 | local x, y = self:toBatch(xs, ys, batchSize) 93 | if useNCE then 94 | local mask = y:ne(0):float() 95 | y[y:eq(0)] = 1 96 | local y_neg, y_prob, y_neg_prob = self.ncedata:getYNegProbs(y) 97 | return x, y, y_neg, y_prob, y_neg_prob, mask 98 | else 99 | return x, y 100 | end 101 | end 102 | end 103 | end 104 | 105 | function TreeLM_NCE_Dataset:close() 106 | self.h5in:close() 107 | end 108 | -------------------------------------------------------------------------------- /utils/shortcut.lua: -------------------------------------------------------------------------------- 1 | 2 | function printf(s, ...) 3 | return io.write(s:format(...)) 4 | end 5 | 6 | function xprint(s, ...) 7 | local ret = io.write(s:format(...)) 8 | io.flush() 9 | return ret 10 | end 11 | 12 | function xprintln(s, ...) 13 | return xprint(s .. '\n', ...) 14 | end 15 | 16 | -- time -- 17 | function readableTime(stime) 18 | local intervals = {1, 60, 3600} 19 | local units = {"s", "min", "h"} 20 | local i = 2 21 | while i <= #intervals do 22 | if stime < intervals[i] then 23 | break 24 | end 25 | i = i + 1 26 | end 27 | return string.format( '%.2f%s', stime/intervals[i-1], units[i-1] ) 28 | end 29 | 30 | -- for tables -- 31 | function table.keys(t) 32 | local ks = {} 33 | for k, _ in pairs(t) do 34 | ks[#ks + 1] = k 35 | end 36 | return ks 37 | end 38 | 39 | function table.len(t) 40 | local size = 0 41 | for _ in pairs(t) do 42 | size = size + 1 43 | end 44 | return size 45 | end 46 | 47 | -- for strings 48 | function xtoCharSet(s) 49 | local set = {} 50 | local i = 1 51 | local ch = '' 52 | while true do 53 | ch = s:sub(i, i) 54 | if ch == '' then break end 55 | set[ch] = true 56 | i = i + 1 57 | end 58 | return set 59 | end 60 | 61 | function string.splitc(s, cseps) 62 | local strs = {} 63 | local cset = xtoCharSet(cseps) 64 | local i, ch = 1, ' ' 65 | while ch ~= '' do 66 | while true do 67 | ch = s:sub(i, i) 68 | if ch == '' or not cset[ch] then break end 69 | i = i + 1 70 | end 71 | local chs = {} 72 | while true do 73 | ch = s:sub(i, i) 74 | if ch == '' or cset[ch] then break end 75 | chs[#chs + 1] = ch 76 | i = i + 1 77 | end 78 | if #chs > 0 then strs[#strs + 1] = table.concat(chs) end 79 | end 80 | 81 | return strs 82 | end 83 | 84 | function string.starts(s, pat) 85 | return s:sub(1, string.len(pat)) == pat 86 | end 87 | 88 | function string.ends(s, pat) 89 | return pat == '' or s:sub(-string.len(pat)) == pat 90 | end 91 | 92 | function string.trim(s) 93 | return s:match'^()%s*$' and '' or s:match'^%s*(.*%S)' 94 | end 95 | 96 | -- the following is for arrays -- 97 | function table.extend(a, b) 98 | for _, v in ipairs(b) do 99 | a[#a + 1] = v 100 | end 101 | return a 102 | end 103 | 104 | function table.subtable(t, istart, iend) 105 | local N = #t 106 | assert(istart <= iend and istart >= 1 and iend <= N, 107 | 'invalid istart or iend') 108 | local subT = {} 109 | for i = istart, iend do 110 | subT[#subT + 1] = t[i] 111 | end 112 | 113 | return subT 114 | end 115 | 116 | function table.contains(t, key) 117 | for _, v in ipairs(t) do 118 | if v == key then return true end 119 | end 120 | return false 121 | end 122 | 123 | function table.clear(t) 124 | for i, _ in ipairs(t) do 125 | t[i] = nil 126 | end 127 | end 128 | 129 | -- the following is for IOs -- 130 | function xreadlines(infile) 131 | local fin = io.open(infile, 'r') 132 | local lines = {} 133 | while true do 134 | local line = fin:read() 135 | if line == nil then break end 136 | lines[#lines + 1] = line 137 | end 138 | fin:close() 139 | 140 | return lines 141 | end 142 | 143 | function xcountlines(infile) 144 | local fin = io.open(infile, 'r') 145 | local cnt = 0 146 | while true do 147 | local line = fin:read() 148 | if line == nil then break end 149 | cnt = cnt + 1 150 | end 151 | fin:close() 152 | 153 | return cnt 154 | end 155 | 156 | 157 | 158 | 159 | -------------------------------------------------------------------------------- /scripts/conllx2hdf5.lua: -------------------------------------------------------------------------------- 1 | 2 | require 'hdf5' 3 | require 'torch' 4 | include '../utils/shortcut.lua' 5 | require 'deptreeutils' 6 | 7 | local function getOpts() 8 | local cmd = torch.CmdLine() 9 | cmd:text('== convert CoNLL-X dependency trees to hdf5 format ==') 10 | cmd:text() 11 | cmd:text('Options:') 12 | cmd:option('--train', '', 'train CoNLL X file') 13 | cmd:option('--valid', '', 'valid CoNLL X file') 14 | cmd:option('--test', '', 'test CoNLL X file') 15 | cmd:option('--dataset', '', 'the resulting dataset (.h5)') 16 | cmd:option('--freq', 0, 'words less than or equals to \"freq\" times will be replaced with UNK token') 17 | cmd:option('--ignoreCase', false, 'case will be ignored') 18 | cmd:option('--normalizeNumber', false, 'normalize numbers to ') 19 | cmd:option('--keepFreq', false, 'keep frequency information during creating vocabulary') 20 | cmd:option('--maxLen', 100, 'sentences longer than maxlen will be ignored!') 21 | cmd:option('--sort', 0, '0: no sorting of the training data; -1: sort training data by their length; k (k > 0): sort the consecutive k batches by their length') 22 | cmd:option('--batchSize', 64, 'batch size when --sort > 0 or --sort == -1') 23 | 24 | cmd:option('--bidirectional', false, 'create bidirectional model') 25 | 26 | return cmd:parse(arg) 27 | end 28 | 29 | local function main() 30 | local opts = getOpts() 31 | print(opts) 32 | local vocab = DepTreeUtils.createVocabCoNLLX(opts.train, opts.freq, opts.ignoreCase, opts.keepFreq, opts.normalizeNumber) 33 | assert(opts.dataset:ends('.h5'), 'dataset must be hdf5 file .h5') 34 | local dataPrefix = opts.dataset:sub(1, -4) 35 | local vocabPath = dataPrefix .. '.vocab.t7' 36 | printf('save vocab to %s\n', vocabPath) 37 | torch.save(vocabPath, vocab) 38 | 39 | local h5out = hdf5.open(opts.dataset, 'w') 40 | if opts.bidirectional then 41 | DepTreeUtils.conllx2hdf5Bidirectional(opts.train, h5out, 'train', vocab, opts.maxLen) 42 | else 43 | DepTreeUtils.conllx2hdf5(opts.train, h5out, 'train', vocab, opts.maxLen) 44 | end 45 | print('create training set done!') 46 | if opts.bidirectional then 47 | DepTreeUtils.conllx2hdf5Bidirectional(opts.valid, h5out, 'valid', vocab, opts.maxLen) 48 | else 49 | DepTreeUtils.conllx2hdf5(opts.valid, h5out, 'valid', vocab, opts.maxLen) 50 | end 51 | print('create validating set done!') 52 | if opts.bidirectional then 53 | DepTreeUtils.conllx2hdf5Bidirectional(opts.test, h5out, 'test', vocab, opts.maxLen) 54 | else 55 | DepTreeUtils.conllx2hdf5(opts.test, h5out, 'test', vocab, opts.maxLen) 56 | end 57 | print('create testing set done!') 58 | h5out:close() 59 | printf('save dataset to %s\n', opts.dataset) 60 | 61 | if opts.sort ~= 0 then 62 | assert(opts.sort == -1 or opts.sort > 0, 'valid values [0, -1, > 0]') 63 | print '========begin to sort dataset========' 64 | require 'sorthdf5' 65 | local h5sorter = nil 66 | local mid = opts.sort == -1 and 'sort' or string.format('sort%d', opts.sort) 67 | local h5sortFile = opts.dataset:sub(1, -4) .. string.format('.%s.h5', mid) 68 | local h5sortVocabFile = opts.dataset:sub(1, -4) .. string.format('.%s.vocab.t7', mid) 69 | local cmd = string.format('cp %s %s', vocabPath, h5sortVocabFile) 70 | print(cmd) 71 | os.execute(cmd) 72 | if opts.bidirectional then 73 | require 'sorthdf5bid' 74 | h5sorter = SortHDF5Bidirectional(opts.dataset, h5sortFile, h5sortVocabFile) 75 | else 76 | h5sorter = SortHDF5(opts.dataset, h5sortFile, h5sortVocabFile) 77 | end 78 | h5sorter:sortHDF5(opts.sort, opts.batchSize) 79 | printf('save dataset to %s\n', h5sortFile) 80 | end 81 | 82 | end 83 | 84 | main() 85 | -------------------------------------------------------------------------------- /scripts/sorthdf5bid.lua: -------------------------------------------------------------------------------- 1 | 2 | require 'torch' 3 | require 'hdf5' 4 | include '../utils/shortcut.lua' 5 | 6 | local SortHDF5Bid = torch.class('SortHDF5Bidirectional', 'SortHDF5') 7 | 8 | function SortHDF5Bid:writeSplit(splitLabel, idxs) 9 | print 'write bidirectional model!' 10 | local index = self.h5in:read(string.format('/%s/index', splitLabel)) 11 | local x_data = self.h5in:read(string.format('/%s/x_data', splitLabel)) 12 | local y_data = self.h5in:read(string.format('/%s/y_data', splitLabel)) 13 | local l_data = self.h5in:read(string.format('/%s/l_data', splitLabel)) 14 | local lindex = self.h5in:read(string.format('/%s/lindex', splitLabel)) 15 | 16 | local offset, isFirst = 1, true 17 | local x_ts = {} 18 | local y_ts = {} 19 | local i_ts = {} 20 | 21 | local loffset = 1 22 | local l_ts = {} 23 | local li_ts = {} 24 | 25 | local gxdata = string.format('/%s/x_data', splitLabel) 26 | local gydata = string.format('/%s/y_data', splitLabel) 27 | local gindex = string.format('/%s/index', splitLabel) 28 | 29 | local gldata = string.format('/%s/l_data', splitLabel) 30 | local glindex = string.format('/%s/lindex', splitLabel) 31 | 32 | local xOpt = hdf5.DataSetOptions() 33 | xOpt:setChunked(1024*50*10, 5) 34 | -- xOpt:setDeflate(1) 35 | local yOpt = hdf5.DataSetOptions() 36 | yOpt:setChunked(1024*50*10) 37 | -- yOpt:setDeflate(1) 38 | local iOpt = hdf5.DataSetOptions() 39 | iOpt:setChunked(1024*10, 2) 40 | -- iOpt:setDeflate(1) 41 | local lOpt = hdf5.DataSetOptions() 42 | lOpt:setChunked(1024*50*10, 2) 43 | local liOpt = hdf5.DataSetOptions(1024*10, 2) 44 | liOpt:setChunked(1024*10, 2) 45 | 46 | local function appendData() 47 | local x_data_ = torch.IntTensor(x_ts) 48 | local y_data_ = torch.IntTensor(y_ts) 49 | local index_ = torch.IntTensor(i_ts) 50 | local l_data_ = torch.IntTensor(l_ts) 51 | local lindex_ = torch.IntTensor(li_ts) 52 | if not isFirst then 53 | self.h5out:append(gxdata, x_data_, xOpt) 54 | self.h5out:append(gydata, y_data_, yOpt) 55 | self.h5out:append(gindex, index_, iOpt) 56 | self.h5out:append(gldata, l_data_, lOpt) 57 | self.h5out:append(glindex, lindex_, liOpt) 58 | else 59 | self.h5out:write(gxdata, x_data_, xOpt) 60 | self.h5out:write(gydata, y_data_, yOpt) 61 | self.h5out:write(gindex, index_, iOpt) 62 | self.h5out:write(gldata, l_data_, lOpt) 63 | self.h5out:write(glindex, lindex_, liOpt) 64 | isFirst = false 65 | end 66 | end 67 | 68 | for sentCount, i in ipairs(idxs) do 69 | local idx = index:partial({i, i}, {1, 2}) 70 | local start, len = idx[1][1], idx[1][2] 71 | local x = x_data:partial({start, start + len - 1}, {1, 5}) 72 | local y = y_data:partial({start, start + len - 1}) 73 | table.extend(x_ts, x:totable()) 74 | table.extend(y_ts, y:totable()) 75 | table.insert(i_ts, {offset, len}) 76 | 77 | local lidx = lindex:partial({i, i}, {1, 2}) 78 | local lstart, llen = lidx[1][1], lidx[1][2] 79 | if llen ~= 0 then 80 | local lc = l_data:partial({lstart, lstart + llen - 1}, {1, 2}) 81 | table.extend(l_ts, lc:totable()) 82 | end 83 | table.insert(li_ts, {loffset, llen}) 84 | 85 | if sentCount % 50000 == 0 then 86 | appendData() 87 | x_ts = {} 88 | y_ts = {} 89 | i_ts = {} 90 | l_ts = {} 91 | li_ts = {} 92 | printf('write [%s] line count = %d\n', splitLabel, sentCount) 93 | collectgarbage() 94 | end 95 | 96 | offset = offset + len 97 | loffset = loffset + llen 98 | end 99 | 100 | if #x_ts > 0 then 101 | appendData() 102 | end 103 | end 104 | 105 | -------------------------------------------------------------------------------- /scripts/deptree2hdf5.lua: -------------------------------------------------------------------------------- 1 | 2 | require 'hdf5' 3 | require 'torch' 4 | include '../utils/shortcut.lua' 5 | require 'deptreeutils' 6 | 7 | local function getOpts() 8 | local cmd = torch.CmdLine() 9 | cmd:text('== convert dependency trees to hdf5 format ==') 10 | cmd:text() 11 | cmd:text('Options:') 12 | cmd:option('--train', '', 'train text file') 13 | cmd:option('--valid', '', 'valid text file') 14 | cmd:option('--test', '', 'test text file') 15 | cmd:option('--dataset', '', 'the resulting dataset (.h5)') 16 | cmd:option('--freq', 0, 'words less than or equals to \"freq\" times will be replaced with UNK token') 17 | cmd:option('--ignoreCase', false, 'case will be ignored') 18 | cmd:option('--keepFreq', false, 'keep frequency information during creating vocabulary') 19 | cmd:option('--maxLen', 100, 'sentences longer than maxlen will be ignored!') 20 | cmd:option('--sort', 0, '0: no sorting of the training data; -1: sort training data by their length; k (k > 0): sort the consecutive k batches by their length') 21 | cmd:option('--batchSize', 64, 'batch size when --sort > 0 or --sort == -1') 22 | 23 | cmd:option('--bidirectional', false, 'create bidirectional model') 24 | 25 | return cmd:parse(arg) 26 | end 27 | 28 | local function main() 29 | local opts = getOpts() 30 | print(opts) 31 | local vocab = DepTreeUtils.createVocab(opts.train, opts.freq, opts.ignoreCase, opts.keepFreq) 32 | assert(opts.dataset:ends('.h5'), 'dataset must be hdf5 file .h5') 33 | local dataPrefix = opts.dataset:sub(1, -4) 34 | local vocabPath = dataPrefix .. '.vocab.t7' 35 | printf('save vocab to %s\n', vocabPath) 36 | torch.save(vocabPath, vocab) 37 | 38 | local h5out = hdf5.open(opts.dataset, 'w') 39 | if opts.bidirectional then 40 | DepTreeUtils.deptree2hdf5Bidirectional(opts.train, h5out, 'train', vocab, opts.maxLen) 41 | else 42 | DepTreeUtils.deptree2hdf5(opts.train, h5out, 'train', vocab, opts.maxLen) 43 | end 44 | print('create training set done!') 45 | if opts.bidirectional then 46 | DepTreeUtils.deptree2hdf5Bidirectional(opts.valid, h5out, 'valid', vocab, opts.maxLen) 47 | else 48 | DepTreeUtils.deptree2hdf5(opts.valid, h5out, 'valid', vocab, opts.maxLen) 49 | end 50 | print('create validating set done!') 51 | if opts.bidirectional then 52 | DepTreeUtils.deptree2hdf5Bidirectional(opts.test, h5out, 'test', vocab, opts.maxLen) 53 | else 54 | DepTreeUtils.deptree2hdf5(opts.test, h5out, 'test', vocab, opts.maxLen) 55 | end 56 | print('create testing set done!') 57 | h5out:close() 58 | printf('save dataset to %s\n', opts.dataset) 59 | 60 | return opts 61 | end 62 | 63 | local function sort(opts) 64 | local dataPrefix = opts.dataset:sub(1, -4) 65 | local vocabPath = dataPrefix .. '.vocab.t7' 66 | 67 | if opts.sort ~= 0 then 68 | assert(opts.sort == -1 or opts.sort > 0, 'valid values [0, -1, > 0]') 69 | print '========begin to sort dataset========' 70 | require 'sorthdf5' 71 | local h5sorter = nil 72 | local mid = opts.sort == -1 and 'sort' or string.format('sort%d', opts.sort) 73 | local h5sortFile = opts.dataset:sub(1, -4) .. string.format('.%s.h5', mid) 74 | local h5sortVocabFile = opts.dataset:sub(1, -4) .. string.format('.%s.vocab.t7', mid) 75 | local cmd = string.format('cp %s %s', vocabPath, h5sortVocabFile) 76 | print(cmd) 77 | os.execute(cmd) 78 | if opts.bidirectional then 79 | require 'sorthdf5bid' 80 | h5sorter = SortHDF5Bidirectional(opts.dataset, h5sortFile, h5sortVocabFile) 81 | else 82 | h5sorter = SortHDF5(opts.dataset, h5sortFile, h5sortVocabFile) 83 | end 84 | h5sorter:sortHDF5(opts.sort, opts.batchSize) 85 | printf('save dataset to %s\n', h5sortFile) 86 | end 87 | end 88 | 89 | local opts = main() 90 | -- local opts = getOpts() 91 | print(opts) 92 | sort(opts) 93 | -------------------------------------------------------------------------------- /dataset/BidTreeLM_Dataset.lua: -------------------------------------------------------------------------------- 1 | 2 | require 'torch' 3 | require 'hdf5' 4 | require 'shortcut' 5 | 6 | local BidTreeLM_Dataset = torch.class('BidTreeLM_Dataset') 7 | 8 | function BidTreeLM_Dataset:__init(datasetPath) 9 | self.vocab = torch.load(datasetPath:sub(1, -3) .. 'vocab.t7') 10 | xprintln('load vocab done!') 11 | self.h5in = hdf5.open(datasetPath, 'r') 12 | 13 | local function getLength(label) 14 | local index = self.h5in:read(string.format('/%s/index', label)) 15 | return index:dataspaceSize()[1] 16 | end 17 | self.trainSize = getLength('train') 18 | self.validSize = getLength('valid') 19 | self.testSize = getLength('test') 20 | xprintln('train size %d, valid size %d, test size %d', self.trainSize, self.validSize, self.testSize) 21 | xprintln('vocab size %d', self.vocab.nvocab) 22 | self.UNK = self.vocab.UNK 23 | xprintln('unknown word token is %d', self.UNK) 24 | end 25 | 26 | function BidTreeLM_Dataset:getVocabSize() 27 | return self.vocab.nvocab 28 | end 29 | 30 | function BidTreeLM_Dataset:getTrainSize() 31 | return self.trainSize 32 | end 33 | 34 | function BidTreeLM_Dataset:getValidSize() 35 | return self.validSize 36 | end 37 | 38 | function BidTreeLM_Dataset:getTestSize() 39 | return self.testSize 40 | end 41 | 42 | function BidTreeLM_Dataset:toBatch(xs, ys, lcs, batchSize) 43 | local dtype = 'torch.LongTensor' 44 | local maxn = 0 45 | for _, y_ in ipairs(ys) do 46 | if y_:size(1) > maxn then 47 | maxn = y_:size(1) 48 | end 49 | end 50 | local x = torch.ones(maxn, batchSize, 5):type(dtype) 51 | x:mul(self.UNK) 52 | x[{ {}, {}, 4 }] = torch.linspace(2, maxn + 1, maxn):resize(maxn, 1):expand(maxn, batchSize) 53 | x[{ {}, {}, 5 }] = 0 -- in default, I don't want them to have 54 | local nsent = #ys 55 | local y = torch.zeros(maxn, batchSize):type(dtype) 56 | for i = 1, nsent do 57 | local sx, sy = xs[i], ys[i] 58 | x[{ {1, sx:size(1)}, i, {} }] = sx 59 | y[{ {1, sy:size(1)}, i }] = sy 60 | end 61 | 62 | -- for left children 63 | assert(#lcs == #xs, 'should be the same!') 64 | local lcBatchSize = 0 65 | local maxLcSeqLen = 0 66 | for _, lc in ipairs(lcs) do 67 | if lc:dim() ~= 0 then 68 | lcBatchSize = lcBatchSize + 1 69 | maxLcSeqLen = math.max(maxLcSeqLen, lc:size(1)) 70 | end 71 | end 72 | local lchild = torch.Tensor():type(dtype) 73 | local lc_mask = torch.FloatTensor() 74 | 75 | if lcBatchSize ~= 0 then 76 | lchild:resize(maxLcSeqLen, lcBatchSize):fill(self.UNK) 77 | lc_mask:resize(maxLcSeqLen, lcBatchSize):fill(0) 78 | local j = 0 79 | for i, lc in ipairs(lcs) do 80 | if lc:dim() ~= 0 then 81 | j = j + 1 82 | lchild[{ {1, lc:size(1)}, j }] = lc[{ {}, 1 }] 83 | lc_mask[{ {1, lc:size(1)}, j }] = lc[{ {}, 2 }] + 1 84 | local xcol = x[{ {}, i, 5 }] 85 | local idxs = xcol:ne(0) 86 | xcol[idxs] = (xcol[idxs] - 1) * lcBatchSize + j 87 | end 88 | end 89 | end 90 | 91 | return x, y, lchild, lc_mask 92 | end 93 | 94 | function BidTreeLM_Dataset:createBatch(label, batchSize) 95 | local h5in = self.h5in 96 | local x_data = h5in:read(string.format('/%s/x_data', label)) 97 | local y_data = h5in:read(string.format('/%s/y_data', label)) 98 | local index = h5in:read(string.format('/%s/index', label)) 99 | local l_data = h5in:read( string.format('/%s/l_data', label) ) 100 | local lindex = h5in:read( string.format('/%s/lindex', label) ) 101 | local N = index:dataspaceSize()[1] 102 | 103 | local istart = 1 104 | 105 | return function() 106 | if istart <= N then 107 | local iend = math.min(istart + batchSize - 1, N) 108 | local xs = {} 109 | local ys = {} 110 | local lcs = {} 111 | for i = istart, iend do 112 | local idx = index:partial({i, i}, {1, 2}) 113 | local start, len = idx[1][1], idx[1][2] 114 | local x = x_data:partial({start, start + len - 1}, {1, 5}) 115 | local y = y_data:partial({start, start + len - 1}) 116 | table.insert(xs, x) 117 | table.insert(ys, y) 118 | 119 | local lidx = lindex:partial({i, i}, {1, 2}) 120 | local lstart, llen = lidx[1][1], lidx[1][2] 121 | local lc 122 | if llen == 0 then 123 | lc = torch.IntTensor() -- to be the same type as l_data 124 | else 125 | lc = l_data:partial({lstart, lstart + llen - 1}, {1, 2}) 126 | end 127 | -- print(lc) 128 | table.insert(lcs, lc) 129 | end 130 | 131 | istart = iend + 1 132 | 133 | return self:toBatch(xs, ys, lcs, batchSize) 134 | end 135 | end 136 | end 137 | 138 | function BidTreeLM_Dataset:close() 139 | self.h5in:close() 140 | end 141 | -------------------------------------------------------------------------------- /dataset/BidTreeLM_NCE_Dataset.lua: -------------------------------------------------------------------------------- 1 | 2 | require 'torch' 3 | require 'hdf5' 4 | require 'shortcut' 5 | require 'NCEDataGenerator' 6 | 7 | local BidTreeLM_NCE_Dataset = torch.class('BidTreeLM_NCE_Dataset') 8 | 9 | function BidTreeLM_NCE_Dataset:__init(datasetPath, nneg, power, normalizeUNK) 10 | self.vocab = torch.load(datasetPath:sub(1, -3) .. 'vocab.t7') 11 | xprintln('load vocab done!') 12 | self.h5in = hdf5.open(datasetPath, 'r') 13 | 14 | local function getLength(label) 15 | local index = self.h5in:read(string.format('/%s/index', label)) 16 | return index:dataspaceSize()[1] 17 | end 18 | self.trainSize = getLength('train') 19 | self.validSize = getLength('valid') 20 | self.testSize = getLength('test') 21 | xprintln('train size %d, valid size %d, test size %d', self.trainSize, self.validSize, self.testSize) 22 | xprintln('vocab size %d', self.vocab.nvocab) 23 | self.UNK = self.vocab.UNK 24 | xprintln('unknown word token is %d', self.UNK) 25 | 26 | self.ncedata = NCEDataGenerator(self.vocab, nneg, power, normalizeUNK) 27 | end 28 | 29 | function BidTreeLM_NCE_Dataset:getVocabSize() 30 | return self.vocab.nvocab 31 | end 32 | 33 | function BidTreeLM_NCE_Dataset:getTrainSize() 34 | return self.trainSize 35 | end 36 | 37 | function BidTreeLM_NCE_Dataset:getValidSize() 38 | return self.validSize 39 | end 40 | 41 | function BidTreeLM_NCE_Dataset:getTestSize() 42 | return self.testSize 43 | end 44 | 45 | function BidTreeLM_NCE_Dataset:toBatch(xs, ys, lcs, batchSize) 46 | local dtype = 'torch.LongTensor' 47 | local maxn = 0 48 | for _, y_ in ipairs(ys) do 49 | if y_:size(1) > maxn then 50 | maxn = y_:size(1) 51 | end 52 | end 53 | local x = torch.ones(maxn, batchSize, 5):type(dtype) 54 | x:mul(self.UNK) 55 | x[{ {}, {}, 4 }] = torch.linspace(2, maxn + 1, maxn):resize(maxn, 1):expand(maxn, batchSize) 56 | x[{ {}, {}, 5 }] = 0 -- in default, I don't want them to have 57 | local nsent = #ys 58 | local y = torch.zeros(maxn, batchSize):type(dtype) 59 | for i = 1, nsent do 60 | local sx, sy = xs[i], ys[i] 61 | x[{ {1, sx:size(1)}, i, {} }] = sx 62 | y[{ {1, sy:size(1)}, i }] = sy 63 | end 64 | 65 | -- for left children 66 | assert(#lcs == #xs, 'should be the same!') 67 | local lcBatchSize = 0 68 | local maxLcSeqLen = 0 69 | for _, lc in ipairs(lcs) do 70 | if lc:dim() ~= 0 then 71 | lcBatchSize = lcBatchSize + 1 72 | maxLcSeqLen = math.max(maxLcSeqLen, lc:size(1)) 73 | end 74 | end 75 | local lchild = torch.Tensor():type(dtype) 76 | local lc_mask = torch.FloatTensor() 77 | 78 | if lcBatchSize ~= 0 then 79 | lchild:resize(maxLcSeqLen, lcBatchSize):fill(self.UNK) 80 | lc_mask:resize(maxLcSeqLen, lcBatchSize):fill(0) 81 | local j = 0 82 | for i, lc in ipairs(lcs) do 83 | if lc:dim() ~= 0 then 84 | j = j + 1 85 | lchild[{ {1, lc:size(1)}, j }] = lc[{ {}, 1 }] 86 | lc_mask[{ {1, lc:size(1)}, j }] = lc[{ {}, 2 }] + 1 87 | local xcol = x[{ {}, i, 5 }] 88 | local idxs = xcol:ne(0) 89 | xcol[idxs] = (xcol[idxs] - 1) * lcBatchSize + j 90 | end 91 | end 92 | end 93 | 94 | return x, y, lchild, lc_mask 95 | end 96 | 97 | function BidTreeLM_NCE_Dataset:createBatch(label, batchSize, useNCE) 98 | local h5in = self.h5in 99 | local x_data = h5in:read(string.format('/%s/x_data', label)) 100 | local y_data = h5in:read(string.format('/%s/y_data', label)) 101 | local index = h5in:read(string.format('/%s/index', label)) 102 | local l_data = h5in:read( string.format('/%s/l_data', label) ) 103 | local lindex = h5in:read( string.format('/%s/lindex', label) ) 104 | local N = index:dataspaceSize()[1] 105 | 106 | local istart = 1 107 | 108 | return function() 109 | if istart <= N then 110 | local iend = math.min(istart + batchSize - 1, N) 111 | local xs = {} 112 | local ys = {} 113 | local lcs = {} 114 | for i = istart, iend do 115 | local idx = index:partial({i, i}, {1, 2}) 116 | local start, len = idx[1][1], idx[1][2] 117 | local x = x_data:partial({start, start + len - 1}, {1, 5}) 118 | local y = y_data:partial({start, start + len - 1}) 119 | table.insert(xs, x) 120 | table.insert(ys, y) 121 | 122 | local lidx = lindex:partial({i, i}, {1, 2}) 123 | local lstart, llen = lidx[1][1], lidx[1][2] 124 | local lc 125 | if llen == 0 then 126 | lc = torch.IntTensor() -- to be the same type as l_data 127 | else 128 | lc = l_data:partial({lstart, lstart + llen - 1}, {1, 2}) 129 | end 130 | table.insert(lcs, lc) 131 | end 132 | 133 | istart = iend + 1 134 | 135 | local x, y, lchild, lc_mask = self:toBatch(xs, ys, lcs, batchSize) 136 | if useNCE then 137 | local mask = y:ne(0):float() 138 | y[y:eq(0)] = 1 139 | local y_neg, y_prob, y_neg_prob = self.ncedata:getYNegProbs(y) 140 | return x, y, lchild, lc_mask, y_neg, y_prob, y_neg_prob, mask 141 | else 142 | return x, y, lchild, lc_mask 143 | end 144 | end 145 | end 146 | end 147 | 148 | function BidTreeLM_NCE_Dataset:close() 149 | self.h5in:close() 150 | end 151 | -------------------------------------------------------------------------------- /scripts/depgraph.lua: -------------------------------------------------------------------------------- 1 | ------------------------------------ 2 | ---- Inside the dependency tree ---- 3 | ------------------------------------ 4 | 5 | include '../utils/xqueue.lua' 6 | require 'dgvertex' 7 | require 'dgedge' 8 | 9 | local DepGraph = torch.class('DepGraph') 10 | 11 | DepGraph.actions = {JL = 1, JR = 2, JLF = 3, JRF = 4} 12 | 13 | -- build dependency graph 14 | -- return true mean build successfully 15 | function DepGraph:build(depWords) 16 | -- local depWords = DepTreeUtils.parseDepStr(depStr) 17 | local maxID = -1 18 | for _, dword in ipairs(depWords) do 19 | maxID = math.max(maxID, math.max(dword.p1, dword.p2)) 20 | end 21 | self.size = maxID + 1 22 | self.vertices = {} 23 | for i = 1, self.size do self.vertices[i] = DGVertex() end 24 | for _, dword in ipairs(depWords) do 25 | local name, u, v = dword.rel, dword.p1 + 1, dword.p2 + 1 26 | local e = DGEdge(u, v, name) 27 | self.vertices[u].v = u 28 | self.vertices[u].tok = dword.w1 29 | self.vertices[v].v = v 30 | self.vertices[v].tok = dword.w2 31 | table.insert(self.vertices[u].adjList, e) 32 | 33 | if u == 1 and dword.w1 == DepTreeUtils.ROOT_MARK then 34 | self.root = u 35 | end 36 | end 37 | 38 | if self.root ~= 1 then return false end 39 | for i = 1, self.size do 40 | if self.vertices[i]:isEmpty() then return false end 41 | end 42 | 43 | return true 44 | end 45 | 46 | -- sort children of one node according to their distance 47 | function DepGraph:sortChildren() 48 | local Q = XQueue(self.size) 49 | Q:push({self.root, 0}) 50 | while not Q:isEmpty() do 51 | local u, level = unpack( Q:pop() ) 52 | local curV = self.vertices[u] 53 | -- print(curV.v, curV.tok, level) 54 | curV.leftChildren = {} 55 | curV.rightChildren = {} 56 | for _, e in ipairs( curV.adjList ) do 57 | local v = e.v 58 | local nxV = self.vertices[v] 59 | if nxV.v < u then 60 | table.insert(curV.leftChildren, nxV) 61 | else 62 | table.insert(curV.rightChildren, nxV) 63 | end 64 | end 65 | table.sort(curV.leftChildren, function(vet1, vet2) 66 | return vet1.v > vet2.v 67 | end) 68 | table.sort(curV.rightChildren, function(vet1, vet2) 69 | return vet1.v < vet2.v 70 | end) 71 | for _, ch in ipairs(curV.leftChildren) do 72 | Q:push({ch.v, level + 1}) 73 | end 74 | for _, ch in ipairs(curV.rightChildren) do 75 | Q:push({ch.v, level + 1}) 76 | end 77 | end 78 | end 79 | 80 | function DepGraph:getLinearRepr() 81 | local lrepr = {} 82 | local Q = XQueue(self.size) 83 | Q:push(self.vertices[self.root]) 84 | local bfsOrder = 1 85 | while not Q:isEmpty() do 86 | local u = Q:pop() 87 | u.bfsID = bfsOrder 88 | bfsOrder = bfsOrder + 1 89 | if u.v ~= self.root then 90 | table.insert(lrepr, {u.dependencyVertex.tok, 91 | u.action, u.dependencyVertex.bfsID, u.bfsID, u.tok}) 92 | end 93 | -- assign dependency path for left children 94 | for i, v in ipairs(u.leftChildren) do 95 | if i == 1 then 96 | v.dependencyVertex = u 97 | v.action = DepGraph.actions.JL 98 | else 99 | v.dependencyVertex = u.leftChildren[i - 1] 100 | v.action = DepGraph.actions.JLF 101 | end 102 | Q:push(v) 103 | end 104 | -- assign dependency path for right children 105 | for i, v in ipairs(u.rightChildren) do 106 | if i == 1 then 107 | v.dependencyVertex = u 108 | v.action = DepGraph.actions.JR 109 | else 110 | v.dependencyVertex = u.rightChildren[i - 1] 111 | v.action = DepGraph.actions.JRF 112 | end 113 | Q:push(v) 114 | end 115 | end 116 | 117 | return lrepr 118 | end 119 | 120 | -- get bi-directional linear represetation 121 | function DepGraph:getBidirectionalLinearRepr() 122 | local lrepr = {} 123 | local lchildren = {} 124 | 125 | local Q = XQueue(self.size) 126 | Q:push(self.vertices[self.root]) 127 | local bfsOrder = 1 128 | while not Q:isEmpty() do 129 | local u = Q:pop() 130 | u.bfsID = bfsOrder 131 | bfsOrder = bfsOrder + 1 132 | if u.v ~= self.root then 133 | table.insert(lrepr, {u.dependencyVertex.tok, 134 | u.action, u.dependencyVertex.bfsID, u.bfsID, u.leftCxtPos, u.tok}) 135 | end 136 | -- assign dependency path for left children 137 | for i, v in ipairs(u.leftChildren) do 138 | if i == 1 then 139 | v.dependencyVertex = u 140 | v.action = DepGraph.actions.JL 141 | else 142 | v.dependencyVertex = u.leftChildren[i - 1] 143 | v.action = DepGraph.actions.JLF 144 | end 145 | Q:push(v) 146 | end 147 | -- assign dependency path for right children 148 | for i, v in ipairs(u.rightChildren) do 149 | if i == 1 then 150 | v.dependencyVertex = u 151 | v.action = DepGraph.actions.JR 152 | if #u.leftChildren ~= 0 then 153 | for j = #u.leftChildren, 1, -1 do 154 | local lc = u.leftChildren[j] 155 | lchildren[#lchildren + 1] = {lc.tok, j == #u.leftChildren and 1 or 0} 156 | end 157 | v.leftCxtPos = #lchildren 158 | end 159 | else 160 | v.dependencyVertex = u.rightChildren[i - 1] 161 | v.action = DepGraph.actions.JRF 162 | end 163 | Q:push(v) 164 | end 165 | end 166 | 167 | return lrepr, lchildren 168 | end 169 | 170 | 171 | -------------------------------------------------------------------------------- /layers/NCE2.lua: -------------------------------------------------------------------------------- 1 | 2 | -- normalizing constant is learned automatically 3 | local NCE, parent = torch.class('NCE', 'nn.LookupTable') 4 | 5 | function NCE:__init(inputSize, outputSize, Z) 6 | parent.__init(self, outputSize, inputSize) 7 | 8 | self.bias = torch.Tensor(1) 9 | self.gradBias = torch.Tensor(1) 10 | 11 | print('self.bias is acutally Z') 12 | print('NCE is from nn.LookupTable') 13 | end 14 | 15 | function NCE:updateOutput(input) 16 | -- hs: 2D, y: 1D, y_neg: 2D (x, n_neg), y_prob: 1D, y_neg_prob: 2D (x, n_neg) 17 | local hs, y, y_neg, y_prob, y_neg_prob = unpack(input) 18 | local Who = self.weight 19 | if hs:dim() == 2 then 20 | 21 | self.Z = self.bias[1] 22 | 23 | -- compute non-normalized softmax for y 24 | self.We_out = Who:index(1, y) 25 | local We_out = self.We_out 26 | local pos_a = torch.cmul(We_out, hs):sum(2) 27 | local p_rnn_pos = pos_a:exp():div(self.Z) 28 | local k = y_neg:size(2) 29 | -- local y_prob_2d = y_prob:view(y_prob:size(1), 1) 30 | self.P_pos = torch.cdiv( p_rnn_pos, (p_rnn_pos + y_prob * k) ) -- P_pos shape (seqlen * bs, 1) 31 | local P_pos = self.P_pos 32 | local log_P_pos = torch.log(P_pos) 33 | 34 | -- compute non-normalized softmax for negative examples of y, y_neg 35 | local y_neg_ = y_neg:view(y_neg:size(1) * y_neg:size(2)) 36 | local We_out_n_ = Who:index(1, y_neg_) 37 | local n_hid = Who:size(2) 38 | self.We_out_n = We_out_n_:view( y_neg:size(1), y_neg:size(2), n_hid ) 39 | local We_out_n = self.We_out_n 40 | local neg_a = torch.cmul( We_out_n, hs:view(hs:size(1), 1, hs:size(2)):expand(hs:size(1), y_neg:size(2), hs:size(2)) ):sum(3) 41 | local p_rnn_neg = neg_a:exp():div(self.Z) 42 | local k_y_neg_prob = y_neg_prob * k 43 | self.P_neg = torch.cdiv( k_y_neg_prob, (p_rnn_neg + k_y_neg_prob) ) 44 | local P_neg = self.P_neg 45 | local log_P_neg = torch.log(P_neg) 46 | 47 | self.output = log_P_pos + log_P_neg:sum(2) 48 | 49 | return self.output 50 | else 51 | error('input must be 2D matrix, currently only support batch mode') 52 | end 53 | end 54 | 55 | function NCE:updateGradInput(input, gradOutput) 56 | -- hs: 2D, y: 1D, y_neg: 2D (x, n_neg), y_prob: 1D, y_neg_prob: 2D (x, n_neg) 57 | local hs, y, y_neg, y_prob, y_neg_prob = unpack(input) 58 | 59 | -- gradOutput: is the scale of the gradients, gradOutput can contain 0s; 60 | -- that is to say gradOutput can also be served as mask; shape: (bs*seq, 1) 61 | 62 | if self.gradInput then 63 | -- I can't see why self.gradInput:zero() is useful 64 | local nElement = self.gradInput:nElement() 65 | self.gradInput:resizeAs(hs) 66 | if self.gradInput:nElement() ~= nElement then 67 | self.gradInput:zero() 68 | end 69 | 70 | if hs:dim() == 2 then 71 | -- gradients from the positive samples 72 | -- take mask (gradOutput) into account 73 | 74 | self.d_P_pos = torch.cmul( (-self.P_pos + 1), gradOutput ) 75 | local d_P_pos = self.d_P_pos 76 | self.gradInput:cmul( self.We_out, d_P_pos:expand(self.P_pos:size(1), hs:size(2)) ) 77 | 78 | -- gradients from the negative samples 79 | -- take (gradOutput) into account 80 | self.d_P_neg = torch.cmul( (self.P_neg - 1), gradOutput:expand(gradOutput:size(1), self.P_neg:size(2)) ) 81 | local d_P_neg = self.d_P_neg 82 | local d_hs = self.We_out_n:cmul( 83 | d_P_neg:view(d_P_neg:size(1), d_P_neg:size(2), 1):expand(d_P_neg:size(1), d_P_neg:size(2), hs:size(2)) 84 | ) 85 | self.gradInput:add(d_hs:sum(2)) 86 | 87 | return {self.gradInput} 88 | else 89 | error('input must be 2D matrix, currently only support batch mode') 90 | end 91 | end 92 | 93 | end 94 | 95 | function NCE:accGradParameters(input, gradOutput) 96 | -- hs: 2D, y: 1D, y_neg: 2D (x, n_neg), y_prob: 1D, y_neg_prob: 2D (x, n_neg) 97 | local hs, y, y_neg, y_prob, y_neg_prob = unpack(input) 98 | 99 | self:backCompatibility() 100 | 101 | if hs:dim() == 2 then 102 | local d_P_pos = self.d_P_pos 103 | self.gradBias:add( (-d_P_pos / self.Z):sum() ) 104 | local gradWeight_pos = torch.cmul( hs, d_P_pos:expand(self.P_pos:size(1), hs:size(2)) ) 105 | 106 | y = self:makeInputContiguous(y) 107 | y = self.copiedInput and self._input or y 108 | self.gradWeight.nn.LookupTable_accGradParameters(self, y, gradWeight_pos, 1) 109 | 110 | local d_P_neg = self.d_P_neg 111 | self.gradBias:add( (-d_P_neg / self.Z):sum() ) 112 | local gradWeight_neg = torch.cmul( 113 | hs:view(hs:size(1), 1, hs:size(2)):expand(hs:size(1), y_neg:size(2), hs:size(2)), 114 | d_P_neg:view(d_P_neg:size(1), d_P_neg:size(2), 1):expand(d_P_neg:size(1), d_P_neg:size(2), hs:size(2)) 115 | ) 116 | 117 | y_neg = self:makeInputContiguous(y_neg) 118 | y_neg = self.copiedInput and self._input or y_neg 119 | self.gradWeight.nn.LookupTable_accGradParameters(self, y_neg:view(-1), gradWeight_neg, 1) 120 | 121 | self.d_P_pos = nil 122 | self.d_P_neg = nil 123 | else 124 | error('input must be 2D matrix, currently only support batch mode') 125 | end 126 | 127 | end 128 | 129 | 130 | -- we do not need to accumulate parameters when sharing 131 | NCE.sharedAccUpdateGradParameters = NCE.accUpdateGradParameters 132 | 133 | function NCE:__tostring__() 134 | return torch.type(self) .. 135 | string.format('(%d -> %d)', self.weight:size(2), self.weight:size(1)) 136 | end 137 | -------------------------------------------------------------------------------- /layers/NCE0.lua: -------------------------------------------------------------------------------- 1 | 2 | local NCE, parent = torch.class('NCE', 'nn.Module') 3 | 4 | function NCE:__init(inputSize, outputSize, Z) 5 | parent.__init(self) 6 | 7 | self.weight = torch.Tensor(outputSize, inputSize) 8 | -- self.bias = torch.Tensor(outputSize) 9 | self.gradWeight = torch.Tensor(outputSize, inputSize) 10 | -- self.gradBias = torch.Tensor(outputSize) 11 | self.Z = Z 12 | end 13 | 14 | function NCE:updateOutput(input) 15 | -- hs: 2D, y: 1D, y_neg: 2D (x, n_neg), y_prob: 1D, y_neg_prob: 2D (x, n_neg) 16 | local hs, y, y_neg, y_prob, y_neg_prob = unpack(input) 17 | local Who = self.weight 18 | if hs:dim() == 2 then 19 | -- compute non-normalized softmax for y 20 | self.We_out = Who:index(1, y) 21 | local We_out = self.We_out 22 | local pos_a = torch.cmul(We_out, hs):sum(2) 23 | local p_rnn_pos = pos_a:exp():div(self.Z) 24 | local k = y_neg:size(2) 25 | -- local y_prob_2d = y_prob:view(y_prob:size(1), 1) 26 | self.P_pos = torch.cdiv( p_rnn_pos, (p_rnn_pos + y_prob * k) ) -- P_pos shape (seqlen * bs, 1) 27 | local P_pos = self.P_pos 28 | local log_P_pos = torch.log(P_pos) 29 | 30 | -- compute non-normalized softmax for negative examples of y, y_neg 31 | local y_neg_ = y_neg:view(y_neg:size(1) * y_neg:size(2)) 32 | local We_out_n_ = Who:index(1, y_neg_) 33 | local n_hid = Who:size(2) 34 | self.We_out_n = We_out_n_:view( y_neg:size(1), y_neg:size(2), n_hid ) 35 | local We_out_n = self.We_out_n 36 | local neg_a = torch.cmul( We_out_n, hs:view(hs:size(1), 1, hs:size(2)):expand(hs:size(1), y_neg:size(2), hs:size(2)) ):sum(3) 37 | local p_rnn_neg = neg_a:exp():div(self.Z) 38 | local k_y_neg_prob = y_neg_prob * k 39 | self.P_neg = torch.cdiv( k_y_neg_prob, (p_rnn_neg + k_y_neg_prob) ) 40 | local P_neg = self.P_neg 41 | local log_P_neg = torch.log(P_neg) 42 | 43 | self.output = log_P_pos + log_P_neg:sum(2) 44 | -- print('NCE updateOutput self.output size') 45 | -- print(self.output:size()) 46 | 47 | return self.output 48 | else 49 | error('input must be 2D matrix, currently only support batch mode') 50 | end 51 | end 52 | 53 | function NCE:updateGradInput(input, gradOutput) 54 | -- hs: 2D, y: 1D, y_neg: 2D (x, n_neg), y_prob: 1D, y_neg_prob: 2D (x, n_neg) 55 | local hs, y, y_neg, y_prob, y_neg_prob = unpack(input) 56 | 57 | -- gradOutput: is the scale of the gradients, gradOutput can contain 0s; 58 | -- that is to say gradOutput can also be served as mask; shape: (bs*seq, 1) 59 | 60 | if self.gradInput then 61 | -- I can't see why self.gradInput:zero() is useful 62 | local nElement = self.gradInput:nElement() 63 | self.gradInput:resizeAs(hs) 64 | if self.gradInput:nElement() ~= nElement then 65 | self.gradInput:zero() 66 | end 67 | 68 | if hs:dim() == 2 then 69 | -- gradients from the positive samples 70 | -- take mask (gradOutput) into account 71 | 72 | --[[ 73 | print('size of gradOutput') 74 | print(gradOutput:size()) 75 | print(gradOutput) 76 | print('size of self.P_pos') 77 | print(self.P_pos:size()) 78 | --]] 79 | 80 | self.d_P_pos = torch.cmul( (-self.P_pos + 1), gradOutput ) 81 | local d_P_pos = self.d_P_pos 82 | self.gradInput:cmul( self.We_out, d_P_pos:expand(self.P_pos:size(1), hs:size(2)) ) 83 | 84 | -- gradients from the negative samples 85 | -- take (gradOutput) into account 86 | self.d_P_neg = torch.cmul( (self.P_neg - 1), gradOutput:expand(gradOutput:size(1), self.P_neg:size(2)) ) 87 | local d_P_neg = self.d_P_neg 88 | local d_hs = self.We_out_n:cmul( 89 | d_P_neg:view(d_P_neg:size(1), d_P_neg:size(2), 1):expand(d_P_neg:size(1), d_P_neg:size(2), hs:size(2)) 90 | ) 91 | self.gradInput:add(d_hs:sum(2)) 92 | 93 | return {self.gradInput} 94 | else 95 | error('input must be 2D matrix, currently only support batch mode') 96 | end 97 | end 98 | 99 | end 100 | 101 | function NCE:accGradParameters(input, gradOutput) 102 | -- hs: 2D, y: 1D, y_neg: 2D (x, n_neg), y_prob: 1D, y_neg_prob: 2D (x, n_neg) 103 | local hs, y, y_neg, y_prob, y_neg_prob = unpack(input) 104 | 105 | if hs:dim() == 2 then 106 | local d_P_pos = self.d_P_pos 107 | local gradWeight_pos = torch.cmul( hs, d_P_pos:expand(self.P_pos:size(1), hs:size(2)) ) 108 | local n_pos = y:size(1) 109 | 110 | for i = 1, n_pos do 111 | self.gradWeight[{ y[i], {} }]:add( gradWeight_pos[{ i, {} }] ) 112 | end 113 | 114 | local d_P_neg = self.d_P_neg 115 | local gradWeight_neg = torch.cmul( 116 | hs:view(hs:size(1), 1, hs:size(2)):expand(hs:size(1), y_neg:size(2), hs:size(2)), 117 | d_P_neg:view(d_P_neg:size(1), d_P_neg:size(2), 1):expand(d_P_neg:size(1), d_P_neg:size(2), hs:size(2)) 118 | ) 119 | local n1 = y_neg:size(1) 120 | local n2 = y_neg:size(2) 121 | 122 | for i = 1, n1 do 123 | for j = 1, n2 do 124 | self.gradWeight[{ y_neg[{ i, j }], {} }]:add( gradWeight_neg[{ i, j, {} }] ) 125 | end 126 | end 127 | 128 | self.d_P_pos = nil 129 | self.d_P_neg = nil 130 | else 131 | error('input must be 2D matrix, currently only support batch mode') 132 | end 133 | 134 | -- print('accGradParameters safe!') 135 | 136 | end 137 | 138 | -- we do not need to accumulate parameters when sharing 139 | NCE.sharedAccUpdateGradParameters = NCE.accUpdateGradParameters 140 | 141 | function NCE:__tostring__() 142 | return torch.type(self) .. 143 | string.format('(%d -> %d)', self.weight:size(2), self.weight:size(1)) 144 | end 145 | -------------------------------------------------------------------------------- /gpu_lock.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | 3 | """ 4 | A simple discretionary locking system for /dev/nvidia devices. 5 | 6 | Iain Murray, November 2009, January 2010, January 2011. 7 | """ 8 | 9 | import os, os.path 10 | 11 | _dev_prefix = '/dev/nvidia' 12 | #URL = 'http://www.cs.toronto.edu/~murray/code/gpu_monitoring/' 13 | URL = 'http://homepages.inf.ed.ac.uk/imurray2/code/gpu_monitoring/' 14 | 15 | 16 | # Get ID's of NVIDIA boards. Should do this through a CUDA call, but this is 17 | # a quick and dirty way that works for now: 18 | def board_ids(): 19 | """Returns integer board ids available on this machine.""" 20 | from glob import glob 21 | board_devs = glob(_dev_prefix + '[0-9]*') 22 | return range(len(board_devs)) 23 | 24 | def _lock_file(id): 25 | """lock file from integer id""" 26 | # /tmp is cleared on reboot on many systems, but it doesn't have to be 27 | if os.path.exists('/dev/shm'): 28 | # /dev/shm on linux machines is a RAM disk, so is definitely cleared 29 | return '/dev/shm/gpu_lock_%d' % id 30 | else: 31 | return '/tmp/gpu_lock_%d' % id 32 | 33 | def owner_of_lock(id): 34 | """Username that has locked the device id. (Empty string if no lock).""" 35 | import pwd 36 | try: 37 | statinfo = os.lstat(_lock_file(id)) 38 | return pwd.getpwuid(statinfo.st_uid).pw_name 39 | except: 40 | return "" 41 | 42 | def _obtain_lock(id): 43 | """Attempts to lock id, returning success as True/False.""" 44 | try: 45 | # On POSIX systems symlink creation is atomic, so this should be a 46 | # robust locking operation: 47 | os.symlink('/dev/null', _lock_file(id)) 48 | return True 49 | except: 50 | return False 51 | 52 | def _launch_reaper(id, pid): 53 | """Start a process that will free a lock when process pid terminates""" 54 | from subprocess import Popen, PIPE 55 | me = __file__ 56 | if me.endswith('.pyc'): 57 | me = me[:-1] 58 | myloc = os.path.dirname(me) 59 | if not myloc: 60 | myloc = os.getcwd() 61 | reaper_cmd = os.path.join(myloc, 'run_on_me_or_pid_quit') 62 | Popen([reaper_cmd, str(pid), me, '--free', str(id)], 63 | stdout=open('/dev/null', 'w')) 64 | 65 | def obtain_lock_id(pid = None): 66 | """ 67 | Finds a free id, locks it and returns integer id, or -1 if none free. 68 | 69 | A process is spawned that will free the lock automatically when the 70 | process pid (by default the current python process) terminates. 71 | """ 72 | id = -1 73 | id = obtain_lock_id_to_hog() 74 | try: 75 | if id >= 0: 76 | if pid is None: 77 | pid = os.getpid() 78 | _launch_reaper(id, pid) 79 | except: 80 | free_lock(id) 81 | id = -1 82 | return id 83 | 84 | def obtain_lock_id_to_hog(gpu_id = None): 85 | """ 86 | Finds a free id, locks it and returns integer id, or -1 if none free. 87 | 88 | * Lock must be freed manually * 89 | """ 90 | if gpu_id: 91 | id = gpu_id 92 | if _obtain_lock(id): 93 | return id 94 | else: 95 | for id in board_ids(): 96 | if _obtain_lock(id): 97 | return id 98 | return -1 99 | 100 | def free_lock(id): 101 | """Attempts to free lock id, returning success as True/False.""" 102 | try: 103 | filename = _lock_file(id) 104 | # On POSIX systems os.rename is an atomic operation, so this is the safe 105 | # way to delete a lock: 106 | os.rename(filename, filename + '.redundant') 107 | os.remove(filename + '.redundant') 108 | return True 109 | except: 110 | return False 111 | 112 | 113 | # If run as a program: 114 | if __name__ == "__main__": 115 | import sys 116 | me = sys.argv[0] 117 | # Report 118 | if '--id' in sys.argv: 119 | if len(sys.argv) > 2: 120 | try: 121 | pid = int(sys.argv[2]) 122 | assert(os.path.exists('/proc/%d' % pid)) 123 | except: 124 | print 'Usage: %s --id [pid_to_wait_on]' % me 125 | print 'The optional process id must exist if specified.' 126 | print 'Otherwise the id of the parent process is used.' 127 | sys.exit(1) 128 | else: 129 | pid = os.getppid() 130 | print obtain_lock_id(pid) 131 | elif '--id-to-hog' in sys.argv: 132 | gpu_id = int(sys.argv[2]) if len(sys.argv) > 2 else None 133 | print obtain_lock_id_to_hog(gpu_id) 134 | elif '--free' in sys.argv: 135 | try: 136 | id = int(sys.argv[2]) 137 | except: 138 | print 'Usage: %s --free ' % me 139 | sys.exit(1) 140 | if free_lock(id): 141 | print "Lock freed" 142 | else: 143 | owner = owner_of_lock(id) 144 | if owner: 145 | print "Failed to free lock id=%d owned by %s" % (id, owner) 146 | else: 147 | print "Failed to free lock, but it wasn't actually set?" 148 | else: 149 | print '\n Usage instructions:\n' 150 | print ' To obtain and lock an id: %s --id' % me 151 | print ' The lock is automatically freed when the parent terminates' 152 | print 153 | print " To get an id that won't be freed: %s --id-to-hog" % me 154 | print " You *must* manually free these ids: %s --free \n" % me 155 | print ' More info: %s\n' % URL 156 | div = ' ' + "-"*60 157 | print '\n' + div 158 | print " NVIDIA board users:" 159 | print div 160 | for id in board_ids(): 161 | print " Board %d: %s" % (id, owner_of_lock(id)) 162 | print div + '\n' 163 | -------------------------------------------------------------------------------- /experiments/msr/gpu_lock.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | 3 | """ 4 | A simple discretionary locking system for /dev/nvidia devices. 5 | 6 | Iain Murray, November 2009, January 2010, January 2011. 7 | """ 8 | 9 | import os, os.path 10 | 11 | _dev_prefix = '/dev/nvidia' 12 | #URL = 'http://www.cs.toronto.edu/~murray/code/gpu_monitoring/' 13 | URL = 'http://homepages.inf.ed.ac.uk/imurray2/code/gpu_monitoring/' 14 | 15 | 16 | # Get ID's of NVIDIA boards. Should do this through a CUDA call, but this is 17 | # a quick and dirty way that works for now: 18 | def board_ids(): 19 | """Returns integer board ids available on this machine.""" 20 | from glob import glob 21 | board_devs = glob(_dev_prefix + '[0-9]*') 22 | return range(len(board_devs)) 23 | 24 | def _lock_file(id): 25 | """lock file from integer id""" 26 | # /tmp is cleared on reboot on many systems, but it doesn't have to be 27 | if os.path.exists('/dev/shm'): 28 | # /dev/shm on linux machines is a RAM disk, so is definitely cleared 29 | return '/dev/shm/gpu_lock_%d' % id 30 | else: 31 | return '/tmp/gpu_lock_%d' % id 32 | 33 | def owner_of_lock(id): 34 | """Username that has locked the device id. (Empty string if no lock).""" 35 | import pwd 36 | try: 37 | statinfo = os.lstat(_lock_file(id)) 38 | return pwd.getpwuid(statinfo.st_uid).pw_name 39 | except: 40 | return "" 41 | 42 | def _obtain_lock(id): 43 | """Attempts to lock id, returning success as True/False.""" 44 | try: 45 | # On POSIX systems symlink creation is atomic, so this should be a 46 | # robust locking operation: 47 | os.symlink('/dev/null', _lock_file(id)) 48 | return True 49 | except: 50 | return False 51 | 52 | def _launch_reaper(id, pid): 53 | """Start a process that will free a lock when process pid terminates""" 54 | from subprocess import Popen, PIPE 55 | me = __file__ 56 | if me.endswith('.pyc'): 57 | me = me[:-1] 58 | myloc = os.path.dirname(me) 59 | if not myloc: 60 | myloc = os.getcwd() 61 | reaper_cmd = os.path.join(myloc, 'run_on_me_or_pid_quit') 62 | Popen([reaper_cmd, str(pid), me, '--free', str(id)], 63 | stdout=open('/dev/null', 'w')) 64 | 65 | def obtain_lock_id(pid = None): 66 | """ 67 | Finds a free id, locks it and returns integer id, or -1 if none free. 68 | 69 | A process is spawned that will free the lock automatically when the 70 | process pid (by default the current python process) terminates. 71 | """ 72 | id = -1 73 | id = obtain_lock_id_to_hog() 74 | try: 75 | if id >= 0: 76 | if pid is None: 77 | pid = os.getpid() 78 | _launch_reaper(id, pid) 79 | except: 80 | free_lock(id) 81 | id = -1 82 | return id 83 | 84 | def obtain_lock_id_to_hog(gpu_id = None): 85 | """ 86 | Finds a free id, locks it and returns integer id, or -1 if none free. 87 | 88 | * Lock must be freed manually * 89 | """ 90 | if gpu_id: 91 | id = gpu_id 92 | if _obtain_lock(id): 93 | return id 94 | else: 95 | for id in board_ids(): 96 | if _obtain_lock(id): 97 | return id 98 | return -1 99 | 100 | def free_lock(id): 101 | """Attempts to free lock id, returning success as True/False.""" 102 | try: 103 | filename = _lock_file(id) 104 | # On POSIX systems os.rename is an atomic operation, so this is the safe 105 | # way to delete a lock: 106 | os.rename(filename, filename + '.redundant') 107 | os.remove(filename + '.redundant') 108 | return True 109 | except: 110 | return False 111 | 112 | 113 | # If run as a program: 114 | if __name__ == "__main__": 115 | import sys 116 | me = sys.argv[0] 117 | # Report 118 | if '--id' in sys.argv: 119 | if len(sys.argv) > 2: 120 | try: 121 | pid = int(sys.argv[2]) 122 | assert(os.path.exists('/proc/%d' % pid)) 123 | except: 124 | print 'Usage: %s --id [pid_to_wait_on]' % me 125 | print 'The optional process id must exist if specified.' 126 | print 'Otherwise the id of the parent process is used.' 127 | sys.exit(1) 128 | else: 129 | pid = os.getppid() 130 | print obtain_lock_id(pid) 131 | elif '--id-to-hog' in sys.argv: 132 | gpu_id = int(sys.argv[2]) if len(sys.argv) > 2 else None 133 | print obtain_lock_id_to_hog(gpu_id) 134 | elif '--free' in sys.argv: 135 | try: 136 | id = int(sys.argv[2]) 137 | except: 138 | print 'Usage: %s --free ' % me 139 | sys.exit(1) 140 | if free_lock(id): 141 | print "Lock freed" 142 | else: 143 | owner = owner_of_lock(id) 144 | if owner: 145 | print "Failed to free lock id=%d owned by %s" % (id, owner) 146 | else: 147 | print "Failed to free lock, but it wasn't actually set?" 148 | else: 149 | print '\n Usage instructions:\n' 150 | print ' To obtain and lock an id: %s --id' % me 151 | print ' The lock is automatically freed when the parent terminates' 152 | print 153 | print " To get an id that won't be freed: %s --id-to-hog" % me 154 | print " You *must* manually free these ids: %s --free \n" % me 155 | print ' More info: %s\n' % URL 156 | div = ' ' + "-"*60 157 | print '\n' + div 158 | print " NVIDIA board users:" 159 | print div 160 | for id in board_ids(): 161 | print " Board %d: %s" % (id, owner_of_lock(id)) 162 | print div + '\n' 163 | -------------------------------------------------------------------------------- /experiments/depparse/gpu_lock.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | 3 | """ 4 | A simple discretionary locking system for /dev/nvidia devices. 5 | 6 | Iain Murray, November 2009, January 2010, January 2011. 7 | """ 8 | 9 | import os, os.path 10 | 11 | _dev_prefix = '/dev/nvidia' 12 | #URL = 'http://www.cs.toronto.edu/~murray/code/gpu_monitoring/' 13 | URL = 'http://homepages.inf.ed.ac.uk/imurray2/code/gpu_monitoring/' 14 | 15 | 16 | # Get ID's of NVIDIA boards. Should do this through a CUDA call, but this is 17 | # a quick and dirty way that works for now: 18 | def board_ids(): 19 | """Returns integer board ids available on this machine.""" 20 | from glob import glob 21 | board_devs = glob(_dev_prefix + '[0-9]*') 22 | return range(len(board_devs)) 23 | 24 | def _lock_file(id): 25 | """lock file from integer id""" 26 | # /tmp is cleared on reboot on many systems, but it doesn't have to be 27 | if os.path.exists('/dev/shm'): 28 | # /dev/shm on linux machines is a RAM disk, so is definitely cleared 29 | return '/dev/shm/gpu_lock_%d' % id 30 | else: 31 | return '/tmp/gpu_lock_%d' % id 32 | 33 | def owner_of_lock(id): 34 | """Username that has locked the device id. (Empty string if no lock).""" 35 | import pwd 36 | try: 37 | statinfo = os.lstat(_lock_file(id)) 38 | return pwd.getpwuid(statinfo.st_uid).pw_name 39 | except: 40 | return "" 41 | 42 | def _obtain_lock(id): 43 | """Attempts to lock id, returning success as True/False.""" 44 | try: 45 | # On POSIX systems symlink creation is atomic, so this should be a 46 | # robust locking operation: 47 | os.symlink('/dev/null', _lock_file(id)) 48 | return True 49 | except: 50 | return False 51 | 52 | def _launch_reaper(id, pid): 53 | """Start a process that will free a lock when process pid terminates""" 54 | from subprocess import Popen, PIPE 55 | me = __file__ 56 | if me.endswith('.pyc'): 57 | me = me[:-1] 58 | myloc = os.path.dirname(me) 59 | if not myloc: 60 | myloc = os.getcwd() 61 | reaper_cmd = os.path.join(myloc, 'run_on_me_or_pid_quit') 62 | Popen([reaper_cmd, str(pid), me, '--free', str(id)], 63 | stdout=open('/dev/null', 'w')) 64 | 65 | def obtain_lock_id(pid = None): 66 | """ 67 | Finds a free id, locks it and returns integer id, or -1 if none free. 68 | 69 | A process is spawned that will free the lock automatically when the 70 | process pid (by default the current python process) terminates. 71 | """ 72 | id = -1 73 | id = obtain_lock_id_to_hog() 74 | try: 75 | if id >= 0: 76 | if pid is None: 77 | pid = os.getpid() 78 | _launch_reaper(id, pid) 79 | except: 80 | free_lock(id) 81 | id = -1 82 | return id 83 | 84 | def obtain_lock_id_to_hog(gpu_id = None): 85 | """ 86 | Finds a free id, locks it and returns integer id, or -1 if none free. 87 | 88 | * Lock must be freed manually * 89 | """ 90 | if gpu_id: 91 | id = gpu_id 92 | if _obtain_lock(id): 93 | return id 94 | else: 95 | for id in board_ids(): 96 | if _obtain_lock(id): 97 | return id 98 | return -1 99 | 100 | def free_lock(id): 101 | """Attempts to free lock id, returning success as True/False.""" 102 | try: 103 | filename = _lock_file(id) 104 | # On POSIX systems os.rename is an atomic operation, so this is the safe 105 | # way to delete a lock: 106 | os.rename(filename, filename + '.redundant') 107 | os.remove(filename + '.redundant') 108 | return True 109 | except: 110 | return False 111 | 112 | 113 | # If run as a program: 114 | if __name__ == "__main__": 115 | import sys 116 | me = sys.argv[0] 117 | # Report 118 | if '--id' in sys.argv: 119 | if len(sys.argv) > 2: 120 | try: 121 | pid = int(sys.argv[2]) 122 | assert(os.path.exists('/proc/%d' % pid)) 123 | except: 124 | print 'Usage: %s --id [pid_to_wait_on]' % me 125 | print 'The optional process id must exist if specified.' 126 | print 'Otherwise the id of the parent process is used.' 127 | sys.exit(1) 128 | else: 129 | pid = os.getppid() 130 | print obtain_lock_id(pid) 131 | elif '--id-to-hog' in sys.argv: 132 | gpu_id = int(sys.argv[2]) if len(sys.argv) > 2 else None 133 | print obtain_lock_id_to_hog(gpu_id) 134 | elif '--free' in sys.argv: 135 | try: 136 | id = int(sys.argv[2]) 137 | except: 138 | print 'Usage: %s --free ' % me 139 | sys.exit(1) 140 | if free_lock(id): 141 | print "Lock freed" 142 | else: 143 | owner = owner_of_lock(id) 144 | if owner: 145 | print "Failed to free lock id=%d owned by %s" % (id, owner) 146 | else: 147 | print "Failed to free lock, but it wasn't actually set?" 148 | else: 149 | print '\n Usage instructions:\n' 150 | print ' To obtain and lock an id: %s --id' % me 151 | print ' The lock is automatically freed when the parent terminates' 152 | print 153 | print " To get an id that won't be freed: %s --id-to-hog" % me 154 | print " You *must* manually free these ids: %s --free \n" % me 155 | print ' More info: %s\n' % URL 156 | div = ' ' + "-"*60 157 | print '\n' + div 158 | print " NVIDIA board users:" 159 | print div 160 | for id in board_ids(): 161 | print " Board %d: %s" % (id, owner_of_lock(id)) 162 | print div + '\n' 163 | -------------------------------------------------------------------------------- /scripts/words2hdf5.lua: -------------------------------------------------------------------------------- 1 | 2 | require 'hdf5' 3 | require 'torch' 4 | include '../utils/shortcut.lua' 5 | 6 | local SENT_END = {'###eos###'} 7 | 8 | local function createVocab(inputFile, freqCut, ignoreCase, keepFreq) 9 | local fin = io.open(inputFile, 'r') 10 | local wordVector = {} 11 | local wordFreq = {} 12 | while true do 13 | local line = fin:read() 14 | if line == nil then break end 15 | local words = line:splitc(' \t\r\n') 16 | for _, word in ipairs(words) do 17 | if ignoreCase then word = word:lower() end 18 | if wordFreq[word] then 19 | wordFreq[word] = wordFreq[word] + 1 20 | else 21 | wordFreq[word] = 1 22 | wordVector[#wordVector + 1] = word 23 | end 24 | end 25 | end 26 | fin:close() 27 | 28 | local wid = 1 29 | local word2idx = {} 30 | if not wordFreq['UNK'] then 31 | word2idx = {UNK = wid} 32 | wid = wid + 1 33 | end 34 | local uniqUNK = 0 35 | local freqs = { 0 } 36 | for _, wd in ipairs(wordVector) do 37 | if wordFreq[wd] > freqCut then 38 | word2idx[wd] = wid 39 | freqs[wid] = wordFreq[wd] 40 | wid = wid + 1 41 | else 42 | uniqUNK = uniqUNK + 1 43 | if not wordFreq['UNK'] then 44 | freqs[1] = freqs[1] + wordFreq[wd] 45 | end 46 | end 47 | end 48 | word2idx[SENT_END[1]] = wid 49 | local vocabSize = wid 50 | -- wid = wid + 1 51 | 52 | local idx2word = {} 53 | for wd, i in pairs(word2idx) do 54 | idx2word[i] = wd 55 | end 56 | 57 | local vocab = {word2idx = word2idx, idx2word = idx2word, 58 | freqCut = freqCut, ignoreCase = ignoreCase, 59 | keepFreq = keepFreq, UNK = word2idx['UNK'], 60 | EOS = word2idx['###eos###']} 61 | if keepFreq then 62 | vocab['freqs'] = freqs 63 | vocab['uniqUNK'] = uniqUNK 64 | printf('freqs size %d\n', #freqs) 65 | end 66 | 67 | assert(vocabSize == table.len(word2idx)) 68 | printf('original #words %d, after cut = %d, #words %d\n', #wordVector, freqCut, vocabSize) 69 | vocab['nvocab'] = vocabSize 70 | -- print(table.keys(vocab)) 71 | for k, v in pairs(vocab) do 72 | printf('%s -- ', k) 73 | if type(v) ~= 'table' then 74 | print(v) 75 | else 76 | print('table') 77 | end 78 | end 79 | 80 | return vocab 81 | end 82 | 83 | local function words2hdf5(vocab, h5out, splitLabel, splitFile, maxlen) 84 | local gdata = string.format('/%s/x_data', splitLabel) 85 | local gindex = string.format('/%s/index', splitLabel) 86 | local fin = io.open(splitFile, 'r') 87 | local word2idx = vocab.word2idx 88 | 89 | local dOpt = hdf5.DataSetOptions() 90 | dOpt:setChunked(1024*50*10) 91 | -- dOpt:setDeflate() 92 | local iOpt = hdf5.DataSetOptions() 93 | iOpt:setChunked(1024*10, 2) 94 | -- iOpt:setDeflate() 95 | 96 | local lineNo, offset = 0, 1 97 | local x_data = {} 98 | local index = {} 99 | local isFirst = true 100 | 101 | function appendData() 102 | local data_ = torch.IntTensor(x_data) 103 | local ind_ = torch.IntTensor(index) 104 | if not isFirst then 105 | h5out:append(gdata, data_, dOpt) 106 | h5out:append(gindex, ind_, iOpt) 107 | else 108 | h5out:write(gdata, data_, dOpt) 109 | h5out:write(gindex, ind_, iOpt) 110 | isFirst = false 111 | end 112 | end 113 | 114 | local ndel = 0 115 | 116 | while true do 117 | local line = fin:read() 118 | if line == nil then break end 119 | local words = line:splitc(' \t\r\n') 120 | -- print(#words, maxlen) 121 | if #words <= maxlen then 122 | local xs = {} 123 | local idx = {} 124 | for _, word in ipairs(words) do 125 | word = vocab.ignoreCase and word:lower() or word 126 | local wid = word2idx[word] or vocab.UNK 127 | xs[#xs + 1] = wid 128 | end 129 | local vlen = #xs 130 | idx = {offset, vlen} 131 | table.extend(x_data, xs) 132 | index[#index + 1] = idx 133 | 134 | offset = offset + vlen 135 | lineNo = lineNo + 1 136 | if lineNo % 50000 == 0 then 137 | appendData() 138 | x_data = {} 139 | index = {} 140 | collectgarbage() 141 | end 142 | else 143 | ndel = ndel + 1 144 | end 145 | end 146 | 147 | if #x_data > 0 then 148 | appendData() 149 | end 150 | 151 | printf('[%s] delete %d sentences\n', splitLabel, ndel) 152 | 153 | fin:close() 154 | end 155 | 156 | local function getOpts() 157 | local cmd = torch.CmdLine() 158 | cmd:text('== convert text to hdf5 format ==') 159 | cmd:text() 160 | cmd:option('--train', '', 'train text file') 161 | cmd:option('--valid', '', 'valid text file') 162 | cmd:option('--test', '', 'test text file') 163 | cmd:option('--dataset', '', 'the resulting dataset (.h5)') 164 | cmd:option('--freq', 0, 'words less than or equals to \"freq\" times will be ignored') 165 | cmd:option('--ignorecase', false, 'case will be ignored') 166 | cmd:option('--keepfreq', false, 'keep frequency information during creating vocabulary') 167 | cmd:option('--maxlen', 100, 'sentences longer than maxlen will be ignored!') 168 | 169 | return cmd:parse(arg) 170 | end 171 | 172 | local function main() 173 | local opts = getOpts() 174 | print(opts) 175 | local vocab = createVocab(opts.train, opts.freq, opts.ignorecase, opts.keepfreq) 176 | assert(opts.dataset:ends('.h5'), 'dataset must be hdf5 file .h5') 177 | local dataPrefix = opts.dataset:sub(1, -4) 178 | local vocabPath = dataPrefix .. '.vocab.t7' 179 | printf('save vocab to %s\n', vocabPath) 180 | torch.save(dataPrefix .. '.vocab.t7', vocab) 181 | 182 | local h5out = hdf5.open(opts.dataset, 'w') 183 | words2hdf5(vocab, h5out, 'train', opts.train, opts.maxlen) 184 | print('create training set done!') 185 | words2hdf5(vocab, h5out, 'valid', opts.valid, opts.maxlen) 186 | print('create validing set done!') 187 | words2hdf5(vocab, h5out, 'test', opts.test, opts.maxlen) 188 | print('create testing set done!') 189 | h5out:close() 190 | printf('save dataset to %s\n', opts.dataset) 191 | end 192 | 193 | main() 194 | -------------------------------------------------------------------------------- /nnets/GPULSTMLM.lua: -------------------------------------------------------------------------------- 1 | 2 | require 'torch' 3 | require 'nn' 4 | require 'optim' 5 | require 'nngraph' 6 | require 'Embedding' 7 | require 'MaskedClassNLLCriterion' 8 | 9 | require 'basic' 10 | 11 | local model_utils = require 'model_utils' 12 | 13 | local GPULSTMLM = torch.class('GPULSTMLM', 'BModel') 14 | 15 | local function transferData(useGPU, data) 16 | if useGPU then 17 | return data:cuda() 18 | else 19 | return data 20 | end 21 | end 22 | 23 | function GPULSTMLM:__init(opts) 24 | self.opts = opts 25 | self.name = 'GPULSTMLM' 26 | self:print( 'build LSTMLM ...' ) 27 | -- torch.manualSeed(opts.seed) 28 | -- build model 29 | opts.nivocab = opts.nivocab or opts.nvocab 30 | opts.novocab = opts.novocab or opts.nvocab 31 | opts.seqlen = opts.seqlen or 10 32 | self.coreNetwork = self:createNetwork(opts) 33 | self.params, self.grads = self.coreNetwork:getParameters() 34 | self.params:uniform(-opts.initRange, opts.initRange) 35 | print(self.params:size()) 36 | print(self.params[{ {1, 10} }]) 37 | 38 | self:print( 'Begin to clone model' ) 39 | self.networks = model_utils.clone_many_times(self.coreNetwork, opts.seqLen) 40 | self:print( 'Clone model done!' ) 41 | 42 | self:print('init states') 43 | self:setup(opts) 44 | self:print('init states done!') 45 | 46 | self:print( 'build LSTMLM done!' ) 47 | end 48 | 49 | function GPULSTMLM:setup(opts) 50 | self.hidStates = {} -- including all h_t and c_t 51 | self.initStates = {} 52 | self.df_hidStates = {} 53 | self.df_StatesT = {} 54 | 55 | for i = 1, 2*opts.nlayers do 56 | self.initStates[i] = transferData(opts.useGPU, torch.ones(opts.batchSize, opts.nhid) * opts.initHidVal) 57 | self.df_StatesT[i] = transferData(opts.useGPU, torch.zeros(opts.batchSize, opts.nhid)) 58 | end 59 | self.hidStates[0] = self.initStates 60 | self.err = transferData(opts.useGPU, torch.zeros(opts.seqLen)) 61 | end 62 | 63 | function GPULSTMLM:createLSTM(x_t, c_tm1, h_tm1, nin, nhid) 64 | -- compute activations of four gates all together 65 | local x2h = nn.Linear(nin, nhid * 4)(x_t) 66 | local h2h = nn.Linear(nhid, nhid * 4)(h_tm1) 67 | local allGatesActs = nn.CAddTable()({x2h, h2h}) 68 | local allGatesActsSplits = nn.SplitTable(2)( nn.Reshape(4, nhid)(allGatesActs) ) 69 | -- unpack all gate activations 70 | local i_t = nn.Sigmoid()( nn.SelectTable(1)( allGatesActsSplits ) ) 71 | local f_t = nn.Sigmoid()( nn.SelectTable(2)( allGatesActsSplits ) ) 72 | local o_t = nn.Sigmoid()( nn.SelectTable(3)( allGatesActsSplits ) ) 73 | local n_t = nn.Tanh()( nn.SelectTable(4)( allGatesActsSplits ) ) 74 | -- compute new cell 75 | local c_t = nn.CAddTable()({ 76 | nn.CMulTable()({ i_t, n_t }), 77 | nn.CMulTable()({ f_t, c_tm1 }) 78 | }) 79 | -- compute new hidden state 80 | local h_t = nn.CMulTable()({ o_t, nn.Tanh()( c_t ) }) 81 | 82 | return c_t, h_t 83 | end 84 | 85 | function GPULSTMLM:createNetwork(opts) 86 | local x_t = nn.Identity()() 87 | local y_t = nn.Identity()() 88 | local s_tm1 = nn.Identity()() 89 | local in_t = {[0] = nn.LookupTable(opts.nivocab, opts.nin)(x_t)} 90 | -- local in_t = {[0] = Embedding(opts.nivocab, opts.nin)(x_t)} 91 | local s_t = {} 92 | local splits_tm1 = {s_tm1:split(2 * opts.nlayers)} 93 | 94 | for i = 1, opts.nlayers do 95 | local c_tm1_i = splits_tm1[i + i - 1] 96 | local h_tm1_i = splits_tm1[i + i] 97 | local x_t_i = in_t[i - 1] 98 | local c_t_i, h_t_i = nil, nil 99 | if i == 1 then 100 | c_t_i, h_t_i = self:createLSTM(x_t_i, c_tm1_i, h_tm1_i, opts.nin, opts.nhid) 101 | else 102 | c_t_i, h_t_i = self:createLSTM(x_t_i, c_tm1_i, h_tm1_i, opts.nhid, opts.nhid) 103 | end 104 | s_t[i+i-1] = c_t_i 105 | s_t[i+i] = h_t_i 106 | in_t[i] = h_t_i 107 | end 108 | 109 | local h2y = nn.Linear(opts.nhid, opts.novocab)(in_t[opts.nlayers]) 110 | local y_pred = nn.LogSoftMax()(h2y) 111 | local err = MaskedClassNLLCriterion()({y_pred, y_t}) 112 | 113 | local model = nn.gModule({x_t, y_t, s_tm1}, {nn.Identity()(s_t), err}) 114 | if opts.useGPU then 115 | return model:cuda() 116 | else 117 | return model 118 | end 119 | end 120 | 121 | function GPULSTMLM:trainBatch(x, y, sgdParam) 122 | --[[ 123 | x = x:type('torch.DoubleTensor') 124 | y = y:type('torch.DoubleTensor') 125 | --]] 126 | if self.opts.useGPU then 127 | x = x:cuda() 128 | y = y:cuda() 129 | end 130 | local function feval(params_) 131 | if self.params ~= params_ then 132 | self.params:copy(params_) 133 | end 134 | self.grads:zero() 135 | -- forward pass 136 | local loss = 0 137 | local T = x:size(1) 138 | for t = 1, T do 139 | local s_tm1 = self.hidStates[t - 1] 140 | self.hidStates[t], self.err[t] = 141 | unpack( self.networks[t]:forward({ x[{ t, {} }], y[{ t, {} }], s_tm1 }) ) 142 | loss = loss + self.err[t] 143 | end 144 | 145 | for i = 1, 2*self.opts.nlayers do 146 | self.df_StatesT[i]:zero() 147 | end 148 | self.df_hidStates[T] = self.df_StatesT 149 | 150 | for t = T, 1, -1 do 151 | local s_tm1 = self.hidStates[t - 1] 152 | local derr = transferData(self.opts.useGPU, torch.ones(1)) 153 | local _, _, df_hidStates_tm1 = unpack( 154 | self.networks[t]:backward( 155 | {x[{ t, {} }], y[{ t, {} }], s_tm1}, 156 | {self.df_hidStates[t], derr} 157 | ) 158 | ) 159 | self.df_hidStates[t-1] = df_hidStates_tm1 160 | 161 | if self.opts.useGPU then 162 | cutorch.synchronize() 163 | end 164 | end 165 | 166 | -- clip the gradients 167 | self.grads:clamp(-5, 5) 168 | 169 | return loss, self.grads 170 | end 171 | 172 | local _, loss_ = optim.adagrad(feval, self.params, sgdParam) 173 | return loss_[1] 174 | end 175 | 176 | function GPULSTMLM:validBatch(x, y) 177 | local loss = 0 178 | local T = x:size(1) 179 | for t = 1, T do 180 | local s_tm1 = self.hidStates[t - 1] 181 | self.hidStates[t], self.err[t] = 182 | unpack( self.networks[t]:forward({ x[{ t, {} }], y[{ t, {} }], s_tm1 }) ) 183 | loss = loss + self.err[t] 184 | end 185 | 186 | return loss 187 | end 188 | 189 | 190 | -------------------------------------------------------------------------------- /layers/NCE1.lua: -------------------------------------------------------------------------------- 1 | 2 | -- note the lnZ is fixed 3 | local NCE, parent = torch.class('NCE', 'nn.LookupTable') 4 | 5 | function NCE:__init(inputSize, outputSize, Z) 6 | parent.__init(self, outputSize, inputSize) 7 | 8 | -- self.weight = torch.Tensor(outputSize, inputSize) 9 | -- self.bias = torch.Tensor(outputSize) 10 | -- self.gradWeight = torch.Tensor(outputSize, inputSize) 11 | -- self.gradBias = torch.Tensor(outputSize) 12 | self.Z = Z 13 | 14 | print('NCE is from nn.LookupTable') 15 | end 16 | 17 | function NCE:updateOutput(input) 18 | -- hs: 2D, y: 1D, y_neg: 2D (x, n_neg), y_prob: 1D, y_neg_prob: 2D (x, n_neg) 19 | local hs, y, y_neg, y_prob, y_neg_prob = unpack(input) 20 | local Who = self.weight 21 | if hs:dim() == 2 then 22 | -- compute non-normalized softmax for y 23 | self.We_out = Who:index(1, y) 24 | local We_out = self.We_out 25 | local pos_a = torch.cmul(We_out, hs):sum(2) 26 | local p_rnn_pos = pos_a:exp():div(self.Z) 27 | local k = y_neg:size(2) 28 | -- local y_prob_2d = y_prob:view(y_prob:size(1), 1) 29 | self.P_pos = torch.cdiv( p_rnn_pos, (p_rnn_pos + y_prob * k) ) -- P_pos shape (seqlen * bs, 1) 30 | local P_pos = self.P_pos 31 | local log_P_pos = torch.log(P_pos) 32 | 33 | -- compute non-normalized softmax for negative examples of y, y_neg 34 | local y_neg_ = y_neg:view(y_neg:size(1) * y_neg:size(2)) 35 | local We_out_n_ = Who:index(1, y_neg_) 36 | local n_hid = Who:size(2) 37 | self.We_out_n = We_out_n_:view( y_neg:size(1), y_neg:size(2), n_hid ) 38 | local We_out_n = self.We_out_n 39 | local neg_a = torch.cmul( We_out_n, hs:view(hs:size(1), 1, hs:size(2)):expand(hs:size(1), y_neg:size(2), hs:size(2)) ):sum(3) 40 | local p_rnn_neg = neg_a:exp():div(self.Z) 41 | local k_y_neg_prob = y_neg_prob * k 42 | self.P_neg = torch.cdiv( k_y_neg_prob, (p_rnn_neg + k_y_neg_prob) ) 43 | local P_neg = self.P_neg 44 | local log_P_neg = torch.log(P_neg) 45 | 46 | self.output = log_P_pos + log_P_neg:sum(2) 47 | -- print('NCE updateOutput self.output size') 48 | -- print(self.output:size()) 49 | 50 | return self.output 51 | else 52 | error('input must be 2D matrix, currently only support batch mode') 53 | end 54 | end 55 | 56 | function NCE:updateGradInput(input, gradOutput) 57 | -- hs: 2D, y: 1D, y_neg: 2D (x, n_neg), y_prob: 1D, y_neg_prob: 2D (x, n_neg) 58 | local hs, y, y_neg, y_prob, y_neg_prob = unpack(input) 59 | 60 | -- gradOutput: is the scale of the gradients, gradOutput can contain 0s; 61 | -- that is to say gradOutput can also be served as mask; shape: (bs*seq, 1) 62 | 63 | if self.gradInput then 64 | -- I can't see why self.gradInput:zero() is useful 65 | local nElement = self.gradInput:nElement() 66 | self.gradInput:resizeAs(hs) 67 | if self.gradInput:nElement() ~= nElement then 68 | self.gradInput:zero() 69 | end 70 | 71 | if hs:dim() == 2 then 72 | -- gradients from the positive samples 73 | -- take mask (gradOutput) into account 74 | 75 | --[[ 76 | print('size of gradOutput') 77 | print(gradOutput:size()) 78 | print(gradOutput) 79 | print('size of self.P_pos') 80 | print(self.P_pos:size()) 81 | --]] 82 | 83 | self.d_P_pos = torch.cmul( (-self.P_pos + 1), gradOutput ) 84 | local d_P_pos = self.d_P_pos 85 | self.gradInput:cmul( self.We_out, d_P_pos:expand(self.P_pos:size(1), hs:size(2)) ) 86 | 87 | -- gradients from the negative samples 88 | -- take (gradOutput) into account 89 | self.d_P_neg = torch.cmul( (self.P_neg - 1), gradOutput:expand(gradOutput:size(1), self.P_neg:size(2)) ) 90 | local d_P_neg = self.d_P_neg 91 | local d_hs = self.We_out_n:cmul( 92 | d_P_neg:view(d_P_neg:size(1), d_P_neg:size(2), 1):expand(d_P_neg:size(1), d_P_neg:size(2), hs:size(2)) 93 | ) 94 | self.gradInput:add(d_hs:sum(2)) 95 | 96 | return {self.gradInput} 97 | else 98 | error('input must be 2D matrix, currently only support batch mode') 99 | end 100 | end 101 | 102 | end 103 | 104 | function NCE:accGradParameters(input, gradOutput) 105 | -- hs: 2D, y: 1D, y_neg: 2D (x, n_neg), y_prob: 1D, y_neg_prob: 2D (x, n_neg) 106 | local hs, y, y_neg, y_prob, y_neg_prob = unpack(input) 107 | 108 | self:backCompatibility() 109 | 110 | if hs:dim() == 2 then 111 | local d_P_pos = self.d_P_pos 112 | local gradWeight_pos = torch.cmul( hs, d_P_pos:expand(self.P_pos:size(1), hs:size(2)) ) 113 | 114 | y = self:makeInputContiguous(y) 115 | y = self.copiedInput and self._input or y 116 | self.gradWeight.nn.LookupTable_accGradParameters(self, y, gradWeight_pos, 1) 117 | 118 | --[[ 119 | local n_pos = y:size(1) 120 | for i = 1, n_pos do 121 | self.gradWeight[{ y[i], {} }]:add( gradWeight_pos[{ i, {} }] ) 122 | end 123 | --]] 124 | 125 | local d_P_neg = self.d_P_neg 126 | local gradWeight_neg = torch.cmul( 127 | hs:view(hs:size(1), 1, hs:size(2)):expand(hs:size(1), y_neg:size(2), hs:size(2)), 128 | d_P_neg:view(d_P_neg:size(1), d_P_neg:size(2), 1):expand(d_P_neg:size(1), d_P_neg:size(2), hs:size(2)) 129 | ) 130 | 131 | y_neg = self:makeInputContiguous(y_neg) 132 | y_neg = self.copiedInput and self._input or y_neg 133 | self.gradWeight.nn.LookupTable_accGradParameters(self, y_neg:view(-1), gradWeight_neg, 1) 134 | 135 | --[[ 136 | local n1 = y_neg:size(1) 137 | local n2 = y_neg:size(2) 138 | 139 | for i = 1, n1 do 140 | for j = 1, n2 do 141 | self.gradWeight[{ y_neg[{ i, j }], {} }]:add( gradWeight_neg[{ i, j, {} }] ) 142 | end 143 | end 144 | --]] 145 | 146 | self.d_P_pos = nil 147 | self.d_P_neg = nil 148 | else 149 | error('input must be 2D matrix, currently only support batch mode') 150 | end 151 | 152 | -- print('accGradParameters safe!') 153 | 154 | end 155 | 156 | -- we do not need to accumulate parameters when sharing 157 | NCE.sharedAccUpdateGradParameters = NCE.accUpdateGradParameters 158 | 159 | function NCE:__tostring__() 160 | return torch.type(self) .. 161 | string.format('(%d -> %d)', self.weight:size(2), self.weight:size(1)) 162 | end 163 | -------------------------------------------------------------------------------- /layers/NCE.lua: -------------------------------------------------------------------------------- 1 | 2 | -- normalizing constant is learned automatically 3 | -- support learning Z automatically 4 | -- I want to make it faster 5 | 6 | local NCE, parent = torch.class('NCE', 'nn.LookupTable') 7 | 8 | function NCE:__init(inputSize, outputSize, Z, learnZ) 9 | parent.__init(self, outputSize, inputSize) 10 | 11 | self.learnZ = learnZ 12 | if learnZ then 13 | self.bias = torch.Tensor(1) 14 | self.gradBias = torch.Tensor(1) 15 | print('learning Z') 16 | print('self.bias is acutally Z') 17 | else 18 | self.Z = Z 19 | print('Z is a hyper-parameter') 20 | end 21 | 22 | print('NCE is from nn.LookupTable') 23 | end 24 | 25 | function NCE:updateOutput(input) 26 | -- hs: 2D, y: 1D, y_neg: 2D (x, n_neg), y_prob: 1D, y_neg_prob: 2D (x, n_neg) 27 | local hs, y, y_neg, y_prob, y_neg_prob = unpack(input) 28 | local Who = self.weight 29 | if hs:dim() == 2 then 30 | 31 | if self.learnZ then 32 | self.Z = self.bias[1] 33 | end 34 | 35 | -- compute non-normalized softmax for y 36 | self.We_out = Who:index(1, y) 37 | local We_out = self.We_out 38 | local pos_a = torch.cmul(We_out, hs):sum(2) 39 | local p_rnn_pos = pos_a:exp():div(self.Z) 40 | local k = y_neg:size(2) 41 | -- local y_prob_2d = y_prob:view(y_prob:size(1), 1) 42 | self.P_pos = torch.cdiv( p_rnn_pos, (p_rnn_pos + y_prob * k) ) -- P_pos shape (seqlen * bs, 1) 43 | local P_pos = self.P_pos 44 | local log_P_pos = torch.log(P_pos) 45 | 46 | -- compute non-normalized softmax for negative examples of y, y_neg 47 | local y_neg_ = y_neg:view(y_neg:size(1) * y_neg:size(2)) 48 | local We_out_n_ = Who:index(1, y_neg_) 49 | local n_hid = Who:size(2) 50 | self.We_out_n = We_out_n_:view( y_neg:size(1), y_neg:size(2), n_hid ) 51 | local We_out_n = self.We_out_n 52 | local neg_a = torch.cmul( We_out_n, hs:view(hs:size(1), 1, hs:size(2)):expand(hs:size(1), y_neg:size(2), hs:size(2)) ):sum(3) 53 | local p_rnn_neg = neg_a:exp():div(self.Z) 54 | local k_y_neg_prob = y_neg_prob * k 55 | self.P_neg = torch.cdiv( k_y_neg_prob, (p_rnn_neg + k_y_neg_prob) ) 56 | local P_neg = self.P_neg 57 | local log_P_neg = torch.log(P_neg) 58 | 59 | self.output = log_P_pos + log_P_neg:sum(2) 60 | 61 | return self.output 62 | else 63 | error('input must be 2D matrix, currently only support batch mode') 64 | end 65 | end 66 | 67 | function NCE:updateGradInput(input, gradOutput) 68 | -- hs: 2D, y: 1D, y_neg: 2D (x, n_neg), y_prob: 1D, y_neg_prob: 2D (x, n_neg) 69 | local hs, y, y_neg, y_prob, y_neg_prob = unpack(input) 70 | 71 | -- gradOutput: is the scale of the gradients, gradOutput can contain 0s; 72 | -- that is to say gradOutput can also be served as mask; shape: (bs*seq, 1) 73 | 74 | if self.gradInput then 75 | -- I can't see why self.gradInput:zero() is useful 76 | local nElement = self.gradInput:nElement() 77 | self.gradInput:resizeAs(hs) 78 | if self.gradInput:nElement() ~= nElement then 79 | self.gradInput:zero() 80 | end 81 | 82 | if hs:dim() == 2 then 83 | -- gradients from the positive samples 84 | -- take mask (gradOutput) into account 85 | 86 | self.d_P_pos = torch.cmul( (-self.P_pos + 1), gradOutput ) 87 | local d_P_pos = self.d_P_pos 88 | self.gradInput:cmul( self.We_out, d_P_pos:expand(self.P_pos:size(1), hs:size(2)) ) 89 | 90 | -- gradients from the negative samples 91 | -- take (gradOutput) into account 92 | self.d_P_neg = torch.cmul( (self.P_neg - 1), gradOutput:expand(gradOutput:size(1), self.P_neg:size(2)) ) 93 | local d_P_neg = self.d_P_neg 94 | local d_hs = self.We_out_n:cmul( 95 | d_P_neg:view(d_P_neg:size(1), d_P_neg:size(2), 1):expand(d_P_neg:size(1), d_P_neg:size(2), hs:size(2)) 96 | ) 97 | self.gradInput:add(d_hs:sum(2)) 98 | 99 | return {self.gradInput} 100 | else 101 | error('input must be 2D matrix, currently only support batch mode') 102 | end 103 | end 104 | 105 | end 106 | 107 | function NCE:accGradParameters(input, gradOutput) 108 | -- hs: 2D, y: 1D, y_neg: 2D (x, n_neg), y_prob: 1D, y_neg_prob: 2D (x, n_neg) 109 | local hs, y, y_neg, y_prob, y_neg_prob = unpack(input) 110 | 111 | self:backCompatibility() 112 | 113 | if hs:dim() == 2 then 114 | local d_P_pos = self.d_P_pos 115 | 116 | if self.learnZ then 117 | self.gradBias:add( (-d_P_pos / self.Z):sum() ) 118 | end 119 | 120 | local gradWeight_pos = torch.cmul( hs, d_P_pos:expand(self.P_pos:size(1), hs:size(2)) ) 121 | 122 | y = self:makeInputContiguous(y) 123 | y = self.copiedInput and self._input or y 124 | self.gradWeight.nn.LookupTable_accGradParameters(self, y, gradWeight_pos, 1) 125 | --[[ 126 | self.gradWeight.THNN.LookupTable_accGradParameters( 127 | y:cdata(), 128 | gradWeight_pos:cdata(), 129 | self.gradWeight:cdata() 130 | ) 131 | --]] 132 | 133 | local d_P_neg = self.d_P_neg 134 | 135 | if self.learnZ then 136 | self.gradBias:add( (-d_P_neg / self.Z):sum() ) 137 | end 138 | 139 | local gradWeight_neg = torch.cmul( 140 | hs:view(hs:size(1), 1, hs:size(2)):expand(hs:size(1), y_neg:size(2), hs:size(2)), 141 | d_P_neg:view(d_P_neg:size(1), d_P_neg:size(2), 1):expand(d_P_neg:size(1), d_P_neg:size(2), hs:size(2)) 142 | ) 143 | 144 | y_neg = self:makeInputContiguous(y_neg) 145 | y_neg = self.copiedInput and self._input or y_neg 146 | self.gradWeight.nn.LookupTable_accGradParameters(self, y_neg:view(-1), gradWeight_neg, 1) 147 | --[[ 148 | self.gradWeight.THNN.LookupTable_accGradParameters( 149 | y_neg:view(-1):cdata(), 150 | gradWeight_neg:cdata(), 151 | self.gradWeight:cdata() 152 | ) 153 | --]] 154 | 155 | self.d_P_pos = nil 156 | self.d_P_neg = nil 157 | else 158 | error('input must be 2D matrix, currently only support batch mode') 159 | end 160 | 161 | end 162 | 163 | -- we do not need to accumulate parameters when sharing 164 | NCE.sharedAccUpdateGradParameters = NCE.accUpdateGradParameters 165 | 166 | function NCE:__tostring__() 167 | return torch.type(self) .. 168 | string.format('(%d -> %d)', self.weight:size(2), self.weight:size(1)) 169 | end 170 | -------------------------------------------------------------------------------- /scripts/sorthdf5.lua: -------------------------------------------------------------------------------- 1 | 2 | require 'torch' 3 | require 'hdf5' 4 | include '../utils/shortcut.lua' 5 | 6 | local SortHDF5 = torch.class('SortHDF5') 7 | 8 | function SortHDF5:__init(h5infile, h5outfile) 9 | self.h5in = hdf5.open(h5infile, 'r') 10 | self.h5out = hdf5.open(h5outfile, 'w') 11 | end 12 | 13 | -- sortCmd is an integer 14 | -- 0 means no sorting 15 | -- sortCmd > 0 means sort every k batches 16 | -- sortCmd < 0 means sort training data by their length with respect to 17 | -- original order of the first sentence 18 | function SortHDF5:sortHDF5(sortCmd, batchSize) 19 | assert(sortCmd ~= 0, 'If there is no sorting, don\'t use this class!') 20 | assert(batchSize ~= nil, 'batchSize MUST be specified!') 21 | local lens = self:getLengths('train') 22 | -- print('get lengths of dataset done!') 23 | local idxs 24 | if sortCmd > 0 then 25 | local Kbatch = sortCmd * batchSize 26 | idxs = self:sortKBatches(lens, Kbatch) 27 | print '[sort algorithm] = sort k batches!' 28 | else 29 | idxs = self:sortBatches(lens, batchSize) 30 | print '[sort algorithm] = sort batches!' 31 | end 32 | self:writeSplit('train', idxs) 33 | print('sort training set done!') 34 | 35 | local function sortAll(splitLabel) 36 | local vlens = self:getLengths(splitLabel) 37 | local vidxs = self:sortKBatches(vlens, #vlens) 38 | self:writeSplit(splitLabel, vidxs) 39 | end 40 | 41 | -- for valid and test partition 42 | sortAll('valid') 43 | print('sort valid set done!') 44 | sortAll('test') 45 | print('sort test set done!') 46 | 47 | self.h5in:close() 48 | self.h5out:close() 49 | end 50 | 51 | function SortHDF5:getLengths(splitLabel) 52 | local index = self.h5in:read(string.format('/%s/index', splitLabel)) 53 | local N = index:dataspaceSize()[1] 54 | local lens = {} 55 | for i = 1, N do 56 | local idx = index:partial({i, i}, {1, 2}) 57 | local start, len = idx[1][1], idx[1][2] 58 | table.insert(lens, len) 59 | end 60 | assert(#lens == N, 'Number of length should be consistent!') 61 | 62 | return lens 63 | end 64 | 65 | function SortHDF5:writeSplit(splitLabel, idxs) 66 | local index = self.h5in:read(string.format('/%s/index', splitLabel)) 67 | local x_data = self.h5in:read(string.format('/%s/x_data', splitLabel)) 68 | local y_data = self.h5in:read(string.format('/%s/y_data', splitLabel)) 69 | 70 | local offset, isFirst = 1, true 71 | local x_ts = {} 72 | local y_ts = {} 73 | local i_ts = {} 74 | 75 | local gxdata = string.format('/%s/x_data', splitLabel) 76 | local gydata = string.format('/%s/y_data', splitLabel) 77 | local gindex = string.format('/%s/index', splitLabel) 78 | local xOpt = hdf5.DataSetOptions() 79 | xOpt:setChunked(1024*50*10, 4) 80 | -- xOpt:setDeflate(1) 81 | local yOpt = hdf5.DataSetOptions() 82 | yOpt:setChunked(1024*50*10) 83 | -- yOpt:setDeflate(1) 84 | local iOpt = hdf5.DataSetOptions() 85 | iOpt:setChunked(1024*10, 2) 86 | -- iOpt:setDeflate(1) 87 | 88 | local function appendData() 89 | local x_data_ = torch.IntTensor(x_ts) 90 | local y_data_ = torch.IntTensor(y_ts) 91 | local index_ = torch.IntTensor(i_ts) 92 | if not isFirst then 93 | self.h5out:append(gxdata, x_data_, xOpt) 94 | self.h5out:append(gydata, y_data_, yOpt) 95 | self.h5out:append(gindex, index_, iOpt) 96 | else 97 | self.h5out:write(gxdata, x_data_, xOpt) 98 | self.h5out:write(gydata, y_data_, yOpt) 99 | self.h5out:write(gindex, index_, iOpt) 100 | isFirst = false 101 | end 102 | end 103 | 104 | for sentCount, i in ipairs(idxs) do 105 | local idx = index:partial({i, i}, {1, 2}) 106 | local start, len = idx[1][1], idx[1][2] 107 | local x = x_data:partial({start, start + len - 1}, {1, 4}) 108 | local y = y_data:partial({start, start + len - 1}) 109 | table.extend(x_ts, x:totable()) 110 | table.extend(y_ts, y:totable()) 111 | table.insert(i_ts, {offset, len}) 112 | if sentCount % 50000 == 0 then 113 | appendData() 114 | x_ts = {} 115 | y_ts = {} 116 | i_ts = {} 117 | printf('write [%s] line count = %d\n', splitLabel, sentCount) 118 | collectgarbage() 119 | end 120 | 121 | offset = offset + len 122 | end 123 | 124 | if #x_ts > 0 then 125 | appendData() 126 | end 127 | end 128 | 129 | function SortHDF5:sortKBatches(lens, Kbatch) 130 | local idxs = {} 131 | local N = #lens 132 | for istart = 1, N, Kbatch do 133 | local iend = math.min(istart + Kbatch - 1, N) 134 | local subIdxs = {} 135 | for i = istart, iend do 136 | table.insert(subIdxs, i) 137 | end 138 | table.sort(subIdxs, function(a, b) 139 | return lens[b] < lens[a] 140 | end) 141 | table.extend(idxs, subIdxs) 142 | end 143 | assert(#idxs == #lens) 144 | return idxs 145 | end 146 | 147 | function SortHDF5:sortBatches(lens, batchSize) 148 | local newIdxs = {} 149 | 150 | local len2idxs = {} 151 | local len2idxs_lens = {} 152 | for i, len in ipairs(lens) do 153 | local idxs = len2idxs[len] 154 | if idxs then 155 | table.insert(idxs, i) 156 | else 157 | len2idxs[len] = {i} 158 | table.insert(len2idxs_lens, len) 159 | end 160 | end 161 | 162 | local len2pos = {} 163 | for _, len in ipairs(len2idxs_lens) do 164 | len2pos[len] = 1 165 | end 166 | 167 | local pad = {} 168 | while true do 169 | local selectLen = -1 170 | local selectIdx = #lens + 1 171 | local istart, iend = -1, -1 172 | 173 | for _, len in ipairs(len2idxs_lens) do 174 | local pos = len2pos[len] 175 | local idxs = len2idxs[len] 176 | if pos <= #idxs and idxs[pos] < selectIdx then 177 | selectIdx = idxs[pos] 178 | selectLen = len 179 | istart, iend = pos, math.min(pos + batchSize - 1, #idxs) 180 | end 181 | end 182 | 183 | if selectLen == -1 then break end 184 | local sIdxs = len2idxs[selectLen] 185 | if iend - istart + 1 == batchSize then 186 | for i = istart, iend do 187 | newIdxs[#newIdxs + 1] = sIdxs[i] 188 | end 189 | else 190 | for i = istart, iend do 191 | pad[#pad + 1] = sIdxs[i] 192 | end 193 | end 194 | len2pos[selectLen] = iend + 1 195 | end -- end while 196 | 197 | table.sort(pad, function(a, b) 198 | return lens[b] < lens[a] 199 | end) 200 | table.extend(newIdxs, pad) 201 | 202 | return newIdxs 203 | end 204 | -------------------------------------------------------------------------------- /nnets/LSTMLM.lua: -------------------------------------------------------------------------------- 1 | 2 | require 'torch' 3 | require 'nn' 4 | require 'optim' 5 | require 'nngraph' 6 | require 'Embedding' 7 | require 'MaskedClassNLLCriterion' 8 | 9 | local model_utils = require 'model_utils' 10 | 11 | local LSTMLM = torch.class('LSTMLM') 12 | 13 | function LSTMLM:__init(opts) 14 | print 'build LSTMLM ...' 15 | torch.manualSeed(opts.seed) 16 | -- build model 17 | opts.nivocab = opts.nivocab or opts.nvocab 18 | opts.novocab = opts.novocab or opts.nvocab 19 | opts.seqlen = opts.seqlen or 10 20 | self.emb = Embedding(opts.nivocab, opts.nin) 21 | -- self.lstm = self:createLSTM(opts.nin, opts.nhid) 22 | print 'faster LSTM implmentation?' 23 | self.lstm = self:createLSTMFaster(opts.nin, opts.nhid) 24 | 25 | self.softmax = nn.Sequential():add( nn.Linear(opts.nhid, opts.novocab) ):add( nn.LogSoftMax() ) 26 | self.params, self.grads = model_utils.combine_all_parameters(self.emb, self.lstm, self.softmax) 27 | print('init range', -opts.initRange, opts.initRange) 28 | self.params:uniform(-opts.initRange, opts.initRange) 29 | 30 | -- clone everything -- 31 | print 'begin to clone emb' 32 | self.embs = model_utils.clone_many_times(self.emb, opts.seqLen) 33 | print 'clone emb done' 34 | print 'begin to clone lstm' 35 | self.lstms = model_utils.clone_many_times(self.lstm, opts.seqLen) 36 | print 'clone lstm done!' 37 | print 'begin to clone softmax' 38 | self.softmaxs = model_utils.clone_many_times(self.softmax, opts.seqLen) 39 | print 'clone softmax done!' 40 | 41 | self.h0s = torch.ones(opts.batchSize, opts.nhid) * opts.initHidVal 42 | self.c0s = self.h0s:clone() 43 | self.d_hTs = torch.zeros(opts.batchSize, opts.nhid) 44 | self.d_cTs = self.d_hTs:clone() 45 | 46 | -- self.criterion = nn.ClassNLLCriterion() 47 | self.criterion = MaskedClassNLLCriterion() 48 | self.criterions = model_utils.clone_many_times(self.criterion, opts.seqLen) 49 | print 'build LSTMLM done!' 50 | end 51 | 52 | function LSTMLM:createLSTM(nin, nhid) 53 | -- inputs 54 | local x_t = nn.Identity()() 55 | local c_tm1 = nn.Identity()() 56 | local h_tm1 = nn.Identity()() 57 | 58 | local function newHidLinear() 59 | local i2h = nn.Linear(nin, nhid)(x_t) 60 | local h2h = nn.Linear(nhid, nhid)(h_tm1) 61 | 62 | return nn.CAddTable()({i2h, h2h}) 63 | end 64 | 65 | local i_t = nn.Sigmoid()( newHidLinear() ) 66 | local f_t = nn.Sigmoid()( newHidLinear() ) 67 | local o_t = nn.Sigmoid()( newHidLinear() ) 68 | local n_t = nn.Tanh()( newHidLinear() ) 69 | 70 | local c_t = nn.CAddTable()({ 71 | nn.CMulTable()({ f_t, c_tm1 }), 72 | nn.CMulTable()({ i_t, n_t }) 73 | }) 74 | 75 | local h_t = nn.CMulTable()({ o_t, nn.Tanh()(c_t) }) 76 | 77 | return nn.gModule({x_t, c_tm1, h_tm1}, {c_t, h_t}) 78 | end 79 | 80 | function LSTMLM:createLSTMFaster(nin, nhid) 81 | local x_t = nn.Identity()() 82 | local c_tm1 = nn.Identity()() 83 | local h_tm1 = nn.Identity()() 84 | -- compute four gates together 85 | local x2h = nn.Linear(nin, 4*nhid)(x_t) 86 | local h2h = nn.Linear(nhid, 4*nhid)(h_tm1) 87 | local gateActs = nn.CAddTable()({x2h, h2h}) 88 | -- split the activations of four gates into four 89 | local reshapedGateActs = nn.Reshape(4, nhid)(gateActs) 90 | local gateActsSplits = nn.SplitTable(2)(reshapedGateActs) 91 | -- unpack all gates 92 | local i_t = nn.Sigmoid()( nn.SelectTable(1)(gateActsSplits) ) 93 | local f_t = nn.Sigmoid()( nn.SelectTable(2)(gateActsSplits) ) 94 | local o_t = nn.Sigmoid()( nn.SelectTable(3)(gateActsSplits) ) 95 | local n_t = nn.Tanh()( nn.SelectTable(4)(gateActsSplits) ) 96 | 97 | local c_t = nn.CAddTable()({ 98 | nn.CMulTable()({i_t, n_t}), 99 | nn.CMulTable()({f_t, c_tm1}) 100 | }) 101 | local h_t = nn.CMulTable()({ o_t, nn.Tanh()(c_t) }) 102 | 103 | return nn.gModule({x_t, c_tm1, h_tm1}, {c_t, h_t}) 104 | end 105 | 106 | function LSTMLM:validBatch(x, y) 107 | local bs = x:size(2) 108 | local embeds = {} 109 | local hs = {[0] = self.h0s[{ {1, bs}, {} }]} 110 | local cs = {[0] = self.c0s[{ {1, bs}, {} }]} 111 | local log_y_preds= {} 112 | local loss = 0 113 | local T = x:size(1) 114 | for t = 1, T do 115 | embeds[t] = self.embs[t]:forward(x[{ t, {} }]) 116 | cs[t], hs[t] = unpack( 117 | self.lstms[t]:forward({embeds[t], cs[t-1], hs[t-1]}) 118 | ) 119 | log_y_preds[t] = self.softmaxs[t]:forward(hs[t]) 120 | loss = loss + self.criterions[t]:forward(log_y_preds[t], y[{ t, {} }]) 121 | end 122 | 123 | return loss 124 | end 125 | 126 | function LSTMLM:trainBatch(x, y, sgd_param) 127 | -- x: (seqlen, bs) 128 | -- y: (seqlen, bs) 129 | local function feval(params_) 130 | if self.params ~= params_ then 131 | self.params:copy(params_) 132 | end 133 | 134 | self.grads:zero() 135 | -- forward pass 136 | local bs = x:size(2) 137 | local embeds = {} 138 | local hs = {[0] = self.h0s[{ {1, bs}, {} }]} 139 | local cs = {[0] = self.c0s[{ {1, bs}, {} }]} 140 | local log_y_preds= {} 141 | local loss = 0 142 | local T = x:size(1) 143 | for t = 1, T do 144 | embeds[t] = self.embs[t]:forward(x[{ t, {} }]) 145 | cs[t], hs[t] = unpack( 146 | self.lstms[t]:forward({embeds[t], cs[t-1], hs[t-1]}) 147 | ) 148 | log_y_preds[t] = self.softmaxs[t]:forward(hs[t]) 149 | loss = loss + self.criterions[t]:forward(log_y_preds[t], y[{ t, {} }]) 150 | end 151 | 152 | -- backward pass 153 | local d_embeds = {} 154 | self.d_hTs:zero() 155 | local d_hs = {[T] = self.d_hTs[{ {1, bs}, {} }]} 156 | -- local d_hs = {} 157 | local d_cs = {[T] = self.d_cTs[{ {1, bs}, {} }]} 158 | for t = T, 1, -1 do 159 | local d_log_y_preds_t = self.criterions[t]:backward(log_y_preds[t], y[{ t, {} }]) 160 | --[[ 161 | if t == T then 162 | assert(d_hs[t] == nil) 163 | d_hs[t] = self.softmaxs[t]:backward(hs[t], d_log_y_preds_t) 164 | else 165 | d_hs[t]:add( self.softmaxs[t]:backward(hs[t], d_log_y_preds_t) ) 166 | end 167 | --]] 168 | d_hs[t]:add( self.softmaxs[t]:backward(hs[t], d_log_y_preds_t) ) 169 | d_embeds[t], d_cs[t-1], d_hs[t-1] = unpack(self.lstms[t]:backward( 170 | {embeds[t], cs[t-1], hs[t-1]}, 171 | {d_cs[t], d_hs[t]}) 172 | ) 173 | self.embs[t]:backward(x[{ t, {} }], d_embeds[t]) 174 | end 175 | 176 | --[[ 177 | self.h0s:copy(hs[T]) 178 | self.c0s:copy(cs[T]) 179 | --]] 180 | 181 | -- clip the gradients 182 | self.grads:clamp(-5, 5) 183 | 184 | return loss, self.grads 185 | end 186 | 187 | local _, loss_ = optim.adagrad(feval, self.params, sgd_param) 188 | return loss_[1] 189 | end 190 | 191 | -------------------------------------------------------------------------------- /train_mlp.lua: -------------------------------------------------------------------------------- 1 | 2 | require '.' 3 | require 'MLP' 4 | require 'hdf5' 5 | 6 | local function getOpts() 7 | local cmd = torch.CmdLine() 8 | cmd:text('====== MLP v 1.0 ======') 9 | cmd:text() 10 | cmd:option('--seed', 123, 'random seed') 11 | cmd:option('--useGPU', false, 'use gpu') 12 | cmd:option('--snhids', '400,300,300,2', 'string hidden sizes for each layer') 13 | cmd:option('--activ', 'tanh', 'options: tanh, relu') 14 | cmd:option('--dropout', 0, 'dropout rate (dropping)') 15 | cmd:option('--maxEpoch', 10, 'max number of epochs') 16 | cmd:option('--dataset', 17 | '/disk/scratch/XingxingZhang/treelstm/dataset/depparse/eot.penn_wsj.conllx.sort.h5', 18 | 'dataset') 19 | cmd:option('--ftype', '|x|oe|', '') 20 | cmd:option('--ytype', 1, '') 21 | cmd:option('--batchSize', 256, '') 22 | cmd:option('--lr', 0.01, '') 23 | cmd:option('--optimMethod', 'AdaGrad', 'options: SGD, AdaGrad') 24 | cmd:option('--save', 'model.t7', 'save path') 25 | 26 | return cmd:parse(arg) 27 | end 28 | 29 | local EPOCH_INFO = '' 30 | 31 | local DataIter = {} 32 | function DataIter.getNExamples(dataPath, label) 33 | local h5in = hdf5.open(dataPath, 'r') 34 | local x_data = h5in:read(string.format('/%s/x', label)) 35 | local N = x_data:dataspaceSize()[1] 36 | 37 | return N 38 | end 39 | 40 | -- ftype: x | x, e | x, oe | x, e, oe 41 | function DataIter.createBatch(dataPath, label, ftype, ytype, batchSize) 42 | local h5in = hdf5.open(dataPath, 'r') 43 | local x_data = h5in:read(string.format('/%s/x', label)) 44 | local e_data = h5in:read(string.format('/%s/e', label)) 45 | local oe_data = h5in:read(string.format('/%s/oe', label)) 46 | local y_data = h5in:read(string.format('/%s/y', label)) 47 | local N = x_data:dataspaceSize()[1] 48 | local x_width = x_data:dataspaceSize()[2] 49 | local e_width = e_data:dataspaceSize()[2] 50 | local oe_width = oe_data:dataspaceSize()[2] 51 | 52 | -- print('N = ') 53 | -- print(N) 54 | local istart = 1 55 | 56 | return function() 57 | if istart <= N then 58 | local iend = math.min(istart + batchSize - 1, N) 59 | local x = x_data:partial({istart, iend}, {1, x_width}) 60 | local e = e_data:partial({istart, iend}, {1, e_width}) 61 | local oe = oe_data:partial({istart, iend}, {1, oe_width}) 62 | -- print('OK') 63 | local y = y_data:partial({istart, iend}, {ytype, ytype}):view(-1) + 1 64 | -- print('OK, too') 65 | 66 | local xd = {} 67 | if ftype:find('|x|') then 68 | table.insert(xd, x) 69 | end 70 | if ftype:find('|e|') then 71 | table.insert(xd, e) 72 | end 73 | if ftype:find('|oe|') then 74 | table.insert(xd, oe) 75 | end 76 | istart = iend + 1 77 | 78 | if #xd == 1 then 79 | return xd[1], y 80 | else 81 | local d = 0 82 | for i = 1, #xd do 83 | d = d + xd[i]:size(2) 84 | end 85 | local x_ = torch.zeros(x:size(1), d) 86 | d = 0 87 | for i = 1, #xd do 88 | x_[{ {}, {d + 1, d + xd[i]:size(2)} }] = xd[i] 89 | d = d + xd[i]:size(2) 90 | end 91 | 92 | return x_, y 93 | end 94 | else 95 | h5in:close() 96 | end 97 | end 98 | end 99 | 100 | local function train(mlp, opts) 101 | local dataIter = DataIter.createBatch(opts.dataset, 'train', 102 | opts.ftype, opts.ytype, opts.batchSize) 103 | 104 | local dataSize = DataIter.getNExamples(opts.dataset, 'train') 105 | local percent, inc = 0.001, 0.001 106 | local timer = torch.Timer() 107 | -- local sgdParam = {learningRate = opts.curLR} 108 | local sgdParam = opts.sgdParam 109 | local cnt = 0 110 | local totalLoss = 0 111 | local totalCnt = 0 112 | for x, y in dataIter do 113 | loss = mlp:trainBatch(x, y, sgdParam) 114 | totalLoss = totalLoss + loss * x:size(1) 115 | totalCnt = totalCnt + x:size(1) 116 | 117 | local ratio = totalCnt/dataSize 118 | if ratio >= percent then 119 | local wps = totalCnt / timer:time().real 120 | xprint( '\r%s %.3f %.4f (%s) / %.2f wps ... ', EPOCH_INFO, ratio, totalLoss/totalCnt, readableTime(timer:time().real), wps ) 121 | percent = math.floor(ratio / inc) * inc 122 | percent = percent + inc 123 | end 124 | 125 | cnt = cnt + 1 126 | if cnt % 5 == 0 then 127 | collectgarbage() 128 | end 129 | end 130 | 131 | return totalLoss / totalCnt 132 | end 133 | 134 | local function valid(mlp, label, opts) 135 | local dataIter = DataIter.createBatch(opts.dataset, label, 136 | opts.ftype, opts.ytype, opts.batchSize) 137 | 138 | local cnt = 0 139 | local correct, total = 0, 0 140 | for x, y in dataIter do 141 | local correct_, total_ = mlp:validBatch(x, y) 142 | correct = correct + correct_ 143 | total = total + total_ 144 | cnt = cnt + 1 145 | if cnt % 5 == 0 then collectgarbage() end 146 | end 147 | 148 | return correct, total 149 | end 150 | 151 | local function main() 152 | local opts = getOpts() 153 | torch.manualSeed(opts.seed) 154 | if opts.useGPU then 155 | require 'cutorch' 156 | require 'cunn' 157 | cutorch.manualSeed(opts.seed) 158 | end 159 | local mlp = MLP(opts) 160 | opts.sgdParam = {learningRate = opts.lr} 161 | opts.curLR = opts.lr 162 | print(opts) 163 | 164 | local timer = torch.Timer() 165 | local bestAcc = 0 166 | local bestModel = torch.FloatTensor(mlp.params:size()) 167 | for epoch = 1, opts.maxEpoch do 168 | EPOCH_INFO = string.format('epoch %d', epoch) 169 | local startTime = timer:time().real 170 | local trainCost = train(mlp, opts) 171 | -- local trainCost = 123 172 | xprint('\repoch %d TRAIN nll %f ', epoch, trainCost) 173 | local validCor, validTot = valid(mlp, 'valid', opts) 174 | local validAcc = validCor/validTot 175 | xprint('VALID %d/%d = %f ', validCor, validTot, validAcc) 176 | local endTime = timer:time().real 177 | xprintln('lr = %.4g (%s)', opts.curLR, readableTime(endTime - startTime)) 178 | 179 | if validAcc > bestAcc then 180 | bestAcc = validAcc 181 | mlp:getModel(bestModel) 182 | end 183 | end 184 | 185 | mlp:setModel(bestModel) 186 | opts.sgdParam = nil 187 | mlp:save(opts.save, true) 188 | xprintln('model saved at %s', opts.save) 189 | 190 | local validCor, validTot = valid(mlp, 'valid', opts) 191 | local validAcc = validCor/validTot 192 | xprint('VALID %d/%d = %f, ', validCor, validTot, validAcc) 193 | local testCor, testTot = valid(mlp, 'test', opts) 194 | local testAcc = testCor/testTot 195 | xprint('TEST %d/%d = %f \n', testCor, testTot, testAcc) 196 | end 197 | 198 | main() 199 | 200 | -------------------------------------------------------------------------------- /main.lua: -------------------------------------------------------------------------------- 1 | 2 | require '.' 3 | require 'shortcut' 4 | require 'TreeLSTMLM' 5 | require 'TreeLM_Dataset' 6 | 7 | local EPOCH_INFO = '' 8 | 9 | local function getOpts() 10 | local cmd = torch.CmdLine() 11 | cmd:text('====== Tree LSTM Language Model ======') 12 | cmd:text() 13 | cmd:option('--seed', 123, 'random seed') 14 | cmd:option('--dataset', '', 'dataset path') 15 | cmd:option('--maxEpoch', 100, 'maximum number of epochs') 16 | cmd:option('--batchSize', 64, '') 17 | cmd:option('--nin', 100, 'word embedding size') 18 | cmd:option('--nhid', 300, 'hidden unit size') 19 | cmd:option('--nlayers', 1, 'number of hidden layers') 20 | cmd:option('--lr', 0.1, 'learning rate') 21 | cmd:option('--lrDiv', 0, 'learning rate decay when there is no significant improvement. 0 means turn off') 22 | cmd:option('--minImprovement', 1.0001, 'if improvement on log likelihood is smaller then patient --') 23 | cmd:option('--optimMethod', 'AdaGrad', 'optimization algorithm') 24 | cmd:option('--gradClip', 5, '> 0 means to do Pascanu et al.\'s grad norm rescale http://arxiv.org/pdf/1502.04623.pdf; < 0 means to truncate the gradient larger than gradClip; 0 means turn off gradient clip') 25 | cmd:option('--initRange', 0.1, 'init range') 26 | cmd:option('--initHidVal', 0.01, 'init values for hidden states') 27 | cmd:option('--seqLen', 151, 'maximum seqence length') 28 | cmd:option('--useGPU', false, 'use GPU') 29 | cmd:option('--patience', 2, 'stop training if no lower valid PPL is observed in [patience] consecutive epoch(s)') 30 | cmd:option('--save', 'model.t7', 'save model path') 31 | 32 | return cmd:parse(arg) 33 | end 34 | 35 | local function train(rnn, lmdata, opts) 36 | local dataIter = lmdata:createBatch('train', opts.batchSize) 37 | local dataSize, curDataSize = lmdata:getTrainSize(), 0 38 | local percent, inc = 0.001, 0.001 39 | local timer = torch.Timer() 40 | -- local sgdParam = {learningRate = opts.curLR} 41 | local sgdParam = opts.sgdParam 42 | local cnt = 0 43 | local totalLoss = 0 44 | local totalCnt = 0 45 | for x, y in dataIter do 46 | local loss = rnn:trainBatch(x, y, sgdParam) 47 | local nll = loss * x:size(2) / (y:ne(0):sum()) 48 | totalLoss = totalLoss + loss * x:size(2) 49 | totalCnt = totalCnt + y:ne(0):sum() 50 | 51 | curDataSize = curDataSize + x:size(2) 52 | local ratio = curDataSize/dataSize 53 | if ratio >= percent then 54 | local wps = totalCnt / timer:time().real 55 | xprint( '\r%s %.3f %.4f (%s) / %.2f wps ... ', EPOCH_INFO, ratio, totalLoss/totalCnt, readableTime(timer:time().real), wps ) 56 | percent = math.floor(ratio / inc) * inc 57 | percent = percent + inc 58 | end 59 | 60 | cnt = cnt + 1 61 | if cnt % 5 == 0 then 62 | collectgarbage() 63 | end 64 | end 65 | 66 | return totalLoss / totalCnt 67 | end 68 | 69 | local function valid(rnn, lmdata, opts, splitLabel) 70 | local dataIter = lmdata:createBatch(splitLabel, opts.batchSize) 71 | local totalCnt = 0 72 | local totalLoss = 0 73 | local cnt = 0 74 | for x, y in dataIter do 75 | local loss = rnn:validBatch(x, y) 76 | totalLoss = totalLoss + loss * x:size(2) 77 | totalCnt = totalCnt + y:ne(0):sum() 78 | cnt = cnt + 1 79 | if cnt % 5 == 0 then 80 | collectgarbage() 81 | end 82 | end 83 | local entropy = totalLoss / totalCnt 84 | local ppl = torch.exp(entropy) 85 | return {entropy = entropy, ppl = ppl} 86 | end 87 | 88 | local function verifyModel(modelPath) 89 | xprintln('\n==verify trained model==') 90 | local optsPath = modelPath:sub(1, -4) .. '.state.t7' 91 | local opts = torch.load(optsPath) 92 | xprintln('load state from %s done!', optsPath) 93 | 94 | print(opts) 95 | local lmdata = TreeLM_Dataset(opts.dataset) 96 | local rnn = TreeLSTMLM(opts) 97 | xprintln( 'load model from %s', opts.save ) 98 | rnn:load(opts.save) 99 | xprintln( 'load model from %s done!', opts.save ) 100 | 101 | xprintln('\n') 102 | local validRval = valid(rnn, lmdata, opts, 'valid') 103 | xprint('VALID %f ', validRval.ppl) 104 | local testRval = valid(rnn, lmdata, opts, 'test') 105 | xprintln('TEST %f ', testRval.ppl) 106 | end 107 | 108 | local function initOpts(opts) 109 | -- for different optimization algorithms 110 | local optimMethods = {'AdaGrad', 'Adam', 'AdaDelta', 'SGD'} 111 | if not table.contains(optimMethods, opts.optimMethod) then 112 | error('invalid optimization problem ' .. opts.optimMethod) 113 | end 114 | 115 | opts.curLR = opts.lr 116 | opts.minLR = 1e-7 117 | opts.sgdParam = {learningRate = opts.lr} 118 | if opts.optimMethod == 'AdaDelta' then 119 | opts.rho = 0.95 120 | opts.eps = 1e-6 121 | opts.sgdParam.rho = opts.rho 122 | opts.sgdParam.eps = opts.eps 123 | elseif opts.optimMethod == 'SGD' then 124 | if opts.lrDiv <= 1 then 125 | opts.lrDiv = 2 126 | end 127 | end 128 | 129 | end 130 | 131 | local function main() 132 | local opts = getOpts() 133 | initOpts(opts) 134 | local lmdata = TreeLM_Dataset(opts.dataset) 135 | opts.nvocab = lmdata:getVocabSize() 136 | 137 | print(opts) 138 | torch.manualSeed(opts.seed) 139 | if opts.useGPU then 140 | require 'cutorch' 141 | require 'cunn' 142 | cutorch.manualSeed(opts.seed) 143 | end 144 | 145 | local rnn = TreeLSTMLM(opts) 146 | local bestValid = {ppl = 1e309, entropy = 1e309} 147 | local lastValid = {ppl = 1e309, entropy = 1e309} 148 | local bestModel = torch.FloatTensor(rnn.params:size()) 149 | local patience = opts.patience 150 | local divLR = false 151 | local timer = torch.Timer() 152 | local epochNo = 0 153 | for epoch = 1, opts.maxEpoch do 154 | epochNo = epochNo + 1 155 | EPOCH_INFO = string.format('epoch %d', epoch) 156 | local startTime = timer:time().real 157 | local trainCost = train(rnn, lmdata, opts) 158 | xprint('\repoch %d TRAIN nll %f ', epoch, trainCost) 159 | local validRval = valid(rnn, lmdata, opts, 'valid') 160 | xprint('VALID %f ', validRval.ppl) 161 | local testRval = valid(rnn, lmdata, opts, 'test') 162 | xprint('TEST %f ', testRval.ppl) 163 | local endTime = timer:time().real 164 | xprintln('lr = %.4g (%s) p = %d', opts.curLR, readableTime(endTime - startTime), patience) 165 | 166 | if validRval.ppl < bestValid.ppl then 167 | bestValid.ppl = validRval.ppl 168 | bestValid.entropy = validRval.entropy 169 | bestValid.epoch = epoch 170 | rnn:getModel(bestModel) 171 | -- for non SGD algorithm, we will reset the patience 172 | -- if opts.optimMethod ~= 'SGD' then 173 | if opts.lrDiv <= 1 then 174 | patience = opts.patience 175 | end 176 | else 177 | -- non SGD algorithm decrease patience 178 | if opts.lrDiv <= 1 then 179 | -- if opts.optimMethod ~= 'SGD' then 180 | patience = patience - 1 181 | if patience == 0 then 182 | xprintln('No improvement on PPL for %d epoch(s). Training finished!', opts.patience) 183 | break 184 | end 185 | else 186 | -- SGD with learning rate decay 187 | rnn:setModel(bestModel) 188 | end 189 | 190 | end -- if validRval.ppl < bestValid.ppl 191 | 192 | -- control the learning rate decay 193 | -- if opts.optimMethod == 'SGD' then 194 | if opts.lrDiv > 1 then 195 | if epoch >= 10 and patience > 1 then 196 | patience = 1 197 | end 198 | 199 | if validRval.entropy * opts.minImprovement > lastValid.entropy then 200 | if not divLR then -- patience == 1 201 | patience = patience - 1 202 | if patience < 1 then divLR = true end 203 | else 204 | xprintln('no significant improvement! cur ppl %f, best ppl %f', validRval.ppl, bestValid.ppl) 205 | break 206 | end 207 | end 208 | 209 | if divLR then 210 | opts.curLR = opts.curLR / opts.lrDiv 211 | opts.sgdParam.learningRate = opts.curLR 212 | end 213 | 214 | if opts.curLR < opts.minLR then 215 | xprintln('min lr is met! cur lr %e min lr %e', opts.curLR, opts.minLR) 216 | break 217 | end 218 | lastValid.ppl = validRval.ppl 219 | lastValid.entropy = validRval.entropy 220 | end 221 | end 222 | 223 | if epochNo > opts.maxEpoch then 224 | xprintln('Max number of epoch is met. Training finished!') 225 | end 226 | 227 | lmdata:close() 228 | 229 | rnn:setModel(bestModel) 230 | opts.sgdParam = nil 231 | rnn:save(opts.save, true) 232 | xprintln('model saved at %s', opts.save) 233 | 234 | verifyModel(opts.save) 235 | end 236 | 237 | -- here is the entry 238 | main() 239 | 240 | -------------------------------------------------------------------------------- /rerank.lua: -------------------------------------------------------------------------------- 1 | 2 | require '.' 3 | require 'shortcut' 4 | require 'TreeLSTMNCELM' 5 | require 'hdf5' 6 | require 'BiTreeLSTMNCELM' 7 | 8 | local model_utils = require 'model_utils' 9 | 10 | -- currently only support TreeLSTMNCELM 11 | local Reranker = torch.class('TreeLSTMLMReranker') 12 | 13 | function Reranker:__init(modelPath, useGPU) 14 | local optsPath = modelPath:sub(1, -4) .. '.state.t7' 15 | print(optsPath) 16 | local opts = torch.load(optsPath) 17 | xprintln('load state from %s done!', optsPath) 18 | 19 | opts.useGPU = useGPU 20 | print(opts) 21 | 22 | if opts.model == 'TreeLSTMNCE' then 23 | self.rnnlm = TreeLSTMNCELM(opts) 24 | elseif opts.model == 'BiTreeLSTMNCE' then 25 | self.rnnlm = BiTreeLSTMNCELM(opts) 26 | else 27 | error('currently only support TreeLSTMNCELM') 28 | end 29 | 30 | xprintln( 'load model from %s', modelPath ) 31 | self.rnnlm:load(modelPath) 32 | xprintln( 'load model from %s done!', modelPath ) 33 | end 34 | 35 | function Reranker:rerank(testFile, outFile, batchSize) 36 | self.rnnlm:disableDropout() 37 | 38 | local logp_sents = {} 39 | local dataIter 40 | if self.rnnlm.name:starts('BiTree') then 41 | dataIter = Reranker.createBatchBidirectional(testFile, batchSize) 42 | else 43 | dataIter = Reranker.createBatch(testFile, batchSize) 44 | end 45 | 46 | local cnt = 0 47 | for x, y, lchild, lc_mask in dataIter do 48 | -- yPred | size: (seqlen*bs, nvocab) 49 | local _tmp, yPred 50 | if self.rnnlm.name:starts('BiTree') then 51 | _tmp, yPred = self.rnnlm:validBatch(x, y, lchild, lc_mask) 52 | else 53 | _tmp, yPred = self.rnnlm:validBatch(x, y) 54 | end 55 | local mask = y:ne(0):double() 56 | y[y:eq(0)] = 1 57 | local y_ = y:view(y:size(1) * y:size(2), 1) 58 | local logps = yPred:gather(2, y_) -- shape: seqlen*bs, 1 59 | local logp_sents_ = logps:cmul(mask):view(y:size(1), y:size(2)):sum(1):squeeze() 60 | for i = 1, logp_sents_:size(1) do 61 | logp_sents[#logp_sents + 1] = logp_sents_[i] 62 | end 63 | 64 | cnt = cnt + y:size(2) 65 | if cnt % 100 == 0 then 66 | xprintln('cnt = %d', cnt) 67 | end 68 | end 69 | 70 | local lmfile = './msr_scripts/Holmes.lm_format.questions.txt' 71 | local lines = xreadlines(lmfile) 72 | assert(#lines == #logp_sents, 'there should be the same number of sentences in testFile and lmfile') 73 | 74 | local fout = io.open(outFile, 'w') 75 | for i = 1, #lines do 76 | fout:write( string.format('%s\t%f\n', lines[i]:trim(), logp_sents[i]) ) 77 | end 78 | fout:close() 79 | 80 | local accFile = outFile .. '.acc' 81 | Reranker.score2accuracy(outFile, accFile) 82 | 83 | self.rnnlm:enableDropout() 84 | end 85 | 86 | function Reranker.score2accuracy(scoreFile, accFile) 87 | local bestof5_pl = './msr_scripts/bestof5.pl' 88 | local score_pl = './msr_scripts/score.pl' 89 | local ans_file = './msr_scripts/Holmes.lm_format.answers.txt' 90 | local tmp_file = scoreFile .. '.__sample.temp__' 91 | local cmd = string.format('cat %s | %s > %s', scoreFile, bestof5_pl, tmp_file) 92 | os.execute(cmd) 93 | cmd = string.format('%s %s %s > %s', score_pl, tmp_file, ans_file, accFile) 94 | os.execute(cmd) 95 | 96 | local lines = xreadlines(accFile) 97 | local nLines = #lines 98 | for i = nLines - 4, nLines do 99 | print(lines[i]) 100 | end 101 | end 102 | 103 | function Reranker.toBatch(xs, ys, batchSize) 104 | local dtype = 'torch.LongTensor' 105 | local maxn = 0 106 | for _, y_ in ipairs(ys) do 107 | if y_:size(1) > maxn then 108 | maxn = y_:size(1) 109 | end 110 | end 111 | local x = torch.ones(maxn, batchSize, 4):type(dtype) 112 | -- x:mul(1) 113 | x[{ {}, {}, 4 }] = torch.linspace(2, maxn + 1, maxn):resize(maxn, 1):expand(maxn, batchSize) 114 | local nsent = #ys 115 | local y = torch.zeros(maxn, batchSize):type(dtype) 116 | for i = 1, nsent do 117 | local sx, sy = xs[i], ys[i] 118 | x[{ {1, sx:size(1)}, i, {} }] = sx 119 | y[{ {1, sy:size(1)}, i }] = sy 120 | end 121 | 122 | return x, y 123 | end 124 | 125 | function Reranker.createBatch(testH5File, batchSize) 126 | local h5in = hdf5.open(testH5File, 'r') 127 | local label = 'test' 128 | local x_data = h5in:read(string.format('/%s/x_data', label)) 129 | local y_data = h5in:read(string.format('/%s/y_data', label)) 130 | local index = h5in:read(string.format('/%s/index', label)) 131 | local N = index:dataspaceSize()[1] 132 | 133 | local istart = 1 134 | 135 | return function() 136 | if istart <= N then 137 | local iend = math.min(istart + batchSize - 1, N) 138 | local xs = {} 139 | local ys = {} 140 | for i = istart, iend do 141 | local idx = index:partial({i, i}, {1, 2}) 142 | local start, len = idx[1][1], idx[1][2] 143 | local x = x_data:partial({start, start + len - 1}, {1, 4}) 144 | local y = y_data:partial({start, start + len - 1}) 145 | table.insert(xs, x) 146 | table.insert(ys, y) 147 | end 148 | 149 | istart = iend + 1 150 | 151 | local x, y = Reranker.toBatch(xs, ys, batchSize) 152 | 153 | return x, y 154 | else 155 | h5in:close() 156 | end 157 | end 158 | end 159 | 160 | function Reranker.toBatchBidirectional(xs, ys, lcs, batchSize) 161 | local dtype = 'torch.LongTensor' 162 | local maxn = 0 163 | for _, y_ in ipairs(ys) do 164 | if y_:size(1) > maxn then 165 | maxn = y_:size(1) 166 | end 167 | end 168 | local x = torch.ones(maxn, batchSize, 5):type(dtype) 169 | -- x:mul(self.UNK) 170 | x[{ {}, {}, 4 }] = torch.linspace(2, maxn + 1, maxn):resize(maxn, 1):expand(maxn, batchSize) 171 | x[{ {}, {}, 5 }] = 0 -- in default, I don't want them to have 172 | local nsent = #ys 173 | local y = torch.zeros(maxn, batchSize):type(dtype) 174 | for i = 1, nsent do 175 | local sx, sy = xs[i], ys[i] 176 | x[{ {1, sx:size(1)}, i, {} }] = sx 177 | y[{ {1, sy:size(1)}, i }] = sy 178 | end 179 | 180 | -- for left children 181 | assert(#lcs == #xs, 'should be the same!') 182 | local lcBatchSize = 0 183 | local maxLcSeqLen = 0 184 | for _, lc in ipairs(lcs) do 185 | if lc:dim() ~= 0 then 186 | lcBatchSize = lcBatchSize + 1 187 | maxLcSeqLen = math.max(maxLcSeqLen, lc:size(1)) 188 | end 189 | end 190 | local lchild = torch.Tensor():type(dtype) 191 | local lc_mask = torch.FloatTensor() 192 | 193 | if lcBatchSize ~= 0 then 194 | lchild:resize(maxLcSeqLen, lcBatchSize):fill(1) -- UNK should be 1 195 | lc_mask:resize(maxLcSeqLen, lcBatchSize):fill(0) 196 | local j = 0 197 | for i, lc in ipairs(lcs) do 198 | if lc:dim() ~= 0 then 199 | j = j + 1 200 | lchild[{ {1, lc:size(1)}, j }] = lc[{ {}, 1 }] 201 | lc_mask[{ {1, lc:size(1)}, j }] = lc[{ {}, 2 }] + 1 202 | local xcol = x[{ {}, i, 5 }] 203 | local idxs = xcol:ne(0) 204 | xcol[idxs] = (xcol[idxs] - 1) * lcBatchSize + j 205 | end 206 | end 207 | end 208 | 209 | return x, y, lchild, lc_mask 210 | end 211 | 212 | function Reranker.createBatchBidirectional(testH5File, batchSize) 213 | local h5in = hdf5.open(testH5File, 'r') 214 | local label = 'test' 215 | local x_data = h5in:read(string.format('/%s/x_data', label)) 216 | local y_data = h5in:read(string.format('/%s/y_data', label)) 217 | local index = h5in:read(string.format('/%s/index', label)) 218 | local l_data = h5in:read( string.format('/%s/l_data', label) ) 219 | local lindex = h5in:read( string.format('/%s/lindex', label) ) 220 | local N = index:dataspaceSize()[1] 221 | 222 | local istart = 1 223 | 224 | return function() 225 | if istart <= N then 226 | local iend = math.min(istart + batchSize - 1, N) 227 | local xs = {} 228 | local ys = {} 229 | local lcs = {} 230 | for i = istart, iend do 231 | local idx = index:partial({i, i}, {1, 2}) 232 | local start, len = idx[1][1], idx[1][2] 233 | local x = x_data:partial({start, start + len - 1}, {1, 5}) 234 | local y = y_data:partial({start, start + len - 1}) 235 | table.insert(xs, x) 236 | table.insert(ys, y) 237 | 238 | local lidx = lindex:partial({i, i}, {1, 2}) 239 | local lstart, llen = lidx[1][1], lidx[1][2] 240 | local lc 241 | if llen == 0 then 242 | lc = torch.IntTensor() -- to be the same type as l_data 243 | else 244 | lc = l_data:partial({lstart, lstart + llen - 1}, {1, 2}) 245 | end 246 | table.insert(lcs, lc) 247 | end 248 | 249 | istart = iend + 1 250 | 251 | return Reranker.toBatchBidirectional(xs, ys, lcs, batchSize) 252 | else 253 | h5in:close() 254 | end 255 | end 256 | end 257 | 258 | local function getOpts() 259 | local cmd = torch.CmdLine() 260 | cmd:text('====== Reranking for MSR sentence completion challenge ======') 261 | cmd:option('--useGPU', false, 'do you want to run this on a GPU?') 262 | cmd:option('--modelPath', '', 'path for the trained model; modelPath.state.t7 should be the option of the model') 263 | cmd:option('--testFile', '', 'test file for reranking (.h5)') 264 | cmd:option('--outFile', '', 'path for the output file') 265 | cmd:option('--batchSize', 20, 'batch size') 266 | 267 | return cmd:parse(arg) 268 | end 269 | 270 | local function main() 271 | local opts = getOpts() 272 | assert(5200 % opts.batchSize == 0) 273 | print(opts) 274 | local reranker = TreeLSTMLMReranker(opts.modelPath, opts.useGPU) 275 | reranker:rerank(opts.testFile, opts.outFile, opts.batchSize) 276 | end 277 | 278 | main() 279 | 280 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Top-down Tree Long Short-Term Memory Networks 2 | =============================================== 3 | 4 | 5 | A [Torch](https://github.com/torch) implementation of the Top-down TreeLSTM described in the following paper. 6 | 7 | ### [Top-down Tree Long Short-Term Memory Networks](http://aclweb.org/anthology/N/N16/N16-1035.pdf) 8 | Xingxing Zhang, Liang Lu and Mirella Lapata. In *Proceedings of the 2016 Conference of the North American Chapter of the Association for Computational Linguistics: Human Language Technologies* (NAACL 2016). 9 | ``` 10 | @InProceedings{zhang-lu-lapata:2016:N16-1, 11 | author = {Zhang, Xingxing and Lu, Liang and Lapata, Mirella}, 12 | title = {Top-down Tree Long Short-Term Memory Networks}, 13 | booktitle = {Proceedings of the 2016 Conference of the North American Chapter of the Association for Computational Linguistics: Human Language Technologies}, 14 | month = {June}, 15 | year = {2016}, 16 | address = {San Diego, California}, 17 | publisher = {Association for Computational Linguistics}, 18 | pages = {310--320}, 19 | url = {http://www.aclweb.org/anthology/N16-1035} 20 | } 21 | ``` 22 | 23 | ### Implemented Models 24 | * TreeLSTM 25 | * TreeLSTM-NCE 26 | * LdTreeLSTM 27 | * LdTreeLSTM-NCE 28 | 29 | TreeLSTM and LdTreeLSTM (check the details in the paper above) are trained with Negative Log-likelihood (NLL); while TreeLSTM-NCE and LdTreeLSTM-NCE are trained with [Noise Contrastive Estimation](https://www.cs.helsinki.fi/u/ahyvarin/papers/Gutmann12JMLR.pdf) (NCE) (see [this paper](https://www.cs.helsinki.fi/u/ahyvarin/papers/Gutmann12JMLR.pdf) and also [this paper](https://www.cs.toronto.edu/~amnih/papers/ncelm.pdf) for details). 30 | 31 | Note that in experiments, the normalization term Z of NCE is learned automatically. The implemented NCE module also support keeping Z fixed. 32 | 33 | # Requirements 34 | * a Nvidia GPU 35 | * [CUDA 6.5.19](http://www.nvidia.com/object/cuda_home_new.html) (higher version should be fine) 36 | * [Torch](https://github.com/torch) 37 | * [torch-hdf5](https://github.com/deepmind/torch-hdf5) 38 | 39 | Torch can be installed with the instructions [here](http://torch.ch/docs/getting-started.html). 40 | You also need to install some torch components. 41 | ``` 42 | luarocks install nn 43 | luarocks install nngraph 44 | luarocks install cutorch 45 | luarocks install cunn 46 | ``` 47 | You may find [this document](https://github.com/deepmind/torch-hdf5/blob/master/doc/usage.md) useful when installing torch-hdf5 (DON'T use luarocks). 48 | 49 | Please also note that to run the code, you need to use an old version of Torch with the instructions [here](OLD_VERSION.md). 50 | 51 | 52 | # Language Modeling (MSR Sentence Completion) 53 | 54 | Pre-trained models (TreeLSTM-400 and LdTreeLSTM-400) are available https://drive.google.com/file/d/0B6-YKFW-MnbOc3pya29UZkZpUFU/view?usp=sharing 55 | 56 | ## Preprocessing 57 | First, parse the dataset into dependency trees using [Stanford CoreNLP toolkit](http://stanfordnlp.github.io/CoreNLP/). 58 | It should looks like this 59 | ``` 60 | SILAP10.TXT#0 det(Etext-4, The-1) nn(Etext-4, Project-2) nn(Etext-4, Gutenberg-3) root(ROOT-0, Etext-4) prep(Etext-4, of-5) det(Rise-7, The-6) pobj(of-5, Rise-7) prep(Rise-7, of-8) nn(Lapham-10, Silas-9) pobj(of-8, Lapham-10) prep(Etext-4, by-11) nn(Howells-14, William-12) nn(Howells-14, Dean-13) pobj(by-11, Howells-14) det(RISE-16, THE-15) dep(Etext-4, RISE-16) prep(RISE-16, OF-17) nn(LAPHAM-19, SILAS-18) pobj(OF-17, LAPHAM-19) prep(RISE-16, by-20) nn(Howells-23, William-21) nn(Howells-23, Dean-22) pobj(by-20, Howells-23) npadvmod(Howells-23, I-24) punct(Etext-4, .-25) 61 | SILAP10.TXT#1 advmod(went-4, WHEN-1) nn(Hubbard-3, Bartley-2) nsubj(went-4, Hubbard-3) advcl(received-40, went-4) aux(interview-6, to-5) xcomp(went-4, interview-6) nn(Lapham-8, Silas-7) dobj(interview-6, Lapham-8) prep(interview-6, for-9) det(Men-13, the-10) punct(Men-13, ``-11) amod(Men-13, Solid-12) pobj(for-9, Men-13) prep(Men-13, of-14) pobj(of-14, Boston-15) punct(Men-13, ''-16) dep(Men-13, series-17) punct(series-17, ,-18) dobj(undertook-21, which-19) nsubj(undertook-21, he-20) rcmod(series-17, undertook-21) aux(finish-23, to-22) xcomp(undertook-21, finish-23) prt(finish-23, up-24) prep(finish-23, in-25) det(Events-27, The-26) pobj(in-25, Events-27) punct(received-40, ,-28) mark(replaced-31, after-29) nsubj(replaced-31, he-30) advcl(received-40, replaced-31) poss(projector-34, their-32) amod(projector-34, original-33) dobj(replaced-31, projector-34) prep(replaced-31, on-35) det(newspaper-37, that-36) pobj(on-35, newspaper-37) punct(received-40, ,-38) nsubj(received-40, Lapham-39) root(ROOT-0, received-40) dobj(received-40, him-41) prep(received-40, in-42) poss(office-45, his-43) amod(office-45, private-44) pobj(in-42, office-45) prep(received-40, by-46) amod(appointment-48, previous-47) pobj(by-46, appointment-48) punct(received-40, .-49) 62 | ... 63 | ... 64 | ... 65 | ``` 66 | Each line is a sentence (format: label \t dependency tuples), where *SILAP10.TXT#0* is the label for the sentence (it can be any string and it doesn't matter). 67 | 68 | Dataset after the preprocessing above can be downloaded [here](https://drive.google.com/file/d/0B6-YKFW-MnbOcmE5TmRyVlZYTjg/view?usp=sharing). 69 | 70 | Then, convert the dependency tree dataset into HDF5 format and sort the dataset to make sure sentences in each batch have similar length. Sorting the dataset is for faster training, which is a commonly used strategy for training RNN or Sequence based models. 71 | 72 | ### Create Dataset for TreeLSTM 73 | ``` 74 | cd scripts 75 | ./run_msr.sh 76 | ./run_msr_sort.sh 77 | ./run_msr_test.sh 78 | ``` 79 | Note the program will crash when running *./run_msr.sh*. You can ignore the crash or you should use *--sort 0* switch instead of *--sort 20*. 80 | 81 | ### Create Dataset for LdTreeLSTM 82 | ``` 83 | cd scripts 84 | ./run_msr_bid.sh 85 | ./run_msr_sort_bid.sh 86 | ./run_msr_test_bid.sh 87 | ``` 88 | Alternately, you can contact the first author to request the dataset after preprocessing. 89 | 90 | ## Training and Evaluation 91 | Basically it is just one command. 92 | ``` 93 | cd experiments/msr 94 | # to run TreeLSTM with hidden size 400 95 | ./treelstm_h400.sh 96 | # to run LdTreeLSTM with hidden size 400 97 | ./ldtreelstm_h400.sh 98 | ``` 99 | But don't forget to specify where is your code, your dataset and whatever by modifying treelstm_h400.sh or ldtreelstm_h400.sh. 100 | ``` 101 | # where is your code? (you should use absolute path) 102 | codedir=/afs/inf.ed.ac.uk/group/project/xTreeLSTM/xtreelstm/td-treelstm-release 103 | # where is your dataset (you should use absolute path) 104 | dataset=/disk/scratch/XingxingZhang/treelstm/dataset/msr/msr.dep.100.bid.sort20.h5 105 | # label for this model 106 | label=.ldtree.400 107 | 108 | # where is your testset (you should use absolute path); this will only be used in evaluation 109 | testfile=/disk/scratch/XingxingZhang/treelstm/dataset/msr/msr.dep.100.bid.question.h5 110 | ``` 111 | # Dependency Parsing Reranking 112 | 113 | ## Preprocessing 114 | 115 | ### For TreeLSTM 116 | ``` 117 | cd scripts 118 | ./run_conllx_sort.sh 119 | ``` 120 | 121 | ### For LdTreeLSTM 122 | ``` 123 | cd scripts 124 | ./run_conllx_sort_bid.sh 125 | ``` 126 | 127 | ## Train and Evaluate Dependency Reranking Models 128 | Training TreeLSTMs and LdTreeLSTMs are quit similar. 129 | The following is about training a TreeLSTM. 130 | ``` 131 | cd experiments/depparse 132 | ./treelstm_we_train.sh 133 | 134 | ``` 135 | Then, you will get a trained TreeLSTM. We can use this TreeLSTM 136 | to rerank the *K* dependencies produced by the second order [MSTParser](http://www.seas.upenn.edu/~strctlrn/MSTParser/MSTParser.html). 137 | 138 | The following script will use the trained dependency model to rerank the top 20 dependencies from MSRParser on the validation set. The script will try different *K* and choose the one gives best UAS. 139 | ``` 140 | ./treelstm_we_rerank_valid.sh 141 | ``` 142 | Given the *K* we've got from the validation set, we can get the reranking performance on test set by using the following script. 143 | ``` 144 | ./treelstm_we_rerank_test.sh 145 | ``` 146 | 147 | # Dependency Tree Generation 148 | 149 | ### How will we generate dependency trees? (details see Section 3.4 of the paper) 150 | * Run the Language Modeling experiment or the dependency parsing experiment to get a trained TreeLSTM or LdTreeLSTM 151 | * Generate training data for the four classifiers (Add-Left, Add-Right, Add-Nx-Left, Add-Nx-Right) 152 | * Train Add-Left, Add-Right, Add-Nx-Left and Add-Nx-Right 153 | * Generate dependency trees with a trained TreeLSTM (or LdTreeLSTM) and the four classifiers 154 | 155 | ### Generate Training data 156 | Go to *sampler.lua* and run the following code 157 | ``` 158 | -- model_1.0.w200.t7 is the trained TreeLSTM 159 | -- penn_wsj.conllx.sort.h5 is the dataset for the trained TreeLSTM 160 | -- eot.penn_wsj.conllx.sort.h5 is the output dataset for the four classifiers 161 | generateDataset('/disk/scratch/XingxingZhang/treelstm/experiments/ptb_depparse/2layer_w_wo_we/model_1.0.w200.t7', 162 | '/disk/scratch/XingxingZhang/treelstm/dataset/depparse/dataset/penn_wsj.conllx.sort.h5', 163 | '/disk/scratch/XingxingZhang/treelstm/dataset/depparse/eot.penn_wsj.conllx.sort.h5') 164 | 165 | ``` 166 | 167 | ### Train the Four Classifiers 168 | Use *train_mlp.lua* 169 | ``` 170 | $ th train_mlp.lua -h 171 | Usage: /afs/inf.ed.ac.uk/user/s12/s1270921/torch/install/lib/luarocks/rocks/trepl/scm-1/bin/th [options] 172 | ====== MLP v 1.0 ====== 173 | 174 | --seed random seed [123] 175 | --useGPU use gpu [false] 176 | --snhids string hidden sizes for each layer [400,300,300,2] 177 | --activ options: tanh, relu [tanh] 178 | --dropout dropout rate (dropping) [0] 179 | --maxEpoch max number of epochs [10] 180 | --dataset dataset [/disk/scratch/XingxingZhang/treelstm/dataset/depparse/eot.penn_wsj.conllx.sort.h5] 181 | --ftype [|x|oe|] 182 | --ytype [1] 183 | --batchSize [256] 184 | --lr [0.01] 185 | --optimMethod options: SGD, AdaGrad [AdaGrad] 186 | --save save path [model.t7] 187 | 188 | ``` 189 | Note *--ytype* 1, 2, 3, 4 corresponds to the four classifiers. Here is a sample script: 190 | ``` 191 | ID=`./gpu_lock.py --id-to-hog 2` 192 | echo $ID 193 | if [ $ID -eq -1 ]; then 194 | echo "no gpu is free" 195 | exit 196 | fi 197 | ./gpu_lock.py 198 | 199 | curdir=`pwd` 200 | codedir=/afs/inf.ed.ac.uk/group/project/xTreeLSTM/xtreelstm/MLP_test 201 | lr=0.01 202 | label=yt1.x.oe 203 | model=model.$label.t7 204 | log=log.$label.txt 205 | echo $curdir 206 | echo $codedir 207 | 208 | cd $codedir 209 | CUDA_VISIBLE_DEVICES=$ID th train_mlp.lua --useGPU \ 210 | --activ relu --dropout 0.5 --lr $lr --maxEpoch 10 \ 211 | --snhids "400,300,300,2" --ftype "|x|oe|" --ytype 1 \ 212 | --save $curdir/$model | tee $curdir/$log 213 | 214 | cd $curdir 215 | 216 | ./gpu_lock.py --free $ID 217 | ./gpu_lock.py 218 | 219 | ``` 220 | ### Generation by Sampling 221 | Go to *sampler.lua* and run the following code. The code will output dependency trees in LaTeX format. 222 | ``` 223 | -- model_1.0.w200.t7: trained TreeLSTM 224 | -- model.yt%d.x.oe.t7: trained classifiers, note that model.yt1.x.oe.t7, model.yt2.x.oe.t7, model.yt3.x.oe.t7 and model.yt4.x.oe.t7 must all exist 225 | sampleTrees('/disk/scratch/XingxingZhang/treelstm/experiments/ptb_depparse/2layer_w_wo_we/model_1.0.w200.t7', 226 | '/disk/scratch/XingxingZhang/treelstm/experiments/sampling/eot_classify/model.yt%d.x.oe.t7', 227 | 's100.txt', -- output dependency trees 228 | 1, -- rand seed 229 | 100) -- number of tree samples 230 | ``` 231 | -------------------------------------------------------------------------------- /main_nce.lua: -------------------------------------------------------------------------------- 1 | 2 | require '.' 3 | require 'shortcut' 4 | require 'TreeLSTMLM' 5 | require 'TreeLM_Dataset' 6 | 7 | require 'TreeLSTMNCELM' 8 | require 'TreeLM_NCE_Dataset' 9 | 10 | local model_utils = require 'model_utils' 11 | local EPOCH_INFO = '' 12 | 13 | local function getOpts() 14 | local cmd = torch.CmdLine() 15 | cmd:text('====== Tree LSTM NCE Language Model ======') 16 | cmd:text('version 2.2 add word embedding support') 17 | cmd:text() 18 | cmd:option('--seed', 123, 'random seed') 19 | cmd:option('--model', 'TreeLSTM', 'model options: TreeLSTM, TreeLSTMNCE') 20 | cmd:option('--dataset', '', 'dataset path') 21 | cmd:option('--maxEpoch', 100, 'maximum number of epochs') 22 | cmd:option('--batchSize', 64, '') 23 | cmd:option('--validBatchSize', 16, '') 24 | cmd:option('--nin', 50, 'word embedding size') 25 | cmd:option('--nhid', 100, 'hidden unit size') 26 | cmd:option('--nlayers', 1, 'number of hidden layers') 27 | cmd:option('--wordEmbedding', '', 'path for the word embedding file') 28 | cmd:option('--lr', 0.1, 'learning rate') 29 | cmd:option('--lrDiv', 0, 'learning rate decay when there is no significant improvement. 0 means turn off') 30 | cmd:option('--minImprovement', 1.0001, 'if improvement on log likelihood is smaller then patient --') 31 | cmd:option('--optimMethod', 'AdaGrad', 'optimization algorithm') 32 | cmd:option('--gradClip', 5, '> 0 means to do Pascanu et al.\'s grad norm rescale http://arxiv.org/pdf/1502.04623.pdf; < 0 means to truncate the gradient larger than gradClip; 0 means turn off gradient clip') 33 | cmd:option('--initRange', 0.1, 'init range') 34 | cmd:option('--initHidVal', 0.01, 'init values for hidden states') 35 | cmd:option('--seqLen', 151, 'maximum seqence length') 36 | cmd:option('--useGPU', false, 'use GPU') 37 | cmd:option('--patience', 2, 'stop training if no lower valid PPL is observed in [patience] consecutive epoch(s)') 38 | cmd:option('--save', 'model.t7', 'save model path') 39 | 40 | cmd:text() 41 | cmd:text('Options for NCE') 42 | cmd:option('--nneg', 20, 'number of negative samples') 43 | cmd:option('--power', 0.75, 'for power for unigram frequency') 44 | cmd:option('--lnZ', 9.5, 'default normalization term') 45 | cmd:option('--learnZ', false, 'learn the normalization constant Z') 46 | cmd:option('--normalizeUNK', false, 'if normalize UNK or not') 47 | 48 | cmd:text() 49 | cmd:text('Options for long jobs') 50 | cmd:option('--savePerEpoch', false, 'save model every epoch') 51 | cmd:option('--saveBeforeLrDiv', false, 'save model before lr div') 52 | 53 | cmd:text() 54 | cmd:text('Options for regularization') 55 | cmd:option('--dropout', 0, 'dropout rate (dropping)') 56 | 57 | return cmd:parse(arg) 58 | end 59 | 60 | local function train(rnn, lmdata, opts) 61 | local dataIter 62 | if opts.model:find('NCE') then 63 | dataIter = lmdata:createBatch('train', opts.batchSize, true) 64 | else 65 | dataIter = lmdata:createBatch('train', opts.batchSize) 66 | end 67 | 68 | local dataSize, curDataSize = lmdata:getTrainSize(), 0 69 | local percent, inc = 0.001, 0.001 70 | local timer = torch.Timer() 71 | -- local sgdParam = {learningRate = opts.curLR} 72 | local sgdParam = opts.sgdParam 73 | local cnt = 0 74 | local totalLoss = 0 75 | local totalCnt = 0 76 | for x, y, y_neg, y_prob, y_neg_prob, mask in dataIter do 77 | local loss 78 | if y_neg then 79 | loss = rnn:trainBatch(x, y, y_neg, y_prob, y_neg_prob, mask, sgdParam) 80 | else 81 | loss = rnn:trainBatch(x, y, sgdParam) 82 | end 83 | 84 | local nll = loss * x:size(2) / (y:ne(0):sum()) 85 | if mask then 86 | nll = loss * x:size(2) / (mask:sum()) 87 | else 88 | nll = loss * x:size(2) / (y:ne(0):sum()) 89 | end 90 | 91 | totalLoss = totalLoss + loss * x:size(2) 92 | if mask then 93 | totalCnt = totalCnt + mask:sum() 94 | else 95 | totalCnt = totalCnt + y:ne(0):sum() 96 | end 97 | 98 | curDataSize = curDataSize + x:size(2) 99 | local ratio = curDataSize/dataSize 100 | if ratio >= percent then 101 | local wps = totalCnt / timer:time().real 102 | xprint( '\r%s %.3f %.4f (%s) / %.2f wps ... ', EPOCH_INFO, ratio, totalLoss/totalCnt, readableTime(timer:time().real), wps ) 103 | percent = math.floor(ratio / inc) * inc 104 | percent = percent + inc 105 | end 106 | 107 | cnt = cnt + 1 108 | if cnt % 5 == 0 then 109 | collectgarbage() 110 | end 111 | end 112 | 113 | return totalLoss / totalCnt 114 | end 115 | 116 | local function valid(rnn, lmdata, opts, splitLabel) 117 | rnn:disableDropout() 118 | 119 | local dataIter = lmdata:createBatch(splitLabel, opts.validBatchSize) 120 | local totalCnt = 0 121 | local totalLoss = 0 122 | local cnt = 0 123 | for x, y in dataIter do 124 | local loss = rnn:validBatch(x, y) 125 | totalLoss = totalLoss + loss * x:size(2) 126 | totalCnt = totalCnt + y:ne(0):sum() 127 | cnt = cnt + 1 128 | if cnt % 5 == 0 then 129 | collectgarbage() 130 | end 131 | end 132 | 133 | rnn:enableDropout() 134 | 135 | local entropy = totalLoss / totalCnt 136 | local ppl = torch.exp(entropy) 137 | return {entropy = entropy, ppl = ppl} 138 | end 139 | 140 | local function verifyModel(modelPath) 141 | xprintln('\n==verify trained model==') 142 | local optsPath = modelPath:sub(1, -4) .. '.state.t7' 143 | local opts = torch.load(optsPath) 144 | xprintln('load state from %s done!', optsPath) 145 | 146 | print(opts) 147 | local lmdata = nil 148 | if opts.model == 'TreeLSTM' then 149 | lmdata = TreeLM_Dataset(opts.dataset) 150 | elseif opts.model == 'TreeLSTMNCE' then 151 | lmdata = TreeLM_NCE_Dataset(opts.dataset, opts.nneg, opts.power, opts.normalizeUNK) 152 | end 153 | -- local lmdata = TreeLM_Dataset(opts.dataset) 154 | 155 | local rnn 156 | if opts.model == 'TreeLSTM' then 157 | rnn = TreeLSTMLM(opts) 158 | elseif opts.model == 'TreeLSTMNCE' then 159 | rnn = TreeLSTMNCELM(opts) 160 | end 161 | 162 | -- local rnn = TreeLSTMLM(opts) 163 | xprintln( 'load model from %s', opts.save ) 164 | rnn:load(opts.save) 165 | xprintln( 'load model from %s done!', opts.save ) 166 | 167 | xprintln('\n') 168 | local validRval = valid(rnn, lmdata, opts, 'valid') 169 | xprint('VALID %f ', validRval.ppl) 170 | local testRval = valid(rnn, lmdata, opts, 'test') 171 | xprintln('TEST %f ', testRval.ppl) 172 | end 173 | 174 | local function initOpts(opts) 175 | -- for different models 176 | local nceParams = {'nneg', 'power', 'normalizeUNK', 'learnZ', 'lnZ'} 177 | if opts.model == 'TreeLSTM' then 178 | -- delete nce params 179 | for _, nceparam in ipairs(nceParams) do 180 | opts[nceparam] = nil 181 | end 182 | end 183 | 184 | -- for different optimization algorithms 185 | local optimMethods = {'AdaGrad', 'Adam', 'AdaDelta', 'SGD'} 186 | if not table.contains(optimMethods, opts.optimMethod) then 187 | error('invalid optimization problem ' .. opts.optimMethod) 188 | end 189 | 190 | opts.curLR = opts.lr 191 | opts.minLR = 1e-7 192 | opts.sgdParam = {learningRate = opts.lr} 193 | if opts.optimMethod == 'AdaDelta' then 194 | opts.rho = 0.95 195 | opts.eps = 1e-6 196 | opts.sgdParam.rho = opts.rho 197 | opts.sgdParam.eps = opts.eps 198 | elseif opts.optimMethod == 'SGD' then 199 | if opts.lrDiv <= 1 then 200 | opts.lrDiv = 2 201 | end 202 | end 203 | 204 | end 205 | 206 | local function main() 207 | local opts = getOpts() 208 | print('version 2.2 add word embedding support') 209 | 210 | initOpts(opts) 211 | 212 | local lmdata = nil 213 | if opts.model == 'TreeLSTM' then 214 | lmdata = TreeLM_Dataset(opts.dataset) 215 | elseif opts.model == 'TreeLSTMNCE' then 216 | lmdata = TreeLM_NCE_Dataset(opts.dataset, opts.nneg, opts.power, opts.normalizeUNK) 217 | end 218 | opts.nvocab = lmdata:getVocabSize() 219 | 220 | print(opts) 221 | torch.manualSeed(opts.seed) 222 | if opts.useGPU then 223 | require 'cutorch' 224 | require 'cunn' 225 | cutorch.manualSeed(opts.seed) 226 | end 227 | 228 | local rnn = nil 229 | if opts.model == 'TreeLSTM' then 230 | rnn = TreeLSTMLM(opts) 231 | elseif opts.model == 'TreeLSTMNCE' then 232 | rnn = TreeLSTMNCELM(opts) 233 | end 234 | 235 | local bestValid = {ppl = 1e309, entropy = 1e309} 236 | local lastValid = {ppl = 1e309, entropy = 1e309} 237 | local bestModel = torch.FloatTensor(rnn.params:size()) 238 | local patience = opts.patience 239 | local divLR = false 240 | local timer = torch.Timer() 241 | local epochNo = 0 242 | for epoch = 1, opts.maxEpoch do 243 | epochNo = epochNo + 1 244 | EPOCH_INFO = string.format('epoch %d', epoch) 245 | local startTime = timer:time().real 246 | local trainCost = train(rnn, lmdata, opts) 247 | -- print('training ignored!!!') 248 | -- local trainCost = 123 249 | xprint('\repoch %d TRAIN nll %f ', epoch, trainCost) 250 | local validRval = valid(rnn, lmdata, opts, 'valid') 251 | xprint('VALID %f ', validRval.ppl) 252 | --[[ 253 | local testRval = valid(rnn, lmdata, opts, 'test') 254 | xprint('TEST %f ', testRval.ppl) 255 | --]] 256 | local endTime = timer:time().real 257 | xprintln('lr = %.4g (%s) p = %d', opts.curLR, readableTime(endTime - startTime), patience) 258 | 259 | if validRval.ppl < bestValid.ppl then 260 | bestValid.ppl = validRval.ppl 261 | bestValid.entropy = validRval.entropy 262 | bestValid.epoch = epoch 263 | rnn:getModel(bestModel) 264 | -- for non SGD algorithm, we will reset the patience 265 | -- if opts.optimMethod ~= 'SGD' then 266 | if opts.lrDiv <= 1 then 267 | patience = opts.patience 268 | end 269 | else 270 | -- non SGD algorithm decrease patience 271 | if opts.lrDiv <= 1 then 272 | -- if opts.optimMethod ~= 'SGD' then 273 | patience = patience - 1 274 | if patience == 0 then 275 | xprintln('No improvement on PPL for %d epoch(s). Training finished!', opts.patience) 276 | break 277 | end 278 | else 279 | -- SGD with learning rate decay 280 | rnn:setModel(bestModel) 281 | end 282 | 283 | end -- if validRval.ppl < bestValid.ppl 284 | 285 | if opts.savePerEpoch then 286 | local tmpPath = opts.save:sub(1, -4) .. '.tmp.t7' 287 | rnn:save(tmpPath, true) 288 | end 289 | 290 | if opts.saveBeforeLrDiv then 291 | if opts.optimMethod == 'SGD' and opts.curLR == opts.lr then 292 | local tmpPath = opts.save:sub(1, -4) .. '.blrd.t7' 293 | rnn:save(tmpPath, true) 294 | end 295 | end 296 | 297 | -- control the learning rate decay 298 | -- if opts.optimMethod == 'SGD' then 299 | if opts.lrDiv > 1 then 300 | if epoch >= 10 and patience > 1 then 301 | patience = 1 302 | end 303 | 304 | if validRval.entropy * opts.minImprovement > lastValid.entropy then 305 | if not divLR then -- patience == 1 306 | patience = patience - 1 307 | if patience < 1 then divLR = true end 308 | else 309 | xprintln('no significant improvement! cur ppl %f, best ppl %f', validRval.ppl, bestValid.ppl) 310 | break 311 | end 312 | end 313 | 314 | if divLR then 315 | opts.curLR = opts.curLR / opts.lrDiv 316 | opts.sgdParam.learningRate = opts.curLR 317 | end 318 | 319 | if opts.curLR < opts.minLR then 320 | xprintln('min lr is met! cur lr %e min lr %e', opts.curLR, opts.minLR) 321 | break 322 | end 323 | lastValid.ppl = validRval.ppl 324 | lastValid.entropy = validRval.entropy 325 | end 326 | end 327 | 328 | if epochNo > opts.maxEpoch then 329 | xprintln('Max number of epoch is met. Training finished!') 330 | end 331 | 332 | lmdata:close() 333 | 334 | rnn:setModel(bestModel) 335 | opts.sgdParam = nil 336 | rnn:save(opts.save, true) 337 | xprintln('model saved at %s', opts.save) 338 | 339 | -- verifyModel(opts.save) 340 | end 341 | 342 | -- here is the entry 343 | main() 344 | 345 | --------------------------------------------------------------------------------