├── .gitignore ├── LICENSE ├── README.md ├── conllx_scripts ├── conllx2006_eval.lua ├── conllx_eval.lua ├── eval.lua ├── eval.pl ├── eval_new.pl ├── extract_embed.lua ├── replace_conllx_field.lua └── split_dev.lua ├── dataiter ├── DepDataIter.lua └── DepPosDataIter.lua ├── dense_parser.lua ├── experiments ├── chinese │ ├── gen_lbl_train.sh │ ├── mst-post.sh │ ├── run_lbl.sh │ ├── train.sh │ └── tune.sh ├── czech │ ├── gen_lbl_train.sh │ ├── mst-post.sh │ ├── run_lbl.sh │ ├── train.sh │ └── tune.sh ├── english │ ├── gen_lbl_train.sh │ ├── mst-post.sh │ ├── run_lbl.sh │ ├── train.sh │ └── tune.sh ├── german │ ├── gen_lbl_train.sh │ ├── mst-post.sh │ ├── run_lbl.sh │ ├── train.sh │ └── tune.sh └── run_parser │ ├── run_chinese.sh │ ├── run_czech.sh │ ├── run_english.sh │ └── run_german.sh ├── graph_alg ├── ChuLiuEdmonds.lua ├── Eisner.lua └── PostDepGraph.lua ├── init.lua ├── layers ├── .DS_Store ├── Contiguous.lua ├── DetailedMaskedNLLCriterion.lua ├── EMaskedClassNLLCriterion.lua ├── Linear3D.lua ├── LookupTable_ft.lua ├── MaskedClassNLLCriterion.lua └── PReplicate.lua ├── model_opts.lua ├── mst_postprocess.lua ├── nnets ├── MLP.lua ├── SelectNet.lua ├── SelectNetPos.lua └── basic.lua ├── post_train.lua ├── train.lua ├── train_labeled.lua └── utils ├── .DS_Store ├── model_utils.lua ├── shortcut.lua ├── wordembedding.lua └── wordembedding_ft.lua /.gitignore: -------------------------------------------------------------------------------- 1 | # Prerequisites 2 | *.d 3 | 4 | # Compiled Object files 5 | *.slo 6 | *.lo 7 | *.o 8 | *.obj 9 | 10 | # Precompiled Headers 11 | *.gch 12 | *.pch 13 | 14 | # Compiled Dynamic libraries 15 | *.so 16 | *.dylib 17 | *.dll 18 | 19 | # Fortran module files 20 | *.mod 21 | *.smod 22 | 23 | # Compiled Static libraries 24 | *.lai 25 | *.la 26 | *.a 27 | *.lib 28 | 29 | # Executables 30 | *.exe 31 | *.out 32 | *.app 33 | 34 | # dependency files 35 | *.dep 36 | 37 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright 2016-Present University of Edinburgh (author: Xingxing Zhang) 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | 15 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Dependency Parsing as Head Selection 2 | 3 | This is an implmentation of the DeNSe (**De**pendency 4 | **N**eural **Se**lection) parser described in [Dependency Parsing as Head Selection](https://arxiv.org/abs/1606.01280) 5 | 6 | # Dependencies 7 | * [CUDA 7.0.28](http://www.nvidia.com/object/cuda_home_new.html) 8 | * [Torch](https://github.com/torch) 9 | * [torch-hdf5](https://github.com/deepmind/torch-hdf5) 10 | 11 | You may also need to install some torch components. 12 | ``` 13 | luarocks install nn 14 | luarocks install nngraph 15 | luarocks install cutorch 16 | luarocks install cunn 17 | ``` 18 | The parser was developed with an old version of torch (the version around Feb 2016). 19 | 20 | # Run the Parser 21 | The parser can parse text in conllx format (note that POS tags must be provided). 22 | If the gold standard file is provided via `--gold`, the parse will also print out the UAS and LAS. 23 | ``` 24 | CUDA_VISIBLE_DEVICES=3 th dense_parser.lua --modelPath $model --classifierPath $classifier \ 25 | --input $input --output $output --gold $input --mstalg Eisner 26 | ``` 27 | Feel free to try scripts in `experiments/run_parser`. 28 | 29 | # Get Train Dev Splits for German and Czech 30 | Please refer to the `main` function of `conllx_scripts/split_dev.lua` 31 | 32 | # Convert pre-trained embeddings 33 | You need to convert glove vectors from text format to `t7` format. 34 | ``` 35 | conllx_scripts/extract_embed.lua -h 36 | ``` 37 | 38 | # Train an Unlabeled Parser 39 | Without loss of generality, we use Czech as an example. 40 | 41 | First, train the model with Adam algorithm using the script `experiments/czech/train.sh` 42 | ``` 43 | CUDA_VISIBLE_DEVICES=$ID th train.lua --useGPU \ 44 | --model SelectNetPos \ 45 | --seqLen 112 \ 46 | --maxTrainLen 110 \ 47 | --freqCut 1 \ 48 | --nhid 300 \ 49 | --nin 300 \ 50 | --nlayers 2 \ 51 | --dropout 0.35 \ 52 | --recDropout 0.1 \ 53 | --lr $lr \ 54 | --train $train \ 55 | --valid $valid \ 56 | --test $test \ 57 | --optimMethod Adam \ 58 | --save $model \ 59 | --batchSize 20 \ 60 | --validBatchSize 20 \ 61 | --maxEpoch 15 \ 62 | --npin 40 \ 63 | --evalType conllx \ 64 | | tee $log 65 | ``` 66 | After we reach the convergence of Adam, we switch to plain SGD using `experiments/czech/tune.sh`, which can usually give us a slight improvement. 67 | ``` 68 | CUDA_VISIBLE_DEVICES=$ID th post_train.lua \ 69 | --load $load \ 70 | --save $model \ 71 | --lr $lr \ 72 | --maxEpoch 10 \ 73 | --optimMethod SGD \ 74 | | tee $log 75 | ``` 76 | Lastly, we use a MST algorithm to adjust the non-tree outputs with `experiments/czech/mst-post.sh` 77 | ``` 78 | CUDA_VISIBLE_DEVICES=3 th mst_postprocess.lua \ 79 | --modelPath $model \ 80 | --mstalg ChuLiuEdmonds \ 81 | --validout $validout \ 82 | --testout $testout | tee $log 83 | ``` 84 | 85 | # Train a labeled Parser 86 | Based on the trained unlabled parser, we first generate training data for the labeled parser with `experiments/czech/gen_lbl_train.sh` 87 | ``` 88 | CUDA_VISIBLE_DEVICES=3 th train_labeled.lua --mode generate \ 89 | --modelPath $model \ 90 | --outTrainDataPath $outTrain \ 91 | --inTrain $inTrain \ 92 | --inValid $inValid \ 93 | --inTest $inTest \ 94 | --outValid $outValid \ 95 | --outTest $outTest \ 96 | --language Other | tee $log 97 | ``` 98 | Then we train the labeled parser actually an MLP with `experiments/czech/run_lbl.sh` 99 | ``` 100 | CUDA_VISIBLE_DEVICES=3 th train_labeled.lua --mode train \ 101 | --useGPU \ 102 | --snhids "1880,800,800,82" \ 103 | --activ relu \ 104 | --lr 0.01 \ 105 | --optimMethod AdaGrad \ 106 | --dropout 0.5 \ 107 | --inDropout 0.05 \ 108 | --batchSize 256 \ 109 | --maxEpoch 20 \ 110 | --ftype "|x|xe|xpe|" \ 111 | --dataset $dataset \ 112 | --inTrain $inTrain \ 113 | --inValid $inValid \ 114 | --inTest $inTest \ 115 | --language Other \ 116 | --save $model | tee $log 117 | ``` 118 | 119 | # Downloads 120 | ## Pre-trained Models 121 | https://drive.google.com/file/d/0B6-YKFW-MnbOVjlQNmlYWTFPT2c/view?usp=sharing 122 | ## Pre-trained Chinese Embeddings 123 | https://drive.google.com/file/d/0B6-YKFW-MnbOMjdXSVlKTkFwR0E/view?usp=sharing 124 | 125 | # Citation 126 | ``` 127 | @InProceedings{zhang-cheng-lapata:2017:EACLlong, 128 | author = {Zhang, Xingxing and Cheng, Jianpeng and Lapata, Mirella}, 129 | title = {Dependency Parsing as Head Selection}, 130 | booktitle = {Proceedings of the 15th Conference of the European Chapter of the Association for Computational Linguistics: Volume 1, Long Papers}, 131 | month = {April}, 132 | year = {2017}, 133 | address = {Valencia, Spain}, 134 | publisher = {Association for Computational Linguistics}, 135 | pages = {665--676}, 136 | url = {http://www.aclweb.org/anthology/E17-1063} 137 | } 138 | 139 | ``` 140 | -------------------------------------------------------------------------------- /conllx_scripts/conllx2006_eval.lua: -------------------------------------------------------------------------------- 1 | 2 | include '../utils/shortcut.lua' 3 | 4 | local CoNLLXEval = {} 5 | 6 | function CoNLLXEval.toStd(sysFile, goldFile) 7 | local tmpFile = sysFile .. '__tmp__' 8 | local fin_s = io.open(sysFile) 9 | local fin_g = io.open(goldFile) 10 | local fout = io.open(tmpFile, 'w') 11 | 12 | while true do 13 | local sline = fin_s:read() 14 | local gline = fin_g:read() 15 | if sline == nil then assert(gline == nil) break end 16 | if sline:len() ~= 0 then 17 | local sfields = sline:splitc('\t') 18 | local gfields = gline:splitc('\t') 19 | assert(#sfields == #gfields and #sfields == 10) 20 | for i = 1, 6 do 21 | sfields[i] = gfields[i] 22 | end 23 | sline = table.concat(sfields, '\t') 24 | end 25 | fout:write(sline) 26 | fout:write('\n') 27 | end 28 | 29 | fin_s:close() 30 | fin_g:close() 31 | fout:close() 32 | 33 | return tmpFile 34 | end 35 | 36 | function CoNLLXEval.run(sysFile, goldFile, params) 37 | params = params or '-q' 38 | local eval_script_name = 'eval.pl' 39 | if paths.dirp('./conllx_scripts') then 40 | eval_script_name = './conllx_scripts/' .. eval_script_name 41 | end 42 | local cmd = string.format('perl %s -g %s -s %s %s', eval_script_name, goldFile, sysFile, params) 43 | local fin = io.popen(cmd, 'r') 44 | local s = fin:read('*a') 45 | -- print '----------------------' 46 | -- print(s) 47 | -- print '----------------------' 48 | local reg = 'Labeled attachment score: %d+ / %d+ %* 100 = ([^%s]+) %%' 49 | local _, _, lval = s:find(reg) 50 | local LAS = tonumber(lval) 51 | reg = 'Unlabeled attachment score: %d+ / %d+ %* 100 = ([^%s]+) %%' 52 | local _, _, uval = s:find(reg) 53 | local UAS = tonumber(uval) 54 | 55 | return LAS, UAS 56 | end 57 | 58 | function CoNLLXEval.eval(sysFile, goldFile) 59 | local tmp_sysfile = CoNLLXEval.toStd(sysFile, goldFile) 60 | local LAS, UAS = CoNLLXEval.run(tmp_sysfile, goldFile) 61 | 62 | local LAS_punct, UAS_punct = CoNLLXEval.run(tmp_sysfile, goldFile, "-q -p") 63 | os.execute(string.format('rm %s', tmp_sysfile)) 64 | 65 | xprintln('==no punct==') 66 | xprintln('LAS = %.2f, UAS = %.2f', LAS, UAS) 67 | xprintln('==with punct==') 68 | xprintln('LAS = %.2f, UAS = %.2f', LAS_punct, UAS_punct) 69 | 70 | return LAS_punct, UAS_punct, LAS, UAS 71 | end 72 | 73 | local function getOpts() 74 | local cmd = torch.CmdLine() 75 | cmd:text('====== Evaluation Script for Dependency Parser: conllx 2006 standard ======') 76 | cmd:option('--sysFile', '', 'system output') 77 | cmd:option('--goldFile', '', 'gold standard') 78 | 79 | return cmd:parse(arg) 80 | end 81 | 82 | local function main() 83 | -- local sysFile = '/disk/scratch/Dataset/conll/2006/zxx_version/czech/czech_gold_test_1.conll' 84 | -- local goldFile = '/disk/scratch/Dataset/conll/2006/zxx_version/czech/czech_gold_test.conll' 85 | local opts = getOpts() 86 | CoNLLXEval.eval(opts.sysFile, opts.goldFile) 87 | end 88 | 89 | if not package.loaded['conllx2006_eval'] then 90 | main() 91 | else 92 | return CoNLLXEval 93 | end 94 | -------------------------------------------------------------------------------- /conllx_scripts/conllx_eval.lua: -------------------------------------------------------------------------------- 1 | 2 | include '../utils/shortcut.lua' 3 | 4 | local CoNLLXEval = {} 5 | 6 | function CoNLLXEval.conllxLineIterator(infile) 7 | local fin = io.open(infile) 8 | local bufs = {} 9 | 10 | return function() 11 | while true do 12 | local line = fin:read() 13 | if line == nil then 14 | fin:close() 15 | break 16 | end 17 | line = line:trim() 18 | if line:len() == 0 then 19 | local rlines = {} 20 | for i, buf in ipairs(bufs) do 21 | rlines[i] = buf 22 | end 23 | table.clear(bufs) 24 | 25 | return rlines 26 | else 27 | table.insert(bufs, line) 28 | end 29 | end 30 | end 31 | 32 | end 33 | 34 | function CoNLLXEval.eval(sysFile, goldFile) 35 | -- local punctTags ={ "``", "''", ".", ",", ":" } 36 | local punctTags ={ "``", "''", ".", ",", ":", "PU" } 37 | local punctTagSet = {} 38 | for _, pt in ipairs(punctTags) do 39 | punctTagSet[pt] = true 40 | end 41 | -- print(punctTagSet) 42 | 43 | local sysIter = CoNLLXEval.conllxLineIterator(sysFile) 44 | local goldIter = CoNLLXEval.conllxLineIterator(goldFile) 45 | local sen_cnt = 0 46 | local total, noPunctTotal = 0, 0 47 | local nUA, noPunctNUA = 0, 0 48 | local nLA, noPunctNLA = 0, 0 49 | 50 | for sysLines in sysIter do 51 | local goldLines = goldIter() 52 | assert(#sysLines == #goldLines, 'the sys sentence and the gold sentence should contain the same number of words') 53 | for i = 1, #sysLines do 54 | local sfields = sysLines[i]:splitc('\t ') 55 | local gfields = goldLines[i]:splitc('\t ') 56 | local sAid, gAid = tonumber(sfields[7]), tonumber(gfields[7]) 57 | local sDep, gDep = sfields[8], gfields[8] 58 | if sAid == gAid then 59 | nUA = nUA + 1 60 | if sDep == gDep then nLA = nLA + 1 end 61 | end 62 | 63 | total = total + 1 64 | 65 | local gtag = gfields[5] 66 | if not punctTagSet[gtag] then 67 | noPunctTotal = noPunctTotal + 1 68 | if sAid == gAid then 69 | noPunctNUA = noPunctNUA + 1 70 | if sDep == gDep then noPunctNLA = noPunctNLA + 1 end 71 | end 72 | end 73 | end 74 | 75 | sen_cnt = sen_cnt + 1 76 | end 77 | 78 | xprintln('totally %d sentences', sen_cnt) 79 | local LAS, UAS = nLA / total * 100, nUA / total * 100 80 | local noPunctLAS, noPunctUAS = noPunctNLA / noPunctTotal * 100, noPunctNUA / noPunctTotal * 100 81 | 82 | xprintln('==no punct==') 83 | xprintln('LAS = %.2f, UAS = %.2f', noPunctLAS, noPunctUAS) 84 | xprintln('==with punct==') 85 | xprintln('LAS = %.2f, UAS = %.2f', LAS, UAS) 86 | 87 | return LAS, UAS, noPunctLAS, noPunctUAS 88 | end 89 | 90 | return CoNLLXEval 91 | -------------------------------------------------------------------------------- /conllx_scripts/eval.lua: -------------------------------------------------------------------------------- 1 | 2 | local conllx_eval = require('conllx_eval') 3 | 4 | local function getOpts() 5 | local cmd = torch.CmdLine() 6 | cmd:text('====== Evaluation Script for Dependency Parser ======') 7 | cmd:option('--sysFile', '', 'system output') 8 | cmd:option('--goldFile', '', 'gold standard') 9 | 10 | return cmd:parse(arg) 11 | end 12 | 13 | local function main() 14 | local opts = getOpts() 15 | local conllx_eval = require('conllx_eval') 16 | conllx_eval.eval(opts.sysFile, opts.goldFile) 17 | end 18 | 19 | main() 20 | -------------------------------------------------------------------------------- /conllx_scripts/extract_embed.lua: -------------------------------------------------------------------------------- 1 | 2 | include '../utils/shortcut.lua' 3 | 4 | local function extractTxtWE(vocabF, WEPath, WEOutPath) 5 | local vocab = torch.load(vocabF) 6 | print(table.keys(vocab)) 7 | 8 | local fin = io.open(WEPath) 9 | local msize, nsize = 0, 0 10 | local cnt = 0 11 | local wetable = {} 12 | local idx2word = {} 13 | while true do 14 | local line = fin:read() 15 | if line == nil then break end 16 | local fields = line:splitc(' \t') 17 | local width = #fields - 1 18 | if nsize == 0 then 19 | nsize = width 20 | else 21 | assert(nsize == width) 22 | end 23 | local word = fields[1] 24 | if vocab.word2idx[word] ~= nil then 25 | msize = msize + 1 26 | idx2word[msize] = word 27 | local v = {} 28 | for i = 2, width + 1 do 29 | table.insert(v, tonumber(fields[i])) 30 | end 31 | table.insert(wetable, v) 32 | end 33 | 34 | cnt = cnt + 1 35 | if cnt % 10000 == 0 then 36 | printf('cnt = %d\n', cnt) 37 | end 38 | end 39 | print('totally ' .. msize .. ' lines remain') 40 | 41 | local word2idx = {} 42 | for i, w in pairs(idx2word) do 43 | word2idx[w] = i 44 | end 45 | print(#word2idx, #idx2word) 46 | 47 | local final_we = torch.FloatTensor(wetable) 48 | print 'begin to save' 49 | torch.save(WEOutPath, {final_we, word2idx, idx2word}) 50 | print( 'save done at ' .. WEOutPath) 51 | end 52 | 53 | local function main() 54 | local cmd = torch.CmdLine() 55 | cmd:option('--vocab', 'ptb.train.tmp.vocab.t7', 'path for vocab file') 56 | cmd:option('--wepath', 'glove.840B.300d.txt', 'glove vectors') 57 | cmd:option('--weoutpath', 'ptb.glove.840B.300d.t7', 'output embedding file') 58 | local opts = cmd:parse(arg) 59 | 60 | extractTxtWE(opts.vocab, opts.wepath, opts.weoutpath) 61 | end 62 | 63 | main() 64 | 65 | -------------------------------------------------------------------------------- /conllx_scripts/replace_conllx_field.lua: -------------------------------------------------------------------------------- 1 | 2 | include '../utils/shortcut.lua' 3 | 4 | local ReplaceConllxField = {} 5 | 6 | function ReplaceConllxField.replace(infile, outfile, index, value) 7 | local fin = io.open(infile) 8 | local fout = io.open(outfile, 'w') 9 | while true do 10 | local line = fin:read() 11 | if line == nil then break end 12 | line = line:trim() 13 | if line:len() ~= 0 then 14 | local fields = line:splitc('\t') 15 | assert(#fields == 10, 'MUST have 10 fields!') 16 | fields[index] = value 17 | fout:write( table.concat(fields, '\t') ) 18 | fout:write('\n') 19 | else 20 | fout:write('\n') 21 | end 22 | end 23 | fin:close() 24 | fout:close() 25 | end 26 | 27 | local function main() 28 | ReplaceConllxField.replace('/disk/scratch/XingxingZhang/dep_parse/experiments/we_select_ft_pos_chinese/ft/mst_post/valid.dep', 29 | 'test.out', 30 | 8, 31 | 'ROOT') 32 | end 33 | 34 | if not package.loaded['replace_conllx_field'] then 35 | main() 36 | end 37 | 38 | return ReplaceConllxField 39 | -------------------------------------------------------------------------------- /conllx_scripts/split_dev.lua: -------------------------------------------------------------------------------- 1 | 2 | include '../utils/shortcut.lua' 3 | 4 | local function conllxLineIterator(infile) 5 | local fin = io.open(infile) 6 | local bufs = {} 7 | 8 | return function() 9 | while true do 10 | local line = fin:read() 11 | if line == nil then 12 | fin:close() 13 | break 14 | end 15 | line = line:trim() 16 | if line:len() == 0 then 17 | local rlines = {} 18 | for i, buf in ipairs(bufs) do 19 | rlines[i] = buf 20 | end 21 | table.clear(bufs) 22 | 23 | return rlines 24 | else 25 | table.insert(bufs, line) 26 | end 27 | end 28 | end 29 | 30 | end 31 | 32 | local function get_corpus_information(train_file, maxlen) 33 | local train_sents = {} 34 | local train_iter = conllxLineIterator(train_file) 35 | local word_cnt, sent_cnt = 0, 0 36 | local sent_len = { max = 0, avg = 0, min = 123456789, over_cnt = 0, maxlen = maxlen } 37 | for sent in train_iter do 38 | sent_cnt = sent_cnt + 1 39 | local nword = #sent 40 | word_cnt = word_cnt + nword 41 | sent_len.avg = sent_len.avg + nword 42 | sent_len.max = math.max(sent_len.max, nword) 43 | sent_len.min = math.min(sent_len.min, nword) 44 | if nword > maxlen then 45 | sent_len.over_cnt = sent_len.over_cnt + 1 46 | end 47 | end 48 | sent_len.avg = sent_len.avg / sent_cnt 49 | print '==training set information==' 50 | printf('word cnt = %d, sent cnt = %d\n', word_cnt, sent_cnt) 51 | print 'Sentence Length:' 52 | print(sent_len) 53 | print(sent_len.over_cnt / sent_cnt) 54 | sent_len.sent_cnt = sent_cnt 55 | 56 | return sent_len 57 | end 58 | 59 | local function split_dev(train_file, test_file, out_file) 60 | local maxlen = 110 61 | print '==Corpus Information for Train==' 62 | local train_info = get_corpus_information(train_file, maxlen) 63 | print(train_info) 64 | 65 | print '==Corpus Information for Test==' 66 | local test_info = get_corpus_information(test_file, maxlen) 67 | print(test_info) 68 | 69 | local out_train_file = out_file .. 'train.conll' 70 | local out_dev_file = out_file .. 'dev.conll' 71 | local out_test_file = out_file .. 'test.conll' 72 | assert(train_file ~= out_train_file, 'MUST have a different name!!!') 73 | assert(test_file ~= out_test_file, 'MUST have different name!!!') 74 | local dev_cnt = test_info.sent_cnt + 10 75 | local train_cnt = train_info.sent_cnt - dev_cnt 76 | 77 | local train_iter = conllxLineIterator(train_file) 78 | local cnt = 0 79 | local fout_train = io.open(out_train_file, 'w') 80 | local fout_dev = io.open(out_dev_file, 'w') 81 | for sent in train_iter do 82 | cnt = cnt + 1 83 | local fout = cnt <= train_cnt and fout_train or fout_dev 84 | local label = cnt <= train_cnt and 'Train' or 'Dev' 85 | if #sent <= maxlen then 86 | for _, line in ipairs(sent) do 87 | fout:write(line .. '\n') 88 | end 89 | fout:write('\n') 90 | else 91 | printf('[%s] longer than %d\n', label, maxlen) 92 | end 93 | end 94 | fout_train:close() 95 | fout_dev:close() 96 | 97 | os.execute(string.format('cp %s %s', test_file, out_test_file)) 98 | 99 | print '==Final Dataset==' 100 | print '===Train==' 101 | get_corpus_information(out_train_file, maxlen) 102 | print '===Dev===' 103 | get_corpus_information(out_dev_file, maxlen) 104 | print '===Test===' 105 | get_corpus_information(out_test_file, maxlen) 106 | end 107 | 108 | local function main() 109 | --[[ 110 | split_dev('/disk/scratch/Dataset/conll/2006/zxx_version/czech/czech_pdt_train.conll', 111 | '/disk/scratch/Dataset/conll/2006/zxx_version/czech/czech_pdt_test.conll', 112 | '/disk/scratch/Dataset/conll/2006/zxx_version/czech/czech_gold_') 113 | --]] 114 | 115 | split_dev('/disk/scratch/Dataset/conll/2006/zxx_version/german/german_tiger_train.conll', 116 | '/disk/scratch/Dataset/conll/2006/zxx_version/german/german_tiger_test.conll', 117 | '/disk/scratch/Dataset/conll/2006/zxx_version/german/german_gold_') 118 | end 119 | 120 | main() -------------------------------------------------------------------------------- /dataiter/DepDataIter.lua: -------------------------------------------------------------------------------- 1 | 2 | local DataIter = torch.class('DepDataIter') 3 | 4 | function DataIter.conllx_iter(infile) 5 | local fin = io.open(infile) 6 | 7 | return function() 8 | local items = {} 9 | while true do 10 | local line = fin:read() 11 | if line == nil then break end 12 | line = line:trim() 13 | if line:len() == 0 then 14 | break 15 | end 16 | local fields = line:splitc('\t') 17 | assert(#fields == 10, 'MUST have 10 fields') 18 | local item = {p1 = tonumber(fields[1]), wd = fields[2], 19 | pos = fields[5], p2 = fields[7], rel = fields[8]} 20 | table.insert(items, item) 21 | end 22 | if #items > 0 then 23 | return items 24 | else 25 | fin:close() 26 | end 27 | end 28 | end 29 | 30 | function DataIter.getDataSize(infiles) 31 | local sizes = {} 32 | for _, infile in ipairs(infiles) do 33 | local size = 0 34 | local diter = DataIter.conllx_iter(infile) 35 | for ds in diter do 36 | size = size + 1 37 | end 38 | sizes[#sizes + 1] = size 39 | end 40 | 41 | return sizes 42 | end 43 | 44 | function DataIter.showVocab(vocab) 45 | for k, v in pairs(vocab) do 46 | xprint(k) 47 | if type(v) == 'table' then 48 | print ' -- table' 49 | else 50 | print( ' -- ' .. tostring(v) ) 51 | end 52 | end 53 | end 54 | 55 | function DataIter.createVocab(infile, ignoreCase, freqCut, maxNVocab) 56 | local wordFreq = {} 57 | local wordVec = {} 58 | local diter = DataIter.conllx_iter(infile) 59 | 60 | local function addwd(wd) 61 | local wd = ignoreCase and wd:lower() or wd 62 | local freq = wordFreq[wd] 63 | if freq == nil then 64 | wordFreq[wd] = 1 65 | wordVec[#wordVec + 1] = wd 66 | else 67 | wordFreq[wd] = freq + 1 68 | end 69 | end 70 | 71 | for sent in diter do 72 | addwd('###root###') 73 | for _, ditem in ipairs(sent) do 74 | addwd(ditem.wd) 75 | end 76 | end 77 | 78 | local idx2word 79 | if freqCut and freqCut >= 0 then 80 | idx2word = {} 81 | idx2word[#idx2word + 1] = 'UNK' 82 | for _, wd in ipairs(wordVec) do 83 | if wordFreq[wd] > freqCut then idx2word[#idx2word + 1] = wd end 84 | end 85 | 86 | printf('original word count = %d, after freq cut = %d, word count = %d\n', #wordVec, freqCut, #idx2word) 87 | end 88 | 89 | if maxNVocab and maxNVocab > 0 then 90 | if #idx2word > 0 then 91 | print( 'WARING: rewrote idx2word with maxNVocab = ' .. maxNVocab ) 92 | end 93 | idx2word = {} 94 | idx2word[#idx2word + 1] = 'UNK' 95 | local wfs = {} 96 | for _, k in ipairs(wordVec) do table.insert(wfs, {k, wordFreq[k]}) end 97 | table.sort(wfs, function(x, y) return x[2] > y[2] end) 98 | local lfreq = -1 99 | for cnt, kv in ipairs(wfs) do 100 | idx2word[#idx2word + 1] = kv[1] 101 | lfreq = kv[2] 102 | if cnt >= maxNVocab-1 then break end 103 | end 104 | printf('original word count = %d, after maxNVocab = %d, word count = %d, lowest freq = %d\n', #wordVec, maxNVocab, #idx2word, lfreq) 105 | end 106 | 107 | local word2idx = {} 108 | for i, w in ipairs(idx2word) do word2idx[w] = i end 109 | local vocab = {word2idx = word2idx, idx2word = idx2word, 110 | freqCut = freqCut, ignoreCase = ignoreCase, maxNVocab = maxNVocab, 111 | UNK_STR = 'UNK', UNK = word2idx['UNK'], 112 | ROOT_STR = '###root###', ROOT = word2idx['###root###']} 113 | vocab['nvocab'] = table.len(word2idx) 114 | 115 | DataIter.showVocab(vocab) 116 | 117 | return vocab 118 | end 119 | 120 | function DataIter.toBatch(sents, vocab, batchSize) 121 | local dtype = 'torch.LongTensor' 122 | local maxn = -1 123 | for _, sent in ipairs(sents) do if sent:size(1) > maxn then maxn = sent:size(1) end end 124 | assert(maxn ~= -1) 125 | local x = (torch.ones(maxn + 1, batchSize) * vocab.UNK):type(dtype) 126 | local x_mask = torch.zeros(maxn + 1, batchSize) 127 | local y = torch.zeros(maxn, batchSize):type(dtype) 128 | x[{ 1, {} }] = vocab.ROOT 129 | x_mask[{ 1, {} }] = 1 130 | for i, sent in ipairs(sents) do 131 | local slen = sent:size(1) 132 | x[{ {2, slen + 1}, i }] = sent[{ {}, 1 }] 133 | x_mask[{ {2, slen + 1}, i }] = 1 134 | y[{ {1, slen}, i }] = sent[{ {}, 2 }] 135 | end 136 | 137 | return x, x_mask, y 138 | end 139 | 140 | function DataIter.sent2dep(vocab, sent) 141 | local d = {} 142 | local word2idx = vocab.word2idx 143 | for _, ditem in ipairs(sent) do 144 | local wd = vocab.ignoreCase and ditem.wd:lower() or ditem.wd 145 | local wid = word2idx[wd] or vocab.UNK 146 | d[#d + 1] = {wid, ditem.p2 + 1} 147 | end 148 | return torch.Tensor(d), #d 149 | end 150 | 151 | 152 | function DataIter.createBatch(vocab, infile, batchSize, maxlen) 153 | maxlen = maxlen or 100 154 | local diter = DataIter.conllx_iter(infile) 155 | local isEnd = false 156 | 157 | return function() 158 | if not isEnd then 159 | 160 | local sents = {} 161 | for i = 1, batchSize do 162 | local sent = diter() 163 | if sent == nil then isEnd = true break end 164 | local s, len = DataIter.sent2dep(vocab, sent) 165 | if len <= maxlen then 166 | sents[#sents + 1] = s 167 | else 168 | print ( 'delete sentence with length ' .. tostring(len) ) 169 | end 170 | end 171 | if #sents > 0 then 172 | return DataIter.toBatch(sents, vocab, batchSize) 173 | end 174 | 175 | end 176 | end 177 | end 178 | 179 | function DataIter.createBatchSort(vocab, infile, batchSize, maxlen) 180 | maxlen = maxlen or 100 181 | local diter = DataIter.conllx_iter(infile) 182 | local all_sents = {} 183 | for sent in diter do 184 | local s, len = DataIter.sent2dep(vocab, sent) 185 | all_sents[#all_sents + 1] = s 186 | end 187 | -- print(all_sents[1]) 188 | table.sort(all_sents, function(a, b) return a:size(1) < b:size(1) end) 189 | 190 | local cnt = 0 191 | local ndata = #all_sents 192 | 193 | return function() 194 | 195 | local sents = {} 196 | for i = 1, batchSize do 197 | cnt = cnt + 1 198 | if cnt <= ndata then 199 | sents[#sents + 1] = all_sents[cnt] 200 | end 201 | end 202 | 203 | if #sents > 0 then 204 | return DataIter.toBatch(sents, vocab, batchSize) 205 | end 206 | 207 | end 208 | end 209 | 210 | function DataIter.loadAllSents(vocab, infile, maxlen) 211 | local diter = DataIter.conllx_iter(infile) 212 | local all_sents = {} 213 | local del_cnt = 0 214 | for sent in diter do 215 | local s, len = DataIter.sent2dep(vocab, sent) 216 | if len <= maxlen then 217 | all_sents[#all_sents + 1] = s 218 | else 219 | del_cnt = del_cnt + 1 220 | end 221 | end 222 | if del_cnt > 0 then 223 | printf( 'WARNING!!! delete %d sentences that longer than %d\n', del_cnt, maxlen) 224 | end 225 | 226 | return all_sents 227 | end 228 | 229 | function DataIter.createBatchShuffleSort(all_sents_, vocab, batchSize, sort_flag, shuffle) 230 | assert(sort_flag ~= nil and (shuffle == true or shuffle == false)) 231 | 232 | local function shuffle_dataset(all_sents) 233 | local tmp_sents = {} 234 | local idxs = torch.randperm(#all_sents) 235 | for i = 1, idxs:size(1) do 236 | tmp_sents[#tmp_sents + 1] = all_sents[ idxs[i] ] 237 | end 238 | return tmp_sents 239 | end 240 | 241 | local all_sents 242 | if shuffle then 243 | all_sents = shuffle_dataset(all_sents_) 244 | end 245 | 246 | local len_idxs = {} 247 | for i, sent in ipairs(all_sents) do 248 | len_idxs[#len_idxs + 1] = {sent:size(1), i} 249 | end 250 | 251 | local kbatches = sort_flag * batchSize 252 | local new_idxs = {} 253 | local N = #len_idxs 254 | for istart = 1, N, kbatches do 255 | iend = math.min(istart + kbatches - 1, N) 256 | local tmpa = {} 257 | for i = istart, iend do 258 | tmpa[#tmpa + 1] = len_idxs[i] 259 | end 260 | table.sort(tmpa, function( a, b ) return a[1] < b[1] end) 261 | for _, tmp in ipairs(tmpa) do 262 | new_idxs[#new_idxs + 1] = tmp[2] 263 | end 264 | end 265 | 266 | local final_all_sents = {} 267 | for _, idx in ipairs(new_idxs) do 268 | final_all_sents[#final_all_sents + 1] = all_sents[idx] 269 | end 270 | 271 | local cnt, ndata = 0, #final_all_sents 272 | return function() 273 | 274 | local sents = {} 275 | for i = 1, batchSize do 276 | cnt = cnt + 1 277 | if cnt > ndata then break end 278 | sents[#sents + 1] = final_all_sents[cnt] 279 | end 280 | 281 | if #sents > 0 then 282 | return DataIter.toBatch(sents, vocab, batchSize) 283 | end 284 | 285 | end 286 | end 287 | 288 | local function main() 289 | --[[ 290 | require '../utils/shortcut' 291 | local infile = '/Users/xing/Desktop/depparse/train.autopos' 292 | local diter = DepDataIter.conllx_iter(infile) 293 | local cnt = 0 294 | for item in diter do 295 | -- print(item) 296 | cnt = cnt + 1 297 | -- if cnt == 1 then break end 298 | end 299 | printf('totally %d sentences\n', cnt) 300 | --]] 301 | 302 | --[[ 303 | require '../utils/shortcut' 304 | local infile = '/Users/xing/Desktop/depparse/train.autopos' 305 | local vocab = DepDataIter.createVocab(infile, true) 306 | --]] 307 | 308 | require '../utils/shortcut' 309 | local infile = '/Users/xing/Desktop/depparse/train.autopos' 310 | local vocab = DepDataIter.createVocab(infile, true, 1) 311 | print 'get vocab done!' 312 | 313 | local validfile = '/Users/xing/Desktop/depparse/train.autopos' 314 | local batchIter = DepDataIter.createBatch(vocab, validfile, 30, 100) 315 | local cnt = 0 316 | for x, x_maks, y in batchIter do 317 | cnt = cnt + 1 318 | end 319 | print( 'totally ' .. cnt ) 320 | 321 | -- batchIter() 322 | -- batchIter() 323 | end 324 | 325 | if not package.loaded['DepDataIter'] then 326 | main() 327 | end 328 | 329 | -------------------------------------------------------------------------------- /dataiter/DepPosDataIter.lua: -------------------------------------------------------------------------------- 1 | 2 | local DataIter = torch.class('DepPosDataIter') 3 | 4 | function DataIter.conllx_iter(infile) 5 | local fin = io.open(infile) 6 | 7 | return function() 8 | local items = {} 9 | while true do 10 | local line = fin:read() 11 | if line == nil then break end 12 | line = line:trim() 13 | if line:len() == 0 then 14 | break 15 | end 16 | local fields = line:splitc('\t') 17 | assert(#fields == 10, 'MUST have 10 fields') 18 | local item = {p1 = tonumber(fields[1]), wd = fields[2], 19 | pos = fields[5], p2 = fields[7], rel = fields[8]} 20 | table.insert(items, item) 21 | end 22 | if #items > 0 then 23 | return items 24 | else 25 | fin:close() 26 | end 27 | end 28 | end 29 | 30 | 31 | function DataIter.createDepRelVocab(infile) 32 | local lbl_freq = {} 33 | local lbl_vec = {} 34 | local diter = DataIter.conllx_iter(infile) 35 | for sent in diter do 36 | for _, ditem in ipairs(sent) do 37 | local rel = ditem.rel 38 | local freq = lbl_freq[rel] 39 | if freq == nil then 40 | lbl_vec[#lbl_vec + 1] = rel 41 | lbl_freq[rel] = 1 42 | else 43 | lbl_freq[rel] = freq + 1 44 | end 45 | end 46 | end 47 | 48 | local rel2idx = {} 49 | for i, r in ipairs(lbl_vec) do 50 | rel2idx[r] = i 51 | end 52 | 53 | local lbl_vocab = {} 54 | lbl_vocab.rel2idx = rel2idx 55 | lbl_vocab.idx2rel = lbl_vec 56 | 57 | return lbl_vocab 58 | end 59 | 60 | 61 | function DataIter.getDataSize(infiles) 62 | local sizes = {} 63 | for _, infile in ipairs(infiles) do 64 | local size = 0 65 | local diter = DataIter.conllx_iter(infile) 66 | for ds in diter do 67 | size = size + 1 68 | end 69 | sizes[#sizes + 1] = size 70 | end 71 | 72 | return sizes 73 | end 74 | 75 | function DataIter.showVocab(vocab) 76 | for k, v in pairs(vocab) do 77 | xprint(k) 78 | if type(v) == 'table' then 79 | print ' -- table' 80 | else 81 | print( ' -- ' .. tostring(v) ) 82 | end 83 | end 84 | end 85 | 86 | function DataIter.createVocab(infile, ignoreCase, freqCut, maxNVocab) 87 | local wordFreq = {} 88 | local wordVec = {} 89 | local diter = DataIter.conllx_iter(infile) 90 | 91 | local function addwd(wd) 92 | local wd = ignoreCase and wd:lower() or wd 93 | local freq = wordFreq[wd] 94 | if freq == nil then 95 | wordFreq[wd] = 1 96 | wordVec[#wordVec + 1] = wd 97 | else 98 | wordFreq[wd] = freq + 1 99 | end 100 | end 101 | 102 | local tagFreq = {} 103 | local idx2pos = {} 104 | local function addtag(tag) 105 | local freq = tagFreq[tag] 106 | if freq == nil then 107 | tagFreq[tag] = 1 108 | idx2pos[#idx2pos + 1] = tag 109 | else 110 | tagFreq[tag] = freq + 1 111 | end 112 | end 113 | 114 | for sent in diter do 115 | addwd('###root###') 116 | addtag('###root###') 117 | for _, ditem in ipairs(sent) do 118 | addwd(ditem.wd) 119 | addtag(ditem.pos) 120 | end 121 | end 122 | 123 | local pos2idx = {} 124 | for i, pos in pairs(idx2pos) do 125 | pos2idx[pos] = i 126 | end 127 | printf('totally number of tags is %d\n', #idx2pos) 128 | 129 | local idx2word 130 | if freqCut and freqCut >= 0 then 131 | idx2word = {} 132 | idx2word[#idx2word + 1] = 'UNK' 133 | for _, wd in ipairs(wordVec) do 134 | if wordFreq[wd] > freqCut then idx2word[#idx2word + 1] = wd end 135 | end 136 | 137 | printf('original word count = %d, after freq cut = %d, word count = %d\n', #wordVec, freqCut, #idx2word) 138 | end 139 | 140 | if maxNVocab and maxNVocab > 0 then 141 | if #idx2word > 0 then 142 | print( 'WARING: rewrote idx2word with maxNVocab = ' .. maxNVocab ) 143 | end 144 | idx2word = {} 145 | idx2word[#idx2word + 1] = 'UNK' 146 | local wfs = {} 147 | for _, k in ipairs(wordVec) do table.insert(wfs, {k, wordFreq[k]}) end 148 | table.sort(wfs, function(x, y) return x[2] > y[2] end) 149 | local lfreq = -1 150 | for cnt, kv in ipairs(wfs) do 151 | idx2word[#idx2word + 1] = kv[1] 152 | lfreq = kv[2] 153 | if cnt >= maxNVocab-1 then break end 154 | end 155 | printf('original word count = %d, after maxNVocab = %d, word count = %d, lowest freq = %d\n', #wordVec, maxNVocab, #idx2word, lfreq) 156 | end 157 | 158 | local word2idx = {} 159 | for i, w in ipairs(idx2word) do word2idx[w] = i end 160 | local vocab = {word2idx = word2idx, idx2word = idx2word, 161 | freqCut = freqCut, ignoreCase = ignoreCase, maxNVocab = maxNVocab, 162 | UNK_STR = 'UNK', UNK = word2idx['UNK'], 163 | ROOT_STR = '###root###', ROOT = word2idx['###root###']} 164 | vocab['nvocab'] = table.len(word2idx) 165 | 166 | vocab['idx2pos'] = idx2pos 167 | vocab['pos2idx'] = pos2idx 168 | vocab['npos'] = table.len(pos2idx) 169 | vocab['ROOT_POS'] = pos2idx['###root###'] 170 | vocab['ROOT_POS_STR'] = '###root###' 171 | -- print(idx2pos) 172 | 173 | DataIter.showVocab(vocab) 174 | 175 | return vocab 176 | end 177 | 178 | 179 | --[[ 180 | function DataIter.toBatch(sents, vocab, batchSize) 181 | local dtype = 'torch.LongTensor' 182 | local maxn = -1 183 | for _, sent in ipairs(sents) do if sent:size(1) > maxn then maxn = sent:size(1) end end 184 | assert(maxn ~= -1) 185 | local x = (torch.ones(maxn + 1, batchSize) * vocab.UNK):type(dtype) 186 | local x_mask = torch.zeros(maxn + 1, batchSize) 187 | local y = torch.zeros(maxn, batchSize):type(dtype) 188 | x[{ 1, {} }] = vocab.ROOT 189 | x_mask[{ 1, {} }] = 1 190 | for i, sent in ipairs(sents) do 191 | local slen = sent:size(1) 192 | x[{ {2, slen + 1}, i }] = sent[{ {}, 1 }] 193 | x_mask[{ {2, slen + 1}, i }] = 1 194 | y[{ {1, slen}, i }] = sent[{ {}, 2 }] 195 | end 196 | 197 | return x, x_mask, y 198 | end 199 | --]] 200 | 201 | function DataIter.toBatch(sents, vocab, batchSize) 202 | local dtype = 'torch.LongTensor' 203 | local maxn = -1 204 | for _, sent in ipairs(sents) do if sent:size(1) > maxn then maxn = sent:size(1) end end 205 | assert(maxn ~= -1) 206 | local x = (torch.ones(maxn + 1, batchSize) * vocab.UNK):type(dtype) 207 | local x_mask = torch.zeros(maxn + 1, batchSize) 208 | local x_pos = (torch.ones(maxn + 1, batchSize) * vocab.ROOT_POS):type(dtype) 209 | local y = torch.zeros(maxn, batchSize):type(dtype) 210 | x[{ 1, {} }] = vocab.ROOT 211 | x_pos[{ 1, {} }] = vocab.ROOT_POS 212 | x_mask[{ 1, {} }] = 1 213 | for i, sent in ipairs(sents) do 214 | local slen = sent:size(1) 215 | x[{ {2, slen + 1}, i }] = sent[{ {}, 1 }] 216 | x_pos[{ {2, slen + 1}, i }] = sent[{ {}, 3 }] 217 | x_mask[{ {2, slen + 1}, i }] = 1 218 | y[{ {1, slen}, i }] = sent[{ {}, 2 }] 219 | end 220 | 221 | return x, x_mask, x_pos, y 222 | end 223 | 224 | --[[ 225 | function DataIter.sent2dep(vocab, sent) 226 | local d = {} 227 | local word2idx = vocab.word2idx 228 | for _, ditem in ipairs(sent) do 229 | local wd = vocab.ignoreCase and ditem.wd:lower() or ditem.wd 230 | local wid = word2idx[wd] or vocab.UNK 231 | d[#d + 1] = {wid, ditem.p2 + 1} 232 | end 233 | return torch.Tensor(d), #d 234 | end 235 | --]] 236 | 237 | function DataIter.sent2dep(vocab, sent) 238 | local d = {} 239 | local word2idx = vocab.word2idx 240 | local pos2idx = vocab.pos2idx 241 | for _, ditem in ipairs(sent) do 242 | local wd = vocab.ignoreCase and ditem.wd:lower() or ditem.wd 243 | local wid = word2idx[wd] or vocab.UNK 244 | local posid = pos2idx[ditem.pos] 245 | d[#d + 1] = {wid, ditem.p2 + 1, posid} 246 | end 247 | return torch.Tensor(d), #d 248 | end 249 | 250 | 251 | function DataIter.createBatch(vocab, infile, batchSize, maxlen) 252 | maxlen = maxlen or 100 253 | local diter = DataIter.conllx_iter(infile) 254 | local isEnd = false 255 | 256 | return function() 257 | if not isEnd then 258 | 259 | local sents = {} 260 | for i = 1, batchSize do 261 | local sent = diter() 262 | if sent == nil then isEnd = true break end 263 | local s, len = DataIter.sent2dep(vocab, sent) 264 | if len <= maxlen then 265 | sents[#sents + 1] = s 266 | else 267 | print ( 'delete sentence with length ' .. tostring(len) ) 268 | end 269 | end 270 | if #sents > 0 then 271 | return DataIter.toBatch(sents, vocab, batchSize) 272 | end 273 | 274 | end 275 | end 276 | end 277 | 278 | 279 | function DataIter.createBatchLabel(vocab, rel_vocab, infile, batchSize, maxlen) 280 | maxlen = maxlen or 100 281 | local diter = DataIter.conllx_iter(infile) 282 | local isEnd = false 283 | local rel2idx = rel_vocab.rel2idx 284 | 285 | return function() 286 | if not isEnd then 287 | 288 | local sents = {} 289 | local sent_rels = {} 290 | local sent_ori_rels = {} 291 | for i = 1, batchSize do 292 | local sent = diter() 293 | if sent == nil then isEnd = true break end 294 | local s, len = DataIter.sent2dep(vocab, sent) 295 | if len <= maxlen then 296 | sents[#sents + 1] = s 297 | local sent_rel = {} 298 | local sent_ori_rel = {} 299 | for i, item in ipairs(sent) do 300 | sent_rel[i] = rel2idx[item.rel] 301 | sent_ori_rel[i] = item.rel 302 | end 303 | sent_rels[#sent_rels + 1] = sent_rel 304 | sent_ori_rels[#sent_ori_rels + 1] = sent_ori_rel 305 | else 306 | print ( 'delete sentence with length ' .. tostring(len) ) 307 | end 308 | end 309 | if #sents > 0 then 310 | local x, x_mask, x_pos, y = DataIter.toBatch(sents, vocab, batchSize) 311 | return x, x_mask, x_pos, y, sent_rels, sent_ori_rels 312 | end 313 | 314 | end 315 | end 316 | end 317 | 318 | 319 | function DataIter.createBatchSort(vocab, infile, batchSize, maxlen) 320 | maxlen = maxlen or 100 321 | local diter = DataIter.conllx_iter(infile) 322 | local all_sents = {} 323 | for sent in diter do 324 | local s, len = DataIter.sent2dep(vocab, sent) 325 | all_sents[#all_sents + 1] = s 326 | end 327 | -- print(all_sents[1]) 328 | table.sort(all_sents, function(a, b) return a:size(1) < b:size(1) end) 329 | 330 | local cnt = 0 331 | local ndata = #all_sents 332 | 333 | return function() 334 | 335 | local sents = {} 336 | for i = 1, batchSize do 337 | cnt = cnt + 1 338 | if cnt <= ndata then 339 | sents[#sents + 1] = all_sents[cnt] 340 | end 341 | end 342 | 343 | if #sents > 0 then 344 | return DataIter.toBatch(sents, vocab, batchSize) 345 | end 346 | 347 | end 348 | end 349 | 350 | function DataIter.loadAllSents(vocab, infile, maxlen) 351 | local diter = DataIter.conllx_iter(infile) 352 | local all_sents = {} 353 | local del_cnt = 0 354 | for sent in diter do 355 | local s, len = DataIter.sent2dep(vocab, sent) 356 | if len <= maxlen then 357 | all_sents[#all_sents + 1] = s 358 | else 359 | del_cnt = del_cnt + 1 360 | end 361 | end 362 | if del_cnt > 0 then 363 | printf( 'WARNING!!! delete %d sentences that longer than %d\n', del_cnt, maxlen) 364 | end 365 | 366 | return all_sents 367 | end 368 | 369 | function DataIter.createBatchShuffleSort(all_sents_, vocab, batchSize, sort_flag, shuffle) 370 | assert(sort_flag ~= nil and (shuffle == true or shuffle == false)) 371 | 372 | local function shuffle_dataset(all_sents) 373 | local tmp_sents = {} 374 | local idxs = torch.randperm(#all_sents) 375 | for i = 1, idxs:size(1) do 376 | tmp_sents[#tmp_sents + 1] = all_sents[ idxs[i] ] 377 | end 378 | return tmp_sents 379 | end 380 | 381 | local all_sents 382 | if shuffle then 383 | all_sents = shuffle_dataset(all_sents_) 384 | end 385 | 386 | local len_idxs = {} 387 | for i, sent in ipairs(all_sents) do 388 | len_idxs[#len_idxs + 1] = {sent:size(1), i} 389 | end 390 | 391 | local kbatches = sort_flag * batchSize 392 | local new_idxs = {} 393 | local N = #len_idxs 394 | for istart = 1, N, kbatches do 395 | iend = math.min(istart + kbatches - 1, N) 396 | local tmpa = {} 397 | for i = istart, iend do 398 | tmpa[#tmpa + 1] = len_idxs[i] 399 | end 400 | table.sort(tmpa, function( a, b ) return a[1] < b[1] end) 401 | for _, tmp in ipairs(tmpa) do 402 | new_idxs[#new_idxs + 1] = tmp[2] 403 | end 404 | end 405 | 406 | local final_all_sents = {} 407 | for _, idx in ipairs(new_idxs) do 408 | final_all_sents[#final_all_sents + 1] = all_sents[idx] 409 | end 410 | 411 | local cnt, ndata = 0, #final_all_sents 412 | return function() 413 | 414 | local sents = {} 415 | for i = 1, batchSize do 416 | cnt = cnt + 1 417 | if cnt > ndata then break end 418 | sents[#sents + 1] = final_all_sents[cnt] 419 | end 420 | 421 | if #sents > 0 then 422 | return DataIter.toBatch(sents, vocab, batchSize) 423 | end 424 | 425 | end 426 | end 427 | 428 | local function main() 429 | --[[ 430 | require '../utils/shortcut' 431 | local infile = '/Users/xing/Desktop/depparse/train.autopos' 432 | local diter = DepDataIter.conllx_iter(infile) 433 | local cnt = 0 434 | for item in diter do 435 | -- print(item) 436 | cnt = cnt + 1 437 | -- if cnt == 1 then break end 438 | end 439 | printf('totally %d sentences\n', cnt) 440 | --]] 441 | 442 | --[[ 443 | require '../utils/shortcut' 444 | local infile = '/Users/xing/Desktop/depparse/train.autopos' 445 | local vocab = DepDataIter.createVocab(infile, true) 446 | --]] 447 | 448 | require '../utils/shortcut' 449 | local infile = '/afs/inf.ed.ac.uk/group/project/img2txt/dep_parser/dataset/train.autopos' 450 | local vocab = DepPosDataIter.createVocab(infile, true, 1) 451 | print 'get vocab done!' 452 | 453 | local validfile = '/afs/inf.ed.ac.uk/group/project/img2txt/dep_parser/dataset/valid.autopos' 454 | local batchIter = DepPosDataIter.createBatch(vocab, validfile, 32, 100) 455 | local cnt = 0 456 | for x, x_mask, x_pos, y in batchIter do 457 | cnt = cnt + 1 458 | if cnt < 3 then 459 | print 'x = ' 460 | print(x) 461 | print 'x_mask = ' 462 | print(x_mask) 463 | print 'x_pos = ' 464 | print(x_pos) 465 | print 'y = ' 466 | print(y) 467 | end 468 | end 469 | print( 'totally ' .. cnt ) 470 | end 471 | 472 | if not package.loaded['DepPosDataIter'] then 473 | main() 474 | end 475 | 476 | -------------------------------------------------------------------------------- /dense_parser.lua: -------------------------------------------------------------------------------- 1 | 2 | require '.' 3 | require 'shortcut' 4 | require 'SelectNetPos' 5 | require 'DepPosDataIter' 6 | require 'PostDepGraph' 7 | require 'ChuLiuEdmonds' 8 | require 'Eisner' 9 | require 'MLP' 10 | 11 | local Parser = torch.class('DeNSeParser') 12 | 13 | function Parser:showOpts() 14 | local tmp_vocab = self.opts.vocab 15 | self.opts.vocab = nil 16 | print(self.opts) 17 | self.opts.vocab = tmp_vocab 18 | end 19 | 20 | 21 | function Parser:load(modelPath, classifierPath) 22 | self.opts = torch.load( modelPath:sub(1, -3) .. 'state.t7' ) 23 | local opts = self.opts 24 | -- disable loading pre-trained word embeddings 25 | opts.wordEmbedding = '' 26 | 27 | torch.manualSeed(opts.seed + 1) 28 | if opts.useGPU then 29 | require 'cutorch' 30 | require 'cunn' 31 | cutorch.manualSeed(opts.seed + 1) 32 | end 33 | 34 | self.net = SelectNetPos(opts) 35 | self:showOpts() 36 | 37 | xprintln('load from %s ...', modelPath) 38 | self.net:load(modelPath) 39 | xprintln('load from %s done!', modelPath) 40 | 41 | self.lbl_opts = torch.load(classifierPath:sub(1, -3) .. 'state.t7') 42 | self.mlp = MLP(self.lbl_opts) 43 | xprintln('load classifier from %s ...', modelPath) 44 | self.mlp:load(classifierPath) 45 | xprintln('load classifier from %s done!', modelPath) 46 | 47 | if self.mlp.opts.rel_vocab == nil then 48 | self.mlp.opts.rel_vocab = DepPosDataIter.createDepRelVocab(self.mlp.opts.inTrain) 49 | xprintln('load rel vocab done! You should use new version `train_lableded.lua`') 50 | end 51 | end 52 | 53 | 54 | function Parser:runChuLiuEdmonds(dsent, sent_dep, sent_graph) 55 | local new_dsent = {} 56 | for i, ditem in ipairs(dsent) do 57 | local new_ditem = {p1 = ditem.p1, wd = ditem.wd, pos = ditem.pos, p2 = sent_dep[i]} 58 | new_dsent[#new_dsent + 1] = new_ditem 59 | end 60 | 61 | -- check connectivity 62 | local dgraph = PostDepGraph(new_dsent) 63 | if not dgraph:checkConnectivity() then 64 | local N = #sent_graph + 1 65 | local edges = {} 66 | for i, sp in ipairs(sent_graph) do 67 | for j = 1, sp:size(1) do 68 | edges[#edges + 1] = {j, i+1, sp[j]} 69 | end 70 | end 71 | -- run ChuLiuEdmonds 72 | local cle = ChuLiuEdmonds() 73 | cle:load(N, edges) 74 | local _, selectedEdges = cle:solve(1, N) 75 | table.sort(selectedEdges, function(a, b) return a.v < b.v end) 76 | for i, ditem in ipairs(new_dsent) do 77 | local edge = selectedEdges[i] 78 | assert(edge.v == i+1) 79 | ditem.p2 = edge.u - 1 80 | ditem.p1 = edge.v - 1 81 | end 82 | end 83 | 84 | return new_dsent 85 | end 86 | 87 | 88 | function Parser:runEisner(dsent, sent_dep, sent_graph) 89 | local new_dsent = {} 90 | for i, ditem in ipairs(dsent) do 91 | local new_ditem = {p1 = ditem.p1, wd = ditem.wd, pos = ditem.pos, p2 = sent_dep[i]} 92 | new_dsent[#new_dsent + 1] = new_ditem 93 | end 94 | 95 | -- check connectivity 96 | local dgraph = PostDepGraph(new_dsent) 97 | if not (dgraph:checkConnectivity() and dgraph:isProjective()) then 98 | local N = #sent_graph + 1 99 | local edges = {} 100 | for i, sp in ipairs(sent_graph) do 101 | for j = 1, sp:size(1) do 102 | edges[#edges + 1] = {j, i+1, sp[j]} 103 | end 104 | end 105 | -- run Eisner's algorithm 106 | local eisner = Eisner() 107 | eisner:load(N, edges) 108 | local _, selectedEdges = eisner:solve() 109 | table.sort(selectedEdges, function(a, b) return a.v < b.v end) 110 | for i, ditem in ipairs(new_dsent) do 111 | local edge = selectedEdges[i] 112 | assert(edge.v == i+1) 113 | ditem.p2 = edge.u - 1 114 | ditem.p1 = edge.v - 1 115 | end 116 | end 117 | 118 | return new_dsent 119 | end 120 | 121 | 122 | function Parser:parseConllx(inputFile, outputFile, postAlg) 123 | local dataIter = DepPosDataIter.createBatch(self.opts.vocab, inputFile, self.opts.batchSize, self.opts.maxTrainLen) 124 | local totalCnt = 0 125 | local totalLoss = 0 126 | local cnt = 0 127 | 128 | local fout = io.open(outputFile, 'w') 129 | local y_tmp = torch.LongTensor(self.opts.maxTrainLen, self.opts.batchSize) 130 | local cls_in_dim = 4 * self.opts.nhid + 2 * self.opts.npin + 2 * self.opts.nin 131 | local cls_in = torch.CudaTensor(self.opts.maxTrainLen * self.opts.batchSize, cls_in_dim) 132 | local dep_iter = DepPosDataIter.conllx_iter(inputFile) 133 | 134 | for x, x_mask, x_pos, y in dataIter do 135 | local loss, y_preds = self.net:validBatch(x, x_mask, x_pos, y) 136 | local x_emb = self.net.mod_map.forward_lookup:forward(x) 137 | local x_pos_emb = self.net.mod_map.forward_pos_lookup:forward(x_pos) 138 | local fwd_bak_hs = self.net.all_fwd_bak_hs 139 | 140 | totalLoss = totalLoss + loss * x:size(2) 141 | local y_mask = x_mask[{ {2, -1}, {} }] 142 | 143 | local y_p = y_tmp:resize(y:size(1), y:size(2)) 144 | -- WARNING: y_preds start from 2! 145 | for t = 2, x:size(1) do 146 | local _, mi = y_preds[t]:max(2) 147 | if self.opts.useGPU then mi = mi:double() end 148 | y_p[{ t-1, {} }] = mi 149 | end 150 | 151 | -- get labeled output (bs, seqlen, dim) 152 | cls_in:resize(x:size(2), x:size(1)-1, cls_in_dim):zero() 153 | 154 | -- collects sentence dependents 155 | -- and graph answers 156 | local new_dsents = {} 157 | for i = 1, y_mask:size(2) do 158 | local slen = y_mask[{ {}, i }]:sum() 159 | if slen > 0 then 160 | local dsent = dep_iter() 161 | local sent_dep = {} 162 | local sent_graph = {} 163 | for j = 1, slen do 164 | sent_dep[#sent_dep + 1] = y_p[{ j, i }] - 1 165 | local tmp = y_preds[j+1][{ i, {1, slen + 1} }]:double() 166 | sent_graph[j] = tmp 167 | end 168 | 169 | -- run post-processing algorithm 170 | assert(#sent_dep == #dsent) 171 | assert(#sent_graph == #dsent) 172 | local new_dsent 173 | if postAlg == 'ChuLiuEdmonds' then 174 | new_dsent = self:runChuLiuEdmonds(dsent, sent_dep, sent_graph) 175 | elseif postAlg == 'Eisner' then 176 | new_dsent = self:runEisner(dsent, sent_dep, sent_graph) 177 | else 178 | error('only support ChuLiuEdmonds and Eisner') 179 | end 180 | 181 | -- prepare labeled input 182 | for j, ditem in ipairs(new_dsent) do 183 | local parent_id = ditem.p2 + 1 184 | local start = 1 185 | cls_in[{ i, j, {start, 2 * self.opts.nhid + start - 1} }] = fwd_bak_hs[{ i, j+1, {} }] 186 | start = start + 2 * self.opts.nhid 187 | cls_in[{ i, j, {start, 2 * self.opts.nhid + start - 1} }] = fwd_bak_hs[{ i, parent_id, {} }] 188 | start = start + 2 * self.opts.nhid 189 | cls_in[{ i, j, {start, self.opts.nin * 2 + start - 1} }] = 190 | torch.cat({x_emb[{ j+1, i, {} }], x_emb[{ parent_id, i, {} }]}, 1) 191 | start = start + 2 * self.opts.nin 192 | cls_in[{ i, j, {start, self.opts.npin * 2 + start - 1} }] = 193 | torch.cat({x_pos_emb[{ j+1, i, {} }], x_pos_emb[{ parent_id, i, {} }]}, 1) 194 | end 195 | 196 | new_dsents[#new_dsents + 1] = new_dsent 197 | end 198 | end 199 | 200 | -- run labeld classifier 201 | local labels_ = self.mlp:predictLabelBatch( 202 | cls_in:view( cls_in:size(1) * cls_in:size(2), cls_in:size(3) ) 203 | ) 204 | local labels = labels_:view( cls_in:size(1), cls_in:size(2) ) 205 | -- output everything! 206 | for i, dsent in ipairs(new_dsents) do 207 | for j, ditem in ipairs(dsent) do 208 | -- 1 Influential _ JJ JJ _ 2 amod _ _ 209 | local lbl = self.mlp.opts.rel_vocab.idx2rel[ labels[{ i, j }] ] 210 | fout:write( string.format('%d\t%s\t_\t_\t%s\t_\t%d\t%s\t_\t_\n', ditem.p1, ditem.wd, ditem.pos, ditem.p2, lbl) ) 211 | end 212 | fout:write('\n') 213 | end 214 | 215 | totalCnt = totalCnt + y_mask:sum() 216 | cnt = cnt + 1 217 | if cnt % 5 == 0 then 218 | collectgarbage() 219 | xprintln('cnt = %d * %d = %d', cnt, self.opts.batchSize, cnt * self.opts.batchSize) 220 | end 221 | end 222 | fout:close() 223 | end 224 | 225 | 226 | local function getOpts() 227 | local cmd = torch.CmdLine() 228 | cmd:option('--modelPath', '/disk/scratch/s1270921/dep_parse/experiments/pre-trained-models/german/model_0.001.tune.t7', 'model path') 229 | cmd:option('--classifierPath', '/disk/scratch/s1270921/dep_parse/experiments/pre-trained-models/german/lbl_classifier.t7', 'label classifer path') 230 | 231 | cmd:option('--input', '/disk/scratch/s1270921/dep_parse/data_conll/german/german_gold_test.conll', 'input conllx file') 232 | cmd:option('--output', 'output.txt', 'output conllx file') 233 | cmd:option('--gold', '', 'gold standard file (optional). Empty means no evaluation') 234 | cmd:option('--mstalg', 'ChuLiuEdmonds', 'MST algorithm: ChuLiuEdmonds or Eisner') 235 | 236 | return cmd:parse(arg) 237 | end 238 | 239 | 240 | local function main() 241 | local opts = getOpts() 242 | local dense = DeNSeParser() 243 | dense:load(opts.modelPath, opts.classifierPath) 244 | dense:parseConllx(opts.input, opts.output, opts.mstalg) 245 | 246 | if opts.gold ~= '' or opts.gold == nil then 247 | print '\n\n*** Stanford ***' 248 | local conllx_eval = require 'conllx_eval' 249 | local LAS, UAS, noPunctLAS, noPunctUAS = conllx_eval.eval(opts.output, opts.gold) 250 | 251 | print '\n\n*** CoNLL-X 2006 ***' 252 | conllx_eval = require 'conllx2006_eval' 253 | LAS, UAS, noPunctLAS, noPunctUAS = conllx_eval.eval(opts.output, opts.gold) 254 | end 255 | end 256 | 257 | 258 | if not package.loaded['dense_parser'] then 259 | main() 260 | else 261 | print '[dense_parser] loaded as package!' 262 | end 263 | 264 | -------------------------------------------------------------------------------- /experiments/chinese/gen_lbl_train.sh: -------------------------------------------------------------------------------- 1 | 2 | codedir=/afs/inf.ed.ac.uk/group/project/img2txt/dep_parser/dense_release 3 | 4 | curdir=`pwd` 5 | 6 | model=$curdir/model_0.001.tune.t7 7 | 8 | outTrain=$curdir/label_train.h5 9 | 10 | inTrain=/disk/scratch/s1270921/dep_parse/dataset/ctb/train.gold.conll 11 | inValid=/disk/scratch/s1270921/dep_parse/dataset/ctb/dev.gold.conll 12 | inTest=/disk/scratch/s1270921/dep_parse/dataset/ctb/test.gold.conll 13 | 14 | 15 | outValid=$curdir/valid.dep 16 | outTest=$curdir/test.dep 17 | 18 | log=$curdir/gen-log.txt 19 | 20 | cd $codedir 21 | 22 | CUDA_VISIBLE_DEVICES=2 th train_labeled.lua --mode generate \ 23 | --modelPath $model \ 24 | --outTrainDataPath $outTrain \ 25 | --inTrain $inTrain \ 26 | --inValid $inValid \ 27 | --inTest $inTest \ 28 | --outValid $outValid \ 29 | --outTest $outTest \ 30 | --language Chinese | tee $log 31 | 32 | cd $curdir 33 | 34 | -------------------------------------------------------------------------------- /experiments/chinese/mst-post.sh: -------------------------------------------------------------------------------- 1 | 2 | curdir=`pwd` 3 | codedir=/afs/inf.ed.ac.uk/group/project/img2txt/dep_parser/dense_release 4 | 5 | model=$curdir/model_0.001.tune.t7 6 | validout=$curdir/valid 7 | testout=$curdir/test 8 | log=$curdir/log.txt 9 | 10 | cd $codedir 11 | 12 | CUDA_VISIBLE_DEVICES=2 th mst_postprocess.lua \ 13 | --mstalg Eisner \ 14 | --modelPath $model \ 15 | --validout $validout \ 16 | --testout $testout | tee $log 17 | 18 | cd $curdir 19 | 20 | -------------------------------------------------------------------------------- /experiments/chinese/run_lbl.sh: -------------------------------------------------------------------------------- 1 | 2 | codedir=/afs/inf.ed.ac.uk/group/project/img2txt/dep_parser/dense_release 3 | curdir=`pwd` 4 | 5 | dataset=$curdir/label_train.h5 6 | model=$curdir/lbl_lassifier.t7 7 | 8 | inTrain=/disk/scratch/s1270921/dep_parse/dataset/ctb/train.gold.conll 9 | inValid=/disk/scratch/s1270921/dep_parse/dataset/ctb/dev.gold.conll 10 | inTest=/disk/scratch/s1270921/dep_parse/dataset/ctb/test.gold.conll 11 | 12 | log=$curdir/lbl_log.txt 13 | 14 | cd $codedir 15 | 16 | CUDA_VISIBLE_DEVICES=2 th train_labeled.lua --mode train \ 17 | --useGPU \ 18 | --snhids "1900,800,800,12" \ 19 | --activ relu \ 20 | --lr 0.01 \ 21 | --optimMethod AdaGrad \ 22 | --dropout 0.5 \ 23 | --inDropout 0.05 \ 24 | --batchSize 256 \ 25 | --maxEpoch 20 \ 26 | --ftype "|x|xe|xpe|" \ 27 | --dataset $dataset \ 28 | --inTrain $inTrain \ 29 | --inValid $inValid \ 30 | --inTest $inTest \ 31 | --save $model | tee $log 32 | 33 | cd $curdir 34 | 35 | 36 | -------------------------------------------------------------------------------- /experiments/chinese/train.sh: -------------------------------------------------------------------------------- 1 | 2 | # ID=`./gpu_lock.py --id-to-hog 0` 3 | ID=2 4 | echo $ID 5 | if [ $ID -eq -1 ]; then 6 | echo "this gpu is not free" 7 | exit 8 | fi 9 | # ./gpu_lock.py 10 | 11 | codedir=/afs/inf.ed.ac.uk/group/project/img2txt/dep_parser/dense_release 12 | 13 | curdir=`pwd` 14 | lr=0.001 15 | label=.std.ft0.lpos.dp 16 | model=$curdir/model_$lr$label.t7 17 | log=$curdir/log_$lr$label.txt 18 | 19 | train=/disk/scratch/s1270921/dep_parse/dataset/ctb/train.gold.conll 20 | valid=/disk/scratch/s1270921/dep_parse/dataset/ctb/dev.gold.conll 21 | test=/disk/scratch/s1270921/dep_parse/dataset/ctb/test.gold.conll 22 | 23 | wembed=/disk/scratch/s1270921/dep_parse/dataset/ctb/vectors.ctb.t7 24 | 25 | cd $codedir 26 | 27 | CUDA_VISIBLE_DEVICES=$ID th train.lua --useGPU \ 28 | --model SelectNetPos \ 29 | --seqLen 142 \ 30 | --maxTrainLen 140 \ 31 | --freqCut 1 \ 32 | --nhid 300 \ 33 | --nin 300 \ 34 | --nlayers 2 \ 35 | --dropout 0.35 \ 36 | --recDropout 0.05 \ 37 | --lr $lr \ 38 | --train $train \ 39 | --valid $valid \ 40 | --test $test \ 41 | --optimMethod Adam \ 42 | --save $model \ 43 | --batchSize 20 \ 44 | --validBatchSize 20 \ 45 | --maxEpoch 15 \ 46 | --wordEmbedding $wembed \ 47 | --embedOption fineTune \ 48 | --fineTuneFactor 0 \ 49 | --npin 50 \ 50 | | tee $log 51 | 52 | 53 | cd $curdir 54 | 55 | # ./gpu_lock.py --free $ID 56 | # ./gpu_lock.py 57 | 58 | -------------------------------------------------------------------------------- /experiments/chinese/tune.sh: -------------------------------------------------------------------------------- 1 | 2 | ID=2 3 | echo $ID 4 | if [ $ID -eq -1 ]; then 5 | echo "this gpu is not free" 6 | exit 7 | fi 8 | 9 | codedir=/afs/inf.ed.ac.uk/group/project/img2txt/dep_parser/dense_release 10 | curdir=`pwd` 11 | lr=0.001 12 | label=.tune 13 | model=$curdir/model_$lr$label.t7 14 | log=$curdir/log_$lr$label.txt 15 | 16 | load=$curdir/model_0.001.std.ft0.lpos.dp.t7 17 | 18 | cd $codedir 19 | 20 | CUDA_VISIBLE_DEVICES=$ID th post_train.lua \ 21 | --load $load \ 22 | --save $model \ 23 | --lr $lr \ 24 | --maxEpoch 10 \ 25 | --optimMethod SGD \ 26 | | tee $log 27 | 28 | cd $curdir 29 | 30 | 31 | -------------------------------------------------------------------------------- /experiments/czech/gen_lbl_train.sh: -------------------------------------------------------------------------------- 1 | 2 | codedir=/afs/inf.ed.ac.uk/group/project/img2txt/dep_parser/dense_release 3 | curdir=`pwd` 4 | 5 | # model=/disk/scratch1/XingxingZhang/dep_parse/experiments/release_test/czech/model_0.001.tune.ori.t7 6 | model=$curdir/model_0.001.tune.t7 7 | 8 | outTrain=$curdir/label_train.h5 9 | 10 | inTrain=/disk/scratch/s1270921/dep_parse/data_conll/czech/czech_gold_train.conll 11 | inValid=/disk/scratch/s1270921/dep_parse/data_conll/czech/czech_gold_dev.conll 12 | inTest=/disk/scratch/s1270921/dep_parse/data_conll/czech/czech_gold_test.conll 13 | 14 | 15 | outValid=$curdir/valid.dep 16 | outTest=$curdir/test.dep 17 | 18 | log=$curdir/gen-log.txt 19 | 20 | cd $codedir 21 | 22 | CUDA_VISIBLE_DEVICES=3 th train_labeled.lua --mode generate \ 23 | --modelPath $model \ 24 | --outTrainDataPath $outTrain \ 25 | --inTrain $inTrain \ 26 | --inValid $inValid \ 27 | --inTest $inTest \ 28 | --outValid $outValid \ 29 | --outTest $outTest \ 30 | --language Other | tee $log 31 | 32 | cd $curdir 33 | 34 | -------------------------------------------------------------------------------- /experiments/czech/mst-post.sh: -------------------------------------------------------------------------------- 1 | 2 | curdir=`pwd` 3 | codedir=/afs/inf.ed.ac.uk/group/project/img2txt/dep_parser/dense_release 4 | 5 | model=$curdir/model_0.001.tune.t7 6 | validout=$curdir/valid 7 | testout=$curdir/test 8 | log=$curdir/log.txt 9 | 10 | cd $codedir 11 | 12 | CUDA_VISIBLE_DEVICES=3 th mst_postprocess.lua \ 13 | --modelPath $model \ 14 | --mstalg ChuLiuEdmonds \ 15 | --validout $validout \ 16 | --testout $testout | tee $log 17 | 18 | cd $curdir 19 | 20 | -------------------------------------------------------------------------------- /experiments/czech/run_lbl.sh: -------------------------------------------------------------------------------- 1 | 2 | codedir=/afs/inf.ed.ac.uk/group/project/img2txt/dep_parser/dense_release 3 | curdir=`pwd` 4 | 5 | dataset=$curdir/label_train.h5 6 | model=$curdir/lbl_lassifier.t7 7 | 8 | inTrain=/disk/scratch/s1270921/dep_parse/data_conll/czech/czech_gold_train.conll 9 | inValid=/disk/scratch/s1270921/dep_parse/data_conll/czech/czech_gold_dev.conll 10 | inTest=/disk/scratch/s1270921/dep_parse/data_conll/czech/czech_gold_test.conll 11 | 12 | log=$curdir/lbl_log.txt 13 | 14 | cd $codedir 15 | 16 | CUDA_VISIBLE_DEVICES=3 th train_labeled.lua --mode train \ 17 | --useGPU \ 18 | --snhids "1880,800,800,82" \ 19 | --activ relu \ 20 | --lr 0.01 \ 21 | --optimMethod AdaGrad \ 22 | --dropout 0.5 \ 23 | --inDropout 0.05 \ 24 | --batchSize 256 \ 25 | --maxEpoch 20 \ 26 | --ftype "|x|xe|xpe|" \ 27 | --dataset $dataset \ 28 | --inTrain $inTrain \ 29 | --inValid $inValid \ 30 | --inTest $inTest \ 31 | --language Other \ 32 | --save $model | tee $log 33 | 34 | cd $curdir 35 | 36 | -------------------------------------------------------------------------------- /experiments/czech/train.sh: -------------------------------------------------------------------------------- 1 | 2 | # gpu id 3 | ID=3 4 | 5 | codedir=/afs/inf.ed.ac.uk/group/project/img2txt/dep_parser/dense_release 6 | curdir=`pwd` 7 | lr=0.001 8 | label=.dp0.35.r0.1.bs20 9 | model=$curdir/model_$lr$label.t7 10 | log=$curdir/log_$lr$label.txt 11 | 12 | 13 | train=/disk/scratch/s1270921/dep_parse/data_conll/czech/czech_gold_train.conll 14 | valid=/disk/scratch/s1270921/dep_parse/data_conll/czech/czech_gold_dev.conll 15 | test=/disk/scratch/s1270921/dep_parse/data_conll/czech/czech_gold_test.conll 16 | 17 | 18 | cd $codedir 19 | 20 | CUDA_VISIBLE_DEVICES=$ID th train.lua --useGPU \ 21 | --model SelectNetPos \ 22 | --seqLen 112 \ 23 | --maxTrainLen 110 \ 24 | --freqCut 1 \ 25 | --nhid 300 \ 26 | --nin 300 \ 27 | --nlayers 2 \ 28 | --dropout 0.35 \ 29 | --recDropout 0.1 \ 30 | --lr $lr \ 31 | --train $train \ 32 | --valid $valid \ 33 | --test $test \ 34 | --optimMethod Adam \ 35 | --save $model \ 36 | --batchSize 20 \ 37 | --validBatchSize 20 \ 38 | --maxEpoch 15 \ 39 | --npin 40 \ 40 | --evalType conllx \ 41 | | tee $log 42 | 43 | cd $curdir 44 | 45 | # ./gpu_lock.py --free $ID 46 | # ./gpu_lock.py 47 | 48 | 49 | -------------------------------------------------------------------------------- /experiments/czech/tune.sh: -------------------------------------------------------------------------------- 1 | 2 | ID=3 3 | echo $ID 4 | if [ $ID -eq -1 ]; then 5 | echo "this gpu is not free" 6 | exit 7 | fi 8 | 9 | codedir=/afs/inf.ed.ac.uk/group/project/img2txt/dep_parser/dense_release 10 | curdir=`pwd` 11 | lr=0.001 12 | label=.tune 13 | model=$curdir/model_$lr$label.t7 14 | log=$curdir/log_$lr$label.txt 15 | 16 | load=$curdir/model_0.001.dp0.35.r0.1.bs20.t7 17 | 18 | cd $codedir 19 | 20 | CUDA_VISIBLE_DEVICES=$ID th post_train.lua \ 21 | --load $load \ 22 | --save $model \ 23 | --lr $lr \ 24 | --maxEpoch 10 \ 25 | --optimMethod SGD \ 26 | | tee $log 27 | 28 | cd $curdir 29 | 30 | 31 | -------------------------------------------------------------------------------- /experiments/english/gen_lbl_train.sh: -------------------------------------------------------------------------------- 1 | 2 | codedir=/afs/inf.ed.ac.uk/group/project/img2txt/dep_parser/dense_release 3 | curdir=`pwd` 4 | 5 | model=$curdir/model_0.001.tune.t7 6 | 7 | outTrain=$curdir/label_train.h5 8 | 9 | inTrain=/afs/inf.ed.ac.uk/group/project/img2txt/dep_parser/dataset/train.autopos 10 | inValid=/afs/inf.ed.ac.uk/group/project/img2txt/dep_parser/dataset/valid.autopos 11 | inTest=/afs/inf.ed.ac.uk/group/project/img2txt/dep_parser/dataset/test.autopos 12 | 13 | outValid=$curdir/valid.dep 14 | outTest=$curdir/test.dep 15 | 16 | log=$curdir/gen-log.txt 17 | 18 | cd $codedir 19 | 20 | CUDA_VISIBLE_DEVICES=3 th train_labeled.lua --mode generate \ 21 | --modelPath $model \ 22 | --outTrainDataPath $outTrain \ 23 | --inTrain $inTrain \ 24 | --inValid $inValid \ 25 | --inTest $inTest \ 26 | --outValid $outValid \ 27 | --outTest $outTest \ 28 | --language English | tee $log 29 | 30 | cd $curdir 31 | 32 | -------------------------------------------------------------------------------- /experiments/english/mst-post.sh: -------------------------------------------------------------------------------- 1 | 2 | curdir=`pwd` 3 | codedir=/afs/inf.ed.ac.uk/group/project/img2txt/dep_parser/dense_release 4 | 5 | model=/disk/scratch/s1270921/dep_parse/experiments/English/release/model_0.001.tune.t7 6 | validout=$curdir/valid 7 | testout=$curdir/test 8 | log=$curdir/mst_log.txt 9 | 10 | cd $codedir 11 | 12 | CUDA_VISIBLE_DEVICES=0 th mst_postprocess.lua \ 13 | --mstalg Eisner \ 14 | --modelPath $model \ 15 | --validout $validout \ 16 | --testout $testout | tee $log 17 | 18 | cd $curdir 19 | 20 | -------------------------------------------------------------------------------- /experiments/english/run_lbl.sh: -------------------------------------------------------------------------------- 1 | 2 | codedir=/afs/inf.ed.ac.uk/group/project/img2txt/dep_parser/dense_release 3 | curdir=`pwd` 4 | 5 | dataset=$curdir/label_train.h5 6 | model=$curdir/lbl_lassifier.t7 7 | 8 | log=$curdir/lbl_log.txt 9 | 10 | cd $codedir 11 | 12 | CUDA_VISIBLE_DEVICES=3 th train_labeled.lua --mode train \ 13 | --useGPU \ 14 | --snhids "1860,800,800,45" \ 15 | --activ relu \ 16 | --lr 0.01 \ 17 | --optimMethod AdaGrad \ 18 | --dropout 0.5 \ 19 | --inDropout 0.05 \ 20 | --batchSize 256 \ 21 | --maxEpoch 20 \ 22 | --ftype "|x|xe|xpe|" \ 23 | --dataset $dataset \ 24 | --save $model | tee $log 25 | 26 | cd $curdir 27 | 28 | 29 | -------------------------------------------------------------------------------- /experiments/english/train.sh: -------------------------------------------------------------------------------- 1 | 2 | # ID=`./gpu_lock.py --id-to-hog 0` 3 | ID=2 4 | echo $ID 5 | if [ $ID -eq -1 ]; then 6 | echo "this gpu is not free" 7 | exit 8 | fi 9 | # ./gpu_lock.py 10 | 11 | codedir=/afs/inf.ed.ac.uk/group/project/img2txt/dep_parser/dense_release 12 | curdir=`pwd` 13 | lr=0.001 14 | label=.std.ft0.300.dp0.35.r0.1 15 | model=$curdir/model_$lr$label.t7 16 | log=$curdir/log_$lr$label.txt 17 | 18 | train=/afs/inf.ed.ac.uk/group/project/img2txt/dep_parser/dataset/train.autopos 19 | valid=/afs/inf.ed.ac.uk/group/project/img2txt/dep_parser/dataset/valid.autopos 20 | test=/afs/inf.ed.ac.uk/group/project/img2txt/dep_parser/dataset/test.autopos 21 | 22 | wembed=/afs/inf.ed.ac.uk/group/project/img2txt/dep_parser/dataset/glove.840B.300d.penn.t7 23 | 24 | cd $codedir 25 | 26 | CUDA_VISIBLE_DEVICES=$ID th train.lua --useGPU \ 27 | --model SelectNetPos \ 28 | --seqLen 120 \ 29 | --maxTrainLen 100 \ 30 | --freqCut 1 \ 31 | --nhid 300 \ 32 | --nin 300 \ 33 | --nlayers 2 \ 34 | --dropout 0.35 \ 35 | --recDropout 0.1 \ 36 | --lr $lr \ 37 | --train $train \ 38 | --valid $valid \ 39 | --test $test \ 40 | --optimMethod Adam \ 41 | --save $model \ 42 | --batchSize 32 \ 43 | --validBatchSize 32 \ 44 | --maxEpoch 15 \ 45 | --wordEmbedding $wembed \ 46 | --embedOption fineTune \ 47 | --fineTuneFactor 0 \ 48 | --npin 30 \ 49 | | tee $log 50 | 51 | cd $curdir 52 | 53 | # ./gpu_lock.py --free $ID 54 | # ./gpu_lock.py 55 | 56 | -------------------------------------------------------------------------------- /experiments/english/tune.sh: -------------------------------------------------------------------------------- 1 | 2 | # ID=`./gpu_lock.py --id-to-hog 0` 3 | ID=3 4 | echo $ID 5 | if [ $ID -eq -1 ]; then 6 | echo "this gpu is not free" 7 | exit 8 | fi 9 | # ./gpu_lock.py 10 | 11 | codedir=/afs/inf.ed.ac.uk/group/project/img2txt/dep_parser/dense_release 12 | curdir=`pwd` 13 | lr=0.001 14 | label=.tune 15 | model=$curdir/model_$lr$label.t7 16 | log=$curdir/log_$lr$label.txt 17 | 18 | load=/disk/scratch/s1270921/dep_parse/experiments/English/release/model_0.001.std.ft0.300.dp0.35.r0.1.t7 19 | 20 | cd $codedir 21 | 22 | CUDA_VISIBLE_DEVICES=$ID th post_train.lua \ 23 | --load $load \ 24 | --save $model \ 25 | --lr $lr \ 26 | --maxEpoch 10 \ 27 | --optimMethod SGD \ 28 | | tee $log 29 | 30 | cd $curdir 31 | 32 | # ./gpu_lock.py --free $ID 33 | # ./gpu_lock.py 34 | 35 | -------------------------------------------------------------------------------- /experiments/german/gen_lbl_train.sh: -------------------------------------------------------------------------------- 1 | 2 | codedir=/afs/inf.ed.ac.uk/group/project/img2txt/dep_parser/dense_release 3 | curdir=`pwd` 4 | 5 | model=$curdir/model_0.001.tune.t7 6 | 7 | outTrain=$curdir/label_train.h5 8 | 9 | inTrain=/disk/scratch/s1270921/dep_parse/data_conll/german/german_gold_train.conll 10 | inValid=/disk/scratch/s1270921/dep_parse/data_conll/german/german_gold_dev.conll 11 | inTest=/disk/scratch/s1270921/dep_parse/data_conll/german/german_gold_test.conll 12 | 13 | outValid=$curdir/valid.dep 14 | outTest=$curdir/test.dep 15 | 16 | log=$curdir/gen-log.txt 17 | 18 | cd $codedir 19 | 20 | CUDA_VISIBLE_DEVICES=2 th train_labeled.lua --mode generate \ 21 | --modelPath $model \ 22 | --outTrainDataPath $outTrain \ 23 | --inTrain $inTrain \ 24 | --inValid $inValid \ 25 | --inTest $inTest \ 26 | --outValid $outValid \ 27 | --outTest $outTest \ 28 | --language Other | tee $log 29 | 30 | cd $curdir 31 | 32 | 33 | -------------------------------------------------------------------------------- /experiments/german/mst-post.sh: -------------------------------------------------------------------------------- 1 | 2 | curdir=`pwd` 3 | 4 | codedir=/afs/inf.ed.ac.uk/group/project/img2txt/dep_parser/dense_release 5 | 6 | model=/disk/scratch/s1270921/dep_parse/experiments/German/model_0.001.tune.t7 7 | validout=$curdir/valid 8 | testout=$curdir/test 9 | log=$curdir/log.txt 10 | 11 | cd $codedir 12 | 13 | CUDA_VISIBLE_DEVICES=3 th mst_postprocess.lua \ 14 | --modelPath $model \ 15 | --mstalg ChuLiuEdmonds \ 16 | --validout $validout \ 17 | --testout $testout | tee $log 18 | 19 | cd $curdir 20 | 21 | -------------------------------------------------------------------------------- /experiments/german/run_lbl.sh: -------------------------------------------------------------------------------- 1 | 2 | codedir=/afs/inf.ed.ac.uk/group/project/img2txt/dep_parser/dense_release 3 | curdir=`pwd` 4 | 5 | dataset=$curdir/label_train.h5 6 | model=$curdir/lbl_lassifier.t7 7 | 8 | 9 | inTrain=/disk/scratch/s1270921/dep_parse/data_conll/german/german_gold_train.conll 10 | inValid=/disk/scratch/s1270921/dep_parse/data_conll/german/german_gold_dev.conll 11 | inTest=/disk/scratch/s1270921/dep_parse/data_conll/german/german_gold_test.conll 12 | 13 | 14 | log=$curdir/lbl_log.txt 15 | 16 | cd $codedir 17 | 18 | CUDA_VISIBLE_DEVICES=2 th train_labeled.lua --mode train \ 19 | --useGPU \ 20 | --snhids "1880,800,800,82" \ 21 | --activ relu \ 22 | --lr 0.01 \ 23 | --optimMethod AdaGrad \ 24 | --dropout 0.5 \ 25 | --inDropout 0.05 \ 26 | --batchSize 256 \ 27 | --maxEpoch 20 \ 28 | --ftype "|x|xe|xpe|" \ 29 | --dataset $dataset \ 30 | --inTrain $inTrain \ 31 | --inValid $inValid \ 32 | --inTest $inTest \ 33 | --language Other \ 34 | --save $model | tee $log 35 | 36 | cd $curdir 37 | 38 | -------------------------------------------------------------------------------- /experiments/german/train.sh: -------------------------------------------------------------------------------- 1 | 2 | # ID=`./gpu_lock.py --id-to-hog 0` 3 | ID=3 4 | echo $ID 5 | if [ $ID -eq -1 ]; then 6 | echo "this gpu is not free" 7 | exit 8 | fi 9 | # ./gpu_lock.py 10 | 11 | codedir=/afs/inf.ed.ac.uk/group/project/img2txt/dep_parser/dense_release 12 | curdir=`pwd` 13 | lr=0.001 14 | label=.std.ft0.dp0.4.r0.1.bs20 15 | model=$curdir/model_$lr$label.t7 16 | log=$curdir/log_$lr$label.txt 17 | 18 | : ' 19 | train=/disk/scratch/XingxingZhang/dep_parse/dataset/german/german_gold_train.conll 20 | valid=/disk/scratch/XingxingZhang/dep_parse/dataset/german/german_gold_dev.conll 21 | test=/disk/scratch/XingxingZhang/dep_parse/dataset/german/german_gold_test.conll 22 | ' 23 | train=/disk/scratch/s1270921/dep_parse/data_conll/german/german_gold_train.conll 24 | valid=/disk/scratch/s1270921/dep_parse/data_conll/german/german_gold_dev.conll 25 | test=/disk/scratch/s1270921/dep_parse/data_conll/german/german_gold_test.conll 26 | 27 | 28 | 29 | cd $codedir 30 | 31 | CUDA_VISIBLE_DEVICES=$ID th train.lua --useGPU \ 32 | --model SelectNetPos \ 33 | --seqLen 112 \ 34 | --maxTrainLen 110 \ 35 | --freqCut 1 \ 36 | --nhid 300 \ 37 | --nin 300 \ 38 | --nlayers 2 \ 39 | --dropout 0.4 \ 40 | --recDropout 0.1 \ 41 | --lr $lr \ 42 | --train $train \ 43 | --valid $valid \ 44 | --test $test \ 45 | --optimMethod Adam \ 46 | --save $model \ 47 | --batchSize 20 \ 48 | --validBatchSize 20 \ 49 | --maxEpoch 15 \ 50 | --npin 40 \ 51 | --evalType conllx \ 52 | | tee $log 53 | 54 | cd $curdir 55 | 56 | # ./gpu_lock.py --free $ID 57 | # ./gpu_lock.py 58 | 59 | -------------------------------------------------------------------------------- /experiments/german/tune.sh: -------------------------------------------------------------------------------- 1 | 2 | # ID=`./gpu_lock.py --id-to-hog 0` 3 | ID=3 4 | echo $ID 5 | if [ $ID -eq -1 ]; then 6 | echo "this gpu is not free" 7 | exit 8 | fi 9 | # ./gpu_lock.py 10 | 11 | # codedir=/afs/inf.ed.ac.uk/group/project/img2txt/dep_parser/select_net_multi_ling_v1.3.10 12 | codedir=/afs/inf.ed.ac.uk/group/project/img2txt/dep_parser/dense_release 13 | 14 | curdir=`pwd` 15 | lr=0.001 16 | label=.tune 17 | model=$curdir/model_$lr$label.t7 18 | log=$curdir/log_$lr$label.txt 19 | 20 | # load=/disk/scratch/XingxingZhang/dep_parse/experiments/german/unlabeled/model_0.001.std.ft0.dp0.4.r0.1.bs20.t7 21 | load=/disk/scratch/s1270921/dep_parse/experiments/German/model_0.001.std.ft0.dp0.4.r0.1.bs20.t7 22 | 23 | 24 | cd $codedir 25 | 26 | CUDA_VISIBLE_DEVICES=$ID th post_train.lua \ 27 | --load $load \ 28 | --save $model \ 29 | --lr $lr \ 30 | --maxEpoch 10 \ 31 | --optimMethod SGD \ 32 | | tee $log 33 | 34 | cd $curdir 35 | 36 | # ./gpu_lock.py --free $ID 37 | # ./gpu_lock.py 38 | 39 | -------------------------------------------------------------------------------- /experiments/run_parser/run_chinese.sh: -------------------------------------------------------------------------------- 1 | 2 | codedir=/afs/inf.ed.ac.uk/group/project/img2txt/dep_parser/dense_release 3 | curdir=`pwd` 4 | 5 | model=/disk/scratch/s1270921/dep_parse/experiments/pre-trained-models/chinese/model_0.001.tune.t7 6 | classifier=/disk/scratch/s1270921/dep_parse/experiments/pre-trained-models/chinese/lbl_classifier.t7 7 | input=/disk/scratch/s1270921/dep_parse/dataset/ctb/test.gold.conll 8 | output=$curdir/chinese.conllx 9 | 10 | cd $codedir 11 | 12 | CUDA_VISIBLE_DEVICES=3 th dense_parser.lua --modelPath $model --classifierPath $classifier \ 13 | --input $input --output $output --gold $input --mstalg Eisner 14 | 15 | cd $curdir 16 | 17 | -------------------------------------------------------------------------------- /experiments/run_parser/run_czech.sh: -------------------------------------------------------------------------------- 1 | 2 | codedir=/afs/inf.ed.ac.uk/group/project/img2txt/dep_parser/dense_release 3 | curdir=`pwd` 4 | 5 | model=/disk/scratch/s1270921/dep_parse/experiments/pre-trained-models/czech/model_0.001.tune.t7 6 | classifier=/disk/scratch/s1270921/dep_parse/experiments/pre-trained-models/czech/lbl_classifier.t7 7 | input=/disk/scratch/s1270921/dep_parse/data_conll/czech/czech_gold_test.conll 8 | output=$curdir/czech.conllx 9 | 10 | cd $codedir 11 | 12 | CUDA_VISIBLE_DEVICES=3 th dense_parser.lua --modelPath $model --classifierPath $classifier \ 13 | --input $input --output $output --gold $input --mstalg ChuLiuEdmonds 14 | 15 | cd $curdir 16 | 17 | -------------------------------------------------------------------------------- /experiments/run_parser/run_english.sh: -------------------------------------------------------------------------------- 1 | 2 | codedir=/afs/inf.ed.ac.uk/group/project/img2txt/dep_parser/dense_release 3 | curdir=`pwd` 4 | 5 | model=/disk/scratch/s1270921/dep_parse/experiments/pre-trained-models/english/model_0.001.tune.t7 6 | classifier=/disk/scratch/s1270921/dep_parse/experiments/pre-trained-models/english/lbl_classifier.t7 7 | input=/afs/inf.ed.ac.uk/group/project/img2txt/dep_parser/dataset/test.autopos 8 | output=$curdir/english.conllx 9 | 10 | cd $codedir 11 | 12 | CUDA_VISIBLE_DEVICES=3 th dense_parser.lua --modelPath $model --classifierPath $classifier \ 13 | --input $input --output $output --gold $input --mstalg Eisner 14 | 15 | cd $curdir 16 | 17 | -------------------------------------------------------------------------------- /experiments/run_parser/run_german.sh: -------------------------------------------------------------------------------- 1 | 2 | codedir=/afs/inf.ed.ac.uk/group/project/img2txt/dep_parser/dense_release 3 | curdir=`pwd` 4 | 5 | model=/disk/scratch/s1270921/dep_parse/experiments/pre-trained-models/german/model_0.001.tune.t7 6 | classifier=/disk/scratch/s1270921/dep_parse/experiments/pre-trained-models/german/lbl_classifier.t7 7 | input=/disk/scratch/s1270921/dep_parse/data_conll/german/german_gold_test.conll 8 | output=$curdir/german.conllx 9 | 10 | cd $codedir 11 | 12 | CUDA_VISIBLE_DEVICES=3 th dense_parser.lua --modelPath $model --classifierPath $classifier \ 13 | --input $input --output $output --gold $input --mstalg ChuLiuEdmonds 14 | 15 | cd $curdir 16 | 17 | -------------------------------------------------------------------------------- /graph_alg/ChuLiuEdmonds.lua: -------------------------------------------------------------------------------- 1 | 2 | local CLE = torch.class('ChuLiuEdmonds') 3 | 4 | function CLE:load(size, inEdges) 5 | self.size = size 6 | self.edges = {} 7 | for i, e in ipairs(inEdges) do 8 | local edge = {u = e[1], v = e[2], weight = e[3]} 9 | edge.ori_u = edge.u 10 | edge.ori_v = edge.v 11 | edge.ori_weight = edge.weight 12 | table.insert(self.edges, edge) 13 | end 14 | end 15 | 16 | function CLE:solve(root, N) 17 | local cost = 0 18 | local INF = 1e308 19 | local inWeights = {} 20 | local inEdges = {} 21 | for i = 1, N do inWeights[i] = -INF end 22 | -- for i = 1, N do inWeights[i] = INF end 23 | -- find largest input edges 24 | for _, edge in ipairs(self.edges) do 25 | local u, v = edge.u, edge.v 26 | if inWeights[v] < edge.weight and u ~= v then 27 | -- if inWeights[v] > edge.weight and u ~= v then 28 | inWeights[v] = edge.weight 29 | inEdges[v] = edge 30 | end 31 | end 32 | -- check if there is no incomming edge 33 | for i = 1, N do 34 | if i ~= root and inWeights[i] == -INF then 35 | -- if i ~= root and inWeights[i] == INF then 36 | return -1 37 | end 38 | end 39 | inWeights[root] = 0 40 | -- find cycles 41 | local newid = {} 42 | local vis = {} 43 | for i = 1, N do 44 | newid[i] = -1 45 | vis[i] = -1 46 | end 47 | local nx_id = 1 48 | -- handle cycles 49 | for i = 1, N do 50 | cost = cost + inWeights[i] 51 | local v = i 52 | while vis[v] ~= i and newid[v] == -1 and v ~= root do 53 | vis[v] = i 54 | v = inEdges[v].u 55 | end 56 | if v ~= root and newid[v] == -1 then 57 | local u = inEdges[v].u 58 | while u ~= v do 59 | newid[u] = nx_id 60 | u = inEdges[u].u 61 | end 62 | newid[v] = nx_id 63 | nx_id = nx_id + 1 64 | end 65 | end 66 | if nx_id == 1 then -- no cycle, done! 67 | inEdges[root] = nil 68 | local finalEdges = {} 69 | for _, edge in pairs(inEdges) do 70 | table.insert(finalEdges, edge) 71 | end 72 | return cost, finalEdges 73 | end 74 | local max_cycle_id = nx_id - 1 75 | -- assign other nodes 76 | for i = 1, N do 77 | if newid[i] == -1 then 78 | newid[i] = nx_id 79 | nx_id = nx_id + 1 80 | end 81 | end 82 | -- rebuild graph and backup old edge information 83 | local cp_edges = {} 84 | for i, edge in ipairs(self.edges) do 85 | local u, v = edge.u, edge.v 86 | cp_edges[i] = {u = u, v = v, weight = edge.weight} 87 | edge.u, edge.v = newid[u], newid[v] 88 | if newid[u] ~= newid[v] then edge.weight = edge.weight - inWeights[v] end 89 | end 90 | 91 | local sub_cost, sub_inEdges = self:solve( newid[root], nx_id - 1 ) 92 | 93 | for i, edge in ipairs(self.edges) do 94 | edge.u = cp_edges[i].u 95 | edge.v = cp_edges[i].v 96 | edge.weight = cp_edges[i].weight 97 | end 98 | 99 | local finalEdges = {} 100 | for i, edge in ipairs(sub_inEdges) do 101 | table.insert(finalEdges, edge) 102 | local v = edge.v 103 | -- add edges in a circle 104 | if newid[v] <= max_cycle_id then 105 | local u = inEdges[v].u 106 | while u ~= v do 107 | table.insert(finalEdges, inEdges[u]) 108 | u = inEdges[u].u 109 | end 110 | end 111 | end 112 | 113 | return cost + sub_cost, finalEdges 114 | end 115 | 116 | local function main() 117 | require '../utils/shortcut' 118 | --[[ 119 | 4 6 120 | 0 6 121 | 4 6 122 | 0 0 123 | 7 20 124 | 1 2 125 | 1 3 126 | 2 3 127 | 3 4 128 | 3 1 129 | 3 2 130 | --]] 131 | --[[ 132 | 4 6 133 | 0 6 134 | 4 6 135 | 0 0 136 | 7 20 137 | 1 2 138 | 1 3 139 | 2 3 140 | 3 4 141 | 3 1 142 | 3 2 143 | 4 3 144 | 0 0 145 | 1 0 146 | 0 1 147 | 1 2 148 | 1 3 149 | 4 1 150 | 2 3 151 | --]] 152 | -- read data 153 | 154 | --[[ 155 | local function dist(p1, p2) 156 | return math.sqrt( (p1[1] - p2[1]) * (p1[1] - p2[1]) + (p1[2] - p2[2]) * (p1[2] - p2[2]) ) 157 | end 158 | 159 | local N, M 160 | local edges = {} 161 | local points = {} 162 | N = io.read('*number') 163 | M = io.read('*number') 164 | for i = 1, N do 165 | local x, y 166 | x = io.read('*number') 167 | y = io.read('*number') 168 | points[#points + 1] = {x, y} 169 | end 170 | -- print(points) 171 | for i = 1, M do 172 | local u, v 173 | u = io.read('*number') 174 | v = io.read('*number') 175 | edges[#edges + 1] = {u, v, dist(points[u], points[v])} 176 | end 177 | 178 | print(edges) 179 | --]] 180 | 181 | --[[ 182 | 6 9 183 | 1 2 5 184 | 1 3 6 185 | 2 3 1 186 | 3 5 2 187 | 4 2 5 188 | 4 5 3 189 | 5 2 3 190 | 5 6 2 191 | 6 4 3 192 | --]] 193 | local N, M 194 | local edges = {} 195 | N = io.read('*number') 196 | M = io.read('*number') 197 | for i = 1, M do 198 | local u, v, w 199 | u = io.read('*number') 200 | v = io.read('*number') 201 | w = io.read('*number') 202 | table.insert(edges, {u, v, w}) 203 | end 204 | 205 | local cle = ChuLiuEdmonds() 206 | cle:load(N, edges) 207 | local cost, out_edges = cle:solve(1, N) 208 | print 'get MST done!' 209 | print(cost) 210 | print(out_edges) 211 | table.sort(out_edges, function(a, b) return a.v < b.v end) 212 | print(out_edges) 213 | end 214 | 215 | if not package.loaded['ChuLiuEdmonds'] then 216 | main() 217 | end 218 | 219 | 220 | -------------------------------------------------------------------------------- /graph_alg/Eisner.lua: -------------------------------------------------------------------------------- 1 | 2 | local EisnerAlg = torch.class('Eisner') 3 | 4 | -- index starts from 1 5 | function EisnerAlg:load(size, inEdges) 6 | self.size = size 7 | self.adjMat = {} 8 | for i = 1, size do self.adjMat[i] = {} end 9 | for _, e in ipairs(inEdges) do 10 | self.adjMat[e[1]][e[2]] = e[3] 11 | end 12 | end 13 | 14 | function EisnerAlg:solve() 15 | -- init states 16 | local E = {} 17 | local A = {} 18 | for i = 1, self.size do 19 | E[i] = {} 20 | A[i] = {} 21 | for j = 1, self.size do 22 | E[i][j] = { 0, 0, 0, 0 } 23 | A[i][j] = { -1, -1, -1, -1 } 24 | end 25 | end 26 | -- do dynamic programming 27 | for L = 1, self.size - 1 do 28 | for s = 1, self.size - L do 29 | local t = s + L 30 | -- E[s][t][3] 31 | for q = s, t - 1 do 32 | local w_ts = self.adjMat[t][s] or 0 33 | -- E[s][t][3] = math.max( E[s][t][3], E[s][q][2] + E[q+1][t][1] + w_ts ) 34 | if E[s][t][3] == 0 or E[s][q][2] + E[q+1][t][1] + w_ts > E[s][t][3] then 35 | E[s][t][3] = E[s][q][2] + E[q+1][t][1] + w_ts 36 | A[s][t][3] = q 37 | end 38 | end 39 | -- E[s][t][4] 40 | for q = s, t - 1 do 41 | local w_st = self.adjMat[s][t] or 0 42 | -- E[s][t][4] = math.max( E[s][t][4], E[s][q][2] + E[q+1][t][1] + w_st ) 43 | if E[s][t][4] == 0 or E[s][q][2] + E[q+1][t][1] + w_st > E[s][t][4] then 44 | E[s][t][4] = E[s][q][2] + E[q+1][t][1] + w_st 45 | A[s][t][4] = q 46 | end 47 | end 48 | -- E[s][t][1] 49 | for q = s, t - 1 do 50 | -- E[s][t][1] = math.max( E[s][t][1], E[s][q][1] + E[q][t][3] ) 51 | if E[s][t][1] == 0 or E[s][q][1] + E[q][t][3] > E[s][t][1] then 52 | E[s][t][1] = E[s][q][1] + E[q][t][3] 53 | A[s][t][1] = q 54 | end 55 | end 56 | -- E[s][t][2] 57 | for q = s + 1, t do 58 | -- E[s][t][2] = math.max( E[s][t][2], E[s][q][4] + E[q][t][2] ) 59 | if E[s][t][2] == 0 or E[s][q][4] + E[q][t][2] > E[s][t][2] then 60 | E[s][t][2] = E[s][q][4] + E[q][t][2] 61 | A[s][t][2] = q 62 | end 63 | end 64 | end 65 | end 66 | 67 | -- find the edges 68 | local cost = E[1][self.size][2] 69 | 70 | local edges = {} 71 | local function getEdges(s, t, tt) 72 | if s < t then 73 | local q = A[s][t][tt] 74 | if tt == 3 then 75 | if self.adjMat[t][s] then 76 | edges[#edges + 1] = {u = t, v = s, weight = self.adjMat[t][s]} 77 | end 78 | getEdges(s, q, 2) 79 | getEdges(q + 1, t, 1) 80 | elseif tt == 4 then 81 | if self.adjMat[s][t] then 82 | edges[#edges + 1] = {u = s, v = t, weight = self.adjMat[s][t]} 83 | end 84 | getEdges(s, q, 2) 85 | getEdges(q + 1, t, 1) 86 | elseif tt == 1 then 87 | getEdges(s, q, 1) 88 | getEdges(q, t, 3) 89 | elseif tt == 2 then 90 | getEdges(s, q, 4) 91 | getEdges(q, t, 2) 92 | end 93 | end 94 | end 95 | 96 | getEdges(1, self.size, 2) 97 | 98 | return cost, edges 99 | end 100 | 101 | -- this is for test 102 | local function main() 103 | require '../utils/shortcut' 104 | local eisner = Eisner() 105 | local N, M 106 | local edges = {} 107 | N = io.read('*number') 108 | M = io.read('*number') 109 | --[[ 110 | 4 9 111 | 1 2 9 112 | 3 2 30 113 | 4 2 11 114 | 1 3 10 115 | 2 3 20 116 | 4 3 0 117 | 1 4 9 118 | 2 4 3 119 | 3 4 30 120 | --]] 121 | for i = 1, M do 122 | local u, v, w 123 | u = io.read('*number') 124 | v = io.read('*number') 125 | w = io.read('*number') 126 | table.insert(edges, {u, v, w}) 127 | end 128 | eisner:load(N, edges) 129 | print( eisner:solve() ) 130 | end 131 | 132 | if not package.loaded['Eisner'] then 133 | main() 134 | end 135 | 136 | -------------------------------------------------------------------------------- /graph_alg/PostDepGraph.lua: -------------------------------------------------------------------------------- 1 | 2 | local Vertex = torch.class('PDVertex') 3 | 4 | function Vertex:__init(id, wd) 5 | self.id = id or -1 6 | self.wd = wd or 'UNKNOWN' 7 | self.adjList = {} 8 | end 9 | 10 | local DepGraph = torch.class('PostDepGraph') 11 | 12 | function DepGraph:__init(sent) 13 | self.vertices = {[0] = PDVertex(0, '###root###')} 14 | for _, item in ipairs(sent) do 15 | local id, wd = item.p1, item.wd 16 | self.vertices[id] = PDVertex(id, wd) 17 | end 18 | self.size = #sent + 1 19 | 20 | for _, item in ipairs(sent) do 21 | local u = tonumber(item.p2) 22 | local v = tonumber(item.p1) 23 | -- printf('u = %d, v = %d, N = %d\n', u, v, self.size) 24 | table.insert( self.vertices[u].adjList, self.vertices[v] ) 25 | end 26 | end 27 | 28 | function DepGraph:checkConnectivity() 29 | local visited = {} 30 | for i = 0, self.size-1 do visited[i] = false end 31 | local root = self.vertices[0] 32 | self:dfs(root, visited) 33 | for i = 0, self.size-1 do 34 | if visited[i] == false then return false end 35 | end 36 | return true 37 | end 38 | 39 | function DepGraph:dfs(u, visited) 40 | visited[u.id] = true 41 | for _, v in ipairs(u.adjList) do 42 | if not visited[v.id] then 43 | self:dfs(v, visited) 44 | end 45 | end 46 | end 47 | 48 | function DepGraph:isProjective() 49 | -- sort child 50 | for i = 0, self.size - 1 do 51 | local u = self.vertices[i] 52 | u.lchild = {} 53 | u.rchild = {} 54 | for _, v in ipairs(u.adjList) do 55 | if v.id < u.id then 56 | u.lchild[#u.lchild + 1] = v 57 | elseif v.id > u.id then 58 | u.rchild[#u.rchild + 1] = v 59 | else 60 | error('impossible!') 61 | end 62 | end 63 | table.sort(u.lchild, function(a, b) return a.id < b.id end) 64 | table.sort(u.rchild, function(a, b) return a.id < b.id end) 65 | end 66 | 67 | local root = self.vertices[0] 68 | self.dfs_id = 0 69 | self.is_proj = true 70 | self:proj_dfs(root) 71 | 72 | for i = 0, self.size - 1 do 73 | self.vertices[i].lchild = nil 74 | self.vertices[i].rchild = nil 75 | end 76 | 77 | return self.is_proj 78 | end 79 | 80 | function DepGraph:proj_dfs(u) 81 | if #u.lchild > 0 then 82 | for _, v in ipairs(u.lchild) do 83 | self:proj_dfs(v) 84 | end 85 | end 86 | 87 | -- printf('u.id = %d, visited id = %d\n', u.id, self.dfs_id) 88 | if u.id ~= self.dfs_id then self.is_proj = false end 89 | self.dfs_id = self.dfs_id + 1 90 | 91 | if #u.rchild > 0 then 92 | for _, v in ipairs(u.rchild) do 93 | self:proj_dfs(v) 94 | end 95 | end 96 | end 97 | 98 | function DepGraph:showSentence() 99 | local sents = {} 100 | for i = 0, self.size - 1 do 101 | sents[#sents + 1] = self.vertices[i].wd 102 | end 103 | print(table.concat(sents, ' ')) 104 | end 105 | 106 | -------------------------------------------------------------------------------- /init.lua: -------------------------------------------------------------------------------- 1 | 2 | function importdir(dpath) 3 | package.path = dpath .. "/?.lua;" .. package.path 4 | end 5 | 6 | -- load all local packages 7 | for dir in paths.iterdirs(".") do 8 | importdir(dir) 9 | end 10 | 11 | require 'torch' 12 | require 'nn' 13 | require 'nngraph' 14 | require 'optim' 15 | require 'shortcut' -------------------------------------------------------------------------------- /layers/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XingxingZhang/dense_parser/750c7d00ab33defd61e04739ef12ff8273afe304/layers/.DS_Store -------------------------------------------------------------------------------- /layers/Contiguous.lua: -------------------------------------------------------------------------------- 1 | local Contiguous, parent = torch.class('Contiguous', 'nn.Module') 2 | 3 | function Contiguous:updateOutput(input) 4 | if not input:isContiguous() then 5 | self.output:resizeAs(input):copy(input) 6 | else 7 | self.output:set(input) 8 | end 9 | return self.output 10 | end 11 | 12 | function Contiguous:updateGradInput(input, gradOutput) 13 | if not gradOutput:isContiguous() then 14 | self.gradInput:resizeAs(gradOutput):copy(gradOutput) 15 | else 16 | self.gradInput:set(gradOutput) 17 | end 18 | 19 | return self.gradInput 20 | end 21 | -------------------------------------------------------------------------------- /layers/DetailedMaskedNLLCriterion.lua: -------------------------------------------------------------------------------- 1 | 2 | --[[ 3 | -- when y contains zeros 4 | --]] 5 | 6 | local DetailedMaskedNLLCriterion, parent = torch.class('DetailedMaskedNLLCriterion', 'nn.Module') 7 | 8 | function DetailedMaskedNLLCriterion:__init() 9 | parent.__init(self) 10 | end 11 | 12 | function DetailedMaskedNLLCriterion:updateOutput(input_) 13 | local input, target, div = unpack(input_) 14 | div = div or 1 15 | if input:dim() == 2 then 16 | self.output:resize(target:size()) 17 | self.output:zero() 18 | local nlls = self.output 19 | local n = target:size(1) 20 | for i = 1, n do 21 | if target[i] ~= 0 then 22 | nlls[i] = -input[i][target[i]] 23 | else 24 | nlls[i] = 0 25 | end 26 | end 27 | 28 | return self.output 29 | else 30 | error('input must be matrix! Note only batch mode is supported!') 31 | end 32 | end 33 | 34 | function DetailedMaskedNLLCriterion:updateGradInput(input_) 35 | local input, target, div = unpack(input_) 36 | div = div or 1 37 | 38 | self.gradInput:resizeAs(input) 39 | self.gradInput:zero() 40 | local er = -1 / div 41 | if input:dim() == 2 then 42 | local n = target:size(1) 43 | for i = 1, n do 44 | if target[i] ~= 0 then 45 | self.gradInput[i][target[i]] = er 46 | end 47 | end 48 | return self.gradInput 49 | else 50 | error('input must be matrix! Note only batch mode is supported!') 51 | end 52 | end 53 | 54 | -------------------------------------------------------------------------------- /layers/EMaskedClassNLLCriterion.lua: -------------------------------------------------------------------------------- 1 | 2 | --[[ 3 | -- when y contains zeros 4 | --]] 5 | 6 | local EMaskedClassNLLCriterion, parent = torch.class('EMaskedClassNLLCriterion', 'nn.Module') 7 | 8 | function EMaskedClassNLLCriterion:__init() 9 | parent.__init(self) 10 | end 11 | 12 | function EMaskedClassNLLCriterion:updateOutput(input_) 13 | local input, target, div = unpack(input_) 14 | if input:dim() == 2 then 15 | local nll = 0 16 | local n = target:size(1) 17 | for i = 1, n do 18 | if target[i] ~= 0 then 19 | nll = nll - input[i][target[i]] 20 | end 21 | end 22 | self.output = nll / div 23 | return self.output 24 | else 25 | error('input must be matrix! Note only batch mode is supported!') 26 | end 27 | end 28 | 29 | function EMaskedClassNLLCriterion:updateGradInput(input_) 30 | local input, target, div = unpack(input_) 31 | -- print('self.gradInput', self.gradInput) 32 | self.gradInput:resizeAs(input) 33 | self.gradInput:zero() 34 | local er = -1 / div 35 | if input:dim() == 2 then 36 | local n = target:size(1) 37 | for i = 1, n do 38 | if target[i] ~= 0 then 39 | self.gradInput[i][target[i]] = er 40 | end 41 | end 42 | return self.gradInput 43 | else 44 | error('input must be matrix! Note only batch mode is supported!') 45 | end 46 | end 47 | 48 | -------------------------------------------------------------------------------- /layers/Linear3D.lua: -------------------------------------------------------------------------------- 1 | local Linear, parent = torch.class('Linear3D', 'nn.Module') 2 | 3 | function Linear:__init(inputSize, outputSize, bias) 4 | parent.__init(self) 5 | local bias = ((bias == nil) and true) or bias 6 | self.weight = torch.Tensor(outputSize, inputSize) 7 | self.gradWeight = torch.Tensor(outputSize, inputSize) 8 | if bias then 9 | self.bias = torch.Tensor(outputSize) 10 | self.gradBias = torch.Tensor(outputSize) 11 | end 12 | self:reset() 13 | end 14 | 15 | function Linear:reset(stdv) 16 | if stdv then 17 | stdv = stdv * math.sqrt(3) 18 | else 19 | stdv = 1./math.sqrt(self.weight:size(2)) 20 | end 21 | if nn.oldSeed then 22 | for i=1,self.weight:size(1) do 23 | self.weight:select(1, i):apply(function() 24 | return torch.uniform(-stdv, stdv) 25 | end) 26 | end 27 | if self.bias then 28 | for i=1,self.bias:nElement() do 29 | self.bias[i] = torch.uniform(-stdv, stdv) 30 | end 31 | end 32 | else 33 | self.weight:uniform(-stdv, stdv) 34 | if self.bias then self.bias:uniform(-stdv, stdv) end 35 | end 36 | return self 37 | end 38 | 39 | function Linear:updateOutput(input_) 40 | if input_:dim() == 3 then 41 | local input = input_:view(input_:size(1) * input_:size(2), input_:size(3)) 42 | local nframe = input:size(1) 43 | local nElement = self.output:nElement() 44 | self.output:resize(nframe, self.weight:size(1)) 45 | if self.output:nElement() ~= nElement then 46 | self.output:zero() 47 | end 48 | self.addBuffer = self.addBuffer or input.new() 49 | if self.addBuffer:nElement() ~= nframe then 50 | self.addBuffer:resize(nframe):fill(1) 51 | end 52 | self.output:addmm(0, self.output, 1, input, self.weight:t()) 53 | if self.bias then self.output:addr(1, self.addBuffer, self.bias) end 54 | 55 | self.output = self.output:view(input_:size(1), input_:size(2), self.output:size(2)) 56 | else 57 | error('input must be 3D tensor') 58 | end 59 | 60 | return self.output 61 | end 62 | 63 | function Linear:updateGradInput(input_, gradOutput_) 64 | if self.gradInput then 65 | assert(input_:dim() == 3, 'input_ must be 3D tensor') 66 | assert(gradOutput_:dim() == 3, 'gradOutput_ must be 3D tensor') 67 | local input = input_:view(input_:size(1)*input_:size(2), input_:size(3)) 68 | local gradOutput = gradOutput_:view(gradOutput_:size(1) * gradOutput_:size(2), gradOutput_:size(3)) 69 | 70 | local nElement = self.gradInput:nElement() 71 | self.gradInput:resizeAs(input) 72 | if self.gradInput:nElement() ~= nElement then 73 | self.gradInput:zero() 74 | end 75 | if input:dim() == 2 then 76 | self.gradInput:addmm(0, 1, gradOutput, self.weight) 77 | else 78 | error('input right now must be a matrix') 79 | end 80 | 81 | self.gradInput = self.gradInput:view( input_:size(1), input_:size(2), self.gradInput:size(2) ) 82 | 83 | return self.gradInput 84 | end 85 | end 86 | 87 | function Linear:accGradParameters(input_, gradOutput_, scale) 88 | assert(input_:dim() == 3, 'input_ must be 3D tensor') 89 | assert(gradOutput_:dim() == 3, 'gradOutput_ must be 3D tensor') 90 | local input = input_:view(input_:size(1)*input_:size(2), input_:size(3)) 91 | local gradOutput = gradOutput_:view(gradOutput_:size(1) * gradOutput_:size(2), gradOutput_:size(3)) 92 | 93 | scale = scale or 1 94 | if input:dim() == 2 then 95 | self.gradWeight:addmm(scale, gradOutput:t(), input) 96 | if self.bias then 97 | self.gradBias:addmv(scale, gradOutput:t(), self.addBuffer) 98 | end 99 | else 100 | error('input right now must be a matrix') 101 | end 102 | end 103 | 104 | -- we do not need to accumulate parameters when sharing 105 | Linear.sharedAccUpdateGradParameters = Linear.accUpdateGradParameters 106 | 107 | function Linear:clearState() 108 | if self.addBuffer then self.addBuffer:set() end 109 | return parent.clearState(self) 110 | end 111 | 112 | function Linear:__tostring__() 113 | return torch.type(self) .. 114 | string.format('(%d -> %d)', self.weight:size(2), self.weight:size(1)) .. 115 | (self.bias == nil and ' without bias' or '') 116 | end 117 | -------------------------------------------------------------------------------- /layers/LookupTable_ft.lua: -------------------------------------------------------------------------------- 1 | local LookupTable, parent = torch.class('LookupTable_ft', 'nn.LookupTable') 2 | 3 | LookupTable.__version = 1 4 | 5 | function LookupTable:__init(nIndex, nOutput, updateMask) 6 | parent.__init(self, nIndex, nOutput) 7 | self.updateMask = torch.ones(nIndex):view(nIndex, 1):expand(nIndex, nOutput) 8 | if updateMask ~= nil then 9 | self:setUpdateMask(updateMask) 10 | end 11 | end 12 | 13 | function LookupTable:setUpdateMask(updateMask) 14 | local nIndex = self.weight:size(1) 15 | local nOutput = self.weight:size(2) 16 | assert(updateMask:nElement() == nIndex) 17 | self.updateMask:copy( updateMask:view(nIndex, 1):expand(nIndex, nOutput) ) 18 | end 19 | 20 | function LookupTable:applyGradMask() 21 | -- a slow solution 22 | self.gradWeight:cmul( self.updateMask ) 23 | end 24 | 25 | -- we do not need to accumulate parameters when sharing 26 | LookupTable.sharedAccUpdateGradParameters = LookupTable.accUpdateGradParameters 27 | 28 | -------------------------------------------------------------------------------- /layers/MaskedClassNLLCriterion.lua: -------------------------------------------------------------------------------- 1 | 2 | --[[ 3 | -- when y contains zeros 4 | --]] 5 | 6 | local MaskedClassNLLCriterion, parent = torch.class('MaskedClassNLLCriterion', 'nn.Module') 7 | 8 | function MaskedClassNLLCriterion:__init() 9 | parent.__init(self) 10 | end 11 | 12 | function MaskedClassNLLCriterion:updateOutput(input_) 13 | local input, target = unpack(input_) 14 | if input:dim() == 2 then 15 | local nll = 0 16 | local n = target:size(1) 17 | for i = 1, n do 18 | if target[i] ~= 0 then 19 | nll = nll - input[i][target[i]] 20 | end 21 | end 22 | self.output = nll / target:size(1) 23 | return self.output 24 | else 25 | error('input must be matrix! Note only batch mode is supported!') 26 | end 27 | end 28 | 29 | function MaskedClassNLLCriterion:updateGradInput(input_) 30 | local input, target = unpack(input_) 31 | -- print('self.gradInput', self.gradInput) 32 | self.gradInput:resizeAs(input) 33 | self.gradInput:zero() 34 | local er = -1 / target:size(1) 35 | if input:dim() == 2 then 36 | local n = target:size(1) 37 | for i = 1, n do 38 | if target[i] ~= 0 then 39 | self.gradInput[i][target[i]] = er 40 | end 41 | end 42 | return self.gradInput 43 | else 44 | error('input must be matrix! Note only batch mode is supported!') 45 | end 46 | end 47 | 48 | -------------------------------------------------------------------------------- /layers/PReplicate.lua: -------------------------------------------------------------------------------- 1 | local Replicate, parent = torch.class('PReplicate','nn.Module') 2 | 3 | function Replicate:__init(dim, ndim) 4 | parent.__init(self) 5 | -- self.nfeatures = nf 6 | self.dim = dim or 1 7 | self.ndim = ndim 8 | assert(self.dim > 0, "Can only replicate across positive integer dimensions.") 9 | end 10 | 11 | function Replicate:updateOutput(input_) 12 | local input, nf = unpack( input_ ) 13 | self.nfeatures = nf 14 | 15 | self.dim = self.dim or 1 --backwards compatible 16 | assert( 17 | self.dim <= input:dim()+1, 18 | "Not enough input dimensions to replicate along dimension " .. 19 | tostring(self.dim) .. ".") 20 | local batchOffset = self.ndim and input:dim() > self.ndim and 1 or 0 21 | local rdim = self.dim + batchOffset 22 | local sz = torch.LongStorage(input:dim()+1) 23 | sz[rdim] = self.nfeatures 24 | for i = 1,input:dim() do 25 | local offset = 0 26 | if i >= rdim then 27 | offset = 1 28 | end 29 | sz[i+offset] = input:size(i) 30 | end 31 | local st = torch.LongStorage(input:dim()+1) 32 | st[rdim] = 0 33 | for i = 1,input:dim() do 34 | local offset = 0 35 | if i >= rdim then 36 | offset = 1 37 | end 38 | st[i+offset] = input:stride(i) 39 | end 40 | self.output = input.new(input:storage(),input:storageOffset(),sz,st) 41 | return self.output 42 | end 43 | 44 | function Replicate:updateGradInput(input_, gradOutput) 45 | local input, nf = unpack(input_) 46 | self.nfeatures = nf 47 | 48 | self.gradInput:resizeAs(input):zero() 49 | local batchOffset = self.ndim and input:dim() > self.ndim and 1 or 0 50 | local rdim = self.dim + batchOffset 51 | local sz = torch.LongStorage(input:dim()+1) 52 | sz[rdim] = 1 53 | for i = 1,input:dim() do 54 | local offset = 0 55 | if i >= rdim then 56 | offset = 1 57 | end 58 | sz[i+offset] = input:size(i) 59 | end 60 | local gradInput = self.gradInput:view(sz) 61 | gradInput:sum(gradOutput, rdim) 62 | 63 | return {self.gradInput, 0} 64 | end 65 | 66 | -------------------------------------------------------------------------------- /model_opts.lua: -------------------------------------------------------------------------------- 1 | 2 | local ModelOpts = {} 3 | 4 | function ModelOpts.getOpts() 5 | local cmd = torch.CmdLine() 6 | cmd:text('====== Select Network ======') 7 | cmd:option('--seed', 123, 'random seed') 8 | cmd:option('--model', 'SelectNet', 'model options: SelectNet or SelectNetPos') 9 | cmd:option('--train', '/afs/inf.ed.ac.uk/group/project/img2txt/dep_parser/dataset/train.autopos', 'train file') 10 | cmd:option('--freqCut', 1, 'for word frequencies') 11 | cmd:option('--ignoreCase', false, 'whether you will ignore the case') 12 | cmd:option('--maxNVocab', 0, 'you can also set maximum number of vocabulary') 13 | cmd:option('--valid', '/afs/inf.ed.ac.uk/group/project/img2txt/dep_parser/dataset/valid.autopos', 'valid file') 14 | cmd:option('--test', '/afs/inf.ed.ac.uk/group/project/img2txt/dep_parser/dataset/test.autopos', 'test file (in default: no test file)') 15 | cmd:option('--maxEpoch', 30, 'maximum number of epochs') 16 | cmd:option('--batchSize', 32, '') 17 | cmd:option('--validBatchSize', 32, '') 18 | cmd:option('--nin', 100, 'word embedding size') 19 | cmd:option('--npin', 50, 'pos tag embeeding size') 20 | cmd:option('--nhid', 100, 'hidden unit size') 21 | cmd:option('--nlayers', 1, 'number of hidden layers') 22 | cmd:option('--lr', 0.1, 'learning rate') 23 | cmd:option('--lrDiv', 0, 'learning rate decay when there is no significant improvement. 0 means turn off') 24 | cmd:option('--minImprovement', 1.001, 'if improvement on log likelihood is smaller then patient --') 25 | cmd:option('--optimMethod', 'AdaGrad', 'optimization algorithm') 26 | cmd:option('--gradClip', 5, '> 0 means to do Pascanu et al.\'s grad norm rescale http://arxiv.org/pdf/1502.04623.pdf; < 0 means to truncate the gradient larger than gradClip; 0 means turn off gradient clip') 27 | cmd:option('--initRange', 0.1, 'init range') 28 | cmd:option('--seqLen', 150, 'maximum seqence length') 29 | cmd:option('--maxTrainLen', 120, 'maximum train sentence length') 30 | cmd:option('--useGPU', false, 'use GPU') 31 | cmd:option('--patience', 1, 'stop training if no lower valid PPL is observed in [patience] consecutive epoch(s)') 32 | cmd:option('--save', 'model.t7', 'save model path') 33 | 34 | cmd:option('--disableEearlyStopping', false, 'no early stopping during training') 35 | 36 | cmd:text() 37 | cmd:text('Options for long jobs') 38 | cmd:option('--savePerEpoch', false, 'save model every epoch') 39 | cmd:option('--saveBeforeLrDiv', false, 'save model before lr div') 40 | 41 | cmd:text() 42 | cmd:text('Options for regularization') 43 | cmd:option('--dropout', 0, 'dropout rate (dropping)') 44 | cmd:text() 45 | cmd:text('Options for rec dropout') 46 | cmd:option('--recDropout', 0, 'recurrent dropout') 47 | 48 | cmd:text() 49 | cmd:text('Options for Word Embedding initialization') 50 | cmd:option('--wordEmbedding', '', 'word embedding path') 51 | cmd:option('--embedOption', 'init', 'options: init, fineTune (if you use fineTune option, you must specify fineTuneFactor)') 52 | cmd:option('--fineTuneFactor', 0, '0 mean not upates, other value means such as 0.01') 53 | 54 | cmd:text() 55 | cmd:text('Options for evaluation Standard') 56 | cmd:option('--evalType', 'stanford', 'stanford or conllx') 57 | 58 | local opts = cmd:parse(arg) 59 | ModelOpts.initOpts(opts) 60 | 61 | return opts 62 | end 63 | 64 | function ModelOpts.initOpts(opts) 65 | -- for different optimization algorithms 66 | local optimMethods = {'AdaGrad', 'Adam', 'AdaDelta', 'SGD'} 67 | if not table.contains(optimMethods, opts.optimMethod) then 68 | error('invalid optimization method! ' .. opts.optimMethod) 69 | end 70 | 71 | opts.curLR = opts.lr 72 | opts.minLR = 1e-7 73 | opts.sgdParam = {learningRate = opts.lr} 74 | if opts.optimMethod == 'AdaDelta' then 75 | opts.rho = 0.95 76 | opts.eps = 1e-6 77 | opts.sgdParam.rho = opts.rho 78 | opts.sgdParam.eps = opts.eps 79 | elseif opts.optimMethod == 'SGD' then 80 | if opts.lrDiv <= 1 then 81 | opts.lrDiv = 2 82 | end 83 | end 84 | 85 | torch.manualSeed(opts.seed) 86 | if opts.useGPU then 87 | require 'cutorch' 88 | require 'cunn' 89 | cutorch.manualSeed(opts.seed) 90 | end 91 | end 92 | 93 | return ModelOpts 94 | 95 | -------------------------------------------------------------------------------- /mst_postprocess.lua: -------------------------------------------------------------------------------- 1 | 2 | require '.' 3 | require 'shortcut' 4 | require 'SelectNetPos' 5 | require 'DepPosDataIter' 6 | require 'PostDepGraph' 7 | require 'ChuLiuEdmonds' 8 | require 'Eisner' 9 | 10 | local MST = torch.class('MSTPostProcessor') 11 | 12 | function MST:showOpts() 13 | local tmp_vocab = self.opts.vocab 14 | self.opts.vocab = nil 15 | print(self.opts) 16 | self.opts.vocab = tmp_vocab 17 | end 18 | 19 | function MST:load(modelPath) 20 | self.opts = torch.load( modelPath:sub(1, -3) .. 'state.t7' ) 21 | local opts = self.opts 22 | 23 | torch.manualSeed(opts.seed + 1) 24 | if opts.useGPU then 25 | require 'cutorch' 26 | require 'cunn' 27 | cutorch.manualSeed(opts.seed + 1) 28 | end 29 | 30 | local vocabPath = opts.train .. '.tmp.pos.vocab.t7' 31 | local recreateVocab = true 32 | if paths.filep(vocabPath) then 33 | opts.vocab = torch.load(vocabPath) 34 | if opts.vocab.ignoreCase == opts.ignoreCase and opts.vocab.freqCut == opts.freqCut and opts.vocab.maxNVocab == opts.maxNVocab then 35 | recreateVocab = false 36 | DepPosDataIter.showVocab(opts.vocab) 37 | print '****load from existing vocab!!!****\n\n' 38 | end 39 | end 40 | assert(not recreateVocab, 'you should load existing vocabulary') 41 | 42 | self.net = SelectNetPos(opts) 43 | self:showOpts() 44 | 45 | xprintln('load from %s ...', modelPath) 46 | self.net:load(modelPath) 47 | xprintln('load from %s done!', modelPath) 48 | end 49 | 50 | function MST:validConllx(validFile, outConllxFile) 51 | local dataIter = DepPosDataIter.createBatch(self.opts.vocab, validFile, self.opts.batchSize, 150) 52 | local totalCnt = 0 53 | local totalLoss = 0 54 | local cnt = 0 55 | 56 | local sents_dep = {} 57 | local y_tmp = torch.LongTensor(150, self.opts.batchSize) 58 | for x, x_mask, x_pos, y in dataIter do 59 | local loss, y_preds = self.net:validBatch(x, x_mask, x_pos, y) 60 | 61 | totalLoss = totalLoss + loss * x:size(2) 62 | local y_mask = x_mask[{ {2, -1}, {} }] 63 | 64 | local y_p = y_tmp:resize(y:size(1), y:size(2)) 65 | for t = 2, x:size(1) do 66 | local _, mi = y_preds[t]:max(2) 67 | if self.opts.useGPU then mi = mi:double() end 68 | y_p[{ t-1, {} }] = mi 69 | end 70 | 71 | for i = 1, y_mask:size(2) do 72 | local slen = y_mask[{ {}, i }]:sum() 73 | local sent_dep = {} 74 | for j = 1, slen do 75 | sent_dep[#sent_dep + 1] = y_p[{ j, i }] - 1 76 | end 77 | sents_dep[#sents_dep + 1] = sent_dep 78 | end 79 | 80 | totalCnt = totalCnt + y_mask:sum() 81 | cnt = cnt + 1 82 | if cnt % 5 == 0 then 83 | collectgarbage() 84 | end 85 | end 86 | 87 | outConllxFile = outConllxFile or '__tmp__.dep' 88 | 89 | local dep_iter = DepPosDataIter.conllx_iter(validFile) 90 | local sent_idx = 0 91 | local sys_out = outConllxFile 92 | local fout = io.open(sys_out, 'w') 93 | for dsent in dep_iter do 94 | sent_idx = sent_idx + 1 95 | local sent_dep = sents_dep[sent_idx] 96 | assert(#sent_dep == #dsent) 97 | for i, ditem in ipairs(dsent) do 98 | -- '%d\t%s\t_\t_\t%s\t_\t%d\t%s\t_\t_\n' 99 | -- 1 Influential _ JJ JJ _ 2 amod _ _ 100 | fout:write(string.format('%d\t%s\t_\t_\t%s\t_\t%d\tprep\t_\t_\n', ditem.p1, ditem.wd, ditem.pos, sent_dep[i])) 101 | end 102 | fout:write('\n') 103 | end 104 | fout:close() 105 | 106 | -- local conllx_eval = require 'conllx_eval' 107 | local conllx_eval = self.opts.evalType == 'stanford' and require 'conllx_eval' or require 'conllx2006_eval' 108 | local LAS, UAS, noPunctLAS, noPunctUAS = conllx_eval.eval(sys_out, validFile) 109 | 110 | local entropy = totalLoss / totalCnt 111 | local ppl = torch.exp(entropy) 112 | 113 | return {entropy = entropy, ppl = ppl, UAS = noPunctUAS} 114 | end 115 | 116 | function MST:validWithMSTPost(validFile, outConllxFile) 117 | local dataIter = DepPosDataIter.createBatch(self.opts.vocab, validFile, self.opts.batchSize, 150) 118 | local totalCnt = 0 119 | local totalLoss = 0 120 | local cnt = 0 121 | 122 | local sents_dep = {} 123 | local sents_graph = {} 124 | local y_tmp = torch.LongTensor(150, self.opts.batchSize) 125 | for x, x_mask, x_pos, y in dataIter do 126 | local loss, y_preds = self.net:validBatch(x, x_mask, x_pos, y) 127 | 128 | totalLoss = totalLoss + loss * x:size(2) 129 | local y_mask = x_mask[{ {2, -1}, {} }] 130 | 131 | local y_p = y_tmp:resize(y:size(1), y:size(2)) 132 | -- WARNING: y_preds start from 2! 133 | for t = 2, x:size(1) do 134 | local _, mi = y_preds[t]:max(2) 135 | if self.opts.useGPU then mi = mi:double() end 136 | y_p[{ t-1, {} }] = mi 137 | end 138 | 139 | for i = 1, y_mask:size(2) do 140 | local slen = y_mask[{ {}, i }]:sum() 141 | local sent_dep = {} 142 | local sent_graph = {} 143 | for j = 1, slen do 144 | sent_dep[#sent_dep + 1] = y_p[{ j, i }] - 1 145 | local tmp = y_preds[j+1][{ i, {1, slen + 1} }]:double() 146 | sent_graph[j] = tmp 147 | end 148 | sents_dep[#sents_dep + 1] = sent_dep 149 | sents_graph[#sents_graph + 1] = sent_graph 150 | end 151 | 152 | totalCnt = totalCnt + y_mask:sum() 153 | cnt = cnt + 1 154 | if cnt % 5 == 0 then 155 | collectgarbage() 156 | end 157 | end 158 | 159 | outConllxFile = outConllxFile or '__tmp__.dep' 160 | 161 | local dep_iter = DepPosDataIter.conllx_iter(validFile) 162 | local sent_idx = 0 163 | local connected_count = 0 164 | local sys_out = outConllxFile 165 | local fout = io.open(sys_out, 'w') 166 | for dsent in dep_iter do 167 | sent_idx = sent_idx + 1 168 | local sent_dep = sents_dep[sent_idx] 169 | assert(#sent_dep == #dsent) 170 | local sent_graph = sents_graph[sent_idx] 171 | assert(#sent_graph == #dsent) 172 | 173 | local new_dsent = {} 174 | for i, ditem in ipairs(dsent) do 175 | local new_ditem = {p1 = ditem.p1, wd = ditem.wd, pos = ditem.pos, p2 = sent_dep[i]} 176 | new_dsent[#new_dsent + 1] = new_ditem 177 | end 178 | 179 | -- check connectivity 180 | local dgraph = PostDepGraph(new_dsent) 181 | if not dgraph:checkConnectivity() then 182 | local N = #sent_graph + 1 183 | local edges = {} 184 | for i, sp in ipairs(sent_graph) do 185 | for j = 1, sp:size(1) do 186 | edges[#edges + 1] = {j, i+1, sp[j]} 187 | end 188 | end 189 | -- run ChuLiuEdmonds 190 | local cle = ChuLiuEdmonds() 191 | cle:load(N, edges) 192 | local _, selectedEdges = cle:solve(1, N) 193 | table.sort(selectedEdges, function(a, b) return a.v < b.v end) 194 | for i, ditem in ipairs(new_dsent) do 195 | local edge = selectedEdges[i] 196 | assert(edge.v == i+1) 197 | ditem.p2 = edge.u - 1 198 | ditem.p1 = edge.v - 1 199 | end 200 | 201 | local dgraph2 = PostDepGraph(new_dsent) 202 | assert(dgraph2:checkConnectivity()) 203 | else 204 | connected_count = connected_count + 1 205 | end 206 | 207 | for i, ditem in ipairs(new_dsent) do 208 | -- '%d\t%s\t_\t_\t%s\t_\t%d\t%s\t_\t_\n' 209 | -- 1 Influential _ JJ JJ _ 2 amod _ _ 210 | fout:write(string.format('%d\t%s\t_\t_\t%s\t_\t%d\tprep\t_\t_\n', ditem.p1, ditem.wd, ditem.pos, ditem.p2)) 211 | end 212 | fout:write('\n') 213 | end 214 | fout:close() 215 | printf('%d/%d = %f are connected graph\n', connected_count, sent_idx, connected_count/sent_idx) 216 | 217 | -- local conllx_eval = require 'conllx_eval' 218 | printf('evalType = %s\n', self.opts.evalType) 219 | local conllx_eval = self.opts.evalType == 'stanford' and require 'conllx_eval' or require 'conllx2006_eval' 220 | local LAS, UAS, noPunctLAS, noPunctUAS = conllx_eval.eval(sys_out, validFile) 221 | 222 | local entropy = totalLoss / totalCnt 223 | local ppl = torch.exp(entropy) 224 | 225 | return {entropy = entropy, ppl = ppl, UAS = noPunctUAS} 226 | end 227 | 228 | 229 | function MST:validWithMSTPostEisner(validFile, outConllxFile) 230 | local dataIter = DepPosDataIter.createBatch(self.opts.vocab, validFile, self.opts.batchSize, 150) 231 | local totalCnt = 0 232 | local totalLoss = 0 233 | local cnt = 0 234 | 235 | local sents_dep = {} 236 | local sents_graph = {} 237 | local y_tmp = torch.LongTensor(150, self.opts.batchSize) 238 | for x, x_mask, x_pos, y in dataIter do 239 | local loss, y_preds = self.net:validBatch(x, x_mask, x_pos, y) 240 | 241 | totalLoss = totalLoss + loss * x:size(2) 242 | local y_mask = x_mask[{ {2, -1}, {} }] 243 | 244 | local y_p = y_tmp:resize(y:size(1), y:size(2)) 245 | -- WARNING: y_preds start from 2! 246 | for t = 2, x:size(1) do 247 | local _, mi = y_preds[t]:max(2) 248 | if self.opts.useGPU then mi = mi:double() end 249 | y_p[{ t-1, {} }] = mi 250 | end 251 | 252 | for i = 1, y_mask:size(2) do 253 | local slen = y_mask[{ {}, i }]:sum() 254 | local sent_dep = {} 255 | local sent_graph = {} 256 | for j = 1, slen do 257 | sent_dep[#sent_dep + 1] = y_p[{ j, i }] - 1 258 | local tmp = y_preds[j+1][{ i, {1, slen + 1} }]:double() 259 | sent_graph[j] = tmp 260 | end 261 | sents_dep[#sents_dep + 1] = sent_dep 262 | sents_graph[#sents_graph + 1] = sent_graph 263 | end 264 | 265 | totalCnt = totalCnt + y_mask:sum() 266 | cnt = cnt + 1 267 | if cnt % 5 == 0 then 268 | collectgarbage() 269 | end 270 | end 271 | 272 | outConllxFile = outConllxFile or '__tmp__.dep' 273 | 274 | local dep_iter = DepPosDataIter.conllx_iter(validFile) 275 | local sent_idx = 0 276 | local connected_count = 0 277 | local sys_out = outConllxFile 278 | local fout = io.open(sys_out, 'w') 279 | for dsent in dep_iter do 280 | sent_idx = sent_idx + 1 281 | local sent_dep = sents_dep[sent_idx] 282 | assert(#sent_dep == #dsent) 283 | local sent_graph = sents_graph[sent_idx] 284 | assert(#sent_graph == #dsent) 285 | 286 | local new_dsent = {} 287 | for i, ditem in ipairs(dsent) do 288 | local new_ditem = {p1 = ditem.p1, wd = ditem.wd, pos = ditem.pos, p2 = sent_dep[i]} 289 | new_dsent[#new_dsent + 1] = new_ditem 290 | end 291 | 292 | -- check connectivity 293 | local dgraph = PostDepGraph(new_dsent) 294 | if not (dgraph:checkConnectivity() and dgraph:isProjective()) then 295 | local N = #sent_graph + 1 296 | local edges = {} 297 | for i, sp in ipairs(sent_graph) do 298 | for j = 1, sp:size(1) do 299 | edges[#edges + 1] = {j, i+1, sp[j]} 300 | end 301 | end 302 | -- run Eisner's algorithm 303 | local eisner = Eisner() 304 | eisner:load(N, edges) 305 | local _, selectedEdges = eisner:solve() 306 | table.sort(selectedEdges, function(a, b) return a.v < b.v end) 307 | for i, ditem in ipairs(new_dsent) do 308 | local edge = selectedEdges[i] 309 | assert(edge.v == i+1) 310 | ditem.p2 = edge.u - 1 311 | ditem.p1 = edge.v - 1 312 | end 313 | 314 | -- local dgraph2 = PostDepGraph(new_dsent) 315 | -- assert(dgraph2:checkConnectivity()) 316 | else 317 | connected_count = connected_count + 1 318 | end 319 | 320 | for i, ditem in ipairs(new_dsent) do 321 | -- '%d\t%s\t_\t_\t%s\t_\t%d\t%s\t_\t_\n' 322 | -- 1 Influential _ JJ JJ _ 2 amod _ _ 323 | fout:write(string.format('%d\t%s\t_\t_\t%s\t_\t%d\tprep\t_\t_\n', ditem.p1, ditem.wd, ditem.pos, ditem.p2)) 324 | end 325 | fout:write('\n') 326 | end 327 | fout:close() 328 | printf('%d/%d = %f are projective trees\n', connected_count, sent_idx, connected_count/sent_idx) 329 | 330 | -- local conllx_eval = require 'conllx_eval' 331 | printf('evalType = %s\n', self.opts.evalType) 332 | local conllx_eval = self.opts.evalType == 'stanford' and require 'conllx_eval' or require 'conllx2006_eval' 333 | local LAS, UAS, noPunctLAS, noPunctUAS = conllx_eval.eval(sys_out, validFile) 334 | 335 | local entropy = totalLoss / totalCnt 336 | local ppl = torch.exp(entropy) 337 | 338 | return {entropy = entropy, ppl = ppl, UAS = noPunctUAS} 339 | end 340 | 341 | local function getOpts() 342 | local cmd = torch.CmdLine() 343 | cmd:option('--modelPath', '/disk/scratch/XingxingZhang/dep_parse/experiments/we_select_ft_pos/ft/model_0.001.std.ft0.t7', 'model path') 344 | cmd:option('--validout', 'valid', 'output conllx file for validation set') 345 | cmd:option('--testout', 'test', 'output conllx file for test set') 346 | cmd:option('--mstalg', 'ChuLiuEdmonds', 'MST algorithm: ChuLiuEdmonds or Eisner') 347 | 348 | return cmd:parse(arg) 349 | end 350 | 351 | local function main() 352 | local opts = getOpts() 353 | -- local modelPath = '/disk/scratch/XingxingZhang/dep_parse/experiments/we_select_ft_pos/ft/model_0.001.std.ft0.t7' 354 | -- local modelPath = '/disk/scratch/XingxingZhang/dep_parse/experiments/we_select_ft_pos_chinese/ft/model_0.001.std.ft0.lpos.dp.t7' 355 | local mst_post = MSTPostProcessor() 356 | mst_post:load(opts.modelPath) 357 | -- test performance on validation and test dataset 358 | print '==Valid Performance==' 359 | local vret = mst_post:validConllx(mst_post.opts.valid, opts.validout .. '.ori.dep') 360 | print(vret) 361 | print '--after post processing--' 362 | if opts.mstalg == 'ChuLiuEdmonds' then 363 | xprintln('Using ChuLiuEdmonds') 364 | vret = mst_post:validWithMSTPost(mst_post.opts.valid, opts.validout .. '.dep') 365 | elseif opts.mstalg == 'Eisner' then 366 | xprintln('Using Eisner') 367 | vret = mst_post:validWithMSTPostEisner(mst_post.opts.valid, opts.validout .. '.dep') 368 | else 369 | error(string.format('[%s] not supported!', opts.mstalg)) 370 | end 371 | 372 | print(vret) 373 | print '' 374 | 375 | print '==Test Performance==' 376 | local tret = mst_post:validConllx(mst_post.opts.test, opts.testout .. '.ori.dep') 377 | print(tret) 378 | print '--after post processing--' 379 | if opts.mstalg == 'ChuLiuEdmonds' then 380 | xprintln('Using ChuLiuEdmonds') 381 | tret = mst_post:validWithMSTPost(mst_post.opts.test, opts.testout .. '.dep') 382 | elseif opts.mstalg == 'Eisner' then 383 | xprintln('Using Eisner') 384 | tret = mst_post:validWithMSTPostEisner(mst_post.opts.test, opts.testout .. '.dep') 385 | else 386 | error(string.format('[%s] not supported!', opts.mstalg)) 387 | end 388 | print(tret) 389 | print '' 390 | end 391 | 392 | if not package.loaded['mst_postprocess'] then 393 | main() 394 | end 395 | -------------------------------------------------------------------------------- /nnets/MLP.lua: -------------------------------------------------------------------------------- 1 | 2 | require 'basic' 3 | require 'shortcut' 4 | require 'optim' 5 | 6 | local mlp = torch.class('MLP', 'BModel') 7 | 8 | local function transferData(useGPU, data) 9 | if useGPU then 10 | return data:cuda() 11 | else 12 | return data 13 | end 14 | end 15 | 16 | function mlp:__init(opts) 17 | self.opts = opts 18 | self.model = nn.Sequential() 19 | local nhids = opts.snhids:splitc(',') 20 | opts.nhids = {} 21 | for _, snhid in ipairs(nhids) do 22 | table.insert(opts.nhids, tonumber(snhid)) 23 | end 24 | 25 | self.model = nn.Sequential() 26 | 27 | if opts.inDropout > 0 then 28 | self.model:add(nn.Dropout(opts.inDropout)) 29 | end 30 | 31 | local nlayers = #opts.nhids 32 | for i = 2, nlayers do 33 | self.model:add(nn.Linear(opts.nhids[i-1], opts.nhids[i])) 34 | if i ~= nlayers then 35 | if opts.batchNorm then 36 | self.model:add( nn.BatchNormalization(opts.nhids[i]) ) 37 | end 38 | 39 | if opts.activ == 'tanh' then 40 | self.model:add(nn.Tanh()) 41 | elseif opts.activ == 'relu' then 42 | self.model:add(nn.ReLU()) 43 | else 44 | error(opts.activ .. ' not supported!') 45 | end 46 | 47 | if opts.dropout > 0 then 48 | self.model:add(nn.Dropout(opts.dropout)) 49 | end 50 | 51 | end 52 | end 53 | self.model:add(nn.LogSoftMax()) 54 | print(self.model) 55 | self.criterion = nn.ClassNLLCriterion() 56 | 57 | if opts.useGPU then 58 | print 'use GPU!' 59 | self.model = self.model:cuda() 60 | self.criterion = self.criterion:cuda() 61 | end 62 | 63 | self.params, self.grads = self.model:getParameters() 64 | printf('#param %d\n', self.params:size(1)) 65 | 66 | -- self.model = require('weight-init')(self.model, 'kaiming') 67 | 68 | if opts.optimMethod == 'AdaGrad' then 69 | self.optimMethod = optim.adagrad 70 | elseif opts.optimMethod == 'SGD' then 71 | self.optimMethod = optim.sgd 72 | elseif opts.optimMethod == 'Adam' then 73 | self.optimMethod = optim.adam 74 | end 75 | 76 | end 77 | 78 | function mlp:trainBatch(x, y, sgd_params) 79 | self.model:training() 80 | if self.opts.useGPU then 81 | x = x:cuda() 82 | y = y:cuda() 83 | end 84 | 85 | local feval = function(newParam) 86 | if self.params ~= newParam then 87 | self.params:copy(newParam) 88 | end 89 | 90 | self.grads:zero() 91 | local output = self.model:forward(x) 92 | local loss = self.criterion:forward(output, y) 93 | local df = self.criterion:backward(output, y) 94 | self.model:backward(x, df) 95 | 96 | return loss, self.grads 97 | end 98 | 99 | local _, loss_ = self.optimMethod(feval, self.params, sgd_params) 100 | 101 | return loss_[1] 102 | end 103 | 104 | function mlp:validBatch(x, y) 105 | if self.opts.useGPU then 106 | y = y:cuda() 107 | end 108 | 109 | local yPred = self:predictBatch(x) 110 | local maxv, maxi = yPred:max(2) 111 | return torch.sum( torch.eq(maxi, y) ), x:size(1), maxi 112 | end 113 | 114 | function mlp:predictBatch(x) 115 | self.model:evaluate() 116 | if self.opts.useGPU then 117 | x = x:cuda() 118 | end 119 | 120 | return torch.exp(self.model:forward(x)) 121 | end 122 | 123 | function mlp:predictLabelBatch(x) 124 | self.model:evaluate() 125 | if self.opts.useGPU then 126 | x = x:cuda() 127 | end 128 | 129 | local yPred = torch.exp(self.model:forward(x)) 130 | local maxv, maxi = yPred:max(2) 131 | 132 | return maxi 133 | end 134 | 135 | -------------------------------------------------------------------------------- /nnets/SelectNet.lua: -------------------------------------------------------------------------------- 1 | 2 | require 'basic' 3 | require 'PReplicate' 4 | require 'Linear3D' 5 | require 'Contiguous' 6 | require 'EMaskedClassNLLCriterion' 7 | require 'LookupTable_ft' 8 | 9 | local model_utils = require 'model_utils' 10 | 11 | local SelectNet = torch.class('SelectNet', 'BModel') 12 | 13 | function SelectNet:__init(opts) 14 | self.opts = opts 15 | self.name = 'SelectNet' 16 | opts.nvocab = opts.vocab.nvocab 17 | self:createNetwork(opts) 18 | 19 | if opts.optimMethod == 'AdaGrad' then 20 | self.optimMethod = optim.adagrad 21 | elseif opts.optimMethod == 'Adam' then 22 | self.optimMethod = optim.adam 23 | elseif opts.optimMethod == 'AdaDelta' then 24 | self.optimMethod = optim.adadelta 25 | elseif opts.optimMethod == 'SGD' then 26 | self.optimMethod = optim.sgd 27 | end 28 | end 29 | 30 | function SelectNet:transData(d, cpu_type) 31 | if self.opts.useGPU then 32 | return d:cuda() 33 | else 34 | if cpu_type then 35 | return d:type(cpu_type) 36 | else 37 | return d 38 | end 39 | end 40 | end 41 | 42 | function SelectNet:createLSTM(x_t, c_tm1, h_tm1, nin, nhid, label) 43 | -- compute activations of four gates all together 44 | local x2h = nn.Linear(nin, nhid * 4)(x_t) 45 | local h2h = nn.Linear(nhid, nhid * 4)(h_tm1) 46 | local allGatesActs = nn.CAddTable()({x2h, h2h}) 47 | local allGatesActsSplits = nn.SplitTable(2)( nn.Reshape(4, nhid)(allGatesActs) ) 48 | -- unpack all gate activations 49 | local i_t = nn.Sigmoid()( nn.SelectTable(1)( allGatesActsSplits ) ) 50 | local f_t = nn.Sigmoid()( nn.SelectTable(2)( allGatesActsSplits ) ) 51 | local o_t = nn.Sigmoid()( nn.SelectTable(3)( allGatesActsSplits ) ) 52 | local n_t = nn.Tanh()( nn.SelectTable(4)( allGatesActsSplits ) ) 53 | 54 | if self.opts.recDropout > 0 then 55 | n_t = nn.Dropout(self.opts.recDropout)(n_t) 56 | printf( 'lstm [%s], RECURRENT dropout = %f\n', label, self.opts.recDropout) 57 | end 58 | 59 | -- compute new cell 60 | local c_t = nn.CAddTable()({ 61 | nn.CMulTable()({ i_t, n_t }), 62 | nn.CMulTable()({ f_t, c_tm1 }) 63 | }) 64 | -- compute new hidden state 65 | local h_t = nn.CMulTable()({ o_t, nn.Tanh()( c_t ) }) 66 | 67 | return c_t, h_t 68 | end 69 | 70 | function SelectNet:createDeepLSTM(opts, label) 71 | local emb = (opts.embedOption ~= nil and opts.embedOption == 'fineTune') 72 | and LookupTable_ft(opts.nvocab, opts.nin) 73 | or nn.LookupTable(opts.nvocab, opts.nin) 74 | -- local emb = nn.LookupTable(opts.nvocab, opts.nin) 75 | local x_t = nn.Identity()() 76 | local s_tm1 = nn.Identity()() 77 | 78 | local in_t = { [0] = emb(x_t):annotate{name= label ..'lookup'} } 79 | local s_t = {} 80 | local splits_tm1 = {s_tm1:split(2 * opts.nlayers)} 81 | 82 | for i = 1, opts.nlayers do 83 | local c_tm1_i = splits_tm1[i + i - 1] 84 | local h_tm1_i = splits_tm1[i + i] 85 | local x_t_i = in_t[i - 1] 86 | local c_t_i, h_t_i = nil, nil 87 | 88 | if opts.dropout > 0 then 89 | printf( '%s lstm layer %d, dropout = %f\n', label, i, opts.dropout) 90 | x_t_i = nn.Dropout(opts.dropout)(x_t_i) 91 | end 92 | 93 | if i == 1 then 94 | c_t_i, h_t_i = self:createLSTM(x_t_i, c_tm1_i, h_tm1_i, opts.nin, opts.nhid, label .. i) 95 | else 96 | c_t_i, h_t_i = self:createLSTM(x_t_i, c_tm1_i, h_tm1_i, opts.nhid, opts.nhid, label .. i) 97 | end 98 | s_t[i+i-1] = c_t_i 99 | s_t[i+i] = h_t_i 100 | in_t[i] = h_t_i 101 | end 102 | 103 | local model = nn.gModule({x_t, s_tm1}, {nn.Identity()(s_t)}) 104 | return self:transData(model) 105 | end 106 | 107 | function SelectNet:createAttention(opts) 108 | -- note you used both forward and backward model 109 | local nhid = opts.nhid * 2 110 | -- enc_hs shape: (bs, seqlen, nhid) 111 | local enc_hs = nn.Identity()() 112 | local hs = nn.Identity()() 113 | local seqlen = nn.Identity()() 114 | local mask = nn.Identity()() 115 | local mask_sub = nn.Identity()() 116 | 117 | local h1 = Linear3D(nhid, nhid)(enc_hs) 118 | local h2_ = nn.Linear(nhid, nhid)(hs) 119 | local h2 = Contiguous()( PReplicate(2){h2_, seqlen} ) 120 | local h = nn.Tanh()( nn.CAddTable(){h1, h2} ) 121 | local aout = nn.Sum(3)( Linear3D(nhid, 1)(h) ) 122 | 123 | aout = nn.CAddTable()({ 124 | nn.CMulTable()({ aout, mask }), 125 | mask_sub 126 | }) 127 | 128 | local y_prob = nn.LogSoftMax()(aout) 129 | 130 | local model = nn.gModule({enc_hs, hs, seqlen, mask, mask_sub}, 131 | {y_prob}) 132 | 133 | return self:transData(model) 134 | end 135 | 136 | function SelectNet:createNetwork(opts) 137 | self.forward_lstm_master = self:createDeepLSTM(opts, 'forward_') 138 | self.backward_lstm_master = self:createDeepLSTM(opts, 'backward_') 139 | self:print('create forward and backward LSTM done!') 140 | self.attention_master = self:createAttention(opts) 141 | self:print('create attention model done!') 142 | -- backward_lookup is ignored 143 | self.params, self.grads = model_utils.combine_selectnet_parameters(self.forward_lstm_master, 144 | self.backward_lstm_master, self.attention_master) 145 | self.params:uniform(-opts.initRange, opts.initRange) 146 | self:print('#params ' .. self.params:nElement()) 147 | 148 | -- share forward and backward lookupTable 149 | self.mod_map = BModel.get_module_map({self.forward_lstm_master, self.backward_lstm_master, self.attention_master}) 150 | self.mod_map.backward_lookup.weight:set( self.mod_map.forward_lookup.weight ) 151 | self.mod_map.backward_lookup.gradWeight:set( self.mod_map.forward_lookup.gradWeight ) 152 | collectgarbage() 153 | self:print('forward lstm and backward lstm share parameters') 154 | 155 | -- intialize with pre-trained word embedding 156 | if self.opts.wordEmbedding ~= nil and self.opts.wordEmbedding ~= '' then 157 | local net_lookup = self.mod_map.forward_lookup 158 | self.net_lookup = net_lookup 159 | if self.opts.embedOption == 'init' then 160 | model_utils.load_embedding_init(net_lookup, self.opts.vocab, self.opts.wordEmbedding) 161 | elseif self.opts.embedOption == 'fineTune' then 162 | model_utils.load_embedding_fine_tune(net_lookup, self.opts.vocab, self.opts.wordEmbedding, self.opts.fineTuneFactor) 163 | else 164 | error('invalid option -- ' .. self.opts.embedOption) 165 | end 166 | end 167 | 168 | if self.opts.embedOption == 'fineTune' then 169 | -- this will not copy updateMask 170 | self.forward_lstms = model_utils.clone_many_times_emb_ft(self.forward_lstm_master, opts.seqLen) 171 | self.backward_lstms = model_utils.clone_many_times_emb_ft(self.backward_lstm_master, opts.seqLen) 172 | else 173 | self.forward_lstms = model_utils.clone_many_times(self.forward_lstm_master, opts.seqLen) 174 | self.backward_lstms = model_utils.clone_many_times(self.backward_lstm_master, opts.seqLen) 175 | end 176 | self.attentions = model_utils.clone_many_times(self.attention_master, opts.seqLen) 177 | self:print('clone model done!') 178 | 179 | -- time for dealing with criterions 180 | self.criterions = {} 181 | for i = 1, opts.seqLen do 182 | self.criterions[i] = self:transData(EMaskedClassNLLCriterion()) 183 | end 184 | 185 | -- init model 186 | self:initModel(opts) 187 | end 188 | 189 | function SelectNet:initModel(opts) 190 | self.fwd_h0 = {} 191 | self.df_fwd_h = {} 192 | for i = 1, 2*opts.nlayers do 193 | self.fwd_h0[i] = self:transData( torch.zeros(opts.batchSize, opts.nhid) ) 194 | self.df_fwd_h[i] = self:transData( torch.zeros(opts.batchSize, opts.nhid) ) 195 | end 196 | -- print(self.fwd_h0) 197 | self.fwd_hs = {} 198 | for i = 0, opts.seqLen do 199 | local tmp = {} 200 | for j = 1, 2*opts.nlayers do 201 | tmp[j] = self:transData( torch.zeros(opts.batchSize, opts.nhid) ) 202 | end 203 | self.fwd_hs[i] = tmp 204 | end 205 | 206 | self.bak_h0 = {} 207 | self.df_bak_h = {} 208 | for i = 1, 2*opts.nlayers do 209 | self.bak_h0[i] = self:transData( torch.zeros(opts.batchSize, opts.nhid) ) 210 | self.df_bak_h[i] = self:transData( torch.zeros(opts.batchSize, opts.nhid) ) 211 | end 212 | -- print(self.bak_h0) 213 | self.bak_hs = {} 214 | for i = 1, opts.seqLen + 1 do 215 | local tmp = {} 216 | for j = 1, 2*opts.nlayers do 217 | tmp[j] = self:transData( torch.zeros(opts.batchSize, opts.nhid) ) 218 | end 219 | self.bak_hs[i] = tmp 220 | end 221 | 222 | -- this is for attention model 223 | self.all_fwd_bak_hs = self:transData( torch.zeros(opts.batchSize * opts.seqLen, 2 * opts.nhid) ) 224 | self.df_all_fwd_bak_hs = self:transData( torch.zeros(opts.batchSize * opts.seqLen, 2 * opts.nhid) ) 225 | end 226 | 227 | function SelectNet:training() 228 | self.forward_lstm_master:training() 229 | self.backward_lstm_master:training() 230 | self.attention_master:training() 231 | 232 | for i = 1, self.opts.seqLen do 233 | self.forward_lstms[i]:training() 234 | self.backward_lstms[i]:training() 235 | self.attentions[i]:training() 236 | end 237 | end 238 | 239 | function SelectNet:evaluate() 240 | self.forward_lstm_master:evaluate() 241 | self.backward_lstm_master:evaluate() 242 | self.attention_master:evaluate() 243 | 244 | for i = 1, self.opts.seqLen do 245 | self.forward_lstms[i]:evaluate() 246 | self.backward_lstms[i]:evaluate() 247 | self.attentions[i]:evaluate() 248 | end 249 | end 250 | 251 | function SelectNet:trainBatch(x, x_mask, y, sgdParam) 252 | self:training() 253 | 254 | x = self:transData(x) 255 | x_mask = self:transData(x_mask) 256 | y = self:transData(y) 257 | local x_mask_t = x_mask:t() 258 | local x_mask_sub = (-x_mask_t + 1) * -50 259 | x_mask_sub = self:transData( x_mask_sub ) 260 | 261 | local function feval(params_) 262 | if self.params ~= params_ then 263 | self.params:copy(params_) 264 | end 265 | self.grads:zero() 266 | 267 | -- forward pass for forward lstm 268 | local Tx = x:size(1) 269 | for i = 1, self.opts.nlayers * 2 do 270 | self.fwd_hs[0][i]:zero() 271 | end 272 | self.all_fwd_bak_hs:resize(self.opts.batchSize, Tx, self.opts.nhid * 2) 273 | for t = 1, Tx do 274 | self.fwd_hs[t] = self.forward_lstms[t]:forward({x[{ t, {} }], self.fwd_hs[t-1]}) 275 | self.all_fwd_bak_hs[{ {}, t, {1, self.opts.nhid} }] = self.fwd_hs[t][self.opts.nlayers*2] 276 | if self.opts.useGPU then cutorch.synchronize() end 277 | end 278 | 279 | -- forward pass for backward lstm 280 | for i = 1, 2*self.opts.nlayers do 281 | self.bak_hs[Tx+1][i]:zero() 282 | end 283 | for t = Tx, 1, -1 do 284 | self.bak_hs[t] = self.backward_lstms[t]:forward({x[{t, {} }], self.bak_hs[t+1]}) 285 | local cmask = x_mask[{ t, {} }]:view(self.opts.batchSize, 1):expand(self.opts.batchSize, self.opts.nhid) 286 | for i = 1, 2*self.opts.nlayers do 287 | self.bak_hs[t][i]:cmul( cmask ) 288 | end 289 | self.all_fwd_bak_hs[{ {}, t, {self.opts.nhid + 1, 2*self.opts.nhid} }] = self.bak_hs[t][2*self.opts.nlayers] 290 | if self.opts.useGPU then cutorch.synchronize() end 291 | end 292 | 293 | -- forward pass for attention model 294 | local loss = 0 295 | local y_preds = {} 296 | local Ty = y:size(1) 297 | assert(Ty + 1 == Tx, 'Tx words sentence must have Tx - 1 options') 298 | for t = 2, Tx do 299 | y_preds[t] = self.attentions[t]:forward({self.all_fwd_bak_hs, self.all_fwd_bak_hs[{ {}, t, {} }], Tx, 300 | x_mask_t, x_mask_sub}) 301 | local loss_ = self.criterions[t]:forward({y_preds[t], y[{ t-1, {} }], self.opts.batchSize}) 302 | loss = loss + loss_ 303 | end 304 | 305 | self.df_all_fwd_bak_hs:resize(self.opts.batchSize, Tx, self.opts.nhid * 2):zero() 306 | -- backward pass for attention model 307 | for t = Tx, 2, -1 do 308 | local df_crit = self.criterions[t]:backward({y_preds[t], y[{ t-1, {} }], self.opts.batchSize}) 309 | local tmp_df_all_hs, tmp_df_a_h, _, _, _ = unpack( 310 | self.attentions[t]:backward({self.all_fwd_bak_hs, self.all_fwd_bak_hs[{ {}, t, {} }], Tx, 311 | x_mask_t, x_mask_sub}, df_crit) 312 | ) 313 | self.df_all_fwd_bak_hs:add( tmp_df_all_hs ) 314 | self.df_all_fwd_bak_hs[{ {}, t, {} }]:add( tmp_df_a_h ) 315 | end 316 | 317 | -- prepare backward prop for forward and backward lstms 318 | for i = 1, 2 * self.opts.nlayers do 319 | self.df_bak_h[i]:zero() 320 | self.df_fwd_h[i]:zero() 321 | end 322 | 323 | -- backward pass for backward lstm 324 | for t = 1, Tx do 325 | -- no mask is needed, since in the forward pass, some rows of self.bak_hs[t+1] have been set to 0 326 | -- No error will be back-prop 327 | self.df_bak_h[2*self.opts.nlayers]:add( self.df_all_fwd_bak_hs[{ {}, t, {self.opts.nhid + 1, 2*self.opts.nhid} }] ) 328 | local _, tmp = unpack( self.backward_lstms[t]:backward({x[{ t, {} }], self.bak_hs[t+1]}, self.df_bak_h) ) 329 | model_utils.copy_table(self.df_bak_h, tmp) 330 | end 331 | 332 | -- backward pass for forward lstm 333 | for t = Tx, 1, -1 do 334 | self.df_fwd_h[2*self.opts.nlayers]:add( self.df_all_fwd_bak_hs[{ {}, t, {1, self.opts.nhid} }] ) 335 | -- mask should be used here 336 | local cmask = x_mask[{ t, {} }]:view(self.opts.batchSize, 1):expand(self.opts.batchSize, self.opts.nhid) 337 | for i = 1, 2*self.opts.nlayers do 338 | self.df_fwd_h[i]:cmul( cmask ) 339 | end 340 | local _, tmp = unpack( self.forward_lstms[t]:backward( {x[{ t, {} }], self.fwd_hs[t-1]}, self.df_fwd_h ) ) 341 | model_utils.copy_table(self.df_fwd_h, tmp) 342 | end 343 | 344 | if self.opts.embedOption ~= nil and self.opts.embedOption == 'fineTune' then 345 | self.net_lookup:applyGradMask() 346 | end 347 | 348 | if self.opts.gradClip < 0 then 349 | local clip = -self.opts.gradClip 350 | self.grads:clamp(-clip, clip) 351 | elseif self.opts.gradClip > 0 then 352 | local maxGradNorm = self.opts.gradClip 353 | local gradNorm = self.grads:norm() 354 | if gradNorm > maxGradNorm then 355 | local shrinkFactor = maxGradNorm / gradNorm 356 | self.grads:mul(shrinkFactor) 357 | end 358 | end 359 | 360 | return loss, self.grads 361 | end 362 | 363 | local _, loss_ = self.optimMethod(feval, self.params, sgdParam) 364 | 365 | return loss_[1] 366 | end 367 | 368 | function SelectNet:validBatch(x, x_mask, y) 369 | self:evaluate() 370 | 371 | x = self:transData(x) 372 | x_mask = self:transData(x_mask) 373 | y = self:transData(y) 374 | local x_mask_t = x_mask:t() 375 | local x_mask_sub = (-x_mask_t + 1) * -50 376 | x_mask_sub = self:transData( x_mask_sub ) 377 | 378 | -- forward pass for forward lstm 379 | local Tx = x:size(1) 380 | for i = 1, self.opts.nlayers * 2 do 381 | self.fwd_hs[0][i]:zero() 382 | end 383 | self.all_fwd_bak_hs:resize(self.opts.batchSize, Tx, self.opts.nhid * 2) 384 | for t = 1, Tx do 385 | self.fwd_hs[t] = self.forward_lstms[t]:forward({x[{ t, {} }], self.fwd_hs[t-1]}) 386 | self.all_fwd_bak_hs[{ {}, t, {1, self.opts.nhid} }] = self.fwd_hs[t][self.opts.nlayers*2] 387 | if self.opts.useGPU then cutorch.synchronize() end 388 | end 389 | 390 | -- forward pass for backward lstm 391 | for i = 1, 2*self.opts.nlayers do 392 | self.bak_hs[Tx+1][i]:zero() 393 | end 394 | for t = Tx, 1, -1 do 395 | self.bak_hs[t] = self.backward_lstms[t]:forward({x[{t, {} }], self.bak_hs[t+1]}) 396 | local cmask = x_mask[{ t, {} }]:view(self.opts.batchSize, 1):expand(self.opts.batchSize, self.opts.nhid) 397 | for i = 1, 2*self.opts.nlayers do 398 | self.bak_hs[t][i]:cmul( cmask ) 399 | end 400 | self.all_fwd_bak_hs[{ {}, t, {self.opts.nhid + 1, 2*self.opts.nhid} }] = self.bak_hs[t][2*self.opts.nlayers] 401 | if self.opts.useGPU then cutorch.synchronize() end 402 | end 403 | 404 | -- forward pass for attention model 405 | local loss = 0 406 | local y_preds = {} 407 | local Ty = y:size(1) 408 | assert(Ty + 1 == Tx, 'Tx words sentence must have Tx - 1 options') 409 | for t = 2, Tx do 410 | y_preds[t] = self.attentions[t]:forward({self.all_fwd_bak_hs, self.all_fwd_bak_hs[{ {}, t, {} }], Tx, 411 | x_mask_t, x_mask_sub}) 412 | local loss_ = self.criterions[t]:forward({y_preds[t], y[{ t-1, {} }], self.opts.batchSize}) 413 | loss = loss + loss_ 414 | end 415 | 416 | return loss, y_preds 417 | end 418 | 419 | -------------------------------------------------------------------------------- /nnets/SelectNetPos.lua: -------------------------------------------------------------------------------- 1 | 2 | require 'basic' 3 | require 'PReplicate' 4 | require 'Linear3D' 5 | require 'Contiguous' 6 | require 'EMaskedClassNLLCriterion' 7 | require 'LookupTable_ft' 8 | 9 | local model_utils = require 'model_utils' 10 | 11 | local SelectNet = torch.class('SelectNetPos', 'BModel') 12 | 13 | function SelectNet:__init(opts) 14 | self.opts = opts 15 | self.name = 'SelectNetPos' 16 | opts.nvocab = opts.vocab.nvocab 17 | opts.npos = opts.vocab.npos 18 | self:createNetwork(opts) 19 | 20 | if opts.optimMethod == 'AdaGrad' then 21 | self.optimMethod = optim.adagrad 22 | elseif opts.optimMethod == 'Adam' then 23 | self.optimMethod = optim.adam 24 | elseif opts.optimMethod == 'AdaDelta' then 25 | self.optimMethod = optim.adadelta 26 | elseif opts.optimMethod == 'SGD' then 27 | self.optimMethod = optim.sgd 28 | end 29 | end 30 | 31 | function SelectNet:transData(d, cpu_type) 32 | if self.opts.useGPU then 33 | return d:cuda() 34 | else 35 | if cpu_type then 36 | return d:type(cpu_type) 37 | else 38 | return d 39 | end 40 | end 41 | end 42 | 43 | function SelectNet:createLSTM(x_t, c_tm1, h_tm1, nin, nhid, label) 44 | -- compute activations of four gates all together 45 | local x2h = nn.Linear(nin, nhid * 4)(x_t) 46 | local h2h = nn.Linear(nhid, nhid * 4)(h_tm1) 47 | local allGatesActs = nn.CAddTable()({x2h, h2h}) 48 | local allGatesActsSplits = nn.SplitTable(2)( nn.Reshape(4, nhid)(allGatesActs) ) 49 | -- unpack all gate activations 50 | local i_t = nn.Sigmoid()( nn.SelectTable(1)( allGatesActsSplits ) ) 51 | local f_t = nn.Sigmoid()( nn.SelectTable(2)( allGatesActsSplits ) ) 52 | local o_t = nn.Sigmoid()( nn.SelectTable(3)( allGatesActsSplits ) ) 53 | local n_t = nn.Tanh()( nn.SelectTable(4)( allGatesActsSplits ) ) 54 | 55 | if self.opts.recDropout > 0 then 56 | n_t = nn.Dropout(self.opts.recDropout)(n_t) 57 | printf( 'lstm [%s], RECURRENT dropout = %f\n', label, self.opts.recDropout) 58 | end 59 | 60 | -- compute new cell 61 | local c_t = nn.CAddTable()({ 62 | nn.CMulTable()({ i_t, n_t }), 63 | nn.CMulTable()({ f_t, c_tm1 }) 64 | }) 65 | -- compute new hidden state 66 | local h_t = nn.CMulTable()({ o_t, nn.Tanh()( c_t ) }) 67 | 68 | return c_t, h_t 69 | end 70 | 71 | function SelectNet:createDeepLSTM(opts, label) 72 | local emb = (opts.embedOption ~= nil and opts.embedOption == 'fineTune') 73 | and LookupTable_ft(opts.nvocab, opts.nin) 74 | or nn.LookupTable(opts.nvocab, opts.nin) 75 | 76 | local pos_emb = nn.LookupTable(opts.npos, opts.npin) 77 | 78 | local x_t = nn.Identity()() 79 | local x_pos_t = nn.Identity()() 80 | local s_tm1 = nn.Identity()() 81 | 82 | local we = emb(x_t):annotate{name= label ..'lookup'} 83 | local pose = pos_emb(x_pos_t):annotate{name= label ..'pos_lookup'} 84 | 85 | local in_t = { [0] = nn.JoinTable(2){we, pose} } 86 | local s_t = {} 87 | local splits_tm1 = {s_tm1:split(2 * opts.nlayers)} 88 | 89 | for i = 1, opts.nlayers do 90 | local c_tm1_i = splits_tm1[i + i - 1] 91 | local h_tm1_i = splits_tm1[i + i] 92 | local x_t_i = in_t[i - 1] 93 | local c_t_i, h_t_i = nil, nil 94 | 95 | if opts.dropout > 0 then 96 | printf( '%s lstm layer %d, dropout = %f\n', label, i, opts.dropout) 97 | x_t_i = nn.Dropout(opts.dropout)(x_t_i) 98 | end 99 | 100 | if i == 1 then 101 | c_t_i, h_t_i = self:createLSTM(x_t_i, c_tm1_i, h_tm1_i, opts.nin + opts.npin, opts.nhid, label .. i) 102 | else 103 | c_t_i, h_t_i = self:createLSTM(x_t_i, c_tm1_i, h_tm1_i, opts.nhid, opts.nhid, label .. i) 104 | end 105 | s_t[i+i-1] = c_t_i 106 | s_t[i+i] = h_t_i 107 | in_t[i] = h_t_i 108 | end 109 | 110 | local model = nn.gModule({x_t, x_pos_t, s_tm1}, {nn.Identity()(s_t)}) 111 | return self:transData(model) 112 | end 113 | 114 | function SelectNet:createAttention(opts) 115 | -- note you used both forward and backward model 116 | local nhid = opts.nhid * 2 117 | -- enc_hs shape: (bs, seqlen, nhid) 118 | local enc_hs = nn.Identity()() 119 | local hs = nn.Identity()() 120 | local seqlen = nn.Identity()() 121 | local mask = nn.Identity()() 122 | local mask_sub = nn.Identity()() 123 | 124 | local h1 = Linear3D(nhid, nhid)(enc_hs) 125 | local h2_ = nn.Linear(nhid, nhid)(hs) 126 | -- local h2 = Contiguous()( PReplicate(2){h2_, seqlen} ) 127 | local h2 = ( PReplicate(2){h2_, seqlen} ) 128 | local h = nn.Tanh()( nn.CAddTable(){h1, h2} ) 129 | local aout = nn.Sum(3)( Linear3D(nhid, 1)(h) ) 130 | 131 | aout = nn.CAddTable()({ 132 | nn.CMulTable()({ aout, mask }), 133 | mask_sub 134 | }) 135 | 136 | local y_prob = nn.LogSoftMax()(aout) 137 | 138 | local model = nn.gModule({enc_hs, hs, seqlen, mask, mask_sub}, 139 | {y_prob}) 140 | 141 | return self:transData(model) 142 | end 143 | 144 | function SelectNet:createNetwork(opts) 145 | self.forward_lstm_master = self:createDeepLSTM(opts, 'forward_') 146 | self.backward_lstm_master = self:createDeepLSTM(opts, 'backward_') 147 | self:print('create forward and backward LSTM done!') 148 | self.attention_master = self:createAttention(opts) 149 | self:print('create attention model done!') 150 | -- backward_lookup is ignored 151 | self.params, self.grads = model_utils.combine_selectnet_pos_parameters(self.forward_lstm_master, 152 | self.backward_lstm_master, self.attention_master) 153 | self.params:uniform(-opts.initRange, opts.initRange) 154 | self:print('#params ' .. self.params:nElement()) 155 | 156 | -- share forward and backward lookupTable 157 | self.mod_map = BModel.get_module_map({self.forward_lstm_master, self.backward_lstm_master, self.attention_master}) 158 | self.mod_map.backward_lookup.weight:set( self.mod_map.forward_lookup.weight ) 159 | self.mod_map.backward_lookup.gradWeight:set( self.mod_map.forward_lookup.gradWeight ) 160 | self.mod_map.backward_pos_lookup.weight:set( self.mod_map.forward_pos_lookup.weight ) 161 | self.mod_map.backward_pos_lookup.gradWeight:set( self.mod_map.forward_pos_lookup.gradWeight ) 162 | collectgarbage() 163 | self:print('forward lstm and backward lstm share parameters') 164 | 165 | -- intialize with pre-trained word embedding 166 | if self.opts.wordEmbedding ~= nil and self.opts.wordEmbedding ~= '' then 167 | local net_lookup = self.mod_map.forward_lookup 168 | self.net_lookup = net_lookup 169 | if self.opts.embedOption == 'init' then 170 | model_utils.load_embedding_init(net_lookup, self.opts.vocab, self.opts.wordEmbedding) 171 | elseif self.opts.embedOption == 'fineTune' then 172 | model_utils.load_embedding_fine_tune(net_lookup, self.opts.vocab, self.opts.wordEmbedding, self.opts.fineTuneFactor) 173 | else 174 | error('invalid option -- ' .. self.opts.embedOption) 175 | end 176 | end 177 | 178 | if self.opts.embedOption == 'fineTune' then 179 | -- this will not copy updateMask 180 | self.forward_lstms = model_utils.clone_many_times_emb_ft(self.forward_lstm_master, opts.seqLen) 181 | self.backward_lstms = model_utils.clone_many_times_emb_ft(self.backward_lstm_master, opts.seqLen) 182 | else 183 | self.forward_lstms = model_utils.clone_many_times(self.forward_lstm_master, opts.seqLen) 184 | self.backward_lstms = model_utils.clone_many_times(self.backward_lstm_master, opts.seqLen) 185 | end 186 | self.attentions = model_utils.clone_many_times(self.attention_master, opts.seqLen) 187 | self:print('clone model done!') 188 | 189 | -- time for dealing with criterions 190 | self.criterions = {} 191 | for i = 1, opts.seqLen do 192 | self.criterions[i] = self:transData(EMaskedClassNLLCriterion()) 193 | end 194 | 195 | -- init model 196 | self:initModel(opts) 197 | end 198 | 199 | function SelectNet:initModel(opts) 200 | self.fwd_h0 = {} 201 | self.df_fwd_h = {} 202 | for i = 1, 2*opts.nlayers do 203 | self.fwd_h0[i] = self:transData( torch.zeros(opts.batchSize, opts.nhid) ) 204 | self.df_fwd_h[i] = self:transData( torch.zeros(opts.batchSize, opts.nhid) ) 205 | end 206 | -- print(self.fwd_h0) 207 | self.fwd_hs = {} 208 | for i = 0, opts.seqLen do 209 | local tmp = {} 210 | for j = 1, 2*opts.nlayers do 211 | tmp[j] = self:transData( torch.zeros(opts.batchSize, opts.nhid) ) 212 | end 213 | self.fwd_hs[i] = tmp 214 | end 215 | 216 | self.bak_h0 = {} 217 | self.df_bak_h = {} 218 | for i = 1, 2*opts.nlayers do 219 | self.bak_h0[i] = self:transData( torch.zeros(opts.batchSize, opts.nhid) ) 220 | self.df_bak_h[i] = self:transData( torch.zeros(opts.batchSize, opts.nhid) ) 221 | end 222 | -- print(self.bak_h0) 223 | self.bak_hs = {} 224 | for i = 1, opts.seqLen + 1 do 225 | local tmp = {} 226 | for j = 1, 2*opts.nlayers do 227 | tmp[j] = self:transData( torch.zeros(opts.batchSize, opts.nhid) ) 228 | end 229 | self.bak_hs[i] = tmp 230 | end 231 | 232 | -- this is for attention model 233 | self.all_fwd_bak_hs = self:transData( torch.zeros(opts.batchSize * opts.seqLen, 2 * opts.nhid) ) 234 | self.df_all_fwd_bak_hs = self:transData( torch.zeros(opts.batchSize * opts.seqLen, 2 * opts.nhid) ) 235 | end 236 | 237 | function SelectNet:training() 238 | self.forward_lstm_master:training() 239 | self.backward_lstm_master:training() 240 | self.attention_master:training() 241 | 242 | for i = 1, self.opts.seqLen do 243 | self.forward_lstms[i]:training() 244 | self.backward_lstms[i]:training() 245 | self.attentions[i]:training() 246 | end 247 | end 248 | 249 | function SelectNet:evaluate() 250 | self.forward_lstm_master:evaluate() 251 | self.backward_lstm_master:evaluate() 252 | self.attention_master:evaluate() 253 | 254 | for i = 1, self.opts.seqLen do 255 | self.forward_lstms[i]:evaluate() 256 | self.backward_lstms[i]:evaluate() 257 | self.attentions[i]:evaluate() 258 | end 259 | end 260 | 261 | -- x, x_mask, x_pos, y 262 | function SelectNet:trainBatch(x, x_mask, x_pos, y, sgdParam) 263 | self:training() 264 | 265 | x = self:transData(x) 266 | x_mask = self:transData(x_mask) 267 | x_pos = self:transData(x_pos) 268 | y = self:transData(y) 269 | local x_mask_t = x_mask:t() 270 | local x_mask_sub = (-x_mask_t + 1) * -50 271 | x_mask_sub = self:transData( x_mask_sub ) 272 | 273 | local function feval(params_) 274 | if self.params ~= params_ then 275 | self.params:copy(params_) 276 | end 277 | self.grads:zero() 278 | 279 | -- forward pass for forward lstm 280 | local Tx = x:size(1) 281 | for i = 1, self.opts.nlayers * 2 do 282 | self.fwd_hs[0][i]:zero() 283 | end 284 | self.all_fwd_bak_hs:resize(self.opts.batchSize, Tx, self.opts.nhid * 2) 285 | for t = 1, Tx do 286 | self.fwd_hs[t] = self.forward_lstms[t]:forward({x[{ t, {} }], x_pos[{ t, {} }], self.fwd_hs[t-1]}) 287 | self.all_fwd_bak_hs[{ {}, t, {1, self.opts.nhid} }] = self.fwd_hs[t][self.opts.nlayers*2] 288 | if self.opts.useGPU then cutorch.synchronize() end 289 | end 290 | 291 | -- forward pass for backward lstm 292 | for i = 1, 2*self.opts.nlayers do 293 | self.bak_hs[Tx+1][i]:zero() 294 | end 295 | for t = Tx, 1, -1 do 296 | self.bak_hs[t] = self.backward_lstms[t]:forward({x[{t, {} }], x_pos[{ t, {} }], self.bak_hs[t+1]}) 297 | local cmask = x_mask[{ t, {} }]:view(self.opts.batchSize, 1):expand(self.opts.batchSize, self.opts.nhid) 298 | for i = 1, 2*self.opts.nlayers do 299 | self.bak_hs[t][i]:cmul( cmask ) 300 | end 301 | self.all_fwd_bak_hs[{ {}, t, {self.opts.nhid + 1, 2*self.opts.nhid} }] = self.bak_hs[t][2*self.opts.nlayers] 302 | if self.opts.useGPU then cutorch.synchronize() end 303 | end 304 | 305 | -- forward pass for attention model 306 | local loss = 0 307 | local y_preds = {} 308 | local Ty = y:size(1) 309 | assert(Ty + 1 == Tx, 'Tx words sentence must have Tx - 1 options') 310 | for t = 2, Tx do 311 | y_preds[t] = self.attentions[t]:forward({self.all_fwd_bak_hs, self.all_fwd_bak_hs[{ {}, t, {} }], Tx, 312 | x_mask_t, x_mask_sub}) 313 | local loss_ = self.criterions[t]:forward({y_preds[t], y[{ t-1, {} }], self.opts.batchSize}) 314 | loss = loss + loss_ 315 | end 316 | 317 | self.df_all_fwd_bak_hs:resize(self.opts.batchSize, Tx, self.opts.nhid * 2):zero() 318 | -- backward pass for attention model 319 | for t = Tx, 2, -1 do 320 | local df_crit = self.criterions[t]:backward({y_preds[t], y[{ t-1, {} }], self.opts.batchSize}) 321 | local tmp_df_all_hs, tmp_df_a_h, _, _, _ = unpack( 322 | self.attentions[t]:backward({self.all_fwd_bak_hs, self.all_fwd_bak_hs[{ {}, t, {} }], Tx, 323 | x_mask_t, x_mask_sub}, df_crit) 324 | ) 325 | self.df_all_fwd_bak_hs:add( tmp_df_all_hs ) 326 | self.df_all_fwd_bak_hs[{ {}, t, {} }]:add( tmp_df_a_h ) 327 | end 328 | 329 | -- prepare backward prop for forward and backward lstms 330 | for i = 1, 2 * self.opts.nlayers do 331 | self.df_bak_h[i]:zero() 332 | self.df_fwd_h[i]:zero() 333 | end 334 | 335 | -- backward pass for backward lstm 336 | for t = 1, Tx do 337 | -- no mask is needed, since in the forward pass, some rows of self.bak_hs[t+1] have been set to 0 338 | -- No error will be back-prop 339 | self.df_bak_h[2*self.opts.nlayers]:add( self.df_all_fwd_bak_hs[{ {}, t, {self.opts.nhid + 1, 2*self.opts.nhid} }] ) 340 | local _, _, tmp = unpack( self.backward_lstms[t]:backward({x[{ t, {} }], x_pos[{ t, {} }], self.bak_hs[t+1]}, self.df_bak_h) ) 341 | model_utils.copy_table(self.df_bak_h, tmp) 342 | end 343 | 344 | -- backward pass for forward lstm 345 | for t = Tx, 1, -1 do 346 | self.df_fwd_h[2*self.opts.nlayers]:add( self.df_all_fwd_bak_hs[{ {}, t, {1, self.opts.nhid} }] ) 347 | -- mask should be used here 348 | local cmask = x_mask[{ t, {} }]:view(self.opts.batchSize, 1):expand(self.opts.batchSize, self.opts.nhid) 349 | for i = 1, 2*self.opts.nlayers do 350 | self.df_fwd_h[i]:cmul( cmask ) 351 | end 352 | local _, _, tmp = unpack( self.forward_lstms[t]:backward( {x[{ t, {} }], x_pos[{ t, {} }], self.fwd_hs[t-1]}, self.df_fwd_h ) ) 353 | model_utils.copy_table(self.df_fwd_h, tmp) 354 | end 355 | 356 | if self.opts.embedOption ~= nil and self.opts.embedOption == 'fineTune' then 357 | self.net_lookup:applyGradMask() 358 | end 359 | 360 | if self.opts.gradClip < 0 then 361 | local clip = -self.opts.gradClip 362 | self.grads:clamp(-clip, clip) 363 | elseif self.opts.gradClip > 0 then 364 | local maxGradNorm = self.opts.gradClip 365 | local gradNorm = self.grads:norm() 366 | if gradNorm > maxGradNorm then 367 | local shrinkFactor = maxGradNorm / gradNorm 368 | self.grads:mul(shrinkFactor) 369 | end 370 | end 371 | 372 | return loss, self.grads 373 | end 374 | 375 | local _, loss_ = self.optimMethod(feval, self.params, sgdParam) 376 | 377 | return loss_[1] 378 | end 379 | 380 | -- x, x_mask, x_pos, y 381 | function SelectNet:validBatch(x, x_mask, x_pos, y) 382 | self:evaluate() 383 | 384 | x = self:transData(x) 385 | x_mask = self:transData(x_mask) 386 | x_pos = self:transData(x_pos) 387 | y = self:transData(y) 388 | local x_mask_t = x_mask:t() 389 | local x_mask_sub = (-x_mask_t + 1) * -50 390 | x_mask_sub = self:transData( x_mask_sub ) 391 | 392 | -- forward pass for forward lstm 393 | local Tx = x:size(1) 394 | for i = 1, self.opts.nlayers * 2 do 395 | self.fwd_hs[0][i]:zero() 396 | end 397 | self.all_fwd_bak_hs:resize(self.opts.batchSize, Tx, self.opts.nhid * 2) 398 | for t = 1, Tx do 399 | self.fwd_hs[t] = self.forward_lstms[t]:forward({x[{ t, {} }], x_pos[{ t, {} }], self.fwd_hs[t-1]}) 400 | self.all_fwd_bak_hs[{ {}, t, {1, self.opts.nhid} }] = self.fwd_hs[t][self.opts.nlayers*2] 401 | if self.opts.useGPU then cutorch.synchronize() end 402 | end 403 | 404 | -- forward pass for backward lstm 405 | for i = 1, 2*self.opts.nlayers do 406 | self.bak_hs[Tx+1][i]:zero() 407 | end 408 | for t = Tx, 1, -1 do 409 | self.bak_hs[t] = self.backward_lstms[t]:forward({x[{t, {} }], x_pos[{ t, {} }], self.bak_hs[t+1]}) 410 | local cmask = x_mask[{ t, {} }]:view(self.opts.batchSize, 1):expand(self.opts.batchSize, self.opts.nhid) 411 | for i = 1, 2*self.opts.nlayers do 412 | self.bak_hs[t][i]:cmul( cmask ) 413 | end 414 | self.all_fwd_bak_hs[{ {}, t, {self.opts.nhid + 1, 2*self.opts.nhid} }] = self.bak_hs[t][2*self.opts.nlayers] 415 | if self.opts.useGPU then cutorch.synchronize() end 416 | end 417 | 418 | -- forward pass for attention model 419 | local loss = 0 420 | local y_preds = {} 421 | local Ty = y:size(1) 422 | assert(Ty + 1 == Tx, 'Tx words sentence must have Tx - 1 options') 423 | for t = 2, Tx do 424 | y_preds[t] = self.attentions[t]:forward({self.all_fwd_bak_hs, self.all_fwd_bak_hs[{ {}, t, {} }], Tx, 425 | x_mask_t, x_mask_sub}) 426 | local loss_ = self.criterions[t]:forward({y_preds[t], y[{ t-1, {} }], self.opts.batchSize}) 427 | loss = loss + loss_ 428 | end 429 | 430 | return loss, y_preds 431 | end 432 | 433 | 434 | -------------------------------------------------------------------------------- /nnets/basic.lua: -------------------------------------------------------------------------------- 1 | 2 | require 'torch' 3 | 4 | local BModel = torch.class('BModel') 5 | 6 | function BModel:__init() 7 | self.name = 'Basic Model, name needed!' 8 | end 9 | 10 | function BModel:save(modelPath, saveOpts) 11 | if not modelPath:ends('.t7') then 12 | modelPath = modelPath .. '.t7' 13 | end 14 | 15 | if self.params:type() == 'torch.CudaTensor' then 16 | torch.save(modelPath, self.params:float()) 17 | else 18 | torch.save(modelPath, self.params) 19 | end 20 | 21 | if saveOpts then 22 | local optPath = modelPath:sub(1, -4) .. '.state.t7' 23 | torch.save(optPath, self.opts) 24 | end 25 | end 26 | 27 | function BModel:load(modelPath) 28 | self.params:copy( torch.load(modelPath) ) 29 | end 30 | 31 | function BModel:setModel(params) 32 | self.params:copy(params) 33 | end 34 | 35 | function BModel:getModel(outModel) 36 | return outModel:copy(self.params) 37 | end 38 | 39 | function BModel:print(msg) 40 | if msg == nil then 41 | xprint('the model is [%s]\n', self.name) 42 | else 43 | xprintln('[%s] %s', self.name, msg) 44 | end 45 | end 46 | 47 | function BModel.get_module_map(mods) 48 | local mdict = {} 49 | 50 | local function get_map(m) 51 | for _, node in ipairs(m.forwardnodes) do 52 | if node.data.annotations.name then 53 | mdict[node.data.annotations.name] = node.data.module 54 | end 55 | end 56 | end 57 | 58 | if torch.type(mods) == 'table' then 59 | for _, mod in ipairs(mods) do 60 | get_map(mod) 61 | end 62 | else 63 | get_map(mods) 64 | end 65 | 66 | return mdict 67 | end 68 | 69 | 70 | -------------------------------------------------------------------------------- /post_train.lua: -------------------------------------------------------------------------------- 1 | 2 | require '.' 3 | require 'shortcut' 4 | require 'SelectNetPos' 5 | require 'train' 6 | 7 | local Trainer = torch.class('SelectNetPostTrainer', 'SelectNetTrainer') 8 | 9 | local function getOpts() 10 | local cmd = torch.CmdLine() 11 | cmd:option('--load', '', 'model path') 12 | cmd:option('--save', 'model.t7', 'save model path') 13 | cmd:option('--lr', 0.001, 'learning rate') 14 | cmd:option('--maxEpoch', 30, 'maximum number of epochs') 15 | cmd:option('--optimMethod', 'SGD', 'optimization algorithm') 16 | cmd:option('--decay', 1, 'decay learning rate') 17 | 18 | local opts = cmd:parse(arg) 19 | 20 | return opts 21 | end 22 | 23 | function Trainer:main() 24 | local opts_ = getOpts() 25 | self.opts = torch.load( opts_.load:sub(1, -3) .. 'state.t7' ) 26 | assert(self.opts.save ~= opts_.save) 27 | self.opts.load = opts_.load 28 | self.opts.save = opts_.save 29 | self.opts.lr = opts_.lr 30 | self.opts.maxEpoch = opts_.maxEpoch 31 | self.opts.optimMethod = opts_.optimMethod 32 | local opts = self.opts 33 | 34 | torch.manualSeed(opts.seed + 1) 35 | if opts.useGPU then 36 | require 'cutorch' 37 | require 'cunn' 38 | cutorch.manualSeed(opts.seed + 1) 39 | end 40 | 41 | self.trainSize, self.validSize, self.testSize = unpack( DepPosDataIter.getDataSize({opts.train, opts.valid, opts.test}) ) 42 | xprintln('train size = %d, valid size = %d, test size = %d', self.trainSize, self.validSize, self.testSize) 43 | 44 | -- local vocabPath = opts.train .. '.tmp.vocab.t7' 45 | local vocabPath = opts.train .. '.tmp.pos.vocab.t7' 46 | local recreateVocab = true 47 | if paths.filep(vocabPath) then 48 | opts.vocab = torch.load(vocabPath) 49 | if opts.vocab.ignoreCase == opts.ignoreCase and opts.vocab.freqCut == opts.freqCut and opts.vocab.maxNVocab == opts.maxNVocab then 50 | recreateVocab = false 51 | DepPosDataIter.showVocab(opts.vocab) 52 | print '****load from existing vocab!!!****\n\n' 53 | end 54 | end 55 | if recreateVocab then 56 | opts.vocab = DepPosDataIter.createVocab(opts.train, opts.ignoreCase, opts.freqCut, opts.maxNVocab) 57 | torch.save(vocabPath, opts.vocab) 58 | xprintln('****create vocab from scratch****\n\n') 59 | end 60 | 61 | self.net = SelectNetPos(opts) 62 | self:showOpts() 63 | 64 | xprintln('load from %s ...', opts.load) 65 | self.net:load(opts.load) 66 | xprintln('load from %s done!', opts.load) 67 | 68 | self.train_all_sents = DepPosDataIter.loadAllSents(opts.vocab, opts.train, opts.maxTrainLen) 69 | local bestUAS = 0 70 | local bestModel = torch.FloatTensor(self.net.params:size()) 71 | local timer = torch.Timer() 72 | 73 | self.opts.sgdParam = {learningRate = opts.lr} 74 | local v = self:validConllx(opts.valid) 75 | print(v) 76 | bestUAS = v.UAS 77 | self.net:getModel(bestModel) 78 | 79 | for epoch = 1, self.opts.maxEpoch do 80 | self.iepoch = epoch 81 | local startTime = timer:time().real 82 | 83 | local train_nll, train_perp = self:train() 84 | xprintln('\nepoch %d TRAIN %f (%f) ', epoch, train_nll, train_perp) 85 | -- local vret = self:valid(opts.valid) 86 | local vret = self:validConllx(opts.valid) 87 | print 'Valid Performance' 88 | print(vret) 89 | local endTime = timer:time().real 90 | xprintln('time spend %s', readableTime(endTime - startTime)) 91 | 92 | if bestUAS < vret.UAS then 93 | bestUAS = vret.UAS 94 | self.net:getModel(bestModel) 95 | if opts.test and opts.test ~= '' then 96 | local vret = self:validConllx(opts.test) 97 | print 'Test Performance' 98 | print(vret) 99 | end 100 | else 101 | xprintln('UAS on valid not increase! early stopping!') 102 | break 103 | end 104 | 105 | self.opts.sgdParam.learningRate = self.opts.sgdParam.learningRate * opts_.decay 106 | end 107 | 108 | -- save final model 109 | self.net:setModel(bestModel) 110 | opts.sgdParam = nil 111 | self.net:save(opts.save, true) 112 | xprintln('model saved at %s', opts.save) 113 | 114 | -- show final perform 115 | local vret = self:validConllx(opts.valid) 116 | print 'Final Valid Performance' 117 | print(vret) 118 | if opts.test and opts.test ~= '' then 119 | vret = self:validConllx(opts.test) 120 | print 'Final Test Performance' 121 | print(vret) 122 | end 123 | 124 | end 125 | 126 | local function main() 127 | local trainer = SelectNetPostTrainer() 128 | trainer:main() 129 | end 130 | 131 | if not package.loaded['post_train'] then 132 | main() 133 | else 134 | print '[post_train] loaded as package!' 135 | end 136 | 137 | -------------------------------------------------------------------------------- /train.lua: -------------------------------------------------------------------------------- 1 | 2 | require '.' 3 | require 'shortcut' 4 | -- require 'SelectNet' 5 | require 'SelectNetPos' 6 | -- require 'DepDataIter' 7 | require 'DepPosDataIter' 8 | 9 | local Trainer = torch.class('SelectNetTrainer') 10 | 11 | function Trainer:showOpts() 12 | local tmp_vocab = self.opts.vocab 13 | self.opts.vocab = nil 14 | print(self.opts) 15 | self.opts.vocab = tmp_vocab 16 | end 17 | 18 | function Trainer:train() 19 | local dataIter = DepPosDataIter.createBatchShuffleSort(self.train_all_sents, self.opts.vocab, self.opts.batchSize, 20, true) 20 | 21 | local dataSize = self.trainSize 22 | local curDataSize = 0 23 | local percent, inc = 0.001, 0.001 24 | local timer = torch.Timer() 25 | local sgdParam = self.opts.sgdParam 26 | local cnt = 0 27 | local totalLoss = 0 28 | local totalCnt = 0 29 | 30 | for x, x_mask, x_pos, y in dataIter do 31 | local loss = self.net:trainBatch(x, x_mask, x_pos, y, sgdParam) 32 | local y_mask = x_mask[{ {2, -1}, {} }] 33 | 34 | totalLoss = totalLoss + loss * x:size(2) 35 | totalCnt = totalCnt + y_mask:sum() 36 | 37 | curDataSize = curDataSize + x:size(2) 38 | local ratio = curDataSize/dataSize 39 | if ratio >= percent then 40 | local wps = totalCnt / timer:time().real 41 | xprint( '\repoch %d %.3f %.4f (%s) / %.2f wps ... ', self.iepoch, ratio, totalLoss/totalCnt, readableTime(timer:time().real), wps ) 42 | percent = math.floor(ratio / inc) * inc 43 | percent = percent + inc 44 | end 45 | 46 | cnt = cnt + 1 47 | if cnt % 5 == 0 then 48 | collectgarbage() 49 | end 50 | end 51 | 52 | local nll = totalLoss / totalCnt 53 | return nll, math.exp(nll) 54 | end 55 | 56 | function Trainer:valid(validFile) 57 | local dataIter = DepPosDataIter.createBatchSort(self.opts.vocab, validFile, self.opts.batchSize, 150) 58 | local totalCnt = 0 59 | local totalLoss = 0 60 | local cnt = 0 61 | local UAS_c, UAS_t = 0, 0 62 | for x, x_mask, x_pos, y in dataIter do 63 | local loss, y_preds = self.net:validBatch(x, x_mask, x_pos, y) 64 | 65 | totalLoss = totalLoss + loss * x:size(2) 66 | local y_mask = x_mask[{ {2, -1}, {} }] 67 | 68 | local y_p = torch.LongTensor(y:size(1), y:size(2)) 69 | for t = 2, x:size(1) do 70 | local _, mi = y_preds[t]:max(2) 71 | if self.opts.useGPU then mi = mi:double() end 72 | y_p[{ t-1, {} }] = mi 73 | end 74 | UAS_c = UAS_c + y:eq(y_p):double():cmul(y_mask):sum() 75 | UAS_t = UAS_t + y_mask:sum() 76 | 77 | totalCnt = totalCnt + y_mask:sum() 78 | cnt = cnt + 1 79 | if cnt % 5 == 0 then 80 | collectgarbage() 81 | end 82 | end 83 | 84 | local entropy = totalLoss / totalCnt 85 | local ppl = torch.exp(entropy) 86 | 87 | return {entropy = entropy, ppl = ppl, UAS = (UAS_c / UAS_t)} 88 | end 89 | 90 | function Trainer:validConllx(validFile) 91 | local dataIter = DepPosDataIter.createBatch(self.opts.vocab, validFile, self.opts.batchSize, 150) 92 | local totalCnt = 0 93 | local totalLoss = 0 94 | local cnt = 0 95 | 96 | local sents_dep = {} 97 | local y_tmp = torch.LongTensor(150, self.opts.batchSize) 98 | for x, x_mask, x_pos, y in dataIter do 99 | local loss, y_preds = self.net:validBatch(x, x_mask, x_pos, y) 100 | 101 | totalLoss = totalLoss + loss * x:size(2) 102 | local y_mask = x_mask[{ {2, -1}, {} }] 103 | 104 | local y_p = y_tmp:resize(y:size(1), y:size(2)) 105 | for t = 2, x:size(1) do 106 | local _, mi = y_preds[t]:max(2) 107 | if self.opts.useGPU then mi = mi:double() end 108 | y_p[{ t-1, {} }] = mi 109 | end 110 | 111 | for i = 1, y_mask:size(2) do 112 | local slen = y_mask[{ {}, i }]:sum() 113 | local sent_dep = {} 114 | for j = 1, slen do 115 | sent_dep[#sent_dep + 1] = y_p[{ j, i }] - 1 116 | end 117 | sents_dep[#sents_dep + 1] = sent_dep 118 | end 119 | 120 | totalCnt = totalCnt + y_mask:sum() 121 | cnt = cnt + 1 122 | if cnt % 5 == 0 then 123 | collectgarbage() 124 | end 125 | end 126 | 127 | local dep_iter = DepPosDataIter.conllx_iter(validFile) 128 | local sent_idx = 0 129 | local sys_out = '__tmp__.dep' 130 | local fout = io.open(sys_out, 'w') 131 | for dsent in dep_iter do 132 | sent_idx = sent_idx + 1 133 | local sent_dep = sents_dep[sent_idx] 134 | assert(#sent_dep == #dsent) 135 | for i, ditem in ipairs(dsent) do 136 | -- 1 Influential _ JJ JJ _ 2 amod _ _ 137 | fout:write(string.format('%d\t%s\t_\t%s\t_\t_\t%d\tN_A\t_\t_\n', ditem.p1, ditem.wd, ditem.pos, sent_dep[i])) 138 | end 139 | fout:write('\n') 140 | end 141 | fout:close() 142 | 143 | local conllx_eval = self.opts.evalType == 'stanford' and require 'conllx_eval' or require 'conllx2006_eval' 144 | local LAS, UAS, noPunctLAS, noPunctUAS = conllx_eval.eval(sys_out, validFile) 145 | 146 | local entropy = totalLoss / totalCnt 147 | local ppl = torch.exp(entropy) 148 | 149 | return {entropy = entropy, ppl = ppl, UAS = noPunctUAS} 150 | end 151 | 152 | function Trainer:main() 153 | local model_opts = require 'model_opts' 154 | local opts = model_opts.getOpts() 155 | self.opts = opts 156 | 157 | self.trainSize, self.validSize, self.testSize = unpack( DepPosDataIter.getDataSize({opts.train, opts.valid, opts.test}) ) 158 | xprintln('train size = %d, valid size = %d, test size = %d', self.trainSize, self.validSize, self.testSize) 159 | 160 | local vocabPath = opts.train .. '.tmp.pos.vocab.t7' 161 | local recreateVocab = true 162 | if paths.filep(vocabPath) then 163 | opts.vocab = torch.load(vocabPath) 164 | if opts.vocab.ignoreCase == opts.ignoreCase and opts.vocab.freqCut == opts.freqCut and opts.vocab.maxNVocab == opts.maxNVocab then 165 | recreateVocab = false 166 | DepPosDataIter.showVocab(opts.vocab) 167 | print '****load from existing vocab!!!****\n\n' 168 | end 169 | end 170 | if recreateVocab then 171 | opts.vocab = DepPosDataIter.createVocab(opts.train, opts.ignoreCase, opts.freqCut, opts.maxNVocab) 172 | torch.save(vocabPath, opts.vocab) 173 | xprintln('****create vocab from scratch****\n\n') 174 | end 175 | 176 | self.net = SelectNetPos(opts) 177 | self:showOpts() 178 | 179 | self.train_all_sents = DepPosDataIter.loadAllSents(opts.vocab, opts.train, opts.maxTrainLen) 180 | local bestUAS = 0 181 | local bestModel = torch.FloatTensor(self.net.params:size()) 182 | local timer = torch.Timer() 183 | 184 | for epoch = 1, self.opts.maxEpoch do 185 | self.iepoch = epoch 186 | local startTime = timer:time().real 187 | 188 | local train_nll, train_perp = self:train() 189 | xprintln('\nepoch %d TRAIN %f (%f) ', epoch, train_nll, train_perp) 190 | -- local vret = self:valid(opts.valid) 191 | local vret = self:validConllx(opts.valid) 192 | print 'Valid Performance' 193 | print(vret) 194 | local endTime = timer:time().real 195 | xprintln('time spend %s', readableTime(endTime - startTime)) 196 | 197 | if bestUAS < vret.UAS then 198 | bestUAS = vret.UAS 199 | self.net:getModel(bestModel) 200 | if opts.test and opts.test ~= '' then 201 | local vret = self:validConllx(opts.test) 202 | print 'Test Performance' 203 | print(vret) 204 | end 205 | else 206 | if not opts.disableEearlyStopping then 207 | xprintln('UAS on valid not increase! early stopping!') 208 | break 209 | end 210 | end 211 | end 212 | 213 | -- save final model 214 | self.net:setModel(bestModel) 215 | opts.sgdParam = nil 216 | self.net:save(opts.save, true) 217 | xprintln('model saved at %s', opts.save) 218 | 219 | -- show final perform 220 | local vret = self:validConllx(opts.valid) 221 | print 'Final Valid Performance' 222 | print(vret) 223 | if opts.test and opts.test ~= '' then 224 | vret = self:validConllx(opts.test) 225 | print 'Final Test Performance' 226 | print(vret) 227 | end 228 | 229 | end 230 | 231 | local function main() 232 | local trainer = SelectNetTrainer() 233 | trainer:main() 234 | end 235 | 236 | if not package.loaded['train'] then 237 | main() 238 | else 239 | print '[train] loaded as package!' 240 | end 241 | 242 | 243 | -------------------------------------------------------------------------------- /train_labeled.lua: -------------------------------------------------------------------------------- 1 | 2 | require '.' 3 | require 'shortcut' 4 | require 'SelectNetPos' 5 | require 'DepPosDataIter' 6 | require 'hdf5' 7 | require 'MLP' 8 | 9 | local LabeledTrainer = torch.class('LabeledModelTrainer') 10 | 11 | local function getOpts() 12 | local cmd = torch.CmdLine() 13 | cmd:option('--mode', 'train', 'two modes: [generate] generate training data; [train] train labeled model') 14 | -- /disk/scratch/XingxingZhang/dep_parse/experiments/we_select_ft_pos/ft/model_0.0001.std.ft0.t7 15 | cmd:option('--modelPath', '/disk/scratch/XingxingZhang/dep_parse/experiments/we_select_ft_pos/ft/model_0.001.std.ft0.t7', 'model path') 16 | cmd:option('--outTrainDataPath', '/disk/scratch/XingxingZhang/dep_parse/experiments/we_select_ft_pos_lbl_pos/label_0.001.std.ft0.pos.h5', 'where will you save the training data') 17 | cmd:option('--inTrain', '/afs/inf.ed.ac.uk/group/project/img2txt/dep_parser/dataset/train.autopos', 'input training data path') 18 | cmd:option('--inValid', '/afs/inf.ed.ac.uk/group/project/img2txt/dep_parser/dataset/valid.autopos', 'input validation data path') 19 | cmd:option('--inTest', '/afs/inf.ed.ac.uk/group/project/img2txt/dep_parser/dataset/test.autopos', 'input test data path') 20 | cmd:option('--outValid', '', 'valid conllx file from last step') 21 | cmd:option('--outTest', '', 'test conllx file from last step') 22 | cmd:option('--language', 'English', 'English or Chinese or Other') 23 | 24 | cmd:text('') 25 | cmd:text('==Options for MLP==') 26 | cmd:option('--seed', 123, 'random seed') 27 | cmd:option('--useGPU', false, 'use gpu') 28 | cmd:option('--snhids', '1460,400,400,45', 'string hidden sizes for each layer') 29 | cmd:option('--ftype', '|x|', 'type: x, xe, xpe. For example: |x|xe|xpe|') 30 | cmd:option('--activ', 'relu', 'options: tanh, relu') 31 | cmd:option('--dropout', 0, 'dropout rate (dropping)') 32 | cmd:option('--inDropout', 0, 'dropout rate (dropping)') 33 | cmd:option('--batchNorm', false, 'add batch normalization') 34 | cmd:option('--maxEpoch', 10, 'max number of epochs') 35 | cmd:option('--dataset', 36 | '/disk/scratch/XingxingZhang/dep_parse/experiments/we_select_ft_pos_lbl_pos/label_0.001.std.ft0.h5', 37 | 'dataset') 38 | cmd:option('--batchSize', 256, '') 39 | cmd:option('--lr', 0.01, '') 40 | cmd:option('--optimMethod', 'AdaGrad', 'options: SGD, AdaGrad, Adam') 41 | cmd:option('--save', '/disk/scratch/XingxingZhang/dep_parse/experiments/we_select_ft_pos_lbl_pos/lclassifier_0.001.std.ft0.t7', 'save path') 42 | 43 | local opts = cmd:parse(arg) 44 | 45 | return opts 46 | end 47 | 48 | function LabeledTrainer:showOpts() 49 | local tmp_vocab = self.opts.vocab 50 | self.opts.vocab = nil 51 | print(self.opts) 52 | self.opts.vocab = tmp_vocab 53 | end 54 | 55 | function LabeledTrainer:validConllx(validFile, outputConllFile, defaultLabel) 56 | xprintln('default label is %s', defaultLabel) 57 | local dataIter = DepPosDataIter.createBatch(self.opts.vocab, validFile, self.opts.batchSize, 150) 58 | local totalCnt = 0 59 | local totalLoss = 0 60 | local cnt = 0 61 | 62 | local sents_dep = {} 63 | local y_tmp = torch.LongTensor(150, self.opts.batchSize) 64 | for x, x_mask, x_pos, y in dataIter do 65 | local loss, y_preds = self.net:validBatch(x, x_mask, x_pos, y) 66 | 67 | totalLoss = totalLoss + loss * x:size(2) 68 | local y_mask = x_mask[{ {2, -1}, {} }] 69 | 70 | local y_p = y_tmp:resize(y:size(1), y:size(2)) 71 | for t = 2, x:size(1) do 72 | local _, mi = y_preds[t]:max(2) 73 | if self.opts.useGPU then mi = mi:double() end 74 | y_p[{ t-1, {} }] = mi 75 | end 76 | 77 | for i = 1, y_mask:size(2) do 78 | local slen = y_mask[{ {}, i }]:sum() 79 | local sent_dep = {} 80 | for j = 1, slen do 81 | sent_dep[#sent_dep + 1] = y_p[{ j, i }] - 1 82 | end 83 | sents_dep[#sents_dep + 1] = sent_dep 84 | end 85 | 86 | totalCnt = totalCnt + y_mask:sum() 87 | cnt = cnt + 1 88 | if cnt % 5 == 0 then 89 | collectgarbage() 90 | end 91 | end 92 | 93 | local dep_iter = DepPosDataIter.conllx_iter(validFile) 94 | local sent_idx = 0 95 | local sys_out = outputConllFile 96 | local fout = io.open(sys_out, 'w') 97 | for dsent in dep_iter do 98 | sent_idx = sent_idx + 1 99 | local sent_dep = sents_dep[sent_idx] 100 | assert(#sent_dep == #dsent) 101 | for i, ditem in ipairs(dsent) do 102 | -- 1 Influential _ JJ JJ _ 2 amod _ _ 103 | fout:write(string.format('%d\t%s\t_\t_\t%s\t_\t%d\t%s\t_\t_\n', ditem.p1, ditem.wd, ditem.pos, sent_dep[i], defaultLabel)) 104 | end 105 | fout:write('\n') 106 | end 107 | fout:close() 108 | 109 | -- local conllx_eval = require 'conllx_eval' 110 | if self.opts.evalType == nil then 111 | self.opts.evalType = 'stanford' 112 | end 113 | xprintln('eval type = %s', self.opts.evalType) 114 | local conllx_eval = self.opts.evalType == 'stanford' and require 'conllx_eval' or require 'conllx2006_eval' 115 | local LAS, UAS, noPunctLAS, noPunctUAS = conllx_eval.eval(sys_out, validFile) 116 | 117 | local entropy = totalLoss / totalCnt 118 | local ppl = torch.exp(entropy) 119 | 120 | return {entropy = entropy, ppl = ppl, UAS = noPunctUAS} 121 | end 122 | 123 | function LabeledTrainer:load(model_path) 124 | local opts = torch.load( model_path:sub(1, -3) .. 'state.t7' ) 125 | self.opts = opts 126 | torch.manualSeed(opts.seed) 127 | if opts.useGPU then 128 | require 'cutorch' 129 | require 'cunn' 130 | cutorch.manualSeed(opts.seed) 131 | end 132 | 133 | assert(opts.vocab ~= nil, 'We must have an existing vocabulary!') 134 | self.net = SelectNetPos(opts) 135 | self:showOpts() 136 | 137 | xprintln('load from %s ...', model_path) 138 | self.net:load(model_path) 139 | xprintln('load from %s done!', model_path) 140 | end 141 | 142 | function LabeledTrainer:createTrainData(indDtaPaths, outDataPath, language) 143 | self.rel_vocab = DepPosDataIter.createDepRelVocab(indDtaPaths.train) 144 | print(self.rel_vocab) 145 | local h5out = hdf5.open(outDataPath, 'w') 146 | 147 | local function generateSplit(slabel, infile, batchSize, maxlen) 148 | local gxdata = string.format('/%s/x', slabel) 149 | local gydata = string.format('/%s/y', slabel) 150 | local gxedata = string.format('/%s/xe', slabel) 151 | local gxpedata = string.format('/%s/xpe', slabel) 152 | 153 | local xOpt = hdf5.DataSetOptions() 154 | xOpt:setChunked(1024*10, self.opts.nhid * 4) 155 | xOpt:setDeflate() 156 | 157 | local xeOpt = hdf5.DataSetOptions() 158 | xeOpt:setChunked(1024*10, self.opts.nin * 2) 159 | xeOpt:setDeflate() 160 | 161 | local xpeOpt = hdf5.DataSetOptions() 162 | xpeOpt:setChunked(1024*10, self.opts.npin * 2) 163 | xpeOpt:setDeflate() 164 | 165 | local yOpt = hdf5.DataSetOptions() 166 | yOpt:setChunked(1024*10) 167 | yOpt:setDeflate() 168 | 169 | local isFirst = true 170 | local diter = DepPosDataIter.createBatchLabel(self.opts.vocab, self.rel_vocab, infile, batchSize, maxlen) 171 | local cnt = 0 172 | for x, x_mask, x_pos, y, sent_rels, sent_ori_rels in diter do 173 | self.net:validBatch(x, x_mask, x_pos, y) 174 | local dsize = x_mask:sum() - x_mask:size(2) 175 | assert(dsize == y:ne(0):sum(), 'size should be the same') 176 | local x_input = torch.zeros(dsize, self.opts.nhid * 4):float() 177 | local y_output = torch.zeros(dsize):int() 178 | local x_input_emb = torch.zeros(dsize, self.opts.nin * 2):float() 179 | local x_input_pos_emb = torch.zeros(dsize, self.opts.npin * 2):float() 180 | 181 | -- self.mod_map.forward_lookup 182 | -- self.mod_map.forward_pos_lookup 183 | local x_emb = self.net.mod_map.forward_lookup:forward(x) 184 | local x_pos_emb = self.net.mod_map.forward_pos_lookup:forward(x_pos) 185 | 186 | -- bs x seqlen x nhid 187 | -- self.net.all_fwd_bak_hs 188 | local example_cnt = 0 189 | for i, sent_rel in ipairs(sent_rels) do 190 | assert(x_mask[{ {}, i }]:sum() == #sent_rel + 1, 'MUST be the same length') 191 | for j, rel_id in ipairs(sent_rel) do 192 | local cur_id = j + 1 193 | local parent_id = y[{ j, i }] 194 | local cur_a = self.net.all_fwd_bak_hs[{ i, cur_id, {} }] 195 | local parent_a = self.net.all_fwd_bak_hs[{ i, parent_id, {} }] 196 | example_cnt = example_cnt + 1 197 | x_input[{ example_cnt, {1, 2 * self.opts.nhid} }] = cur_a:float() 198 | x_input[{ example_cnt, {2 * self.opts.nhid + 1, 4 * self.opts.nhid} }] = parent_a:float() 199 | y_output[{ example_cnt }] = rel_id 200 | 201 | local cur_emb = x_emb[{ cur_id, i, {} }] 202 | local parent_emb = x_emb[{ parent_id, i, {} }] 203 | local cur_pos_emb = x_pos_emb[{ cur_id, i, {} }] 204 | local parent_pos_emb = x_pos_emb[{ parent_id, i, {} }] 205 | x_input_emb[{ example_cnt, {1, self.opts.nin} }] = cur_emb:float() 206 | x_input_emb[{ example_cnt, {self.opts.nin + 1, 2*self.opts.nin} }] = parent_emb:float() 207 | x_input_pos_emb[{ example_cnt, {1, self.opts.npin} }] = cur_pos_emb:float() 208 | x_input_pos_emb[{ example_cnt, {self.opts.npin + 1, 2*self.opts.npin} }] = parent_pos_emb:float() 209 | end 210 | end 211 | 212 | if isFirst then 213 | h5out:write(gxdata, x_input, xOpt) 214 | h5out:write(gydata, y_output, yOpt) 215 | 216 | h5out:write(gxedata, x_input_emb, xeOpt) 217 | h5out:write(gxpedata, x_input_pos_emb, xpeOpt) 218 | 219 | isFirst = false 220 | else 221 | h5out:append(gxdata, x_input, xOpt) 222 | h5out:append(gydata, y_output, yOpt) 223 | 224 | h5out:append(gxedata, x_input_emb, xeOpt) 225 | h5out:append(gxpedata, x_input_pos_emb, xpeOpt) 226 | end 227 | 228 | cnt = cnt + 1 229 | if cnt % 5 == 0 then 230 | collectgarbage() 231 | end 232 | 233 | if cnt % 10 == 0 then 234 | xprint('cnt = %d\n', cnt) 235 | end 236 | end 237 | 238 | print( 'toally ' .. cnt ) 239 | end 240 | 241 | local predictValidFile = outDataPath .. '.valid.conllx' 242 | local predictTestFile = outDataPath .. '.test.conllx' 243 | print(indDtaPaths) 244 | 245 | local dlabel = self.rel_vocab.idx2rel[1] 246 | xprintln('the default dependency label is %s\n', dlabel) 247 | 248 | if indDtaPaths.outvalid == '' then 249 | assert(language == 'English' or language == 'Chinese' or language == 'Other') 250 | --[[ 251 | local dlabel 252 | if language == 'English' then 253 | dlabel = 'pobj' 254 | elseif language == 'Chinese' then 255 | dlabel = 'ROOT' 256 | end 257 | --]] 258 | 259 | self:validConllx(indDtaPaths.valid, predictValidFile, dlabel) 260 | self:validConllx(indDtaPaths.test, predictTestFile, dlabel) 261 | else 262 | -- predictValidFile = indDtaPaths.outvalid 263 | -- predictTestFile = indDtaPaths.outtest 264 | assert(language == 'English' or language == 'Chinese' or language == 'Other') 265 | if language == 'English' then 266 | os.execute( string.format('cp %s %s', indDtaPaths.outvalid, predictValidFile) ) 267 | os.execute( string.format('cp %s %s', indDtaPaths.outtest, predictTestFile) ) 268 | else 269 | local replaceField = require 'replace_conllx_field' 270 | replaceField.replace(indDtaPaths.outvalid, predictValidFile, 8, dlabel) 271 | replaceField.replace(indDtaPaths.outtest, predictTestFile, 8, dlabel) 272 | xprintln('change field 8 to %s', dlabel) 273 | end 274 | 275 | if self.opts.evalType == nil then 276 | self.opts.evalType = 'stanford' 277 | end 278 | xprintln('eval type = %s', self.opts.evalType) 279 | local conllx_eval = self.opts.evalType == 'stanford' and require 'conllx_eval' or require 'conllx2006_eval' 280 | -- local conllx_eval = require 'conllx_eval' 281 | print '===Valid===' 282 | conllx_eval.eval(predictValidFile, indDtaPaths.valid) 283 | print '===Test===' 284 | conllx_eval.eval(predictTestFile, indDtaPaths.test) 285 | end 286 | 287 | assert(language == 'English' or language == 'Chinese' or language == 'Other') 288 | local maxTrainLen = language == 'English' and 100 or 140 289 | if language == 'Other' then 290 | maxTrainLen = 110 291 | end 292 | if self.opts.maxTrainLen ~= nil then 293 | maxTrainLen = self.opts.maxTrainLen 294 | print('maxTrainLen = ', maxTrainLen) 295 | end 296 | 297 | generateSplit('predict_valid', predictValidFile, self.opts.batchSize, 999999) 298 | generateSplit('predict_test', predictTestFile, self.opts.batchSize, 999999) 299 | generateSplit('valid', indDtaPaths.valid, self.opts.batchSize, 999999) 300 | generateSplit('test', indDtaPaths.test, self.opts.batchSize, 999999) 301 | generateSplit('train', indDtaPaths.train, self.opts.batchSize, maxTrainLen) 302 | 303 | h5out:close() 304 | end 305 | 306 | local DataIter = {} 307 | function DataIter.getNExamples(dataPath, label) 308 | local h5in = hdf5.open(dataPath, 'r') 309 | local x_data = h5in:read(string.format('/%s/x', label)) 310 | local N = x_data:dataspaceSize()[1] 311 | 312 | return N 313 | end 314 | 315 | function DataIter.createBatch(dataPath, label, batchSize, ftype) 316 | local h5in = hdf5.open(dataPath, 'r') 317 | 318 | local x_data = h5in:read(string.format('/%s/x', label)) 319 | local xe_data = h5in:read(string.format('/%s/xe', label)) 320 | local xpe_data = h5in:read(string.format('/%s/xpe', label)) 321 | 322 | local y_data = h5in:read(string.format('/%s/y', label)) 323 | local N = x_data:dataspaceSize()[1] 324 | local x_width = x_data:dataspaceSize()[2] 325 | local xe_width = xe_data:dataspaceSize()[2] 326 | local xpe_width = xpe_data:dataspaceSize()[2] 327 | 328 | local istart = 1 329 | 330 | return function() 331 | if istart <= N then 332 | local iend = math.min(istart + batchSize - 1, N) 333 | -- local x = x_data:partial({istart, iend}, {1, x_width}) 334 | local y = y_data:partial({istart, iend}) 335 | 336 | local widths = {x_width} 337 | local xdatas = {x_data} 338 | if ftype:find('|xe|') then 339 | widths[#widths + 1] = xe_width 340 | xdatas[#xdatas + 1] = xe_data 341 | end 342 | if ftype:find('|xpe|') then 343 | widths[#widths + 1] = xpe_width 344 | xdatas[#xdatas + 1] = xpe_data 345 | end 346 | 347 | local width = 0 348 | for _, w in ipairs(widths) do width = width + w end 349 | local x = torch.zeros(y:size(1), width):float() 350 | local s = 0 351 | for i, w in ipairs(widths) do 352 | x[{ {}, {s + 1, s + w} }] = xdatas[i]:partial({istart, iend}, {1, w}) 353 | s = s + w 354 | end 355 | 356 | istart = iend + 1 357 | 358 | return x, y 359 | else 360 | h5in:close() 361 | end 362 | end 363 | end 364 | 365 | 366 | local RndBatcher = torch.class('RandomBatcher') 367 | function RndBatcher:__init(h5in, x_data, xe_data, xpe_data, y_data, bufSize, ftype) 368 | self.h5in = h5in 369 | self.x_data = x_data 370 | self.xe_data = xe_data 371 | self.xpe_data = xpe_data 372 | self.y_data = y_data 373 | self.bufSize = bufSize 374 | self.N = x_data:dataspaceSize()[1] 375 | self.x_width = x_data:dataspaceSize()[2] 376 | self.xe_width = xe_data:dataspaceSize()[2] 377 | self.xpe_width = xpe_data:dataspaceSize()[2] 378 | 379 | self.ftype = ftype 380 | 381 | self.istart = 1 382 | self.idx_chunk = 1 383 | self.chunk_size = 0 384 | end 385 | 386 | function RndBatcher:nextChunk() 387 | if self.istart <= self.N then 388 | local iend = math.min( self.istart + self.bufSize - 1, self.N ) 389 | self.x_chunk = self.x_data:partial({self.istart, iend}, {1, self.x_width}) 390 | self.xe_chunk = self.xe_data:partial({self.istart, iend}, {1, self.xe_width}) 391 | self.xpe_chunk = self.xpe_data:partial({self.istart, iend}, {1, self.xpe_width}) 392 | 393 | self.y_chunk = self.y_data:partial({self.istart, iend}) 394 | 395 | self.chunk_size = iend - self.istart + 1 396 | 397 | self.istart = iend + 1 398 | 399 | return true 400 | else 401 | return false 402 | end 403 | end 404 | 405 | function RndBatcher:nextBatch(batchSize) 406 | if self.idx_chunk > self.chunk_size then 407 | if self:nextChunk() then 408 | self.idx_chunk = 1 409 | self.idxs_chunk = torch.randperm(self.chunk_size):long() 410 | else 411 | return nil 412 | end 413 | end 414 | 415 | local iend = math.min( self.idx_chunk + batchSize - 1, self.chunk_size ) 416 | local idxs = self.idxs_chunk[{ {self.idx_chunk, iend} }] 417 | 418 | local y = self.y_chunk:index(1, idxs) 419 | 420 | local xs = {} 421 | local widths = {} 422 | local width = 0 423 | if self.ftype:find('|x|') then 424 | local x = self.x_chunk:index(1, idxs) 425 | width = width + self.x_width 426 | widths[#widths + 1] = self.x_width 427 | xs[#xs + 1] = x 428 | end 429 | 430 | if self.ftype:find('|xe|') then 431 | local xe = self.xe_chunk:index(1, idxs) 432 | width = width + self.xe_width 433 | widths[#widths + 1] = self.xe_width 434 | xs[#xs + 1] = xe 435 | end 436 | 437 | if self.ftype:find('|xpe|') then 438 | local xpe = self.xpe_chunk:index(1, idxs) 439 | width = width + self.xpe_width 440 | widths[#widths + 1] = self.xpe_width 441 | xs[#xs + 1] = xpe 442 | end 443 | 444 | local x_ = torch.zeros(y:size(1), width):float() 445 | local s = 0 446 | for i, w in ipairs(widths) do 447 | x_[{ {}, {s+1, s+w} }] = xs[i] 448 | s = s + w 449 | end 450 | 451 | self.idx_chunk = iend + 1 452 | 453 | return x_, y 454 | end 455 | 456 | function DataIter.createBatchShuffle(dataPath, label, batchSize, ftype) 457 | local h5in = hdf5.open(dataPath, 'r') 458 | 459 | local x_data = h5in:read(string.format('/%s/x', label)) 460 | local xe_data = h5in:read(string.format('/%s/xe', label)) 461 | local xpe_data = h5in:read(string.format('/%s/xpe', label)) 462 | 463 | local y_data = h5in:read(string.format('/%s/y', label)) 464 | 465 | local bufSize = 1000 * batchSize 466 | local rnd_batcher = RandomBatcher(h5in, x_data, xe_data, xpe_data, y_data, bufSize, ftype) 467 | 468 | return function() 469 | return rnd_batcher:nextBatch(batchSize) 470 | end 471 | end 472 | 473 | function LabeledTrainer:train_label() 474 | local dataIter = DataIter.createBatchShuffle(self.classifier_opts.dataset, 'train', 475 | self.classifier_opts.batchSize, self.classifier_opts.ftype) 476 | --[[ 477 | local dataIter = DataIter.createBatch(self.classifier_opts.dataset, 'train', 478 | self.classifier_opts.batchSize, self.classifier_opts.ftype) 479 | --]] 480 | 481 | local dataSize = DataIter.getNExamples(self.classifier_opts.dataset, 'train') 482 | local percent, inc = 0.001, 0.001 483 | local timer = torch.Timer() 484 | -- local sgdParam = {learningRate = opts.curLR} 485 | local sgdParam = self.classifier_opts.sgdParam 486 | local cnt = 0 487 | local totalLoss = 0 488 | local totalCnt = 0 489 | for x, y in dataIter do 490 | local loss = self.mlp:trainBatch(x, y, sgdParam) 491 | totalLoss = totalLoss + loss * x:size(1) 492 | totalCnt = totalCnt + x:size(1) 493 | 494 | local ratio = totalCnt/dataSize 495 | if ratio >= percent then 496 | local wps = totalCnt / timer:time().real 497 | xprint( '\repoch %d %.3f %.4f (%s) / %.2f wps ... ', self.iepoch, ratio, totalLoss/totalCnt, readableTime(timer:time().real), wps ) 498 | percent = math.floor(ratio / inc) * inc 499 | percent = percent + inc 500 | end 501 | 502 | cnt = cnt + 1 503 | if cnt % 5 == 0 then 504 | collectgarbage() 505 | end 506 | end 507 | 508 | return totalLoss / totalCnt 509 | end 510 | 511 | function LabeledTrainer:valid_label(label) 512 | local dataIter = DataIter.createBatch(self.classifier_opts.dataset, label, 513 | self.classifier_opts.batchSize, self.classifier_opts.ftype) 514 | 515 | local cnt = 0 516 | local correct, total = 0, 0 517 | for x, y in dataIter do 518 | local correct_, total_ = self.mlp:validBatch(x, y) 519 | correct = correct + correct_ 520 | total = total + total_ 521 | cnt = cnt + 1 522 | if cnt % 5 == 0 then collectgarbage() end 523 | end 524 | 525 | return correct, total 526 | end 527 | 528 | function LabeledTrainer:valid_label_conllx(label, conllx_file, gold_file) 529 | local dataIter = DataIter.createBatch(self.classifier_opts.dataset, label, 530 | self.classifier_opts.batchSize, self.classifier_opts.ftype) 531 | 532 | local cnt = 0 533 | local correct, total = 0, 0 534 | local lbl_idxs = {} 535 | for x, y in dataIter do 536 | local correct_, total_, y_pred = self.mlp:validBatch(x, y) 537 | correct = correct + correct_ 538 | total = total + total_ 539 | cnt = cnt + 1 540 | if cnt % 5 == 0 then collectgarbage() end 541 | 542 | local y_pred_ = y_pred:view(-1) 543 | for i = 1, y_pred_:size(1) do 544 | lbl_idxs[#lbl_idxs + 1] = y_pred_[i] 545 | end 546 | end 547 | 548 | local ilbl = 0 549 | local conllx_file_out = conllx_file .. '.out' 550 | 551 | -- begin 552 | local dep_iter = DepPosDataIter.conllx_iter(conllx_file) 553 | local sys_out = conllx_file_out 554 | local fout = io.open(sys_out, 'w') 555 | for dsent in dep_iter do 556 | for _, ditem in ipairs(dsent) do 557 | -- 1 Influential _ JJ JJ _ 2 amod _ _ 558 | ilbl = ilbl + 1 559 | local lbl = self.rel_vocab.idx2rel[ lbl_idxs[ilbl] ] 560 | fout:write( string.format('%d\t%s\t_\t_\t%s\t_\t%d\t%s\t_\t_\n', ditem.p1, ditem.wd, ditem.pos, ditem.p2, lbl) ) 561 | end 562 | fout:write('\n') 563 | end 564 | fout:close() 565 | -- end 566 | 567 | -- local conllx_eval = require 'conllx_eval' 568 | local conllx_eval 569 | xprintln('language = %s', self.classifier_opts.language) 570 | if self.classifier_opts.language == 'Other' then 571 | conllx_eval = require 'conllx2006_eval' 572 | else 573 | conllx_eval = require 'conllx_eval' 574 | end 575 | 576 | -- xprintln('eval type = %s', self.opts.evalType) 577 | local LAS, UAS, noPunctLAS, noPunctUAS = conllx_eval.eval(sys_out, gold_file) 578 | 579 | return {LAS = noPunctLAS, UAS = noPunctUAS} 580 | end 581 | 582 | function LabeledTrainer:trainLabeledClassifier(opts) 583 | torch.manualSeed(opts.seed) 584 | if opts.useGPU then 585 | require 'cutorch' 586 | require 'cunn' 587 | cutorch.manualSeed(opts.seed) 588 | end 589 | local mlp = MLP(opts) 590 | opts.sgdParam = {learningRate = opts.lr} 591 | opts.curLR = opts.lr 592 | print(opts) 593 | 594 | self.classifier_opts = opts 595 | self.mlp = mlp 596 | 597 | local timer = torch.Timer() 598 | local bestAcc = 0 599 | local bestModel = torch.FloatTensor(mlp.params:size()) 600 | local bestLAS = 0 601 | 602 | self.rel_vocab = DepPosDataIter.createDepRelVocab(opts.inTrain) 603 | opts.rel_vocab = self.rel_vocab 604 | xprintln('load rel_vocab done!') 605 | self.predictValidFile = opts.dataset .. '.valid.conllx' 606 | self.predictTestFile = opts.dataset .. '.test.conllx' 607 | 608 | for epoch = 1, opts.maxEpoch do 609 | self.iepoch = epoch 610 | -- EPOCH_INFO = string.format('epoch %d', epoch) 611 | local startTime = timer:time().real 612 | local trainCost = self:train_label() 613 | xprint('\repoch %d TRAIN nll %f ', epoch, trainCost) 614 | -- local validCor, validTot = valid(mlp, 'valid', opts) 615 | local validCor, validTot = self:valid_label('valid') 616 | local validAcc = validCor/validTot 617 | xprint('VALID %d/%d = %f ', validCor, validTot, validAcc) 618 | local endTime = timer:time().real 619 | xprintln('lr = %.4g (%s)', opts.curLR, readableTime(endTime - startTime)) 620 | 621 | local v_ret = self:valid_label_conllx('predict_valid', self.predictValidFile, self.classifier_opts.inValid) 622 | print '==Valid Perf==' 623 | print(v_ret) 624 | print '\n' 625 | 626 | if v_ret.LAS > bestLAS then 627 | bestLAS = v_ret.LAS 628 | mlp:getModel(bestModel) 629 | 630 | local t_ret = self:valid_label_conllx('predict_test', self.predictTestFile, self.classifier_opts.inTest) 631 | print '==Test Perf==' 632 | print(t_ret) 633 | print '\n' 634 | end 635 | end 636 | 637 | mlp:setModel(bestModel) 638 | opts.sgdParam = nil 639 | mlp:save(opts.save, true) 640 | xprintln('model saved at %s', opts.save) 641 | 642 | local v_ret = self:valid_label_conllx('predict_valid', self.predictValidFile, self.classifier_opts.inValid) 643 | print '==Valid Perf==' 644 | print(v_ret) 645 | print '\n' 646 | 647 | local t_ret = self:valid_label_conllx('predict_test', self.predictTestFile, self.classifier_opts.inTest) 648 | print '==Test Perf==' 649 | print(t_ret) 650 | print '\n' 651 | end 652 | 653 | local function main() 654 | local opts = getOpts() 655 | local trainer = LabeledModelTrainer() 656 | if opts.mode == 'generate' then 657 | xprintln('This is generate mode!') 658 | trainer:load(opts.modelPath) 659 | local inDataPaths = {train = opts.inTrain, valid = opts.inValid, test = opts.inTest} 660 | inDataPaths.outvalid = opts.outValid 661 | inDataPaths.outtest = opts.outTest 662 | trainer:createTrainData(inDataPaths, opts.outTrainDataPath, opts.language) 663 | xprintln('create training data done!') 664 | elseif opts.mode == 'train' then 665 | xprintln('This is train mode!') 666 | trainer:trainLabeledClassifier(opts) 667 | xprintln('Training done!') 668 | else 669 | error('only support [generate] and [train] mode') 670 | end 671 | end 672 | 673 | main() 674 | -------------------------------------------------------------------------------- /utils/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XingxingZhang/dense_parser/750c7d00ab33defd61e04739ef12ff8273afe304/utils/.DS_Store -------------------------------------------------------------------------------- /utils/model_utils.lua: -------------------------------------------------------------------------------- 1 | 2 | local model_utils = {} 3 | 4 | function model_utils.combine_selectnet_parameters(forward_lstm, backward_lstm, attention) 5 | --This is a method only works for fflstm-- 6 | -- get parameters. Note we will ignore the lookup table "ngram_lookup" 7 | local parameters = {} 8 | local gradParameters = {} 9 | 10 | -- get LSTM parameters. IGNORE ngram_lookup LookupTable 11 | local function getLSTMParameters(fflstm, parameters, gradParameters) 12 | 13 | for _, node in ipairs(fflstm.forwardnodes) do 14 | -- check IF this is a module and the module is not a lookup table 15 | if node.data.module then 16 | -- 'enc_ngram_lookup' or node.data.annotations.name == 'dec_ngram_lookup' 17 | if node.data.annotations.name ~= 'backward_lookup' then 18 | local mp,mgp = node.data.module:parameters() 19 | if mp and mgp then 20 | for i = 1,#mp do 21 | table.insert(parameters, mp[i]) 22 | table.insert(gradParameters, mgp[i]) 23 | end 24 | end 25 | else 26 | print('[combine_selectnet_parameters] found backward_lookup! ' .. node.data.annotations.name) 27 | end 28 | end 29 | end 30 | 31 | end 32 | 33 | getLSTMParameters(forward_lstm, parameters, gradParameters) 34 | getLSTMParameters(backward_lstm, parameters, gradParameters) 35 | getLSTMParameters(attention, parameters, gradParameters) 36 | 37 | local function storageInSet(set, storage) 38 | local storageAndOffset = set[torch.pointer(storage)] 39 | if storageAndOffset == nil then 40 | return nil 41 | end 42 | local _, offset = unpack(storageAndOffset) 43 | return offset 44 | end 45 | 46 | -- this function flattens arbitrary lists of parameters, 47 | -- even complex shared ones 48 | local function flatten(parameters) 49 | if not parameters or #parameters == 0 then 50 | return torch.Tensor() 51 | end 52 | local Tensor = parameters[1].new 53 | 54 | local storages = {} 55 | local nParameters = 0 56 | for k = 1,#parameters do 57 | local storage = parameters[k]:storage() 58 | if not storageInSet(storages, storage) then 59 | storages[torch.pointer(storage)] = {storage, nParameters} 60 | nParameters = nParameters + storage:size() 61 | end 62 | end 63 | 64 | local flatParameters = Tensor(nParameters):fill(1) 65 | local flatStorage = flatParameters:storage() 66 | 67 | for k = 1,#parameters do 68 | local storageOffset = storageInSet(storages, parameters[k]:storage()) 69 | parameters[k]:set(flatStorage, 70 | storageOffset + parameters[k]:storageOffset(), 71 | parameters[k]:size(), 72 | parameters[k]:stride()) 73 | parameters[k]:zero() 74 | end 75 | 76 | local maskParameters= flatParameters:float():clone() 77 | local cumSumOfHoles = flatParameters:float():cumsum(1) 78 | local nUsedParameters = nParameters - cumSumOfHoles[#cumSumOfHoles] 79 | local flatUsedParameters = Tensor(nUsedParameters) 80 | local flatUsedStorage = flatUsedParameters:storage() 81 | 82 | for k = 1,#parameters do 83 | local offset = cumSumOfHoles[parameters[k]:storageOffset()] 84 | parameters[k]:set(flatUsedStorage, 85 | parameters[k]:storageOffset() - offset, 86 | parameters[k]:size(), 87 | parameters[k]:stride()) 88 | end 89 | 90 | for _, storageAndOffset in pairs(storages) do 91 | local k, v = unpack(storageAndOffset) 92 | flatParameters[{{v+1,v+k:size()}}]:copy(Tensor():set(k)) 93 | end 94 | 95 | if cumSumOfHoles:sum() == 0 then 96 | flatUsedParameters:copy(flatParameters) 97 | else 98 | local counter = 0 99 | for k = 1,flatParameters:nElement() do 100 | if maskParameters[k] == 0 then 101 | counter = counter + 1 102 | flatUsedParameters[counter] = flatParameters[counter+cumSumOfHoles[k]] 103 | end 104 | end 105 | assert (counter == nUsedParameters) 106 | end 107 | return flatUsedParameters 108 | end 109 | 110 | -- flatten parameters and gradients 111 | local flatParameters = flatten(parameters) 112 | local flatGradParameters = flatten(gradParameters) 113 | 114 | -- return new flat vector that contains all discrete parameters 115 | return flatParameters, flatGradParameters 116 | end 117 | 118 | 119 | function model_utils.combine_selectnet_pos_parameters(forward_lstm, backward_lstm, attention) 120 | --This is a method only works for fflstm-- 121 | -- get parameters. Note we will ignore the lookup table "ngram_lookup" 122 | local parameters = {} 123 | local gradParameters = {} 124 | 125 | -- get LSTM parameters. IGNORE ngram_lookup LookupTable 126 | local function getLSTMParameters(fflstm, parameters, gradParameters) 127 | 128 | for _, node in ipairs(fflstm.forwardnodes) do 129 | -- check IF this is a module and the module is not a lookup table 130 | if node.data.module then 131 | -- 'enc_ngram_lookup' or node.data.annotations.name == 'dec_ngram_lookup' 132 | if node.data.annotations.name ~= 'backward_lookup' and node.data.annotations.name ~= 'backward_pos_lookup' then 133 | local mp,mgp = node.data.module:parameters() 134 | if mp and mgp then 135 | for i = 1,#mp do 136 | table.insert(parameters, mp[i]) 137 | table.insert(gradParameters, mgp[i]) 138 | end 139 | end 140 | else 141 | print('[combine_selectnet_parameters] found backward_lookup! ' .. node.data.annotations.name) 142 | end 143 | end 144 | end 145 | 146 | end 147 | 148 | getLSTMParameters(forward_lstm, parameters, gradParameters) 149 | getLSTMParameters(backward_lstm, parameters, gradParameters) 150 | getLSTMParameters(attention, parameters, gradParameters) 151 | 152 | local function storageInSet(set, storage) 153 | local storageAndOffset = set[torch.pointer(storage)] 154 | if storageAndOffset == nil then 155 | return nil 156 | end 157 | local _, offset = unpack(storageAndOffset) 158 | return offset 159 | end 160 | 161 | -- this function flattens arbitrary lists of parameters, 162 | -- even complex shared ones 163 | local function flatten(parameters) 164 | if not parameters or #parameters == 0 then 165 | return torch.Tensor() 166 | end 167 | local Tensor = parameters[1].new 168 | 169 | local storages = {} 170 | local nParameters = 0 171 | for k = 1,#parameters do 172 | local storage = parameters[k]:storage() 173 | if not storageInSet(storages, storage) then 174 | storages[torch.pointer(storage)] = {storage, nParameters} 175 | nParameters = nParameters + storage:size() 176 | end 177 | end 178 | 179 | local flatParameters = Tensor(nParameters):fill(1) 180 | local flatStorage = flatParameters:storage() 181 | 182 | for k = 1,#parameters do 183 | local storageOffset = storageInSet(storages, parameters[k]:storage()) 184 | parameters[k]:set(flatStorage, 185 | storageOffset + parameters[k]:storageOffset(), 186 | parameters[k]:size(), 187 | parameters[k]:stride()) 188 | parameters[k]:zero() 189 | end 190 | 191 | local maskParameters= flatParameters:float():clone() 192 | local cumSumOfHoles = flatParameters:float():cumsum(1) 193 | local nUsedParameters = nParameters - cumSumOfHoles[#cumSumOfHoles] 194 | local flatUsedParameters = Tensor(nUsedParameters) 195 | local flatUsedStorage = flatUsedParameters:storage() 196 | 197 | for k = 1,#parameters do 198 | local offset = cumSumOfHoles[parameters[k]:storageOffset()] 199 | parameters[k]:set(flatUsedStorage, 200 | parameters[k]:storageOffset() - offset, 201 | parameters[k]:size(), 202 | parameters[k]:stride()) 203 | end 204 | 205 | for _, storageAndOffset in pairs(storages) do 206 | local k, v = unpack(storageAndOffset) 207 | flatParameters[{{v+1,v+k:size()}}]:copy(Tensor():set(k)) 208 | end 209 | 210 | if cumSumOfHoles:sum() == 0 then 211 | flatUsedParameters:copy(flatParameters) 212 | else 213 | local counter = 0 214 | for k = 1,flatParameters:nElement() do 215 | if maskParameters[k] == 0 then 216 | counter = counter + 1 217 | flatUsedParameters[counter] = flatParameters[counter+cumSumOfHoles[k]] 218 | end 219 | end 220 | assert (counter == nUsedParameters) 221 | end 222 | return flatUsedParameters 223 | end 224 | 225 | -- flatten parameters and gradients 226 | local flatParameters = flatten(parameters) 227 | local flatGradParameters = flatten(gradParameters) 228 | 229 | -- return new flat vector that contains all discrete parameters 230 | return flatParameters, flatGradParameters 231 | end 232 | 233 | function model_utils.clone_many_times(net, T) 234 | local clones = {} 235 | 236 | local params, gradParams 237 | if net.parameters then 238 | params, gradParams = net:parameters() 239 | if params == nil then 240 | params = {} 241 | end 242 | end 243 | 244 | local paramsNoGrad 245 | if net.parametersNoGrad then 246 | paramsNoGrad = net:parametersNoGrad() 247 | end 248 | 249 | local mem = torch.MemoryFile("w"):binary() 250 | mem:writeObject(net) 251 | 252 | for t = 1, T do 253 | -- We need to use a new reader for each clone. 254 | -- We don't want to use the pointers to already read objects. 255 | local reader = torch.MemoryFile(mem:storage(), "r"):binary() 256 | local clone = reader:readObject() 257 | reader:close() 258 | 259 | if net.parameters then 260 | local cloneParams, cloneGradParams = clone:parameters() 261 | local cloneParamsNoGrad 262 | for i = 1, #params do 263 | cloneParams[i]:set(params[i]) 264 | cloneGradParams[i]:set(gradParams[i]) 265 | end 266 | if paramsNoGrad then 267 | cloneParamsNoGrad = clone:parametersNoGrad() 268 | for i =1,#paramsNoGrad do 269 | cloneParamsNoGrad[i]:set(paramsNoGrad[i]) 270 | end 271 | end 272 | end 273 | 274 | clones[t] = clone 275 | collectgarbage() 276 | end 277 | 278 | mem:close() 279 | return clones 280 | end 281 | 282 | -- share parameters of two lookupTables: forward lstm and backward lstm 283 | function model_utils.share_fflstm_lookup(fflstm) 284 | local lstm_lookup, ngram_lookup 285 | for _, node in ipairs(fflstm.forwardnodes) do 286 | if node.data.module then 287 | if node.data.annotations.name ~= nil and node.data.annotations.name:ends('ngram_lookup') then 288 | ngram_lookup = node.data.module 289 | print('[model_utils.share_fflstm_lookup] ngram_lookup found!') 290 | elseif node.data.annotations.name ~= nil and node.data.annotations.name:ends('lstm_lookup') then 291 | lstm_lookup = node.data.module 292 | print('[model_utils.share_fflstm_lookup] lstm_lookup found!') 293 | end 294 | end 295 | end 296 | 297 | ngram_lookup.weight:set(lstm_lookup.weight) 298 | ngram_lookup.gradWeight:set(lstm_lookup.gradWeight) 299 | 300 | collectgarbage() 301 | 302 | return lstm_lookup, ngram_lookup 303 | end 304 | 305 | function model_utils.load_embedding_init(emb, vocab, embedPath) 306 | require 'wordembedding' 307 | local wordEmbed = WordEmbedding(embedPath) 308 | wordEmbed:initMat(emb.weight, vocab) 309 | wordEmbed:releaseMemory() 310 | vocab = nil 311 | wordEmbed = nil 312 | collectgarbage() 313 | end 314 | 315 | function model_utils.load_embedding_fine_tune(emb, vocab, embedPath, ftFactor) 316 | require 'wordembedding_ft' 317 | local wordEmbed = WordEmbeddingFT(embedPath) 318 | local mask = wordEmbed:initMatFT(emb.weight, vocab, ftFactor) 319 | emb:setUpdateMask(mask) 320 | wordEmbed:releaseMemory() 321 | vocab = nil 322 | wordEmbed = nil 323 | collectgarbage() 324 | end 325 | 326 | 327 | function model_utils.clone_many_times_emb_ft(net, T) 328 | local clones = {} 329 | 330 | local params, gradParams 331 | if net.parameters then 332 | params, gradParams = net:parameters() 333 | if params == nil then 334 | params = {} 335 | end 336 | end 337 | 338 | local paramsNoGrad 339 | if net.parametersNoGrad then 340 | paramsNoGrad = net:parametersNoGrad() 341 | end 342 | 343 | local mem = torch.MemoryFile("w"):binary() 344 | mem:writeObject(net) 345 | 346 | local master_map = BModel.get_module_map(net) 347 | local lt_names = {} 348 | for k, v in pairs(master_map) do 349 | if k:find('lookup') ~= nil then 350 | table.insert(lt_names, k) 351 | end 352 | end 353 | 354 | for t = 1, T do 355 | -- We need to use a new reader for each clone. 356 | -- We don't want to use the pointers to already read objects. 357 | local reader = torch.MemoryFile(mem:storage(), "r"):binary() 358 | local clone = reader:readObject() 359 | reader:close() 360 | 361 | if net.parameters then 362 | local cloneParams, cloneGradParams = clone:parameters() 363 | local cloneParamsNoGrad 364 | for i = 1, #params do 365 | cloneParams[i]:set(params[i]) 366 | cloneGradParams[i]:set(gradParams[i]) 367 | end 368 | if paramsNoGrad then 369 | cloneParamsNoGrad = clone:parametersNoGrad() 370 | for i =1,#paramsNoGrad do 371 | cloneParamsNoGrad[i]:set(paramsNoGrad[i]) 372 | end 373 | end 374 | end 375 | 376 | local clone_map = BModel.get_module_map(clone) 377 | for _, k in ipairs(lt_names) do 378 | if clone_map[k].updateMask then 379 | clone_map[k].updateMask:set( master_map[k].updateMask ) 380 | end 381 | end 382 | 383 | clones[t] = clone 384 | collectgarbage() 385 | end 386 | 387 | mem:close() 388 | return clones 389 | end 390 | 391 | function model_utils.copy_table(to, from) 392 | assert(#to == #from) 393 | for i = 1, #to do 394 | to[i]:copy(from[i]) 395 | end 396 | end 397 | 398 | return model_utils 399 | -------------------------------------------------------------------------------- /utils/shortcut.lua: -------------------------------------------------------------------------------- 1 | 2 | function printf(s, ...) 3 | return io.write(s:format(...)) 4 | end 5 | 6 | function xprint(s, ...) 7 | local ret = io.write(s:format(...)) 8 | io.flush() 9 | return ret 10 | end 11 | 12 | function xprintln(s, ...) 13 | return xprint(s .. '\n', ...) 14 | end 15 | 16 | -- time -- 17 | function readableTime(stime) 18 | local intervals = {1, 60, 3600} 19 | local units = {"s", "min", "h"} 20 | local i = 2 21 | while i <= #intervals do 22 | if stime < intervals[i] then 23 | break 24 | end 25 | i = i + 1 26 | end 27 | return string.format( '%.2f%s', stime/intervals[i-1], units[i-1] ) 28 | end 29 | 30 | -- for tables -- 31 | function table.keys(t) 32 | local ks = {} 33 | for k, _ in pairs(t) do 34 | ks[#ks + 1] = k 35 | end 36 | return ks 37 | end 38 | 39 | function table.len(t) 40 | local size = 0 41 | for _ in pairs(t) do 42 | size = size + 1 43 | end 44 | return size 45 | end 46 | 47 | -- for strings 48 | function xtoCharSet(s) 49 | local set = {} 50 | local i = 1 51 | local ch = '' 52 | while true do 53 | ch = s:sub(i, i) 54 | if ch == '' then break end 55 | set[ch] = true 56 | i = i + 1 57 | end 58 | return set 59 | end 60 | 61 | function string.splitc(s, cseps) 62 | local strs = {} 63 | local cset = xtoCharSet(cseps) 64 | local i, ch = 1, ' ' 65 | while ch ~= '' do 66 | while true do 67 | ch = s:sub(i, i) 68 | if ch == '' or not cset[ch] then break end 69 | i = i + 1 70 | end 71 | local chs = {} 72 | while true do 73 | ch = s:sub(i, i) 74 | if ch == '' or cset[ch] then break end 75 | chs[#chs + 1] = ch 76 | i = i + 1 77 | end 78 | if #chs > 0 then strs[#strs + 1] = table.concat(chs) end 79 | end 80 | 81 | return strs 82 | end 83 | 84 | function string.starts(s, pat) 85 | return s:sub(1, string.len(pat)) == pat 86 | end 87 | 88 | function string.ends(s, pat) 89 | return pat == '' or s:sub(-string.len(pat)) == pat 90 | end 91 | 92 | function string.trim(s) 93 | return s:match'^()%s*$' and '' or s:match'^%s*(.*%S)' 94 | end 95 | 96 | function string.rfind(s, sub, istart, iend, isNotPlain) 97 | istart = 1 or istart 98 | iend = #s or iend 99 | if isNotPlain == nil then isNotPlain = false end 100 | local sub_ = sub:reverse() 101 | local pos1, pos2 = s:reverse():find(sub_, istart, iend, not isNotPlain) 102 | if pos1 ~= nil then 103 | return #s - pos2 + 1, #s - pos1 + 1 104 | end 105 | end 106 | 107 | -- the following is for arrays -- 108 | function table.extend(a, b) 109 | for _, v in ipairs(b) do 110 | a[#a + 1] = v 111 | end 112 | return a 113 | end 114 | 115 | function table.subtable(t, istart, iend) 116 | local N = #t 117 | assert(istart <= iend and istart >= 1 and iend <= N, 118 | 'invalid istart or iend') 119 | local subT = {} 120 | for i = istart, iend do 121 | subT[#subT + 1] = t[i] 122 | end 123 | 124 | return subT 125 | end 126 | 127 | function table.contains(t, key) 128 | for _, v in ipairs(t) do 129 | if v == key then return true end 130 | end 131 | return false 132 | end 133 | 134 | function table.find(t, key) 135 | for i, v in ipairs(t) do 136 | if v == key then return i end 137 | end 138 | return nil 139 | end 140 | 141 | function table.clear(t) 142 | for i, _ in ipairs(t) do 143 | t[i] = nil 144 | end 145 | end 146 | 147 | -- the following is for IOs -- 148 | function xreadlines(infile) 149 | local fin = io.open(infile, 'r') 150 | local lines = {} 151 | while true do 152 | local line = fin:read() 153 | if line == nil then break end 154 | lines[#lines + 1] = line 155 | end 156 | fin:close() 157 | 158 | return lines 159 | end 160 | 161 | function xcountlines(infile) 162 | local fin = io.open(infile, 'r') 163 | local cnt = 0 164 | while true do 165 | local line = fin:read() 166 | if line == nil then break end 167 | cnt = cnt + 1 168 | end 169 | fin:close() 170 | 171 | return cnt 172 | end 173 | 174 | function xmatches(s, reg) 175 | local istart, iend = s:find(reg) 176 | return istart == 1 and iend == s:len() 177 | end 178 | 179 | 180 | 181 | 182 | -------------------------------------------------------------------------------- /utils/wordembedding.lua: -------------------------------------------------------------------------------- 1 | 2 | require 'shortcut' 3 | 4 | local WordEmbed = torch.class('WordEmbedding') 5 | 6 | -- input should be torch7 file 7 | function WordEmbed:__init(embedFile) 8 | self.embed, self.word2idx, self.idx2word = unpack( torch.load(embedFile) ) 9 | xprintln('load embedding done!') 10 | self.lowerCase = true 11 | for _, word in ipairs(self.idx2word) do 12 | if word ~= word:lower() then 13 | self.lowerCase = false 14 | end 15 | end 16 | print('lower case: ') 17 | print(self.lowerCase) 18 | end 19 | 20 | function WordEmbed:releaseMemory() 21 | self.embed = nil 22 | self.word2idx = nil 23 | self.idx2word = nil 24 | collectgarbage() 25 | end 26 | 27 | function WordEmbed:initMat(mat, vocab) 28 | assert(mat:size(2) == self.embed:size(2)) 29 | local idx2word = vocab.idx2word 30 | local nvocab = #idx2word 31 | local cnt = 0 32 | for wid = 1, nvocab do 33 | local word = idx2word[wid] 34 | word = self.lowerCase and word:lower() or word 35 | local wid_ = self.word2idx[word] 36 | if wid_ ~= nil then 37 | mat[wid] = self.embed[wid_] 38 | cnt = cnt + 1 39 | -- else 40 | -- print(word) 41 | end 42 | end 43 | print(string.format('word embedding coverage: %d / %d = %f', cnt, nvocab, cnt / nvocab)) 44 | end 45 | 46 | -------------------------------------------------------------------------------- /utils/wordembedding_ft.lua: -------------------------------------------------------------------------------- 1 | 2 | require 'shortcut' 3 | 4 | local WordEmbed = torch.class('WordEmbeddingFT') 5 | 6 | -- input should be torch7 file 7 | function WordEmbed:__init(embedFile) 8 | self.embed, self.word2idx, self.idx2word = unpack( torch.load(embedFile) ) 9 | xprintln('load embedding done!') 10 | self.lowerCase = true 11 | for _, word in ipairs(self.idx2word) do 12 | if word ~= word:lower() then 13 | self.lowerCase = false 14 | end 15 | end 16 | print('lower case: ') 17 | print(self.lowerCase) 18 | end 19 | 20 | function WordEmbed:releaseMemory() 21 | self.embed = nil 22 | self.word2idx = nil 23 | self.idx2word = nil 24 | collectgarbage() 25 | end 26 | 27 | function WordEmbed:initMatFT(mat, vocab, ftFactor) 28 | assert(mat:size(2) == self.embed:size(2)) 29 | local mask = torch.ones(mat:size(1)) 30 | 31 | local idx2word = vocab.idx2word 32 | local nvocab = #idx2word 33 | local cnt = 0 34 | for wid = 1, nvocab do 35 | local word = idx2word[wid] 36 | word = self.lowerCase and word:lower() or word 37 | local wid_ = self.word2idx[word] 38 | if wid_ ~= nil then 39 | mat[wid] = self.embed[wid_] 40 | cnt = cnt + 1 41 | mask[wid] = ftFactor 42 | -- else 43 | -- print(word) 44 | end 45 | end 46 | print(string.format('word embedding coverage: %d / %d = %f', cnt, nvocab, cnt / nvocab)) 47 | 48 | return mask 49 | end 50 | 51 | 52 | --------------------------------------------------------------------------------