├── .gitignore ├── CMakeLists.txt ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── PATENTS ├── README.md ├── data └── prepare-iwslt14.sh ├── fairseq.gif ├── fairseq ├── clib │ ├── bleu.cpp │ ├── bleu.lua │ ├── init.lua │ ├── logsoftmax.cpp │ └── topk.cpp ├── init.lua ├── models │ ├── avgpool_model.lua │ ├── blstm_model.lua │ ├── conv_model.lua │ ├── ensemble_model.lua │ ├── fconv_model.lua │ ├── init.lua │ ├── model.lua │ ├── selection_blstm_model.lua │ └── utils.lua ├── modules │ ├── AppendBias.lua │ ├── BeamableMM.lua │ ├── CAddTableMulConstant.lua │ ├── CLSTM.lua │ ├── CudnnRnnTable.lua │ ├── GradMultiply.lua │ ├── LinearizedConvolution.lua │ ├── SeqMultiply.lua │ ├── TrainTestLayer.lua │ ├── ZipAlong.lua │ └── init.lua ├── optim │ └── nag.lua ├── search.lua ├── text │ ├── Dictionary.lua │ ├── bleu.lua │ ├── init.lua │ ├── lm_corpus.lua │ ├── pretty.lua │ └── tokenizer.lua ├── torchnet │ ├── MaxBatchDataset.lua │ ├── ResumableDPOptimEngine.lua │ ├── ShardedDatasetIterator.lua │ ├── SingleParallelIterator.lua │ ├── data.lua │ ├── hooks.lua │ └── init.lua └── utils.lua ├── generate-lines.lua ├── generate.lua ├── help.lua ├── optimize-fconv.lua ├── preprocess.lua ├── rocks ├── fairseq-cpu-scm-1.rockspec └── fairseq-scm-1.rockspec ├── run.lua ├── score.lua ├── scripts ├── binarize.lua ├── build_sym_alignment.py ├── make_fconv_vocsel.lua ├── makealigndict.lua ├── makedict.lua └── unkreplace.lua ├── test ├── test.lua ├── test_appendbias.lua ├── test_dictionary.lua ├── test_logsoftmax.lua ├── test_tokenizer.lua ├── test_topk.lua └── test_zipalong.lua ├── tofloat.lua └── train.lua /.gitignore: -------------------------------------------------------------------------------- 1 | *~ 2 | build 3 | test/tst2012.en 4 | -------------------------------------------------------------------------------- /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. An additional grant 6 | # of patent rights can be found in the PATENTS file in the same directory. 7 | 8 | CMAKE_MINIMUM_REQUIRED(VERSION 2.6 FATAL_ERROR) 9 | CMAKE_POLICY(VERSION 2.6) 10 | 11 | FIND_PACKAGE(Torch REQUIRED) 12 | FIND_PACKAGE(OpenMP) 13 | 14 | SET(CMAKE_CXX_FLAGS "-std=c++11 -Ofast") 15 | IF(OpenMP_FOUND) 16 | SET(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${OpenMP_C_FLAGS}") 17 | SET(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS}") 18 | ENDIF() 19 | 20 | # C++ library 21 | IF(APPLE) 22 | SET(CMAKE_SHARED_LIBRARY_SUFFIX ".so") 23 | ENDIF(APPLE) 24 | FILE(GLOB CPPSRC fairseq/clib/*.cpp) 25 | ADD_LIBRARY(fairseq_clib SHARED ${CPPSRC}) 26 | INSTALL(TARGETS fairseq_clib DESTINATION "${ROCKS_LIBDIR}") 27 | 28 | # Lua library 29 | INSTALL(DIRECTORY "fairseq" DESTINATION "${ROCKS_LUADIR}" FILES_MATCHING PATTERN "*.lua") 30 | 31 | # Scripts and main executable 32 | FOREACH(SCRIPT preprocess train tofloat generate generate-lines score optimize-fconv help) 33 | INSTALL(FILES "${SCRIPT}.lua" DESTINATION "${ROCKS_LUADIR}/fairseq/scripts") 34 | ENDFOREACH(SCRIPT) 35 | INSTALL(FILES "run.lua" DESTINATION "${ROCKS_BINDIR}" RENAME "fairseq") 36 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | Facebook has adopted a Code of Conduct that we expect project participants to adhere to. 4 | Please read the [full text](https://code.fb.com/codeofconduct/) 5 | so that you can understand what actions will and will not be tolerated. 6 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to fairseq (Lua) 2 | We want to make contributing to this project as easy and transparent as 3 | possible. 4 | 5 | ## Pull Requests 6 | First, note that there is now a PyTorch version 7 | [fairseq-py](https://github.com/facebookresearch/fairseq-py) of this toolkit and 8 | new development efforts will focus on it. That being said, we actively welcome 9 | your pull requests: 10 | 11 | 1. Fork the repo and create your branch from `main`. 12 | 2. If you've added code that should be tested, add tests. 13 | 3. If you've changed APIs, update the documentation. 14 | 4. Ensure the test suite passes. 15 | 5. If you haven't already, complete the Contributor License Agreement ("CLA"). 16 | 17 | ## Contributor License Agreement ("CLA") 18 | In order to accept your pull request, we need you to submit a CLA. You only need 19 | to do this once to work on any of Facebook's open source projects. 20 | 21 | Complete your CLA here: 22 | 23 | ## Issues 24 | We use GitHub issues to track public bugs. Please ensure your description is 25 | clear and has sufficient instructions to be able to reproduce the issue. 26 | 27 | ## Coding Style 28 | * 4 spaces for indentation rather than tabs 29 | * 80 character line length 30 | * CamelCase; capitalized class names and lower-case function names 31 | 32 | ## License 33 | By contributing to fairseq, you agree that your contributions will be licensed 34 | under the LICENSE file in the root directory of this source tree. 35 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD License 2 | 3 | For fairseq software 4 | 5 | Copyright (c) 2017-present, Facebook, Inc. All rights reserved. 6 | 7 | Redistribution and use in source and binary forms, with or without modification, 8 | are permitted provided that the following conditions are met: 9 | 10 | * Redistributions of source code must retain the above copyright notice, this 11 | list of conditions and the following disclaimer. 12 | 13 | * Redistributions in binary form must reproduce the above copyright notice, 14 | this list of conditions and the following disclaimer in the documentation 15 | and/or other materials provided with the distribution. 16 | 17 | * Neither the name Facebook nor the names of its contributors may be used to 18 | endorse or promote products derived from this software without specific 19 | prior written permission. 20 | 21 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 22 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 23 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 24 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR 25 | ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 26 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 27 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON 28 | ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 29 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 30 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 31 | -------------------------------------------------------------------------------- /PATENTS: -------------------------------------------------------------------------------- 1 | Additional Grant of Patent Rights Version 2 2 | 3 | "Software" means the fairseq software distributed by Facebook, Inc. 4 | 5 | Facebook, Inc. ("Facebook") hereby grants to each recipient of the Software 6 | ("you") a perpetual, worldwide, royalty-free, non-exclusive, irrevocable 7 | (subject to the termination provision below) license under any Necessary 8 | Claims, to make, have made, use, sell, offer to sell, import, and otherwise 9 | transfer the Software. For avoidance of doubt, no license is granted under 10 | Facebook’s rights in any patent claims that are infringed by (i) modifications 11 | to the Software made by you or any third party or (ii) the Software in 12 | combination with any software or other technology. 13 | 14 | The license granted hereunder will terminate, automatically and without notice, 15 | if you (or any of your subsidiaries, corporate affiliates or agents) initiate 16 | directly or indirectly, or take a direct financial interest in, any Patent 17 | Assertion: (i) against Facebook or any of its subsidiaries or corporate 18 | affiliates, (ii) against any party if such Patent Assertion arises in whole or 19 | in part from any software, technology, product or service of Facebook or any of 20 | its subsidiaries or corporate affiliates, or (iii) against any party relating 21 | to the Software. Notwithstanding the foregoing, if Facebook or any of its 22 | subsidiaries or corporate affiliates files a lawsuit alleging patent 23 | infringement against you in the first instance, and you respond by filing a 24 | patent infringement counterclaim in that lawsuit against that party that is 25 | unrelated to the Software, the license granted hereunder will not terminate 26 | under section (i) of this paragraph due to such counterclaim. 27 | 28 | A "Necessary Claim" is a claim of a patent owned by Facebook that is 29 | necessarily infringed by the Software standing alone. 30 | 31 | A "Patent Assertion" is any lawsuit or other action alleging direct, indirect, 32 | or contributory infringement or inducement to infringe any patent, including a 33 | cross-claim or counterclaim. 34 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Introduction 2 | 3 | ***Note***: there is now a PyTorch version of this toolkit ([fairseq-py](https://github.com/pytorch/fairseq)) and new development efforts will focus on it. The Lua version is preserved here, but is provided without any support. 4 | 5 | This is fairseq, a sequence-to-sequence learning toolkit for [Torch](http://torch.ch/) from Facebook AI Research tailored to Neural Machine Translation (NMT). 6 | It implements the convolutional NMT models proposed in [Convolutional Sequence to Sequence Learning](https://arxiv.org/abs/1705.03122) and [A Convolutional Encoder Model for Neural Machine Translation](https://arxiv.org/abs/1611.02344) as well as a standard LSTM-based model. 7 | It features multi-GPU training on a single machine as well as fast beam search generation on both CPU and GPU. 8 | We provide pre-trained models for English to French, English to German and English to Romanian translation. 9 | 10 | ![Model](fairseq.gif) 11 | 12 | # Citation 13 | 14 | If you use the code in your paper, then please cite it as: 15 | 16 | ``` 17 | @article{gehring2017convs2s, 18 | author = {Gehring, Jonas and Auli, Michael and Grangier, David and Yarats, Denis and Dauphin, Yann N}, 19 | title = "{Convolutional Sequence to Sequence Learning}", 20 | journal = {ArXiv e-prints}, 21 | archivePrefix = "arXiv", 22 | eprinttype = {arxiv}, 23 | eprint = {1705.03122}, 24 | primaryClass = "cs.CL", 25 | keywords = {Computer Science - Computation and Language}, 26 | year = 2017, 27 | month = May, 28 | } 29 | ``` 30 | 31 | and 32 | 33 | ``` 34 | @article{gehring2016convenc, 35 | author = {Gehring, Jonas and Auli, Michael and Grangier, David and Dauphin, Yann N}, 36 | title = "{A Convolutional Encoder Model for Neural Machine Translation}", 37 | journal = {ArXiv e-prints}, 38 | archivePrefix = "arXiv", 39 | eprinttype = {arxiv}, 40 | eprint = {1611.02344}, 41 | primaryClass = "cs.CL", 42 | keywords = {Computer Science - Computation and Language}, 43 | year = 2016, 44 | month = Nov, 45 | } 46 | ``` 47 | 48 | # Requirements and Installation 49 | * A computer running macOS or Linux 50 | * For training new models, you'll also need a NVIDIA GPU and [NCCL](https://github.com/NVIDIA/nccl) 51 | * A [Torch installation](http://torch.ch/docs/getting-started.html). For maximum speed, we recommend using LuaJIT and [Intel MKL](https://software.intel.com/en-us/intel-mkl). 52 | * A recent version [nn](https://github.com/torch/nn). The minimum required version is from May 5th, 2017. A simple `luarocks install nn` is sufficient to update your locally installed version. 53 | 54 | Install fairseq by cloning the GitHub repository and running 55 | ``` 56 | luarocks make rocks/fairseq-scm-1.rockspec 57 | ``` 58 | LuaRocks will fetch and build any additional dependencies that may be missing. 59 | In order to install the CPU-only version (which is only useful for translating new data with an existing model), do 60 | ``` 61 | luarocks make rocks/fairseq-cpu-scm-1.rockspec 62 | ``` 63 | 64 | The LuaRocks installation provides a command-line tool that includes the following functionality: 65 | * `fairseq preprocess`: Data pre-processing: build vocabularies and binarize training data 66 | * `fairseq train`: Train a new model on one or multiple GPUs 67 | * `fairseq generate`: Translate pre-processed data with a trained model 68 | * `fairseq generate-lines`: Translate raw text with a trained model 69 | * `fairseq score`: BLEU scoring of generated translations against reference translations 70 | * `fairseq tofloat`: Convert a trained model to a CPU model 71 | * `fairseq optimize-fconv`: Optimize a fully convolutional model for generation. This can also be achieved by passing the `-fconvfast` flag to the generation scripts. 72 | 73 | # Quick Start 74 | 75 | ## Training a New Model 76 | 77 | ### Data Pre-processing 78 | The fairseq source distribution contains an example pre-processing script for 79 | the IWSLT14 German-English corpus. 80 | Pre-process and binarize the data as follows: 81 | ``` 82 | $ cd data/ 83 | $ bash prepare-iwslt14.sh 84 | $ cd .. 85 | $ TEXT=data/iwslt14.tokenized.de-en 86 | $ fairseq preprocess -sourcelang de -targetlang en \ 87 | -trainpref $TEXT/train -validpref $TEXT/valid -testpref $TEXT/test \ 88 | -thresholdsrc 3 -thresholdtgt 3 -destdir data-bin/iwslt14.tokenized.de-en 89 | ``` 90 | This will write binarized data that can be used for model training to data-bin/iwslt14.tokenized.de-en. 91 | 92 | ### Training 93 | Use `fairseq train` to train a new model. 94 | Here a few example settings that work well for the IWSLT14 dataset: 95 | ``` 96 | # Standard bi-directional LSTM model 97 | $ mkdir -p trainings/blstm 98 | $ fairseq train -sourcelang de -targetlang en -datadir data-bin/iwslt14.tokenized.de-en \ 99 | -model blstm -nhid 512 -dropout 0.2 -dropout_hid 0 -optim adam -lr 0.0003125 -savedir trainings/blstm 100 | 101 | # Fully convolutional sequence-to-sequence model 102 | $ mkdir -p trainings/fconv 103 | $ fairseq train -sourcelang de -targetlang en -datadir data-bin/iwslt14.tokenized.de-en \ 104 | -model fconv -nenclayer 4 -nlayer 3 -dropout 0.2 -optim nag -lr 0.25 -clip 0.1 \ 105 | -momentum 0.99 -timeavg -bptt 0 -savedir trainings/fconv 106 | 107 | # Convolutional encoder, LSTM decoder 108 | $ mkdir -p trainings/convenc 109 | $ fairseq train -sourcelang de -targetlang en -datadir data-bin/iwslt14.tokenized.de-en \ 110 | -model conv -nenclayer 6 -dropout 0.2 -dropout_hid 0 -savedir trainings/convenc 111 | ``` 112 | 113 | By default, `fairseq train` will use all available GPUs on your machine. 114 | Use the [CUDA_VISIBLE_DEVICES](http://acceleware.com/blog/cudavisibledevices-masking-gpus) environment variable to select specific GPUs or `-ngpus` to change the number of GPU devices that will be used. 115 | 116 | ### Generation 117 | Once your model is trained, you can translate with it using `fairseq generate` (for binarized data) or `fairseq generate-lines` (for text). 118 | Here, we'll do it for a fully convolutional model: 119 | ``` 120 | # Optional: optimize for generation speed 121 | $ fairseq optimize-fconv -input_model trainings/fconv/model_best.th7 -output_model trainings/fconv/model_best_opt.th7 122 | 123 | # Translate some text 124 | $ DATA=data-bin/iwslt14.tokenized.de-en 125 | $ fairseq generate-lines -sourcedict $DATA/dict.de.th7 -targetdict $DATA/dict.en.th7 \ 126 | -path trainings/fconv/model_best_opt.th7 -beam 10 -nbest 2 127 | | [target] Dictionary: 24738 types 128 | | [source] Dictionary: 35474 types 129 | > eine sprache ist ausdruck des menschlichen geistes . 130 | S eine sprache ist ausdruck des menschlichen geistes . 131 | O eine sprache ist ausdruck des menschlichen geistes . 132 | H -0.23804219067097 a language is expression of human mind . 133 | A 2 2 3 4 5 6 7 8 9 134 | H -0.23861141502857 a language is expression of the human mind . 135 | A 2 2 3 4 5 7 6 7 9 9 136 | ``` 137 | 138 | ### CPU Generation 139 | Use `fairseq tofloat` to convert a trained model to use CPU-only operations (this has to be done on a GPU machine): 140 | ``` 141 | # Optional: optimize for generation speed 142 | $ fairseq optimize-fconv -input_model trainings/fconv/model_best.th7 -output_model trainings/fconv/model_best_opt.th7 143 | 144 | # Convert to float 145 | $ fairseq tofloat -input_model trainings/fconv/model_best_opt.th7 \ 146 | -output_model trainings/fconv/model_best_opt-float.th7 147 | 148 | # Translate some text 149 | $ fairseq generate-lines -sourcedict $DATA/dict.de.th7 -targetdict $DATA/dict.en.th7 \ 150 | -path trainings/fconv/model_best_opt-float.th7 -beam 10 -nbest 2 151 | > eine sprache ist ausdruck des menschlichen geistes . 152 | S eine sprache ist ausdruck des menschlichen geistes . 153 | O eine sprache ist ausdruck des menschlichen geistes . 154 | H -0.2380430996418 a language is expression of human mind . 155 | A 2 2 3 4 5 6 7 8 9 156 | H -0.23861189186573 a language is expression of the human mind . 157 | A 2 2 3 4 5 7 6 7 9 9 158 | ``` 159 | 160 | # Pre-trained Models 161 | 162 | Generation with the binarized test sets can be run in batch mode as follows, e.g. for English-French on a GTX-1080ti: 163 | ``` 164 | $ fairseq generate -sourcelang en -targetlang fr -datadir data-bin/wmt14.en-fr -dataset newstest2014 \ 165 | -path wmt14.en-fr.fconv-cuda/model.th7 -beam 5 -batchsize 128 | tee /tmp/gen.out 166 | ... 167 | | Translated 3003 sentences (95451 tokens) in 136.3s (700.49 tokens/s) 168 | | Timings: setup 0.1s (0.1%), encoder 1.9s (1.4%), decoder 108.9s (79.9%), search_results 0.0s (0.0%), search_prune 12.5s (9.2%) 169 | | BLEU4 = 43.43, 68.2/49.2/37.4/28.8 (BP=0.996, ratio=1.004, sys_len=92087, ref_len=92448) 170 | 171 | # Word-level BLEU scoring: 172 | $ grep ^H /tmp/gen.out | cut -f3- | sed 's/@@ //g' > /tmp/gen.out.sys 173 | $ grep ^T /tmp/gen.out | cut -f2- | sed 's/@@ //g' > /tmp/gen.out.ref 174 | $ fairseq score -sys /tmp/gen.out.sys -ref /tmp/gen.out.ref 175 | BLEU4 = 40.55, 67.6/46.5/34.0/25.3 (BP=1.000, ratio=0.998, sys_len=81369, ref_len=81194) 176 | ``` 177 | 178 | # Join the fairseq community 179 | 180 | * Facebook page: https://www.facebook.com/groups/fairseq.users 181 | * Google group: https://groups.google.com/forum/#!forum/fairseq-users 182 | * Contact: [jgehring@fb.com](mailto:jgehring@fb.com), [michaelauli@fb.com](mailto:michaelauli@fb.com) 183 | 184 | # License 185 | fairseq is BSD-licensed. 186 | The license applies to the pre-trained models as well. 187 | We also provide an additional patent grant. 188 | -------------------------------------------------------------------------------- /data/prepare-iwslt14.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # Copyright (c) 2017-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the LICENSE file in 6 | # the root directory of this source tree. An additional grant of patent rights 7 | # can be found in the PATENTS file in the same directory. 8 | # 9 | # Adapted from https://github.com/facebookresearch/MIXER/blob/master/prepareData.sh 10 | 11 | echo 'Cloning Moses github repository (for tokenization scripts)...' 12 | git clone https://github.com/moses-smt/mosesdecoder.git 13 | 14 | SCRIPTS=mosesdecoder/scripts 15 | TOKENIZER=$SCRIPTS/tokenizer/tokenizer.perl 16 | LC=$SCRIPTS/tokenizer/lowercase.perl 17 | CLEAN=$SCRIPTS/training/clean-corpus-n.perl 18 | 19 | URL="https://wit3.fbk.eu/archive/2014-01/texts/de/en/de-en.tgz" 20 | GZ=de-en.tgz 21 | 22 | if [ ! -d "$SCRIPTS" ]; then 23 | echo "Please set SCRIPTS variable correctly to point to Moses scripts." 24 | exit 25 | fi 26 | 27 | src=de 28 | tgt=en 29 | lang=de-en 30 | prep=iwslt14.tokenized.de-en 31 | tmp=$prep/tmp 32 | orig=orig 33 | 34 | mkdir -p $orig $tmp $prep 35 | 36 | echo "Downloading data from ${URL}..." 37 | cd $orig 38 | wget "$URL" 39 | 40 | if [ -f $GZ ]; then 41 | echo "Data successfully downloaded." 42 | else 43 | echo "Data not successfully downloaded." 44 | exit 45 | fi 46 | 47 | tar zxvf $GZ 48 | cd .. 49 | 50 | echo "pre-processing train data..." 51 | for l in $src $tgt; do 52 | f=train.tags.$lang.$l 53 | tok=train.tags.$lang.tok.$l 54 | 55 | cat $orig/$lang/$f | \ 56 | grep -v '' | \ 57 | grep -v '' | \ 58 | grep -v '' | \ 59 | sed -e 's///g' | \ 60 | sed -e 's/<\/title>//g' | \ 61 | sed -e 's/<description>//g' | \ 62 | sed -e 's/<\/description>//g' | \ 63 | perl $TOKENIZER -threads 8 -l $l > $tmp/$tok 64 | echo "" 65 | done 66 | perl $CLEAN -ratio 1.5 $tmp/train.tags.$lang.tok $src $tgt $tmp/train.tags.$lang.clean 1 175 67 | for l in $src $tgt; do 68 | perl $LC < $tmp/train.tags.$lang.clean.$l > $tmp/train.tags.$lang.$l 69 | done 70 | 71 | echo "pre-processing valid/test data..." 72 | for l in $src $tgt; do 73 | for o in `ls $orig/$lang/IWSLT14.TED*.$l.xml`; do 74 | fname=${o##*/} 75 | f=$tmp/${fname%.*} 76 | echo $o $f 77 | grep '<seg id' $o | \ 78 | sed -e 's/<seg id="[0-9]*">\s*//g' | \ 79 | sed -e 's/\s*<\/seg>\s*//g' | \ 80 | sed -e "s/\’/\'/g" | \ 81 | perl $TOKENIZER -threads 8 -l $l | \ 82 | perl $LC > $f 83 | echo "" 84 | done 85 | done 86 | 87 | 88 | echo "creating train, valid, test..." 89 | for l in $src $tgt; do 90 | awk '{if (NR%23 == 0) print $0; }' $tmp/train.tags.de-en.$l > $prep/valid.$l 91 | awk '{if (NR%23 != 0) print $0; }' $tmp/train.tags.de-en.$l > $prep/train.$l 92 | 93 | cat $tmp/IWSLT14.TED.dev2010.de-en.$l \ 94 | $tmp/IWSLT14.TEDX.dev2012.de-en.$l \ 95 | $tmp/IWSLT14.TED.tst2010.de-en.$l \ 96 | $tmp/IWSLT14.TED.tst2011.de-en.$l \ 97 | $tmp/IWSLT14.TED.tst2012.de-en.$l \ 98 | > $prep/test.$l 99 | done 100 | -------------------------------------------------------------------------------- /fairseq.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/fairseq-lua/7a6fff0f3647c6a4c8c59b4253004c4761eb11b3/fairseq.gif -------------------------------------------------------------------------------- /fairseq/clib/bleu.cpp: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright 2017-present, Facebook, Inc. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #include <map> 10 | #include <array> 11 | #include <cstring> 12 | #include <cstdio> 13 | 14 | typedef struct 15 | { 16 | size_t reflen; 17 | size_t predlen; 18 | size_t match1; 19 | size_t count1; 20 | size_t match2; 21 | size_t count2; 22 | size_t match3; 23 | size_t count3; 24 | size_t match4; 25 | size_t count4; 26 | } bleu_stat; 27 | 28 | // left trim (remove pad) 29 | void bleu_ltrim(size_t* len, int** sent, int pad) { 30 | size_t start = 0; 31 | while(start < *len) { 32 | if (*(*sent + start) != pad) { break; } 33 | start++; 34 | } 35 | *sent += start; 36 | *len -= start; 37 | } 38 | 39 | // right trim remove (eos) 40 | void bleu_rtrim(size_t* len, int** sent, int pad, int eos) { 41 | size_t end = *len - 1; 42 | while (end > 0) { 43 | if (*(*sent + end) != eos && *(*sent + end) != pad) { break; } 44 | end--; 45 | } 46 | *len = end + 1; 47 | } 48 | 49 | // left and right trim 50 | void bleu_trim(size_t* len, int** sent, int pad, int eos) { 51 | bleu_ltrim(len, sent, pad); 52 | bleu_rtrim(len, sent, pad, eos); 53 | } 54 | 55 | size_t bleu_hash(int len, int* data) { 56 | size_t h = 14695981039346656037ul; 57 | size_t prime = 0x100000001b3; 58 | char* b = (char*) data; 59 | size_t blen = sizeof(int) * len; 60 | 61 | while (blen-- > 0) { 62 | h ^= *b++; 63 | h *= prime; 64 | } 65 | 66 | return h; 67 | } 68 | 69 | void bleu_addngram( 70 | size_t *ntotal, size_t *nmatch, int n, 71 | size_t reflen, int* ref, size_t predlen, int* pred) { 72 | 73 | if (predlen < n) { return; } 74 | 75 | predlen = predlen - n + 1; 76 | (*ntotal) += (size_t)(predlen); 77 | 78 | if (reflen < n) { return; } 79 | 80 | reflen = reflen - n + 1; 81 | 82 | std::map<size_t, size_t> count; 83 | while (predlen > 0) { 84 | size_t w = bleu_hash(n, pred++); 85 | count[w]++; 86 | predlen--; 87 | } 88 | 89 | while (reflen > 0) { 90 | size_t w = bleu_hash(n, ref++); 91 | if (count[w] > 0) { 92 | (*nmatch)++; 93 | count[w] -=1; 94 | } 95 | reflen--; 96 | } 97 | } 98 | 99 | extern "C" { 100 | 101 | void bleu_zero_init(bleu_stat* stat) { 102 | std::memset(stat, 0, sizeof(bleu_stat)); 103 | } 104 | 105 | void bleu_one_init(bleu_stat* stat) { 106 | bleu_zero_init(stat); 107 | stat->count1 = 1; 108 | stat->count2 = 1; 109 | stat->count3 = 1; 110 | stat->count4 = 1; 111 | stat->match1 = 1; 112 | stat->match2 = 1; 113 | stat->match3 = 1; 114 | stat->match4 = 1; 115 | } 116 | 117 | void bleu_add( 118 | bleu_stat* stat, 119 | size_t reflen, int* ref, size_t predlen, int* pred, int pad, int eos) { 120 | 121 | bleu_trim(&reflen, &ref, pad, eos); 122 | bleu_trim(&predlen, &pred, pad, eos); 123 | stat->reflen += reflen; 124 | stat->predlen += predlen; 125 | 126 | bleu_addngram(&stat->count1, &stat->match1, 1, reflen, ref, predlen, pred); 127 | bleu_addngram(&stat->count2, &stat->match2, 2, reflen, ref, predlen, pred); 128 | bleu_addngram(&stat->count3, &stat->match3, 3, reflen, ref, predlen, pred); 129 | bleu_addngram(&stat->count4, &stat->match4, 4, reflen, ref, predlen, pred); 130 | } 131 | 132 | } 133 | -------------------------------------------------------------------------------- /fairseq/clib/bleu.lua: -------------------------------------------------------------------------------- 1 | -- Copyright (c) 2017-present, Facebook, Inc. 2 | -- All rights reserved. 3 | -- 4 | -- This source code is licensed under the license found in the LICENSE file in 5 | -- the root directory of this source tree. An additional grant of patent rights 6 | -- can be found in the PATENTS file in the same directory. 7 | -- 8 | --[[ 9 | -- 10 | -- BLEU scorer that operates directly over tensors 11 | -- as opposed to bleu.lua which is string based and 12 | -- takes tables as inputs 13 | -- 14 | --]] 15 | 16 | local ffi = require 'ffi' 17 | local function initBleu(C) 18 | local bleu = torch.class('Bleu') 19 | 20 | function Bleu:__init(pad, eos) 21 | self.stat = ffi.new('bleu_stat') 22 | self.pad = pad or 2 23 | self.eos = eos or 3 24 | self:reset() 25 | end 26 | 27 | function Bleu:reset(oneinit) 28 | self.nsent = 0 29 | if oneinit then 30 | C.bleu_one_init(self.stat) 31 | else 32 | C.bleu_zero_init(self.stat) 33 | end 34 | return self 35 | end 36 | 37 | function Bleu:add(ref, pred) 38 | local nogc = {ref, pred} -- keep pointers to prevent gc 39 | 40 | local reflen, refdata = ref:size(1), ref:data() 41 | local predlen, preddata = pred:size(1), pred:data() 42 | 43 | C.bleu_add( 44 | self.stat, reflen, refdata, predlen, preddata, self.pad, self.eos) 45 | self.nsent = self.nsent + 1 46 | 47 | table.unpack(nogc) 48 | return self 49 | end 50 | 51 | function Bleu:precision(n) 52 | local function ratio(a, b) 53 | return tonumber(b) > 0 and (tonumber(a) / tonumber(b)) or 0 54 | end 55 | local precision = { 56 | ratio(self.stat.match1, self.stat.count1), 57 | ratio(self.stat.match2, self.stat.count2), 58 | ratio(self.stat.match3, self.stat.count3), 59 | ratio(self.stat.match4, self.stat.count4), 60 | } 61 | return n and precision[n] or precision 62 | end 63 | 64 | function Bleu:brevity() 65 | local r = tonumber(self.stat.reflen)/tonumber(self.stat.predlen) 66 | return math.min(1, math.exp(1 - r)) 67 | end 68 | 69 | function Bleu:score() 70 | local psum = 0 71 | for _, p in ipairs(self:precision()) do 72 | psum = psum + math.log(p) 73 | end 74 | 75 | return self:brevity() * math.exp(psum / 4) * 100 76 | end 77 | 78 | function Bleu:results() 79 | return { 80 | bleu = self:score(), 81 | precision = self:precision(), 82 | brevPenalty = self:brevity(), 83 | totalSys = tonumber(self.stat.predlen), 84 | totalRef = tonumber(self.stat.reflen), 85 | } 86 | end 87 | 88 | function Bleu:resultString() 89 | local r = self:results() 90 | local str = string.format('BLEU4 = %.2f, ', r.bleu) 91 | local precs = {} 92 | for i = 1, 4 do 93 | precs[i] = string.format('%.1f', r.precision[i] * 100) 94 | end 95 | str = str .. table.concat(precs, '/') 96 | str = str .. string.format( 97 | ' (BP=%.3f, ratio=%.3f, sys_len=%d, ref_len=%d)', 98 | r.brevPenalty, r.totalRef / r.totalSys, r.totalSys, r.totalRef 99 | ) 100 | return str 101 | end 102 | 103 | return function(...) return Bleu(...) end 104 | end 105 | 106 | return initBleu 107 | -------------------------------------------------------------------------------- /fairseq/clib/init.lua: -------------------------------------------------------------------------------- 1 | -- Copyright (c) 2017-present, Facebook, Inc. 2 | -- All rights reserved. 3 | -- 4 | -- This source code is licensed under the license found in the LICENSE file in 5 | -- the root directory of this source tree. An additional grant of patent rights 6 | -- can be found in the PATENTS file in the same directory. 7 | -- 8 | 9 | local nn = require 'nn' 10 | local ffi = require 'ffi' 11 | local bleu = require 'fairseq.clib.bleu' 12 | local so = package.searchpath('libfairseq_clib', package.cpath) 13 | local cdef =[[ 14 | void ctopk2d(float* top, long* idx, int k, float* values, int len, int n); 15 | void logsoftmax2d(float* input, float* output, int sz1, int sz2); 16 | 17 | typedef struct 18 | { 19 | size_t reflen; 20 | size_t predlen; 21 | size_t match1; 22 | size_t count1; 23 | size_t match2; 24 | size_t count2; 25 | size_t match3; 26 | size_t count3; 27 | size_t match4; 28 | size_t count4; 29 | } bleu_stat; 30 | 31 | void bleu_zero_init(bleu_stat* self); 32 | void bleu_one_init(bleu_stat* self); 33 | 34 | void bleu_add( 35 | bleu_stat* stat, size_t reflen, int* ref, size_t predlen, int* pred, 36 | int pad, int eos); 37 | ]] 38 | 39 | ffi.cdef(cdef) 40 | local C = ffi.load(so) 41 | 42 | local function topk(top, ind, val, k) 43 | assert(val:dim() == 2 and k <= val:size(2)) 44 | if not(val:type() == 'torch.FloatTensor' and val:isContiguous()) then 45 | -- use torch for GPU, non contiguous tensors 46 | return torch.topk(top, ind, val, k, 2, true, true) 47 | else 48 | top, ind = top or torch.FloatTensor(), ind or torch.LongTensor() 49 | local len, n = val:size(2), val:size(1) 50 | top:resize(n, k) 51 | ind:resize(n, k) 52 | assert(top:isContiguous() and ind:isContiguous()) 53 | C.ctopk2d(top:data(), ind:data(), k, val:data(), len, n) 54 | end 55 | return top, ind 56 | end 57 | 58 | local function logsoftmax() 59 | local output = torch.FloatTensor() 60 | local lsm = nn.LogSoftMax() 61 | 62 | return function(input) 63 | if input:type()=='torch.FloatTensor' 64 | and input:dim()==2 65 | and input:isContiguous() 66 | then 67 | output:resizeAs(input) 68 | C.logsoftmax2d( 69 | input:data(), output:data(), input:size(1), input:size(2)) 70 | return output 71 | else 72 | lsm:type(input:type()) 73 | return lsm:updateOutput(input) 74 | end 75 | end 76 | end 77 | 78 | return { 79 | topk = topk, 80 | logsoftmax = logsoftmax, 81 | bleu = bleu(C), 82 | } 83 | -------------------------------------------------------------------------------- /fairseq/clib/logsoftmax.cpp: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright 2017-present, Facebook, Inc. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #include <cfloat> 10 | #include <cmath> 11 | 12 | extern "C" { 13 | 14 | /* Credits to Leon Bottou */ 15 | double approxexpminus(const double x) 16 | { 17 | /* fast approximation of exp(-x) for x positive */ 18 | const double a0 = 1.0; 19 | const double a1 = 0.125; 20 | const double a2 = 0.0078125; 21 | const double a3 = 0.00032552083; 22 | const double a4 = 1.0172526e-5; 23 | if (x < 13.0) 24 | { 25 | double y; 26 | y = a0+x*(a1+x*(a2+x*(a3+x*a4))); 27 | y *= y; 28 | y *= y; 29 | y *= y; 30 | y = 1/y; 31 | return y; 32 | } 33 | return 0; 34 | } 35 | 36 | void logsoftmax1d(float* input, float* output, int sz1) { 37 | float max = -FLT_MAX; 38 | float* in = input; 39 | for (int i = 0; i < sz1; i++) { 40 | float v = *in++; 41 | if (max < v) {max = v;} 42 | } 43 | 44 | double logsum = 0; 45 | in = input; 46 | for (int i = 0; i < sz1; i++) { 47 | logsum += approxexpminus(max - *in++); 48 | } 49 | logsum = max + log(logsum); 50 | 51 | for (int i = 0; i < sz1; i++) { 52 | *output++ = *input++ - logsum; 53 | } 54 | } 55 | 56 | void logsoftmax2d(float* input, float* output, int sz1, int sz2) { 57 | #pragma omp parallel for 58 | for (int i = 0; i < sz1; i++) { 59 | logsoftmax1d(input + i * sz2, output + i * sz2, sz2); 60 | } 61 | } 62 | 63 | } 64 | -------------------------------------------------------------------------------- /fairseq/clib/topk.cpp: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright 2017-present, Facebook, Inc. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #include <algorithm> // std::partial_sort_copy 10 | #include <vector> // std::vector 11 | 12 | typedef struct { 13 | int key; 14 | float value; 15 | } kvp; 16 | 17 | bool comparekvp(const kvp& a, const kvp& b) { return (a.value > b.value); } 18 | 19 | void topk(std::vector<kvp>& top, const std::vector<kvp>& list) { 20 | std::partial_sort_copy( 21 | list.begin(), list.end(), top.begin(), top.end(), 22 | comparekvp 23 | ); 24 | } 25 | 26 | void rangetopk( 27 | std::vector<kvp>& top, const std::vector<kvp>& list, int start, int len) { 28 | std::partial_sort_copy( 29 | list.begin() + start, list.begin() + start + len, top.begin(), top.end(), 30 | comparekvp 31 | ); 32 | } 33 | 34 | void multithreadtopk( 35 | std::vector<kvp>& top, const std::vector<kvp>& list, int nthr) { 36 | int k = top.size(); 37 | int len = list.size(); 38 | 39 | // does multi-threading worth it? 40 | if ((nthr < 2) || (len / nthr < 2*k)) { 41 | topk(top, list); 42 | return; 43 | } 44 | 45 | // map 46 | std::vector<std::vector<kvp>> mtop(nthr); 47 | #pragma omp parallel for 48 | for (int i = 0; i < nthr; i++) { 49 | mtop[i].resize(k); 50 | int start = i * len / nthr; 51 | int end = (i + 1) * len / nthr; 52 | rangetopk(mtop[i], list, start, end - start); 53 | } 54 | 55 | // reduce 56 | std::vector<kvp> tinylist; 57 | for (int i = 0; i < nthr; i++) { 58 | tinylist.insert(tinylist.end(), mtop[i].begin(), mtop[i].end()); 59 | } 60 | topk(top, tinylist); 61 | } 62 | 63 | extern "C" { 64 | 65 | void ctopk1d(float* top, long* ind, int k, float* values, int len, int nthr) { 66 | std::vector<kvp> list(len); 67 | std::vector<kvp> vtop(k); 68 | for (int i = 0; i < len; i++) { 69 | kvp elt {i, *values++}; 70 | list[i] = elt; 71 | } 72 | multithreadtopk(vtop, list, nthr); 73 | for (const auto& elt : vtop) { 74 | *ind++ = elt.key + 1; 75 | *top++ = elt.value; 76 | } 77 | } 78 | 79 | void ctopk2d(float* top, long* ind, int k, float* values, int len, int n) { 80 | if (n == 1) { 81 | // parallel at the sort level 82 | ctopk1d(top, ind, k, values, len, 20); 83 | } else { 84 | // parallel at the batch level 85 | #pragma omp parallel for 86 | for (int i = 0; i < n; i++) { 87 | ctopk1d(top + i * k, ind + i * k, k, values + i * len, len, 1); 88 | } 89 | } 90 | } 91 | 92 | } 93 | -------------------------------------------------------------------------------- /fairseq/init.lua: -------------------------------------------------------------------------------- 1 | -- Copyright (c) 2017-present, Facebook, Inc. 2 | -- All rights reserved. 3 | -- 4 | -- This source code is licensed under the license found in the LICENSE file in 5 | -- the root directory of this source tree. An additional grant of patent rights 6 | -- can be found in the PATENTS file in the same directory. 7 | -- 8 | --[[ 9 | -- 10 | -- base init file. 11 | -- 12 | --]] 13 | 14 | require 'fairseq.models' 15 | require 'fairseq.modules' 16 | require 'fairseq.torchnet' 17 | require 'fairseq.text' 18 | -------------------------------------------------------------------------------- /fairseq/models/blstm_model.lua: -------------------------------------------------------------------------------- 1 | -- Copyright (c) 2017-present, Facebook, Inc. 2 | -- All rights reserved. 3 | -- 4 | -- This source code is licensed under the license found in the LICENSE file in 5 | -- the root directory of this source tree. An additional grant of patent rights 6 | -- can be found in the PATENTS file in the same directory. 7 | -- 8 | --[[ 9 | -- 10 | -- This model uses a bi-directional LSTM encoder. The direction is reversed 11 | -- between layers and two separate columns run in parallel: one on the normal 12 | -- input and one on the reversed input (as described in 13 | -- http://arxiv.org/abs/1606.04199). 14 | -- 15 | -- The attention mechanism and the decoder setup are identical to the avgpool 16 | -- model. 17 | -- 18 | --]] 19 | 20 | require 'nn' 21 | require 'rnnlib' 22 | local usecudnn = pcall(require, 'cudnn') 23 | local argcheck = require 'argcheck' 24 | local mutils = require 'fairseq.models.utils' 25 | local rutils = require 'rnnlib.mutils' 26 | 27 | local BLSTMModel = torch.class('BLSTMModel', 'AvgpoolModel') 28 | 29 | BLSTMModel.makeEncoderColumn = argcheck{ 30 | {name='self', type='BLSTMModel'}, 31 | {name='config', type='table'}, 32 | {name='inith', type='nngraph.Node'}, 33 | {name='input', type='nngraph.Node'}, 34 | {name='nlayers', type='number'}, 35 | call = function(self, config, inith, input, nlayers) 36 | local rnnconfig = { 37 | inputsize = config.nembed, 38 | hidsize = config.nhid, 39 | nlayer = 1, 40 | winitfun = function(network) 41 | rutils.defwinitfun(network, config.init_range) 42 | end, 43 | usecudnn = usecudnn, 44 | } 45 | 46 | local rnn = nn.LSTM(rnnconfig) 47 | rnn.saveHidden = false 48 | local output = nn.SelectTable(-1)(nn.SelectTable(2)( 49 | rnn({inith, input}):annotate{name = 'encoderRNN'} 50 | )) 51 | rnnconfig.inputsize = config.nhid 52 | 53 | for i = 2, nlayers do 54 | if config.dropout_hid > 0 then 55 | output = nn.MapTable(nn.Dropout(config.dropout_hid))(output) 56 | end 57 | local rnn = nn.LSTM(rnnconfig) 58 | rnn.saveHidden = false 59 | output = nn.SelectTable(-1)(nn.SelectTable(2)( 60 | rnn({ 61 | inith, 62 | nn.ReverseTable()(output), 63 | }) 64 | )) 65 | end 66 | return output 67 | end 68 | } 69 | 70 | BLSTMModel.makeEncoder = argcheck{ 71 | doc=[[ 72 | This encoder runs a forward and backward LSTM network and concatenates their 73 | top-most hidden states. 74 | ]], 75 | {name='self', type='BLSTMModel'}, 76 | {name='config', type='table'}, 77 | call = function(self, config) 78 | local sourceIn = nn.Identity()() 79 | local inith, tokens = sourceIn:split(2) 80 | 81 | local dict = config.srcdict 82 | local lut = mutils.makeLookupTable(config, dict:size(), 83 | dict.pad_index) 84 | local embed 85 | if config.dropout_src > 0 then 86 | embed = nn.MapTable(nn.Sequential() 87 | :add(lut) 88 | :add(nn.Dropout(config.dropout_src)))(tokens) 89 | else 90 | embed = nn.MapTable(lut)(tokens) 91 | end 92 | 93 | local col1 = self:makeEncoderColumn{ 94 | config = config, 95 | inith = inith, 96 | input = embed, 97 | nlayers = config.nenclayer, 98 | } 99 | local col2 = self:makeEncoderColumn{ 100 | config = config, 101 | inith = inith, 102 | input = nn.ReverseTable()(embed), 103 | nlayers = config.nenclayer, 104 | } 105 | 106 | -- Each column will switch direction between layers. Before merging, 107 | -- they should both run in the same direction (here: forward). 108 | if config.nenclayer % 2 == 0 then 109 | col1 = nn.ReverseTable()(col1) 110 | else 111 | col2 = nn.ReverseTable()(col2) 112 | end 113 | 114 | local prepare = nn.Sequential() 115 | -- Concatenate forward and backward states 116 | prepare:add(nn.JoinTable(2, 2)) 117 | -- Scale down to nhid for further processing 118 | prepare:add(nn.Linear(config.nhid * 2, config.nembed, false)) 119 | -- Add singleton dimension for subsequent joining 120 | prepare:add(nn.View(-1, 1, config.nembed)) 121 | 122 | local joinedOutput = nn.JoinTable(1, 2)( 123 | nn.MapTable(prepare)( 124 | nn.ZipTable()({col1, col2}) 125 | ) 126 | ) 127 | if config.dropout_hid > 0 then 128 | joinedOutput = nn.Dropout(config.dropout_hid)(joinedOutput) 129 | end 130 | 131 | -- avgpool_model.makeDecoder() expects two encoder outputs, one for 132 | -- attention score computation and the other one for applying them. 133 | -- We'll just use the same output for both. 134 | return nn.gModule({sourceIn}, { 135 | joinedOutput, nn.Identity()(joinedOutput) 136 | }) 137 | end 138 | } 139 | 140 | BLSTMModel.prepareSource = argcheck{ 141 | {name='self', type='BLSTMModel'}, 142 | call = function(self) 143 | -- Device buffers for samples 144 | local buffers = { 145 | source = {}, 146 | } 147 | 148 | -- NOTE: It's assumed that all encoders start from the same hidden 149 | -- state. 150 | local encoderRNN = mutils.findAnnotatedNode( 151 | self:network(), 'encoderRNN' 152 | ) 153 | assert(encoderRNN ~= nil) 154 | 155 | return function(sample) 156 | -- Encoder input 157 | local source = {} 158 | for i = 1, sample.source:size(1) do 159 | buffers.source[i] = buffers.source[i] 160 | or torch.Tensor():type(self:type()) 161 | source[i] = mutils.sendtobuf(sample.source[i], 162 | buffers.source[i]) 163 | end 164 | 165 | local initialHidden = encoderRNN:initializeHidden(sample.bsz) 166 | return {initialHidden, source} 167 | end 168 | end 169 | } 170 | 171 | return BLSTMModel 172 | -------------------------------------------------------------------------------- /fairseq/models/conv_model.lua: -------------------------------------------------------------------------------- 1 | -- Copyright (c) 2017-present, Facebook, Inc. 2 | -- All rights reserved. 3 | -- 4 | -- This source code is licensed under the license found in the LICENSE file in 5 | -- the root directory of this source tree. An additional grant of patent rights 6 | -- can be found in the PATENTS file in the same directory. 7 | -- 8 | --[[ 9 | -- 10 | -- A model similar to AvgpoolModel, but with an encoder consisting of 11 | -- 2 parallel stacks of convolutional layers. 12 | -- 13 | --]] 14 | 15 | require 'nn' 16 | require 'nngraph' 17 | local argcheck = require 'argcheck' 18 | local utils = require 'fairseq.utils' 19 | local mutils = require 'fairseq.models.utils' 20 | 21 | local cuda = utils.loadCuda() 22 | 23 | local ConvModel, parent = torch.class('ConvModel', 'AvgpoolModel') 24 | 25 | ConvModel.__init = argcheck{ 26 | {name='self', type='ConvModel'}, 27 | {name='config', type='table', opt=true}, 28 | call = function(self, config) 29 | parent.__init(self, config) 30 | end 31 | } 32 | 33 | ConvModel.makeTemporalConvolution = argcheck{ 34 | {name='self', type='ConvModel'}, 35 | {name='config', type='table'}, 36 | {name='ninput', type='number'}, 37 | {name='kwidth', type='number'}, 38 | {name='nhid', type='number'}, 39 | call = function(self, config, ninput, kwidth, nhid) 40 | local pad = (kwidth - 1) / 2 41 | local conv 42 | if config.cudnnconv then 43 | conv = cuda.cudnn.TemporalConvolution(ninput, nhid, kwidth, 1, pad) 44 | else 45 | conv = nn.TemporalConvolutionTBC(ninput, nhid, kwidth, pad) 46 | end 47 | 48 | -- Initialize weights using the nn implementation 49 | local nnconv = nn.TemporalConvolution(ninput, nhid, 50 | kwidth, 1) 51 | conv.weight:copy(nnconv.weight) 52 | conv.bias:copy(nnconv.bias) 53 | 54 | -- Scale gradients by sqrt(ninput) to make learning more stable 55 | conv = nn.GradMultiply(conv, 1 / math.sqrt(ninput)) 56 | 57 | return conv 58 | end 59 | } 60 | 61 | ConvModel.makeEncoder = argcheck{ 62 | {name='self', type='ConvModel'}, 63 | {name='config', type='table'}, 64 | call = function(self, config) 65 | local sourceIn = nn.Identity()() 66 | 67 | -- First, computing embeddings for input tokens and their positions 68 | local tokens, positions = sourceIn:split(2) 69 | local dict = config.srcdict 70 | local embedToken = mutils.makeLookupTable(config, dict:size(), 71 | dict:getPadIndex()) 72 | -- XXX Assumes source sentence length < 1024 73 | local embedPosition = 74 | mutils.makeLookupTable(config, 1024, dict:getPadIndex()) 75 | local embed = 76 | nn.CAddTable()({embedToken(tokens), embedPosition(positions)}) 77 | if config.dropout_src > 0 then 78 | embed = nn.Dropout(config.dropout_src)(embed) 79 | end 80 | if not config.cudnnconv then 81 | embed = nn.Transpose({1, 2})(embed) 82 | end 83 | 84 | -- This stack is used for computing attention scores 85 | local cnnA = nn.Sequential() 86 | if config.nembed ~= config.nhid then 87 | -- Up-projection for producing nembed-sized output 88 | cnnA:add(nn.Bottle( 89 | nn.Linear(config.nembed, config.nhid) 90 | )) 91 | -- Bottle requires a continuous gradOutput 92 | cnnA:add(nn.Contiguous()) 93 | end 94 | 95 | for i = 1, config.nenclayer-1 do 96 | -- Residual connections 97 | cnnA:add(nn.ConcatTable() 98 | :add(self:makeTemporalConvolution(config, config.nhid, 99 | config.kwidth, config.nhid)) 100 | :add(nn.Identity())) 101 | cnnA:add(nn.CAddTable()) 102 | cnnA:add(nn.Tanh()) 103 | end 104 | cnnA:add(self:makeTemporalConvolution(config, config.nhid, 105 | config.kwidth, config.nhid)) 106 | cnnA:add(nn.Tanh()) 107 | 108 | if config.nembed ~= config.nhid then 109 | -- Down-projection for producing nembed-sized output 110 | cnnA:add(nn.Bottle( 111 | nn.Linear(config.nhid, config.nembed) 112 | )) 113 | end 114 | if not config.cudnnconv then 115 | cnnA:add(nn.Transpose({1, 2})) 116 | end 117 | 118 | -- This stack is used for aggregating the context for the decoder (using 119 | -- the attention scores) 120 | local cnnC = nn.Sequential() 121 | local nagglayer = config.nagglayer 122 | if nagglayer < 0 then 123 | -- By default, use fewer layers for aggregation than for attention 124 | nagglayer = math.floor(config.nenclayer / 2) 125 | nagglayer = math.max(1, math.min(nagglayer, 5)) 126 | end 127 | for i = 1, nagglayer-1 do 128 | -- Residual connections 129 | cnnC:add(nn.ConcatTable() 130 | :add(self:makeTemporalConvolution(config, config.nembed, 131 | config.kwidth, config.nembed)) 132 | :add(nn.Identity())) 133 | cnnC:add(nn.CAddTable()) 134 | cnnC:add(nn.Tanh()) 135 | end 136 | cnnC:add(self:makeTemporalConvolution(config, config.nembed, 137 | config.kwidth, config.nembed)) 138 | cnnC:add(nn.Tanh()) 139 | if not config.cudnnconv then 140 | cnnC:add(nn.Transpose({1, 2})) 141 | end 142 | 143 | return nn.gModule({sourceIn}, {cnnA(embed), cnnC(embed)}) 144 | end 145 | } 146 | 147 | function ConvModel:float(...) 148 | self.module:replace(function(m) 149 | if torch.isTypeOf(m, 'cudnn.TemporalConvolution') then 150 | return mutils.moveTemporalConvolutionToCPU(m) 151 | end 152 | return m 153 | end) 154 | return parent.float(self, ...) 155 | end 156 | 157 | return ConvModel 158 | -------------------------------------------------------------------------------- /fairseq/models/ensemble_model.lua: -------------------------------------------------------------------------------- 1 | -- Copyright (c) 2017-present, Facebook, Inc. 2 | -- All rights reserved. 3 | -- 4 | -- This source code is licensed under the license found in the LICENSE file in 5 | -- the root directory of this source tree. An additional grant of patent rights 6 | -- can be found in the PATENTS file in the same directory. 7 | -- 8 | --[[ 9 | -- 10 | -- Ensemble model. 11 | -- Can only be used for inference. 12 | -- 13 | --]] 14 | 15 | require 'nn' 16 | require 'rnnlib' 17 | local argcheck = require 'argcheck' 18 | local plstringx = require 'pl.stringx' 19 | local utils = require 'fairseq.utils' 20 | 21 | local cuda = utils.loadCuda() 22 | 23 | local EnsembleModel = torch.class('EnsembleModel', 'Model') 24 | 25 | EnsembleModel.__init = argcheck{ 26 | {name='self', type='EnsembleModel'}, 27 | {name='config', type='table'}, 28 | call = function(self, config) 29 | local paths = plstringx.split(config.path, ',') 30 | self.models = {} 31 | for i, path in pairs(paths) do 32 | self.models[i] = torch.load(path) 33 | end 34 | end 35 | } 36 | 37 | EnsembleModel.type = argcheck{ 38 | doc=[[ 39 | Shorthand for network():type() 40 | ]], 41 | {name='self', type='EnsembleModel'}, 42 | {name='type', type='string', opt=true}, 43 | {name='tensorCache', type='table', opt=true}, 44 | call = function(self, type, tensorCache) 45 | local ret = nil 46 | for _, model in ipairs(self.models) do 47 | ret = model:type(type, tensorCache) 48 | end 49 | return not type and ret or self 50 | end 51 | } 52 | 53 | EnsembleModel.make = argcheck{ 54 | {name='self', type='EnsembleModel'}, 55 | {name='config', type='table'}, 56 | call = function(self, config) 57 | error('Cannot construct a nn.Module instance for ensemble modules') 58 | end 59 | } 60 | 61 | EnsembleModel.generate = argcheck{ 62 | {name='self', type='EnsembleModel'}, 63 | {name='config', type='table'}, 64 | {name='sample', type='table'}, 65 | {name='search', type='table'}, 66 | call = function(self, config, sample, search) 67 | local dict = config.dict 68 | local minlen = config.minlen 69 | local maxlen = config.maxlen 70 | local sourceLen = sample.source:size(1) 71 | local bsz = sample.source:size(2) 72 | local bbsz = config.beam * bsz 73 | 74 | local timers = { 75 | setup = torch.Timer(), 76 | encoder = torch.Timer(), 77 | decoder = torch.Timer(), 78 | search_prune = torch.Timer(), 79 | search_results = torch.Timer(), 80 | } 81 | 82 | local callbacks = {} 83 | for _, model in ipairs(self.models) do 84 | table.insert(callbacks, model:generationCallbacks(config, bsz)) 85 | end 86 | 87 | for _, timer in pairs(timers) do 88 | timer:stop() 89 | timer:reset() 90 | end 91 | 92 | local states = {} 93 | for i = 1, #self.models do 94 | states[i] = { 95 | sample = sample, 96 | } 97 | end 98 | 99 | timers.setup:resume() 100 | local states = {} 101 | for i = 1, #self.models do 102 | states[i] = callbacks[i].setup(sample) 103 | end 104 | timers.setup:stop() 105 | 106 | timers.encoder:resume() 107 | for i = 1, #self.models do 108 | callbacks[i].encode(states[i]) 109 | end 110 | if cuda.cutorch then 111 | cuda.cutorch.synchronize() 112 | end 113 | timers.encoder:stop() 114 | 115 | -- <eos> is used as a start-of-sentence marker 116 | local targetIns = {} 117 | for i = 1, #self.models do 118 | targetIns[i] = torch.Tensor(bbsz):type(self:type()) 119 | targetIns[i]:fill(dict:getEosIndex()) 120 | end 121 | 122 | search.init(bsz, sample) 123 | local vocabsize = 124 | sample.targetVocab and sample.targetVocab:size(1) or dict:size() 125 | local aggSoftmax = torch.zeros(bbsz, vocabsize):type(self:type()) 126 | local aggAttnScores = torch.zeros(bbsz, sourceLen):type(self:type()) 127 | -- We do maxlen + 1 steps to give model a chance to predict EOS 128 | for step = 1, maxlen + 1 do 129 | timers.decoder:resume() 130 | aggSoftmax:zero() 131 | aggAttnScores:zero() 132 | for i = 1, #self.models do 133 | local softmax = callbacks[i].decode(states[i], targetIns[i]) 134 | aggSoftmax:add(softmax) 135 | if callbacks[i].attention then 136 | aggAttnScores:add(callbacks[i].attention(states[i])) 137 | end 138 | end 139 | -- Average softmax and attention scores. 140 | aggSoftmax:div(#self.models) 141 | aggAttnScores:div(#self.models) 142 | if cuda.cutorch then 143 | cuda.cutorch.synchronize() 144 | end 145 | timers.decoder:stop() 146 | 147 | local aggLogSoftmax = aggSoftmax:log() 148 | self:updateMinMaxLenProb(aggLogSoftmax, dict, step, minlen, maxlen) 149 | 150 | timers.search_prune:resume() 151 | local pruned = search.prune(step, aggLogSoftmax, aggAttnScores) 152 | timers.search_prune:stop() 153 | 154 | for i = 1, #self.models do 155 | targetIns[i]:copy(pruned.nextIn) 156 | callbacks[i].update(states[i], pruned.nextHid) 157 | end 158 | 159 | if pruned.eos then 160 | break 161 | end 162 | end 163 | 164 | timers.search_results:resume() 165 | local results = table.pack(search.results()) 166 | -- This is pretty hacky, but basically we can't run finalize for 167 | -- the selection models many times, because it will remap ids many times 168 | -- TODO: refactor this 169 | callbacks[1].finalize(states[1], sample, results) 170 | timers.search_results:stop() 171 | 172 | local times = {} 173 | for k, v in pairs(timers) do 174 | times[k] = v:time() 175 | end 176 | table.insert(results, times) 177 | return table.unpack(results) 178 | end 179 | } 180 | 181 | EnsembleModel.extend = argcheck{ 182 | {name='self', type='EnsembleModel'}, 183 | {name='n', type='number'}, 184 | call = function(self, n) 185 | for _, model in ipairs(self.models) do 186 | model:extend(n) 187 | end 188 | end 189 | } 190 | 191 | return EnsembleModel 192 | -------------------------------------------------------------------------------- /fairseq/models/init.lua: -------------------------------------------------------------------------------- 1 | -- Copyright (c) 2017-present, Facebook, Inc. 2 | -- All rights reserved. 3 | -- 4 | -- This source code is licensed under the license found in the LICENSE file in 5 | -- the root directory of this source tree. An additional grant of patent rights 6 | -- can be found in the PATENTS file in the same directory. 7 | -- 8 | --[[ 9 | -- 10 | -- init files for the models. 11 | -- 12 | --]] 13 | 14 | require 'fairseq.models.model' 15 | require 'fairseq.models.avgpool_model' 16 | require 'fairseq.models.blstm_model' 17 | require 'fairseq.models.fconv_model' 18 | require 'fairseq.models.selection_blstm_model' 19 | require 'fairseq.models.conv_model' 20 | require 'fairseq.models.ensemble_model' 21 | -------------------------------------------------------------------------------- /fairseq/models/model.lua: -------------------------------------------------------------------------------- 1 | -- Copyright (c) 2017-present, Facebook, Inc. 2 | -- All rights reserved. 3 | -- 4 | -- This source code is licensed under the license found in the LICENSE file in 5 | -- the root directory of this source tree. An additional grant of patent rights 6 | -- can be found in the PATENTS file in the same directory. 7 | -- 8 | --[[ 9 | -- 10 | -- Base class for models, outlining the basic interface. 11 | -- 12 | --]] 13 | 14 | local argcheck = require 'argcheck' 15 | local utils = require 'fairseq.utils' 16 | 17 | local cuda = utils.loadCuda() 18 | 19 | local Model = torch.class('Model') 20 | 21 | Model.__init = argcheck{ 22 | doc=[[ 23 | Default constructor. This will construct a network by calling `make()`. 24 | ]], 25 | {name='self', type='Model'}, 26 | {name='config', type='table', opt=true}, 27 | call = function(self, config) 28 | self.module = self:make(config) 29 | end 30 | } 31 | 32 | Model.network = argcheck{ 33 | doc=[[ 34 | Returns the encapsulated nn.Module instance. 35 | ]], 36 | {name='self', type='Model'}, 37 | call = function(self) 38 | return self.module 39 | end 40 | } 41 | 42 | Model.type = argcheck{ 43 | doc=[[ 44 | Shorthand for network():type() 45 | ]], 46 | {name='self', type='Model'}, 47 | {name='type', type='string', opt=true}, 48 | {name='tensorCache', type='table', opt=true}, 49 | call = function(self, type, tensorCache) 50 | local ret = self.module:type(type, tensorCache) 51 | if ret == self.module then 52 | return self 53 | end 54 | return ret 55 | end 56 | } 57 | 58 | function Model:float(...) 59 | return self:type('torch.FloatTensor',...) 60 | end 61 | 62 | function Model:double(...) 63 | return self:type('torch.DoubleTensor',...) 64 | end 65 | 66 | function Model:cuda(...) 67 | return self:type('torch.CudaTensor',...) 68 | end 69 | 70 | Model.make = argcheck{ 71 | doc=[[ 72 | Constructs a new network as a nn.Module instance. 73 | ]], 74 | {name='self', type='Model'}, 75 | call = function(self, config) 76 | error('Implementation expected') 77 | end 78 | } 79 | 80 | Model.resizeCriterionWeights = argcheck{ 81 | doc=[[ 82 | Resize criterion weights to accomadate per sample target vocabulary. 83 | 84 | The default implementation is a no-op. 85 | ]], 86 | {name='self', type='Model'}, 87 | {name='criterion', type='nn.Criterion'}, 88 | {name='critweights', type='torch.CudaTensor'}, 89 | {name='sample', type='table'}, 90 | call = function(self, criterion, critweights, sample) 91 | end 92 | } 93 | 94 | Model.prepareSample = argcheck{ 95 | doc=[[ 96 | Returns a function that prepares a data sample inside a tochnet engine state. 97 | The nn.Module returned by `network()` is expected to be able to compute a 98 | forward pass on `state.sample.input`. 99 | 100 | The default implementation is a no-op. 101 | ]], 102 | {name='self', type='Model'}, 103 | call = function(self) 104 | return function(sample) 105 | end 106 | end 107 | } 108 | 109 | Model.generationCallbacks = argcheck{ 110 | doc=[[ 111 | Returns 5 callback functions to be used during the generation step: 112 | - `setup` will be called before the generation step in order 113 | to prepare a source data sample, extract the attention scores and 114 | initialize the decoder hidden state 115 | - `encode` will be called before the generation step to run 116 | the encoder forward pass and duplicate the output for each beam hypotheses 117 | - `decode` will be called at each step of the generation, 118 | it performs the decoder forward pass, and then applies nn.SoftMax 119 | - `attention` will be called at each step of the generation to 120 | aquire attention scores 121 | - `update` will be called at each step of the generation to 122 | update the decoder hidden state 123 | ]], 124 | {name='self', type='Model'}, 125 | {name='config', type='table'}, 126 | {name='bsz', type='number'}, 127 | call = function(self, config, bsz) 128 | return { 129 | setup = self:generationSetup(config, bsz), 130 | encode = self:generationEncode(config, bsz), 131 | decode = self:generationDecode(config, bsz), 132 | attention = self:generationAttention(config, bsz), 133 | update = self:generationUpdate(config, bsz), 134 | finalize = self:generationFinalize(config), 135 | } 136 | end 137 | } 138 | 139 | Model.generationSetup = argcheck{ 140 | doc=[[ 141 | Returns a function that does some preparation before the generation step. 142 | It converts a data sample into required format, extracts the attention 143 | scores node from the graph and initializes hidden state for the decoder. 144 | ]], 145 | {name='self', type='Model'}, 146 | {name='config', type='table'}, 147 | {name='bsz', type='number'}, 148 | call = function(self, config, bsz) 149 | error('Implementation expected') 150 | end 151 | } 152 | 153 | Model.generationEncode = argcheck{ 154 | doc=[[ 155 | Returns a function that performs the encoder forward pass. 156 | After that it duplicates the encoder output for each 157 | beam hypotheses. 158 | ]], 159 | {name='self', type='Model'}, 160 | {name='config', type='table'}, 161 | {name='bsz', type='number'}, 162 | call = function(self, config, bsz) 163 | error('Implementation expected') 164 | end 165 | } 166 | 167 | Model.generationDecode = argcheck{ 168 | doc=[[ 169 | Returns a function that performs the decoder forward pass, and then 170 | applies nn.SoftMax on the result. 171 | ]], 172 | {name='self', type='Model'}, 173 | {name='config', type='table'}, 174 | {name='bsz', type='number'}, 175 | call = function(self, config, bsz) 176 | error('Implementation expected') 177 | end 178 | } 179 | 180 | Model.generationAttention = argcheck{ 181 | doc=[[ 182 | Returns a function that returns attention scores over the source sentences. 183 | Called after the decode callback. This function can be nil. 184 | ]], 185 | {name='self', type='Model'}, 186 | {name='config', type='table'}, 187 | {name='bsz', type='number'}, 188 | call = function(self, config, bsz) 189 | return nil 190 | end 191 | } 192 | 193 | Model.generationUpdate = argcheck{ 194 | doc=[[ 195 | Returns a function that updates the decoder hidden state. 196 | ]], 197 | {name='self', type='Model'}, 198 | {name='config', type='table'}, 199 | {name='bsz', type='number'}, 200 | call = function(self, config, bsz) 201 | error('Implementation expected') 202 | end 203 | } 204 | 205 | Model.generationFinalize = argcheck{ 206 | doc=[[ 207 | Returns a function that finalizes generation by performing some transformations. 208 | ]], 209 | {name='self', type='Model'}, 210 | {name='config', type='table'}, 211 | call = function(self, config) 212 | return function(state, sample, results) 213 | -- Do nothing 214 | end 215 | end 216 | } 217 | 218 | Model.generate = argcheck{ 219 | doc=[[ 220 | Sentence generation. See search.lua for a description of search functions. 221 | ]], 222 | {name='self', type='Model'}, 223 | {name='config', type='table'}, 224 | {name='sample', type='table'}, 225 | {name='search', type='table'}, 226 | call = function(self, config, sample, search) 227 | local dict = config.dict 228 | local minlen = config.minlen 229 | local maxlen = config.maxlen 230 | local bsz = sample.source:size(2) 231 | local bbsz = config.beam * bsz 232 | local callbacks = self:generationCallbacks(config, bsz) 233 | 234 | local timers = { 235 | setup = torch.Timer(), 236 | encoder = torch.Timer(), 237 | decoder = torch.Timer(), 238 | search_prune = torch.Timer(), 239 | search_results = torch.Timer(), 240 | } 241 | 242 | for k, v in pairs(timers) do 243 | v:stop() 244 | v:reset() 245 | end 246 | 247 | timers.setup:resume() 248 | local state = callbacks.setup(sample) 249 | if cuda.cutorch then 250 | cuda.cutorch.synchronize() 251 | end 252 | timers.setup:stop() 253 | 254 | timers.encoder:resume() 255 | callbacks.encode(state) 256 | timers.encoder:stop() 257 | 258 | -- <eos> is used as a start-of-sentence marker 259 | local targetIn = torch.Tensor(bbsz):type(self:type()) 260 | targetIn:fill(dict:getEosIndex()) 261 | local sourceLen = sample.source:size(1) 262 | local attnscores = torch.zeros(bbsz, sourceLen):type(self:type()) 263 | 264 | search.init(bsz, sample) 265 | -- We do maxlen + 1 steps to give model a chance to 266 | -- predict EOS 267 | for step = 1, maxlen + 1 do 268 | timers.decoder:resume() 269 | local softmax = callbacks.decode(state, targetIn) 270 | local logsoftmax = softmax:log() 271 | if cuda.cutorch then 272 | cuda.cutorch.synchronize() 273 | end 274 | timers.decoder:stop() 275 | 276 | if callbacks.attention then 277 | attnscores:copy(callbacks.attention(state)) 278 | end 279 | 280 | self:updateMinMaxLenProb(logsoftmax, dict, step, minlen, maxlen) 281 | 282 | timers.search_prune:resume() 283 | local pruned = search.prune(step, logsoftmax, attnscores) 284 | targetIn:copy(pruned.nextIn) 285 | callbacks.update(state, pruned.nextHid) 286 | timers.search_prune:stop() 287 | 288 | if pruned.eos then 289 | break 290 | end 291 | end 292 | 293 | timers.search_results:resume() 294 | local results = table.pack(search.results()) 295 | callbacks.finalize(state, sample, results) 296 | timers.search_results:stop() 297 | 298 | local times = {} 299 | for k, v in pairs(timers) do 300 | times[k] = v:time() 301 | end 302 | table.insert(results, times) 303 | return table.unpack(results) 304 | end 305 | } 306 | 307 | Model.extend = argcheck{ 308 | doc=[[ 309 | Ensures that recurrent parts of the model are unrolled for a given 310 | number of time-steps. 311 | ]], 312 | {name='self', type='Model'}, 313 | {name='n', type='number'}, 314 | call = function(self, n) 315 | self:network():apply(function(module) 316 | if torch.isTypeOf(module, 'nn.Recurrent') then 317 | module:extend(n) 318 | elseif torch.isTypeOf(module, 'nn.MapTable') then 319 | module:resize(n) 320 | end 321 | end) 322 | end 323 | } 324 | 325 | function Model:updateMinMaxLenProb(ldist, dict, step, minlen, maxlen) 326 | local eos = dict:getEosIndex() 327 | -- Up until we reach minlen, EOS should never be selected 328 | -- Here we make the probability of chosing EOS -inf 329 | if step <= minlen then 330 | ldist:narrow(2, eos, 1):fill(-math.huge) 331 | end 332 | 333 | -- After reaching maxlen, we need to make sure EOS is selected 334 | -- so, we make probabilities of everything else -inf 335 | if step > maxlen then 336 | local eos = dict:getEosIndex() 337 | local vocabsize = ldist:size(2) 338 | ldist:narrow(2, 1, eos - 1):fill(-math.huge) 339 | ldist:narrow(2, eos + 1, vocabsize - eos):fill(-math.huge) 340 | end 341 | end 342 | 343 | return Model 344 | -------------------------------------------------------------------------------- /fairseq/models/selection_blstm_model.lua: -------------------------------------------------------------------------------- 1 | -- Copyright (c) 2017-present, Facebook, Inc. 2 | -- All rights reserved. 3 | -- 4 | -- This source code is licensed under the license found in the LICENSE file in 5 | -- the root directory of this source tree. An additional grant of patent rights 6 | -- can be found in the PATENTS file in the same directory. 7 | -- 8 | --[[ 9 | -- 10 | -- The BLSTM model that uses words alignment to reduce the target 11 | -- vocabulary size. 12 | -- 13 | --]] 14 | 15 | require 'nn' 16 | require 'rnnlib' 17 | local argcheck = require 'argcheck' 18 | local mutils = require 'fairseq.models.utils' 19 | 20 | local SelectionBLSTMModel = torch.class('SelectionBLSTMModel', 'BLSTMModel') 21 | 22 | SelectionBLSTMModel.make = argcheck{ 23 | {name='self', type='SelectionBLSTMModel'}, 24 | {name='config', type='table'}, 25 | call = function(self, config) 26 | local encoder = self:makeEncoder(config) 27 | local decoder = self:makeDecoder(config) 28 | 29 | -- Wire up encoder and decoder 30 | local input = nn.Identity()() 31 | local prevhIn, targetIn, targetVocab, sourceIn = input:split(4) 32 | local output = decoder({ 33 | prevhIn, 34 | targetIn, 35 | targetVocab, 36 | encoder(sourceIn):annotate{name = 'encoder'}, 37 | }):annotate{name = 'decoder'} 38 | 39 | return nn.gModule({input}, {output}) 40 | end 41 | } 42 | 43 | SelectionBLSTMModel.makeDecoder = argcheck{ 44 | doc=[[ 45 | Constructs a conditional LSTM decoder with soft attention. 46 | It also takes an additional input targetVocab to reduce the 47 | target vocabulary size. 48 | ]], 49 | {name='self', type='SelectionBLSTMModel'}, 50 | {name='config', type='table'}, 51 | call = function(self, config) 52 | local input = nn.Identity()() 53 | local prevhIn, targetIn, targetVocab, encoderOut = input:split(4) 54 | local decoderRNNOut = self:makeDecoderRNN( 55 | config, prevhIn, targetIn, encoderOut) 56 | local output = mutils.makeTargetMappingWithSelection( 57 | config, config.dict:size(), decoderRNNOut, targetVocab) 58 | return nn.gModule({input}, {output}) 59 | end 60 | } 61 | 62 | SelectionBLSTMModel.resizeCriterionWeights = argcheck{ 63 | {name='self', type='SelectionBLSTMModel'}, 64 | {name='criterion', type='nn.Criterion'}, 65 | {name='critweights', type='torch.CudaTensor'}, 66 | {name='sample', type='table'}, 67 | call = function(self, criterion, critweights, sample) 68 | local size = sample.targetVocab:size(1) 69 | -- Resize criterion weights to match target vocab size 70 | -- Note: we only use special weights (different from 1.0) 71 | -- for just a few symbols (like pad), and also we guarantee 72 | -- that those symbols will have the same ids from batch to batch. 73 | -- Thus we don't have to remap anything here. 74 | criterion.nll.weights = critweights:narrow(1, 1, size) 75 | end 76 | } 77 | 78 | SelectionBLSTMModel.prepareSample = argcheck{ 79 | {name='self', type='SelectionBLSTMModel'}, 80 | call = function(self) 81 | local buffers = { 82 | targetVocab = torch.Tensor():type(self:type()), 83 | } 84 | 85 | local prepareSource = self:prepareSource() 86 | local prepareHidden = self:prepareHidden() 87 | local prepareInput = self:prepareInput() 88 | local prepareTarget = self:prepareTarget() 89 | 90 | return function(sample) 91 | local source = prepareSource(sample) 92 | local hid = prepareHidden(sample) 93 | local input = prepareInput(sample) 94 | local target = prepareTarget(sample) 95 | local targetVocab = mutils.sendtobuf( 96 | sample.targetVocab, buffers.targetVocab) 97 | 98 | sample.target = target 99 | sample.input = {hid, input, targetVocab, source} 100 | end 101 | end 102 | } 103 | 104 | SelectionBLSTMModel.generationSetup = argcheck{ 105 | {name='self', type='SelectionBLSTMModel'}, 106 | {name='config', type='table'}, 107 | {name='bsz', type='number'}, 108 | call = function(self, config, bsz) 109 | local beam = config.beam 110 | local bbsz = beam * bsz 111 | local m = self:network() 112 | local prepareSource = self:prepareSource() 113 | local decoderRNN = mutils.findAnnotatedNode(m, 'decoderRNN') 114 | assert(decoderRNN ~= nil) 115 | local targetVocabBuffer = torch.Tensor():type(self:type()) 116 | 117 | return function(sample) 118 | m:evaluate() 119 | 120 | local state = { 121 | remapFn = function(idx) return sample.targetVocab[idx] end, 122 | sourceIn = prepareSource(sample), 123 | prevhIn = decoderRNN:initializeHidden(bbsz), 124 | targetVocab = mutils.sendtobuf(sample.targetVocab, 125 | targetVocabBuffer), 126 | } 127 | return state 128 | end 129 | end 130 | } 131 | 132 | SelectionBLSTMModel.generationDecode = argcheck{ 133 | {name='self', type='SelectionBLSTMModel'}, 134 | {name='config', type='table'}, 135 | {name='bsz', type='number'}, 136 | call = function(self, config, bsz) 137 | local softmax = nn.SoftMax():type(self:type()) 138 | local m = self:network() 139 | local decoder = mutils.findAnnotatedNode(m, 'decoder') 140 | return function(state, targetIn) 141 | targetIn:apply(state.remapFn) 142 | local out = decoder:forward({ 143 | state.prevhIn, {targetIn}, state.targetVocab, state.encoderOut}) 144 | return softmax:forward(out) 145 | end 146 | end 147 | } 148 | 149 | SelectionBLSTMModel.generationFinalize = argcheck{ 150 | {name='self', type='SelectionBLSTMModel'}, 151 | {name='config', type='table'}, 152 | call = function(self, config) 153 | return function(state, sample, results) 154 | local hypos, _, _ = unpack(results) 155 | for _, h in ipairs(hypos) do 156 | h:apply(state.remapFn) 157 | end 158 | sample.target:apply(state.remapFn) 159 | end 160 | end 161 | } 162 | 163 | return SelectionBLSTMModel 164 | -------------------------------------------------------------------------------- /fairseq/modules/AppendBias.lua: -------------------------------------------------------------------------------- 1 | -- Copyright (c) 2017-present, Facebook, Inc. 2 | -- All rights reserved. 3 | -- 4 | -- This source code is licensed under the license found in the LICENSE file in 5 | -- the root directory of this source tree. An additional grant of patent rights 6 | -- can be found in the PATENTS file in the same directory. 7 | -- 8 | --[[ 9 | -- 10 | -- The module takes a tensor and extends the last dimension by 1, 11 | -- filling new elements with 1: 12 | -- 13 | -- Input: 14 | -- x_11 x_12 x_13 15 | -- x_21 x_21 x_22 16 | -- 17 | -- Output: 18 | -- x_11 x_12 x_13 x_14(=1) 19 | -- x_21 x_22 x_23 x_24(=1) 20 | -- 21 | --]] 22 | 23 | local AppendBias, _ = torch.class('nn.AppendBias', 'nn.Module') 24 | 25 | function AppendBias:updateOutput(input) 26 | local dim = input:dim() 27 | local size = input:size() 28 | size[dim] = size[dim] + 1 29 | self.output:resize(size) 30 | -- copy input 31 | self.output:narrow(dim, 1, size[dim] - 1):copy(input) 32 | -- fill new elements with 1 33 | self.output:select(dim, size[dim]):fill(1) 34 | return self.output 35 | end 36 | 37 | function AppendBias:updateGradInput(input, gradOutput) 38 | local dim = input:dim() 39 | local size = input:size() 40 | self.gradInput:resize(size) 41 | -- don't copy added elements 42 | self.gradInput:copy(gradOutput:narrow(dim, 1, size[dim])) 43 | return self.gradInput 44 | end 45 | -------------------------------------------------------------------------------- /fairseq/modules/BeamableMM.lua: -------------------------------------------------------------------------------- 1 | -- Copyright (c) 2017-present, Facebook, Inc. 2 | -- All rights reserved. 3 | -- 4 | -- This source code is licensed under the license found in the LICENSE file in 5 | -- the root directory of this source tree. An additional grant of patent rights 6 | -- can be found in the PATENTS file in the same directory. 7 | -- 8 | --[[ 9 | -- 10 | -- This module specialize MM for beam decoding by attention modules. 11 | -- It leverage the fact that the source-side of the input is replicated beam 12 | -- times and that the target-side of the input is of width one. This layer speed 13 | -- up inference by replacing the inputs {(bsz x 1 x nhu), (bsz x sz2 x nhu)} 14 | -- with smaller inputs {(bsz/beam x beam x nhu), (bsz/beam x sz2 x nhu)} 15 | -- 16 | --]] 17 | 18 | local BeamableMM, parent = torch.class('nn.BeamableMM', 'nn.MM') 19 | 20 | function BeamableMM:__init(...) 21 | parent.__init(self, ...) 22 | self.beam = 0 23 | end 24 | 25 | function BeamableMM:updateOutput(input) 26 | if not(self.train == false) -- test mode 27 | and (self.beam > 0) -- beam size is set 28 | and (input[1]:dim() == 3) -- only support batched inputs 29 | and (input[1]:size(2) == 1) -- single time step update 30 | then 31 | local bsz, beam = input[1]:size(1), self.beam 32 | 33 | -- bsz x 1 x nhu --> bsz/beam x beam x nhu 34 | local in1 = input[1]:select(2, 1):unfold(1, beam, beam):transpose(3, 2) 35 | -- bsz x sz2 x nhu --> bsz/beam x sz2 x nhu 36 | local in2 = input[2]:unfold(1, beam, beam):select(4, 1) 37 | -- use non batched operation if bsz = beam 38 | if in1:size(1) == 1 then in1, in2 = in1[1], in2[1] end 39 | 40 | -- forward and restore correct size 41 | parent.updateOutput(self, {in1, in2}) 42 | self.output = self.output:view(bsz, 1, -1) 43 | return self.output 44 | 45 | else 46 | return parent.updateOutput(self, input) 47 | end 48 | end 49 | 50 | function BeamableMM:setBeamSize(beam) 51 | self.beam = beam or 0 52 | end 53 | -------------------------------------------------------------------------------- /fairseq/modules/CAddTableMulConstant.lua: -------------------------------------------------------------------------------- 1 | -- Copyright (c) 2017-present, Facebook, Inc. 2 | -- All rights reserved. 3 | -- 4 | -- This source code is licensed under the license found in the LICENSE file in 5 | -- the root directory of this source tree. An additional grant of patent rights 6 | -- can be found in the PATENTS file in the same directory. 7 | -- 8 | --[[ 9 | -- 10 | -- CAddTable that scales the output by a constant 11 | -- 12 | --]] 13 | 14 | local CAddTableMulConstant, parent = torch.class('nn.CAddTableMulConstant', 'nn.CAddTable') 15 | 16 | function CAddTableMulConstant:__init(constant_scalar) 17 | parent.__init(self) 18 | self.constant_scalar = constant_scalar 19 | end 20 | 21 | function CAddTableMulConstant:updateOutput(input) 22 | parent.updateOutput(self, input) 23 | self.output:mul(self.constant_scalar) 24 | return self.output 25 | end 26 | 27 | function CAddTableMulConstant:updateGradInput(input, gradOutput) 28 | parent.updateGradInput(self, input, gradOutput) 29 | for i=1,#self.gradInput do 30 | self.gradInput[i]:mul(self.constant_scalar) 31 | end 32 | return self.gradInput 33 | end 34 | -------------------------------------------------------------------------------- /fairseq/modules/CLSTM.lua: -------------------------------------------------------------------------------- 1 | -- Copyright (c) 2017-present, Facebook, Inc. 2 | -- All rights reserved. 3 | -- 4 | -- This source code is licensed under the license found in the LICENSE file in 5 | -- the root directory of this source tree. An additional grant of patent rights 6 | -- can be found in the PATENTS file in the same directory. 7 | -- 8 | --[[ 9 | -- 10 | -- A conditional LSTM network. 11 | -- 12 | --]] 13 | 14 | require 'nn' 15 | 16 | local argcheck = require 'argcheck' 17 | local rmutils = require 'rnnlib.mutils' 18 | local rnnlib = require 'rnnlib.env' 19 | 20 | rnnlib.cell.CLSTM = function(nin, nhid, attention) 21 | local makeVanilla, initVanilla = rnnlib.cell.LSTM(nin * 2, nhid) 22 | 23 | local make = function(prevch, input) 24 | local input, cond = input:split(2) 25 | local prevc, prevh = prevch:split(2) 26 | 27 | return makeVanilla( 28 | nn.Identity()({prevc, prevh}), 29 | nn.JoinTable(1, 1)({ 30 | input, 31 | attention({input, prevh, cond}), 32 | }) 33 | ) 34 | end 35 | 36 | return make, initVanilla 37 | end 38 | 39 | local function addDropoutToInput(make, init, prob) 40 | return function(state, input) 41 | return make(state, nn.Dropout(prob)(input)) 42 | end, init 43 | end 44 | 45 | nn.CLSTM = argcheck{ 46 | { name = "inputsize" , type = "number" , }, 47 | { name = "hidsize" , type = "number" , }, 48 | { name = "nlayer" , type = "number" , }, 49 | { name = "attention" , type = "nn.Module" , }, 50 | { name = "hinitfun" , type = "function" , opt = true }, 51 | { name = "winitfun" , type = "function" , default = rmutils.defwinitfun }, 52 | { name = "savehidden" , type = "boolean" , default = true }, 53 | { name = "dropout" , type = "number" , default = 0 }, 54 | { name = "usecudnn" , type = "boolean" , default = false }, 55 | call = function(inputsize, hidsize, nlayer, attention, hinitfun, winitfun, 56 | savehidden, dropout, usecudnn) 57 | 58 | local modules, initfs = {}, {} 59 | local c, f = rnnlib.cell.CLSTM(inputsize, hidsize, attention) 60 | modules[1] = nn.RecurrentTable{dim = 2, module = rnnlib.cell.gModule(c)} 61 | initfs[1] = f 62 | 63 | if usecudnn and nlayer > 1 then 64 | modules[2] = nn.CudnnRnnTable{ 65 | module = cudnn.LSTM(hidsize, hidsize, nlayer - 1, false, 66 | dropout), 67 | inputsize = hidsize, 68 | dropoutin = dropout, 69 | } 70 | initfs[2] = modules[2]:makeInitializeHidden() 71 | else 72 | for i = 2, nlayer do 73 | local c, f = rnnlib.cell.LSTM(hidsize, hidsize) 74 | if dropout > 0 then 75 | c, f = addDropoutToInput(c, f, dropout) 76 | end 77 | modules[i] = nn.RecurrentTable{ 78 | dim = 2, 79 | module = rnnlib.cell.gModule(c) 80 | } 81 | initfs[i] = f 82 | end 83 | end 84 | 85 | local network = nn.SequenceTable{ 86 | dim = 1, 87 | modules = modules, 88 | } 89 | 90 | local rnn = rnnlib.setupRecurrent{ 91 | network = network, 92 | initfs = initfs, 93 | hinitfun = hinitfun, 94 | winitfun = winitfun, 95 | savehidden = savehidden, 96 | } 97 | 98 | rnn.getLastHidden = function(self) 99 | local hids = {} 100 | for i = 1, #self.modules do 101 | if self.modules[i].getLastHidden then 102 | hids[i] = self.modules[i]:getLastHidden() 103 | else 104 | local hout = self.output[1] 105 | hids[i] = hout[i][#hout[i]] 106 | end 107 | end 108 | return hids 109 | end 110 | 111 | return rnn 112 | end 113 | } 114 | -------------------------------------------------------------------------------- /fairseq/modules/CudnnRnnTable.lua: -------------------------------------------------------------------------------- 1 | -- Copyright (c) 2017-present, Facebook, Inc. 2 | -- All rights reserved. 3 | -- 4 | -- This source code is licensed under the license found in the LICENSE file in 5 | -- the root directory of this source tree. An additional grant of patent rights 6 | -- can be found in the PATENTS file in the same directory. 7 | -- 8 | --[[ 9 | -- 10 | -- Wraps a Cudnn.RNN so that the API is the same as Rnnlib's. 11 | -- 12 | --]] 13 | 14 | require 'nn' 15 | 16 | local argcheck = require 'argcheck' 17 | local _, cutils = pcall(require, 'rnnlib.cudnnutils') 18 | 19 | local CudnnRnnTable, parent = torch.class('nn.CudnnRnnTable', 'nn.Sequential') 20 | 21 | CudnnRnnTable.__init = argcheck{ 22 | { name = 'self' , type = 'nn.CudnnRnnTable' }, 23 | { name = 'module' , type = 'nn.Module' }, 24 | { name = 'inputsize' , type = 'number' }, 25 | { name = 'dropoutin' , type = 'number', default = 0}, 26 | call = function(self, module, inputsize, dropoutin) 27 | parent.__init(self) 28 | 29 | -- This joins the across the table dimension of the input 30 | -- (which has dimension {bptt} x bsz x emsize) 31 | -- to create a tensor of dimension bptt x bsz x emsize. 32 | self 33 | :add(nn.MapTable(nn.View(1, -1, inputsize))) 34 | :add(nn.JoinTable(1)) 35 | if dropoutin > 0 then 36 | self:add(nn.Dropout(dropoutin)) 37 | end 38 | self 39 | :add(module) 40 | :add(nn.SplitTable(1)) 41 | 42 | self.rnn = module 43 | self.output = {} 44 | self.gradInput = {} 45 | end 46 | } 47 | 48 | CudnnRnnTable.__init = argcheck{ 49 | { name = 'self' , type = 'nn.CudnnRnnTable' }, 50 | { name = 'model' , type = 'nn.SequenceTable' }, 51 | { name = 'cellstring' , type = 'string' }, 52 | { name = 'inputsize' , type = 'number' }, 53 | { name = 'hiddensize' , type = 'number' }, 54 | { name = 'nlayer' , type = 'number' }, 55 | { name = 'dropoutin' , type = 'number', default = 0}, 56 | overload = CudnnRnnTable.__init, 57 | call = function(self, model, cellstring, inputsize, hiddensize, nlayer, 58 | dropoutin) 59 | local oldparams = model:parameters() 60 | local rnn = cudnn[cellstring](inputsize, hiddensize, nlayer) 61 | for l = 1, nlayer do 62 | cutils.copyParams( 63 | rnn, cutils.offsets[cellstring], 64 | oldparams[2*l-1], oldparams[2*l], 65 | hiddensize, l 66 | ) 67 | end 68 | return self.__init(self, rnn, inputsize, dropoutin) 69 | end, 70 | } 71 | 72 | CudnnRnnTable.updateOutput = function(self, input) 73 | local module = self.rnn 74 | 75 | local hidinput = input[1] 76 | local seqinput = input[2] 77 | if module.mode:find('LSTM') then 78 | module.cellInput = hidinput[1] 79 | module.hiddenInput = hidinput[2] 80 | else 81 | module.hiddenInput = hidinput 82 | end 83 | 84 | local seqoutput = parent.updateOutput(self, seqinput) 85 | local hidoutput 86 | if module.mode:find('LSTM') then 87 | hidoutput = self .hidoutput or {} 88 | hidoutput[1] = module.cellOutput 89 | hidoutput[2] = module.hiddenOutput 90 | else 91 | hidoutput = module.hiddenOutput 92 | end 93 | self.hidoutput = hidoutput 94 | 95 | self.output = { hidoutput, seqoutput } 96 | return self.output 97 | end 98 | 99 | CudnnRnnTable.updateGradInput = function(self, input, gradOutput) 100 | local module = self.rnn 101 | 102 | local seqinput = input[2] 103 | local hidgradoutput = gradOutput[1] 104 | 105 | local seqgradoutput = gradOutput[2] 106 | if module.mode:find('LSTM') then 107 | module.gradCellOutput = hidgradoutput[1] 108 | module.gradHiddenOutput = hidgradoutput[2] 109 | else 110 | module.gradHiddenOutput = hidgradoutput 111 | end 112 | 113 | local seqgradinput = parent.updateGradInput(self, seqinput, seqgradoutput) 114 | local hidgradinput 115 | if module.mode:find('LSTM') then 116 | hidgradinput = self .hidgradinput or {} 117 | hidgradinput[1] = module.gradCellInput 118 | hidgradinput[2] = module.gradHiddenInput 119 | else 120 | hidgradinput = module.gradHiddenInput 121 | end 122 | self.hidgradinput = hidgradinput 123 | 124 | self.gradInput = { hidgradinput, seqgradinput } 125 | return self.gradInput 126 | end 127 | 128 | CudnnRnnTable.accGradParameters = function(self, input, gradOutput, scale) 129 | local module = self.rnn 130 | 131 | local seqinput = input[2] 132 | local hidgradoutput = gradOutput[1] 133 | local seqgradoutput = gradOutput[2] 134 | if module.mode:find('LSTM') then 135 | module.gradCellOutput = hidgradoutput[1] 136 | module.gradHiddenOutput = hidgradoutput[2] 137 | else 138 | module.gradHiddenOutput = hidgradoutput 139 | end 140 | 141 | parent.accGradParameters(self, seqinput, seqgradoutput, scale) 142 | -- Zero out gradBias to conform with Rnnlib standard which does 143 | -- not use biases in linear projections. 144 | cutils.zeroField(module, 'gradBias') 145 | end 146 | 147 | -- | The backward must be overloaded because nn.Sequential's backward does not 148 | -- actually call updateGradInput or accGradParameters. 149 | CudnnRnnTable.backward = function(self, input, gradOutput, scale) 150 | local gradInput = self:updateGradInput(input, gradOutput) 151 | self:accGradParameters(input, gradOutput, scale) 152 | return gradInput 153 | end 154 | 155 | -- | Get the last hidden state. 156 | CudnnRnnTable.getLastHidden = function(self) 157 | return self.hidoutput 158 | end 159 | 160 | CudnnRnnTable.makeInitializeHidden = function(self) 161 | local module = self.rnn 162 | return function(bsz, t, cache) 163 | local dim = { 164 | module.numLayers, 165 | bsz, 166 | module.hiddenSize, 167 | } 168 | if module.mode:find('LSTM') then 169 | cache = cache 170 | or { 171 | torch.CudaTensor(), 172 | torch.CudaTensor(), 173 | } 174 | cache[1]:resize(table.unpack(dim)):fill(0) 175 | cache[2]:resize(table.unpack(dim)):fill(0) 176 | else 177 | cache = cache or torch.CudaTensor() 178 | cache :resize(table.unpack(dim)):fill(0) 179 | end 180 | return cache 181 | end 182 | end 183 | -------------------------------------------------------------------------------- /fairseq/modules/GradMultiply.lua: -------------------------------------------------------------------------------- 1 | -- Copyright (c) 2017-present, Facebook, Inc. 2 | -- All rights reserved. 3 | -- 4 | -- This source code is licensed under the license found in the LICENSE file in 5 | -- the root directory of this source tree. An additional grant of patent rights 6 | -- can be found in the PATENTS file in the same directory. 7 | -- 8 | --[[ 9 | -- 10 | -- A container that simply sets the scaling factor for accGradParameters 11 | -- 12 | --]] 13 | 14 | local GradMultiply, parent = torch.class('nn.GradMultiply', 'nn.Container') 15 | 16 | function GradMultiply:__init(module, factor) 17 | parent.__init(self) 18 | self.modules[1] = module 19 | self.factor = factor 20 | end 21 | 22 | function GradMultiply:updateOutput(input) 23 | return self.modules[1]:updateOutput(input) 24 | end 25 | 26 | function GradMultiply:updateGradInput(input, gradOutput) 27 | return self.modules[1]:updateGradInput(input, gradOutput) 28 | end 29 | 30 | function GradMultiply:accGradParameters(input, gradOutput, scale) 31 | scale = scale or 1 32 | return self.modules[1]:accGradParameters(input, gradOutput, 33 | scale * self.factor) 34 | end 35 | 36 | function GradMultiply:accUpdateGradParameters(input, gradOutput, lr) 37 | return self.modules[1]:accUpdateGradParameters(input, gradOutput, 38 | lr * self.factor) 39 | end 40 | -------------------------------------------------------------------------------- /fairseq/modules/LinearizedConvolution.lua: -------------------------------------------------------------------------------- 1 | -- Copyright (c) 2017-present, Facebook, Inc. 2 | -- All rights reserved. 3 | -- 4 | -- This source code is licensed under the license found in the LICENSE file in 5 | -- the root directory of this source tree. An additional grant of patent rights 6 | -- can be found in the PATENTS file in the same directory. 7 | -- 8 | --[[ 9 | -- 10 | -- This module allows to perform temporal convolution one time step at a time. 11 | -- It maintains an internal state to buffer signal and accept a single frame 12 | -- as input. This module is for forward evaluation **only** and does not support 13 | -- backpropagation 14 | -- 15 | --]] 16 | 17 | local LinearizedConvolution, parent = 18 | torch.class('nn.LinearizedConvolution', 'nn.Linear') 19 | 20 | function LinearizedConvolution:__init() 21 | parent.__init(self, 1, 1) 22 | self.kw = 0 23 | self.weight = self.output.new() 24 | self.bias = self.output.new() 25 | self.inputBuffer = self.output.new() 26 | self.tmpBuffer = self.output.new() 27 | end 28 | 29 | function LinearizedConvolution:updateOutput(input) 30 | assert(input:dim() == 3, 'only support batched inputs') 31 | local bsz = input:size(1) 32 | local buf = input 33 | 34 | if self.kw > 1 then 35 | buf = self.inputBuffer 36 | if buf:dim() == 0 then 37 | buf:resize(bsz, self.kw, input:size(3)):zero() 38 | end 39 | buf:select(2, buf:size(2)):copy(input:select(2, input:size(2))) 40 | end 41 | 42 | parent.updateOutput(self, buf:view(bsz, -1)) 43 | self.output = self.output:view(bsz, 1, -1) 44 | return self.output 45 | end 46 | 47 | function LinearizedConvolution:resetState() 48 | self.inputBuffer:resize(0) 49 | end 50 | 51 | function LinearizedConvolution:clearState() 52 | return nn.utils.clear(self, 'output', 'gradInput', 'tmp', 'inputBuffer') 53 | end 54 | 55 | function LinearizedConvolution:shiftState(reorder) 56 | if self.kw > 1 then 57 | local buf = self.inputBuffer 58 | local dst = buf:narrow(2, 1, self.kw - 1) 59 | local src = buf:narrow(2, 2, self.kw - 1) 60 | local tmp = self.tmpBuffer 61 | tmp:resizeAs(src):copy(src) 62 | dst:copy(reorder and tmp:index(1, reorder) or tmp) 63 | end 64 | end 65 | 66 | function LinearizedConvolution:setParameters(weight, bias) 67 | -- weights should be nout x kw x nin 68 | local nout, kw, nin = weight:size(1), weight:size(2), weight:size(3) 69 | self.kw = kw 70 | self.weight:resize(nout, kw * nin):copy(weight) 71 | self.bias:resize(bias:size()):copy(bias) 72 | end 73 | 74 | function LinearizedConvolution:updateGradInput(input, gradOutput) 75 | error 'Not supported' 76 | end 77 | 78 | function LinearizedConvolution:accGradParameters(input, gradOutput, scale) 79 | error 'Not supported' 80 | end 81 | 82 | function LinearizedConvolution:zeroGradParameters() 83 | error 'Not supported' 84 | end 85 | 86 | function LinearizedConvolution:updateParameters(lr) 87 | error 'Not supported' 88 | end 89 | -------------------------------------------------------------------------------- /fairseq/modules/SeqMultiply.lua: -------------------------------------------------------------------------------- 1 | -- Copyright (c) 2017-present, Facebook, Inc. 2 | -- All rights reserved. 3 | -- 4 | -- This source code is licensed under the license found in the LICENSE file in 5 | -- the root directory of this source tree. An additional grant of patent rights 6 | -- can be found in the PATENTS file in the same directory. 7 | -- 8 | --[[ 9 | -- 10 | -- Like MulConstant, but the factor is slen * sqrt(1 / slen). slen is the 11 | -- current sequence length (determined by size(2) of the second input element). 12 | -- 13 | --]] 14 | 15 | local SeqMultiply, parent = torch.class('nn.SeqMultiply', 'nn.Module') 16 | 17 | function SeqMultiply:__init() 18 | parent.__init(self) 19 | self.scale = 1 20 | end 21 | 22 | SeqMultiply.updateOutput = function(self, input) 23 | local slen = input[2]:size(2) 24 | self.scale = slen * math.sqrt(1 / slen) 25 | self.output:resizeAs(input[1]) 26 | self.output:copy(input[1]) 27 | self.output:mul(self.scale) 28 | return self.output 29 | end 30 | 31 | SeqMultiply.updateGradInput = function(self, input, gradOutput) 32 | self.zeroGrads = self.zeroGrads or input[2].new() 33 | self.zeroGrads:resizeAs(input[2]):zero() 34 | self.grads = self.grads or input[1].new() 35 | self.grads:resizeAs(gradOutput) 36 | self.grads:copy(gradOutput) 37 | self.grads:mul(self.scale) 38 | self.gradInput = {self.grads, self.zeroGrads} 39 | return self.gradInput 40 | end 41 | -------------------------------------------------------------------------------- /fairseq/modules/TrainTestLayer.lua: -------------------------------------------------------------------------------- 1 | -- Copyright (c) 2017-present, Facebook, Inc. 2 | -- All rights reserved. 3 | -- 4 | -- This source code is licensed under the license found in the LICENSE file in 5 | -- the root directory of this source tree. An additional grant of patent rights 6 | -- can be found in the PATENTS file in the same directory. 7 | -- 8 | --[[ 9 | -- 10 | -- This module contains 2 modules one for train mode, one for evaluate 11 | -- (test) mode 12 | -- 13 | --]] 14 | 15 | local TrainTestLayer, parent = torch.class('nn.TrainTestLayer', 'nn.Container') 16 | 17 | function TrainTestLayer:__init(trainModule, evalModule, onTrain, onEvaluate) 18 | parent.__init(self) 19 | self.modules[1] = trainModule 20 | self.modules[2] = evalModule 21 | self.onTrain = onTrain 22 | self.onEvaluate = onEvaluate 23 | self.train = true 24 | end 25 | 26 | function TrainTestLayer:evaluate() 27 | if self.train then 28 | parent.evaluate(self) 29 | self.onEvaluate(self.modules[1], self.modules[2]) 30 | end 31 | end 32 | 33 | function TrainTestLayer:training() 34 | if not self.train then 35 | parent.training(self) 36 | self.onTrain(self.modules[1], self.modules[2]) 37 | end 38 | end 39 | 40 | function TrainTestLayer:updateOutput(input) 41 | local i = self.train and 1 or 2 42 | self.output:set(self.modules[i]:updateOutput(input)) 43 | return self.output 44 | end 45 | 46 | function TrainTestLayer:updateGradInput(input, gradOutput) 47 | assert(self.train, 'updateGradInput only in training mode') 48 | self.gradInput:set(self.modules[1]:updateGradInput(input, gradOutput)) 49 | return self.gradInput 50 | end 51 | 52 | function TrainTestLayer:accGradParameters(input, gradOutput, scale) 53 | assert(self.train, 'accGradParameters only in training mode') 54 | self.modules[1]:accGradParameters(input, gradOutput, scale) 55 | end 56 | 57 | function TrainTestLayer:zeroGradParameters() 58 | assert(self.train, 'zeroGradParameters only in training mode') 59 | self.modules[1]:zeroGradParameters() 60 | end 61 | 62 | function TrainTestLayer:updateParameters(lr) 63 | assert(self.train, 'updateParameters only in training mode') 64 | self.modules[1]:updateParameters(lr) 65 | end 66 | 67 | function TrainTestLayer:parameters() 68 | assert(self.train, 'TrainTestLayer support parameters in train mode') 69 | return self.modules[1]:parameters() 70 | end 71 | 72 | function TrainTestLayer:__tostring__() 73 | local fmt = 'nn.TrainTestLayer [ %s ; %s ]' 74 | return fmt:format(tostring(self.modules[1]), tostring(self.modules[2])) 75 | end 76 | -------------------------------------------------------------------------------- /fairseq/modules/ZipAlong.lua: -------------------------------------------------------------------------------- 1 | -- Copyright (c) 2017-present, Facebook, Inc. 2 | -- All rights reserved. 3 | -- 4 | -- This source code is licensed under the license found in the LICENSE file in 5 | -- the root directory of this source tree. An additional grant of patent rights 6 | -- can be found in the PATENTS file in the same directory. 7 | -- 8 | --[[ 9 | -- 10 | -- This module expects a table with two entries, and returns a single table of 11 | -- pairs as follows: 12 | -- Input: {{x_1, x_2, ... x_n}, y} 13 | -- Output: {{x_1, y}. {x_2, y}, ...., {x_n, y}} 14 | -- 15 | --]] 16 | 17 | local ZipAlong, parent = torch.class('nn.ZipAlong', 'nn.Module') 18 | 19 | function ZipAlong:__init() 20 | parent.__init(self) 21 | self.output = {} 22 | self.gradInputBase = {} 23 | end 24 | 25 | function ZipAlong:updateOutput(input) 26 | local base = input[1] 27 | local dup = input[2] 28 | self.output = {} 29 | for i = 1, #base do 30 | self.output[i] = {base[i], dup} 31 | end 32 | return self.output 33 | end 34 | 35 | local function zeroTT(dest, src) 36 | if type(src) == 'table' then 37 | if not dest or type(dest) ~= 'table' then 38 | dest = {} 39 | end 40 | for k,v in pairs(src) do 41 | dest[k] = zeroTT(dest[k], v) 42 | end 43 | else 44 | if not dest or not torch.isTypeOf(dest, src) then 45 | dest = src.new() 46 | end 47 | dest:resizeAs(src) 48 | dest:zero() 49 | end 50 | return dest 51 | end 52 | 53 | local function addTT(dest, src) 54 | if type(src) == 'table' then 55 | for k,v in pairs(src) do 56 | addTT(dest[k], v) 57 | end 58 | else 59 | dest:add(src) 60 | end 61 | end 62 | 63 | function ZipAlong:updateGradInput(input, gradOutput) 64 | local basein = input[1] 65 | local dupin = input[2] 66 | 67 | self.gradInputBase = {} 68 | self.gradInputDup = zeroTT(self.gradInputDup, dupin) 69 | for i = 1, #basein do 70 | self.gradInputBase[i] = gradOutput[i][1] 71 | addTT(self.gradInputDup, gradOutput[i][2]) 72 | end 73 | 74 | self.gradInput = {self.gradInputBase, self.gradInputDup} 75 | return self.gradInput 76 | end 77 | 78 | function ZipAlong:clearState() 79 | return nn.utils.clear(self, 'output', 'gradInputBase', 'gradInputDup', 80 | 'gradInput') 81 | end 82 | -------------------------------------------------------------------------------- /fairseq/modules/init.lua: -------------------------------------------------------------------------------- 1 | -- Copyright (c) 2017-present, Facebook, Inc. 2 | -- All rights reserved. 3 | -- 4 | -- This source code is licensed under the license found in the LICENSE file in 5 | -- the root directory of this source tree. An additional grant of patent rights 6 | -- can be found in the PATENTS file in the same directory. 7 | -- 8 | --[[ 9 | -- 10 | -- init file for some extra modules. 11 | -- 12 | --]] 13 | 14 | require 'fairseq.modules.AppendBias' 15 | require 'fairseq.modules.BeamableMM' 16 | require 'fairseq.modules.CAddTableMulConstant' 17 | require 'fairseq.modules.CLSTM' 18 | require 'fairseq.modules.CudnnRnnTable' 19 | require 'fairseq.modules.GradMultiply' 20 | require 'fairseq.modules.LinearizedConvolution' 21 | require 'fairseq.modules.SeqMultiply' 22 | require 'fairseq.modules.TrainTestLayer' 23 | require 'fairseq.modules.ZipAlong' 24 | -------------------------------------------------------------------------------- /fairseq/optim/nag.lua: -------------------------------------------------------------------------------- 1 | -- Copyright (c) 2017-present, Facebook, Inc. 2 | -- All rights reserved. 3 | -- 4 | -- This source code is licensed under the license found in the LICENSE file in 5 | -- the root directory of this source tree. An additional grant of patent rights 6 | -- can be found in the PATENTS file in the same directory. 7 | -- 8 | --[[ A plain implementation of Nesterov's momentum 9 | Implements Nesterov's momentum using the simplified 10 | formulation of https://arxiv.org/pdf/1212.0901.pdf 11 | ARGS: 12 | - `opfunc` : a function that takes a single input (X), the point 13 | of a evaluation, and returns f(X) and df/dX 14 | - `x` : the initial point 15 | - `config` : a table with configuration parameters for the optimizer 16 | - `config.learningRate` : learning rate 17 | - `config.momentum` : momentum 18 | - `state` : a table describing the state of the optimizer; after each 19 | call the state is modified 20 | - `state.evalCounter` : evaluation counter (optional: 0, by default) 21 | RETURN: 22 | - `x` : the new x vector 23 | - `f(x)` : the function, evaluated before the update 24 | (Yann Dauphin, 2016) 25 | ]] 26 | 27 | local function nag(opfunc, x, config, state) 28 | 29 | -- (0) get/update state 30 | local config = config or {} 31 | local state = state or config 32 | local lr = config.learningRate or 1e-3 33 | local l2 = config.l2 or 0 34 | local mom = config.momentum or 0 35 | state.evalCounter = state.evalCounter or 0 36 | 37 | -- (1) evaluate f(x) and df/dx 38 | local fx,dfdx = opfunc(x) 39 | 40 | if not state.dfdx then 41 | state.dfdx = torch.Tensor():typeAs(dfdx):resizeAs(dfdx):fill(0) 42 | end 43 | 44 | -- (2) weight decay 45 | if l2 ~= 0 then 46 | dfdx:add(l2, x) 47 | end 48 | 49 | -- (3) apply update 50 | x:add(mom*mom, state.dfdx):add(-(1 + mom) * lr, dfdx) 51 | 52 | -- (4) apply momentum 53 | state.dfdx:mul(mom):add(-lr, dfdx) 54 | 55 | -- (5) update evaluation counter 56 | state.evalCounter = state.evalCounter + 1 57 | 58 | -- return x*, f(x) before optimization 59 | return x,{fx} 60 | 61 | end 62 | 63 | return nag 64 | -------------------------------------------------------------------------------- /fairseq/text/Dictionary.lua: -------------------------------------------------------------------------------- 1 | -- Copyright (c) 2017-present, Facebook, Inc. 2 | -- All rights reserved. 3 | -- 4 | -- This source code is licensed under the license found in the LICENSE file in 5 | -- the root directory of this source tree. An additional grant of patent rights 6 | -- can be found in the PATENTS file in the same directory. 7 | -- 8 | --[[ 9 | -- 10 | -- A dictionary class that manages symbol (e.g. words) to index mapping and vice 11 | -- versa. Building a dictionary is done by repeatedly calling addSymbol() for 12 | -- all symbols in a given corpus and then calling finalize(). 13 | -- 14 | --]] 15 | 16 | local tds = require 'tds' 17 | local argcheck = require 'argcheck' 18 | 19 | local Dictionary = torch.class('Dictionary') 20 | 21 | 22 | Dictionary.__init = argcheck{ 23 | {name='self', type='Dictionary'}, 24 | {name='threshold', type='number', default=0}, 25 | {name='unk', type='string', default='<unk>'}, 26 | {name='pad', type='string', default='<pad>'}, 27 | {name='eos', type='string', default='</s>'}, 28 | call = function(self, threshold, unk, pad, eos) 29 | self.symbol_to_index = tds.Hash() 30 | self.index_to_symbol = tds.Vec() 31 | self.index_to_freq = tds.Vec() 32 | self.cutoff = math.huge 33 | 34 | -- Pre-populate with unk/pad/eos 35 | self.unk, self.pad, self.eos = unk, pad, eos 36 | self:addSymbol(self.unk) 37 | self.unk_index = self:getIndex(self.unk) 38 | self:addSymbol(self.pad) 39 | self.pad_index = self:getIndex(self.pad) 40 | self:addSymbol(self.eos) 41 | self.eos_index = self:getIndex(self.eos) 42 | self.threshold = threshold 43 | -- It's assumed that indices until and including to self.nspecial are 44 | -- occupied by special symbols. 45 | self.nspecial = 3 46 | end 47 | } 48 | 49 | function Dictionary:addSymbol(symbol) 50 | if self.symbol_to_index[symbol] == nil then 51 | local index = #self.index_to_symbol + 1 52 | self.symbol_to_index[symbol] = index 53 | self.index_to_symbol[index] = symbol 54 | self.index_to_freq[index] = 1 55 | else 56 | local index = self.symbol_to_index[symbol] 57 | self.index_to_freq[index] = self.index_to_freq[index] + 1 58 | end 59 | end 60 | 61 | function Dictionary:_applyFrequencyThreshold() 62 | local cutoff = math.huge 63 | for idx, freq in ipairs(self.index_to_freq) do 64 | if idx > self.nspecial and freq < self.threshold then 65 | cutoff = idx - 1 66 | break 67 | end 68 | end 69 | 70 | if cutoff == math.huge then 71 | -- No regular symbols above threshold, retain special symbols only 72 | cutoff = self.nspecial 73 | end 74 | return cutoff 75 | end 76 | 77 | function Dictionary:finalize() 78 | -- Sort symbols by frequency in descending order, ignoring special ones. 79 | self.index_to_symbol:sort(function(i, j) 80 | local idxi = self.symbol_to_index[i] 81 | local idxj = self.symbol_to_index[j] 82 | if idxi <= self.nspecial or idxj <= self.nspecial then 83 | return idxi < idxj 84 | end 85 | return self.index_to_freq[idxi] > self.index_to_freq[idxj] 86 | end) 87 | 88 | -- Update symbol_to_index and index_to_freq mappings 89 | local new_freq = tds.Vec() 90 | for idx, sym in ipairs(self.index_to_symbol) do 91 | local prev = self.symbol_to_index[sym] 92 | new_freq[idx] = self.index_to_freq[prev] 93 | self.symbol_to_index[sym] = idx 94 | end 95 | self.index_to_freq = new_freq 96 | 97 | collectgarbage() 98 | if self.threshold > 0 then 99 | self.cutoff = self:_applyFrequencyThreshold() 100 | else 101 | self.cutoff = #self.index_to_symbol 102 | end 103 | end 104 | 105 | function Dictionary:getIndex(symbol) 106 | local idx = self.symbol_to_index[symbol] 107 | if idx and idx <= self.cutoff then 108 | return idx 109 | end 110 | return self.unk_index 111 | end 112 | 113 | function Dictionary:getSymbol(idx) 114 | return self.index_to_symbol[idx] 115 | end 116 | 117 | function Dictionary:size() 118 | assert(self.cutoff ~= math.huge, 'Dictionary not finalized') 119 | return self.cutoff 120 | end 121 | 122 | function Dictionary:getUnkIndex() 123 | return self.unk_index 124 | end 125 | 126 | function Dictionary:getPadIndex() 127 | return self.pad_index 128 | end 129 | 130 | function Dictionary:getEosIndex() 131 | return self.eos_index 132 | end 133 | 134 | -- Returns the string of symbols whose indices are provided in vec 135 | function Dictionary:getString(vec) 136 | local out_tbl = {} 137 | for i = 1, vec:size(1) do 138 | table.insert(out_tbl, self:getSymbol(vec[i])) 139 | end 140 | local str = table.concat(out_tbl, ' ') 141 | return str 142 | end 143 | -------------------------------------------------------------------------------- /fairseq/text/bleu.lua: -------------------------------------------------------------------------------- 1 | -- Copyright (c) 2017-present, Facebook, Inc. 2 | -- All rights reserved. 3 | -- 4 | -- This source code is licensed under the license found in the LICENSE file in 5 | -- the root directory of this source tree. An additional grant of patent rights 6 | -- can be found in the PATENTS file in the same directory. 7 | -- 8 | --[[ 9 | -- 10 | -- BLEU scoring 11 | -- 12 | --]] 13 | 14 | local bleu = {} 15 | 16 | local function countNGrams(tokens, order) 17 | local ngramCounts = {} 18 | local orderString = tostring(order) 19 | local len = #tokens 20 | for i = 1, len - order + 1 do 21 | local ngram = orderString 22 | for j = 1, order do 23 | ngram = ngram .. ' ' .. tostring(tokens[i + j - 1]) 24 | end 25 | if ngramCounts[ngram] == nil then 26 | ngramCounts[ngram] = 1 27 | else 28 | ngramCounts[ngram] = ngramCounts[ngram] + 1 29 | end 30 | end 31 | return ngramCounts 32 | end 33 | 34 | function bleu.scorer(maxOrder) 35 | local maxOrder = maxOrder or 4 36 | local totalSys, totalRef = 0, 0 37 | local allCounts, correctCounts = {}, {} 38 | local numSents = 0 39 | 40 | for i = 1, maxOrder do 41 | allCounts[i], correctCounts[i] = 0, 0 42 | end 43 | 44 | local f = {} 45 | 46 | f.update = function(sys, ref) 47 | local refNGrams = {} 48 | for i = 1, maxOrder do 49 | local ngramCounts = countNGrams(ref, i) 50 | for ngram, count in pairs(ngramCounts) do 51 | refNGrams[ngram] = count 52 | end 53 | end 54 | 55 | for i = 1, maxOrder do 56 | local ngramCounts = countNGrams(sys, i) 57 | for ngram, count in pairs(ngramCounts) do 58 | allCounts[i] = allCounts[i] + count 59 | if refNGrams[ngram] ~= nil then 60 | if refNGrams[ngram] >= count then 61 | correctCounts[i] = correctCounts[i] + count 62 | else 63 | correctCounts[i] = correctCounts[i] + refNGrams[ngram] 64 | end 65 | end 66 | end 67 | end 68 | 69 | totalSys = totalSys + #sys 70 | totalRef = totalRef + #ref 71 | numSents = numSents + 1 72 | end 73 | 74 | local results = function() 75 | local precision = {} 76 | local psum = 0 77 | for i = 1, maxOrder do 78 | precision[i] = allCounts[i] > 0 and 79 | (correctCounts[i] / allCounts[i]) or 0 80 | psum = psum + math.log(precision[i]) 81 | end 82 | 83 | local brevPenalty = 1 84 | if totalSys < totalRef then 85 | brevPenalty = math.exp(1 - totalRef / totalSys) 86 | end 87 | 88 | local bleu = brevPenalty * math.exp(psum / maxOrder) * 100 89 | 90 | return { 91 | bleu = bleu, 92 | precision = precision, 93 | brevPenalty = brevPenalty, 94 | totalSys = totalSys, 95 | totalRef = totalRef, 96 | } 97 | end 98 | f.results = results 99 | 100 | f.resultString = function() 101 | local r = results() 102 | local str = string.format('BLEU%d = %.2f, ', maxOrder, r.bleu) 103 | local precs = {} 104 | for i = 1, maxOrder do 105 | precs[i] = string.format('%.1f', r.precision[i] * 100) 106 | end 107 | str = str .. table.concat(precs, '/') 108 | str = str .. string.format( 109 | ' (BP=%.3f, ratio=%.3f, sys_len=%d, ref_len=%d)', 110 | r.brevPenalty, r.totalRef / r.totalSys, r.totalSys, r.totalRef 111 | ) 112 | return str 113 | end 114 | 115 | return f 116 | end 117 | 118 | return bleu 119 | -------------------------------------------------------------------------------- /fairseq/text/init.lua: -------------------------------------------------------------------------------- 1 | -- Copyright (c) 2017-present, Facebook, Inc. 2 | -- All rights reserved. 3 | -- 4 | -- This source code is licensed under the license found in the LICENSE file in 5 | -- the root directory of this source tree. An additional grant of patent rights 6 | -- can be found in the PATENTS file in the same directory. 7 | -- 8 | --[[ 9 | -- 10 | -- init file for the text module. 11 | -- 12 | --]] 13 | 14 | require 'fairseq.text.Dictionary' 15 | -------------------------------------------------------------------------------- /fairseq/text/lm_corpus.lua: -------------------------------------------------------------------------------- 1 | -- Copyright (c) 2017-present, Facebook, Inc. 2 | -- All rights reserved. 3 | -- 4 | -- This source code is licensed under the license found in the LICENSE file in 5 | -- the root directory of this source tree. An additional grant of patent rights 6 | -- can be found in the PATENTS file in the same directory. 7 | -- 8 | --[[ 9 | -- 10 | -- An utility class to ease loading of a language model corpus. 11 | -- 12 | --]] 13 | 14 | local tnt = require 'torchnet' 15 | local argcheck = require 'argcheck' 16 | local tokenizer = require 'fairseq.text.tokenizer' 17 | 18 | local lmc = {} 19 | 20 | local function makeDataPipeline(ds, batchsize, bptt) 21 | -- bptt batching 22 | return tnt.BatchDataset{ 23 | -- Add targets 24 | dataset = tnt.TargetNextDataset{ 25 | -- Place tensor in table 26 | dataset = tnt.TransformDataset{ 27 | -- Batching across sentences 28 | dataset = tnt.SequenceBatchDataset{ 29 | dataset = ds, 30 | batchsize = batchsize, 31 | policy = 'skip-remainder', 32 | }, 33 | transform = function(sample) 34 | return {input = sample} 35 | end, 36 | }, 37 | }, 38 | batchsize = bptt, 39 | merge = function(sample) 40 | -- Merge target vectors into a tensor, keep input as a table 41 | local targets = sample.target 42 | sample.target = targets[1].new() 43 | torch.cat(sample.target, targets) 44 | return sample 45 | end, 46 | } 47 | end 48 | 49 | lmc.iteratorFromIndexed = argcheck{ 50 | {name='indexfilename', type='string'}, 51 | {name='datafilename', type='string'}, 52 | {name='batchsize', type='number'}, 53 | {name='bptt', type='number'}, 54 | call = function(indexfilename, datafilename, batchsize, bptt) 55 | return tnt.ParallelDatasetIterator{ 56 | nthread = 1, 57 | init = function() 58 | require 'torchnet' 59 | end, 60 | closure = function() 61 | local ds = tnt.FlatIndexedDataset{ 62 | indexfilename = indexfilename, 63 | datafilename = datafilename, 64 | } 65 | return makeDataPipeline(ds, batchsize, bptt) 66 | end, 67 | } 68 | end 69 | } 70 | 71 | lmc.iteratorFromText = argcheck{ 72 | {name='filename', type='string'}, 73 | {name='dict', type='Dictionary'}, 74 | {name='batchsize', type='number'}, 75 | {name='bptt', type='number'}, 76 | call = function(filename, dict, batchsize, bptt) 77 | -- XXX Won't scale to large datasets 78 | local data, _ = tokenizer.tensorize(filename, dict) 79 | return tnt.ParallelDatasetIterator{ 80 | nthread = 1, 81 | init = function() 82 | require 'torchnet' 83 | end, 84 | closure = function() 85 | local ds = tnt.TableDataset(data.words:totable()) 86 | collectgarbage() 87 | return makeDataPipeline(ds, batchsize, bptt) 88 | end, 89 | } 90 | end 91 | } 92 | 93 | lmc.loadTextCorpus = argcheck{ 94 | {name='trainfilename', type='string'}, 95 | {name='validfilename', type='string', opt=true}, 96 | {name='testfilename', type='string', opt=true}, 97 | {name='batchsize', type='number'}, 98 | {name='bptt', type='number'}, 99 | {name='dict', type='Dictionary', opt=true}, 100 | call = function(trainfilename, validfilename, testfilename, batchsize, 101 | bptt, dict) 102 | if not dict then 103 | dict = tokenizer.buildDictionary{ 104 | filename = trainfilename, 105 | threshold = 0, 106 | } 107 | end 108 | 109 | local train, valid, test = nil, nil, nil 110 | if trainfilename then 111 | train = lmc.iteratorFromText{ 112 | filename = trainfilename, 113 | dict = dict, 114 | batchsize = batchsize, 115 | bptt = bptt, 116 | } 117 | end 118 | if validfilename then 119 | valid = lmc.iteratorFromText{ 120 | filename = validfilename, 121 | dict = dict, 122 | batchsize = batchsize, 123 | bptt = bptt, 124 | } 125 | end 126 | if testfilename then 127 | test = lmc.iteratorFromText{ 128 | filename = testfilename, 129 | dict = dict, 130 | batchsize = batchsize, 131 | bptt = bptt, 132 | } 133 | end 134 | 135 | return { 136 | dict = dict, 137 | train = train, 138 | valid = valid, 139 | test = test, 140 | } 141 | end 142 | } 143 | 144 | lmc.loadBinarizedCorpus = argcheck{ 145 | {name='trainprefix', type='string'}, 146 | {name='validprefix', type='string', opt=true}, 147 | {name='testprefix', type='string', opt=true}, 148 | {name='dictfilename', type='string', opt=true}, 149 | {name='batchsize', type='number'}, 150 | {name='bptt', type='number'}, 151 | call = function(trainprefix, validprefix, testprefix, dictfilename, 152 | batchsize, bptt) 153 | 154 | local dict = nil 155 | if dictfilename then 156 | dict = torch.load(dictfilename) 157 | end 158 | 159 | local train, valid, test = nil, nil, nil 160 | if trainprefix then 161 | train = lmc.iteratorFromIndexed{ 162 | indexfilename = trainprefix .. '.idx', 163 | datafilename = trainprefix .. '.bin', 164 | dict = dict, 165 | batchsize = batchsize, 166 | bptt = bptt, 167 | } 168 | end 169 | if validprefix then 170 | valid = lmc.iteratorFromIndexed{ 171 | indexfilename = validprefix .. '.idx', 172 | datafilename = validprefix .. '.bin', 173 | dict = dict, 174 | batchsize = batchsize, 175 | bptt = bptt, 176 | } 177 | end 178 | if testprefix then 179 | test = lmc.iteratorFromIndexed{ 180 | indexfilename = testprefix .. '.idx', 181 | datafilename = testprefix .. '.bin', 182 | dict = dict, 183 | batchsize = batchsize, 184 | bptt = bptt, 185 | } 186 | end 187 | 188 | return { 189 | dict = dict, 190 | train = train, 191 | valid = valid, 192 | test = test, 193 | } 194 | end 195 | } 196 | 197 | lmc.binarizeCorpus = argcheck{ 198 | {name='files', type='table'}, 199 | {name='dict', type='Dictionary'}, 200 | call = function(files, dict) 201 | local res = {} 202 | for _, f in ipairs(files) do 203 | local r = tokenizer.binarize{ 204 | filename = f.src, 205 | dict = dict, 206 | indexfilename = f.dest .. '.idx', 207 | datafilename = f.dest .. '.bin', 208 | } 209 | table.insert(res, r) 210 | collectgarbage() 211 | end 212 | return res 213 | end 214 | } 215 | 216 | return lmc 217 | -------------------------------------------------------------------------------- /fairseq/text/pretty.lua: -------------------------------------------------------------------------------- 1 | -- Copyright (c) 2017-present, Facebook, Inc. 2 | -- All rights reserved. 3 | -- 4 | -- This source code is licensed under the license found in the LICENSE file in 5 | -- the root directory of this source tree. An additional grant of patent rights 6 | -- can be found in the PATENTS file in the same directory. 7 | -- 8 | 9 | local argcheck = require 'argcheck' 10 | 11 | -- display source, target, hypothesis and attention 12 | local displayResults = argcheck{ 13 | {name='srcdict', type='Dictionary'}, 14 | {name='dict', type='Dictionary'}, 15 | {name='nbest', type='number'}, 16 | {name='beam', type='number'}, 17 | call = function(dict, srcdict, nbest, beam) 18 | local eos = dict:getSymbol(dict:getEosIndex()) 19 | local unk = dict:getSymbol(dict:getUnkIndex()) 20 | local seos = srcdict:getSymbol(srcdict:getEosIndex()) 21 | local runk = unk 22 | repeat -- select unk token for reference different from hypothesis 23 | runk = string.format('<%s>', runk) 24 | until dict:getIndex(runk) == dict:getUnkIndex() 25 | 26 | return function(sample, hypos, scores, attns) 27 | local src, tgt = sample.source:t(), sample.target:t() 28 | for i = 1, sample.bsz do 29 | local sourceString = srcdict:getString(src[i]):gsub(seos, '') 30 | print('S-' .. sample.index[i], sourceString) 31 | 32 | local ref = dict:getString(tgt[i]) 33 | :gsub(eos .. '.*', ''):gsub(unk, runk) --ref may contain pad 34 | print('T-' .. sample.index[i], ref) 35 | 36 | for j = 1, math.min(nbest, beam) do 37 | local idx = (i - 1) * beam + j 38 | local hypo = dict:getString(hypos[idx]):gsub(eos, '') 39 | print('H-' .. sample.index[i], scores[idx], hypo) 40 | -- NOTE: This will print #hypo + 1 attention maxima. The 41 | -- last one is the attention that was used to generate the 42 | -- <eos> symbol. 43 | local _, maxattns = torch.max(attns[idx], 2) 44 | print('A-' .. sample.index[i], 45 | table.concat(maxattns:squeeze(2):totable(), ' ')) 46 | end 47 | end 48 | end 49 | end 50 | } 51 | 52 | return { 53 | displayResults = displayResults, 54 | } 55 | -------------------------------------------------------------------------------- /fairseq/text/tokenizer.lua: -------------------------------------------------------------------------------- 1 | -- Copyright (c) 2017-present, Facebook, Inc. 2 | -- All rights reserved. 3 | -- 4 | -- This source code is licensed under the license found in the LICENSE file in 5 | -- the root directory of this source tree. An additional grant of patent rights 6 | -- can be found in the PATENTS file in the same directory. 7 | -- 8 | --[[ 9 | -- 10 | -- Build the word based dataset for a text corpus. 11 | -- 12 | --]] 13 | 14 | 15 | local argcheck = require 'argcheck' 16 | local plutils = require 'pl.utils' 17 | local tds = require 'tds' 18 | local tnt = require 'torchnet' 19 | 20 | local tokenizer = {} 21 | 22 | tokenizer.pad = '<pad>' 23 | tokenizer.eos = '</s>' 24 | tokenizer.unk = '<unk>' 25 | 26 | tokenizer.tokenize = function(line) 27 | -- Remove extra whitespace 28 | local s = line:gsub("\t", ""):gsub("^%s+", ""):gsub("%s+$", ""):gsub("%s+", " ") 29 | return plutils.split(s, ' ') 30 | end 31 | 32 | local addFileToDictionary = argcheck{ 33 | {name='filename', type='string'}, 34 | {name='dict', type='Dictionary'}, 35 | {name='tokenize', type='function', default=tokenizer.tokenize}, 36 | call = function(filename, dict, tokenize) 37 | for s in io.lines(filename) do 38 | for i, word in pairs(tokenize(s)) do 39 | dict:addSymbol(word) 40 | end 41 | dict:addSymbol(dict.eos) 42 | end 43 | end 44 | } 45 | 46 | local lineCount = argcheck{ 47 | {name = 'fpath', type = 'string'}, 48 | call = function(fpath) 49 | local nlines = 0 50 | for _ in io.lines(fpath) do 51 | nlines = nlines + 1 52 | end 53 | return nlines 54 | end 55 | } 56 | 57 | tokenizer.makeDictionary = argcheck{ 58 | {name='threshold', type='number', default=1}, 59 | call = function(threshold) 60 | return Dictionary{ 61 | threshold = threshold, 62 | unk = tokenizer.unk, 63 | pad = tokenizer.pad, 64 | eos = tokenizer.eos, 65 | } 66 | end 67 | } 68 | 69 | tokenizer.buildDictionary = argcheck{ 70 | {name='filename', type='string'}, 71 | {name='threshold', type='number', default=1}, 72 | {name='tokenize', type='function', default=tokenizer.tokenize}, 73 | noordered = true, 74 | call = function(filename, threshold, tokenize) 75 | local dict = tokenizer.makeDictionary{ 76 | threshold = threshold, 77 | } 78 | addFileToDictionary(filename, dict, tokenize) 79 | dict:finalize() 80 | return dict 81 | end 82 | } 83 | 84 | tokenizer.buildDictionary = argcheck{ 85 | {name='filenames', type='table'}, 86 | {name='threshold', type='number', default=1}, 87 | {name='tokenize', type='function', default=tokenizer.tokenize}, 88 | overload = tokenizer.buildDictionary, 89 | noordered = true, 90 | call = function(filenames, threshold, tokenize) 91 | local dict = tokenizer.makeDictionary{ 92 | threshold = threshold, 93 | } 94 | for i, filename in pairs(filenames) do 95 | addFileToDictionary(filename, dict, tokenize) 96 | end 97 | dict:finalize() 98 | return dict 99 | end 100 | } 101 | 102 | tokenizer.buildAlignFreqMap = argcheck{ 103 | {name='alignfile', type='string'}, 104 | {name='srcfile', type='string'}, 105 | {name='tgtfile', type='string'}, 106 | {name='srcdict', type='Dictionary'}, 107 | {name='tgtdict', type='Dictionary'}, 108 | {name='tokenize', type='function', default=tokenizer.tokenize}, 109 | call = function(alignfile, srcfile, tgtfile, srcdict, 110 | tgtdict, tokenize) 111 | local freqmap = tds.Vec() 112 | local srccorp = io.lines(srcfile) 113 | local tgtcorp = io.lines(tgtfile) 114 | 115 | local function addalignment(alignment, src, tgt) 116 | if alignment:dim() == 0 then 117 | return 118 | end 119 | 120 | -- Compute src-tgt pair frequencies 121 | for i = 1, alignment:size(1) do 122 | local srcidx = src[alignment[i][1]] 123 | assert(srcidx ~= srcdict:getEosIndex()) 124 | assert(srcidx ~= srcdict:getPadIndex()) 125 | 126 | local tgtidx = tgt[alignment[i][2]] 127 | assert(tgtidx ~= tgtdict:getEosIndex()) 128 | assert(tgtidx ~= tgtdict:getPadIndex()) 129 | 130 | if srcidx ~= srcdict:getUnkIndex() and 131 | tgtidx ~= tgtdict:getUnkIndex() then 132 | if not freqmap[srcidx] then 133 | freqmap[srcidx] = tds.Hash() 134 | end 135 | if not freqmap[srcidx][tgtidx] then 136 | freqmap[srcidx][tgtidx] = 1 137 | else 138 | freqmap[srcidx][tgtidx] = freqmap[srcidx][tgtidx] + 1 139 | end 140 | end 141 | end 142 | 143 | end 144 | 145 | -- TODO: If we modify Dictionary to work better with variable cutoffs 146 | -- this should be replaced with a proper function 147 | freqmap:resize(#srcdict.index_to_symbol) 148 | for line in io.lines(alignfile) do 149 | addalignment( 150 | tokenizer.tensorizeAlignment(line, tokenize), 151 | tokenizer.tensorizeString(srccorp(), srcdict, tokenize), 152 | tokenizer.tensorizeString(tgtcorp(), tgtdict, tokenize)) 153 | end 154 | return freqmap 155 | end 156 | } 157 | 158 | tokenizer.tensorizeString = argcheck{ 159 | {name='text', type='string'}, 160 | {name='dict', type='Dictionary'}, 161 | {name='tokenize', type='function', default=tokenizer.tokenize}, 162 | call = function(text, dict, tokenize) 163 | local words = tokenize(text) 164 | local ids = torch.LongTensor(#words + 1) 165 | for i, word in pairs(words) do 166 | ids[i] = dict:getIndex(word) 167 | end 168 | ids[#words + 1] = dict:getEosIndex() 169 | return ids 170 | end 171 | } 172 | 173 | tokenizer.tensorizeAlignment = argcheck{ 174 | {name='text', type='string'}, 175 | {name='tokenize', type='function', default=tokenizer.tokenize}, 176 | call = function(text, tokenize) 177 | local tokens = tokenize(text) 178 | local alignment = torch.IntTensor(#tokens, 2) 179 | -- Note that alignments are zero-based 180 | for i, token in ipairs(tokens) do 181 | local pair = plutils.split(token, '-') 182 | for j, id in ipairs(pair) do 183 | alignment[i][j] = tonumber(id) + 1 184 | end 185 | end 186 | return alignment 187 | end 188 | } 189 | 190 | tokenizer.tensorize = argcheck{ 191 | {name='filename', type='string'}, 192 | {name='dict', type='Dictionary'}, 193 | {name='tokenize', type='function', default=tokenizer.tokenize}, 194 | call = function(filename, dict, tokenize) 195 | local nSequence = lineCount(filename) 196 | local smap = torch.IntTensor(nSequence, 2) 197 | local ids = torch.LongTensor(nSequence * 20) 198 | local nseq, nunk = 0, 0 199 | local woffset = 1 200 | local replaced = tds.Hash() 201 | for s in io.lines(filename) do 202 | local words = tokenize(s) 203 | local nwords = #words 204 | nseq = nseq + 1 205 | smap[nseq][1] = woffset 206 | smap[nseq][2] = nwords + 1 -- +1 for the additional </s> character 207 | 208 | while woffset + nwords + 1 > ids:nElement() do 209 | ids:resize(math.floor(ids:nElement() * 1.5)) 210 | end 211 | 212 | for i, word in pairs(words) do 213 | local idx = dict:getIndex(word) 214 | if idx == dict.unk_index and word ~= dict.unk then 215 | nunk = nunk + 1 216 | if not replaced[word] then 217 | replaced[word] = 1 218 | else 219 | replaced[word] = replaced[word] + 1 220 | end 221 | end 222 | ids[woffset] = idx 223 | woffset = woffset + 1 224 | end 225 | ids[woffset] = dict.eos_index 226 | woffset = woffset + 1 227 | end 228 | smap = smap:narrow(1, 1, nseq):clone() 229 | ids = ids:narrow(1, 1, woffset - 1):clone() 230 | 231 | return {smap = smap, words = ids}, { 232 | nseq = nseq, 233 | nunk = nunk, 234 | ntok = ids:nElement(), 235 | replaced = replaced, 236 | } 237 | end 238 | } 239 | 240 | tokenizer.binarize = argcheck{ 241 | {name='filename', type='string'}, 242 | {name='dict', type='Dictionary'}, 243 | {name='indexfilename', type='string'}, 244 | {name='datafilename', type='string'}, 245 | {name='tokenize', type='function', default=tokenizer.tokenize}, 246 | call = function(filename, dict, indexfilename, datafilename, tokenize) 247 | local writer = tnt.IndexedDatasetWriter{ 248 | indexfilename = indexfilename, 249 | datafilename = datafilename, 250 | type = 'int', 251 | } 252 | local nseq, ntok, nunk = 0, 0, 0 253 | local ids = torch.IntTensor() 254 | local replaced = tds.Hash() 255 | for s in io.lines(filename) do 256 | local words = tokenize(s) 257 | local nwords = #words 258 | ids:resize(nwords + 1) 259 | nseq = nseq + 1 260 | for i, word in pairs(words) do 261 | local idx = dict:getIndex(word) 262 | if idx == dict.unk_index and word ~= dict.unk then 263 | nunk = nunk + 1 264 | if not replaced[word] then 265 | replaced[word] = 1 266 | else 267 | replaced[word] = replaced[word] + 1 268 | end 269 | end 270 | ids[i] = idx 271 | end 272 | ids[nwords + 1] = dict.eos_index 273 | writer:add(ids) 274 | ntok = ntok + ids:nElement() 275 | end 276 | writer:close() 277 | 278 | return { 279 | nseq = nseq, 280 | nunk = nunk, 281 | ntok = ntok, 282 | replaced = replaced, 283 | } 284 | end 285 | } 286 | 287 | tokenizer.binarizeAlignFreqMap = argcheck{ 288 | {name='freqmap', type='tds.Vec'}, 289 | {name='srcdict', type='Dictionary'}, 290 | {name='indexfilename', type='string'}, 291 | {name='datafilename', type='string'}, 292 | {name='ncandidates', type='number'}, 293 | {name='tokenize', type='function', default=tokenizer.tokenize}, 294 | call = function(freqmap, srcdict, indexfilename, datafilename, 295 | ncandidates, tokenize) 296 | local writer = tnt.IndexedDatasetWriter{ 297 | indexfilename = indexfilename, 298 | datafilename = datafilename, 299 | type = 'int', 300 | } 301 | local empty = torch.IntTensor() 302 | local cands = torch.IntTensor() 303 | local npairs = 0 304 | for srcidx = 1, #srcdict.index_to_symbol do 305 | local ncands = freqmap[srcidx] and #freqmap[srcidx] or 0 306 | if ncands > 0 then 307 | cands:resize(ncands, 2) 308 | local j = 1 309 | for tgtidx, freq in pairs(freqmap[srcidx]) do 310 | cands[j][1] = tgtidx 311 | cands[j][2] = freq 312 | j = j + 1 313 | end 314 | ncands = math.min(ncands, ncandidates) 315 | npairs = npairs + ncands 316 | local _, indices = torch.topk(cands:narrow(2, 2, 1), 317 | ncands, 1, true, true) 318 | writer:add(cands:index(1, indices:squeeze(2))) 319 | else 320 | -- Add empty tensor if there are no candidates for given srcidx 321 | writer:add(empty) 322 | end 323 | end 324 | writer:close() 325 | 326 | return { 327 | npairs = npairs 328 | } 329 | end 330 | } 331 | 332 | return tokenizer 333 | -------------------------------------------------------------------------------- /fairseq/torchnet/MaxBatchDataset.lua: -------------------------------------------------------------------------------- 1 | -- Copyright (c) 2017-present, Facebook, Inc. 2 | -- All rights reserved. 3 | -- 4 | -- This source code is licensed under the license found in the LICENSE file in 5 | -- the root directory of this source tree. An additional grant of patent rights 6 | -- can be found in the PATENTS file in the same directory. 7 | -- 8 | --[[ 9 | -- 10 | -- MaxBatchDataset builds batches of up to and including maxbatch tokens. 11 | -- 12 | --]] 13 | 14 | 15 | local tnt = require 'torchnet.env' 16 | local argcheck = require 'argcheck' 17 | local transform = require 'torchnet.transform' 18 | local vector = require 'vector' 19 | 20 | local MaxBatchDataset, _ = 21 | torch.class('tnt.MaxBatchDataset', 'tnt.Dataset', tnt) 22 | 23 | MaxBatchDataset.__init = argcheck{ 24 | doc = [[ 25 | <a name="MaxBatchDataset"> 26 | #### tnt.MaxBatchDataset(@ARGP) 27 | @ARGT 28 | Given a `dataset`, `tnt.MaxBatchDataset` merges samples from this dataset into 29 | a batch such that the total size of the batch does not exceed maxbatch. 30 | ]], 31 | {name='self', type='tnt.MaxBatchDataset'}, 32 | {name='dataset', type='tnt.Dataset'}, 33 | {name='maxbatch', type='number'}, 34 | {name='samplesize', type='function'}, 35 | {name='merge', type='function', opt=true}, 36 | call = 37 | function(self, dataset, maxbatch, samplesize, merge) 38 | assert(maxbatch > 0 and math.floor(maxbatch) == maxbatch, 39 | 'maxbatch should be a positive integer number') 40 | self.dataset = dataset 41 | self.maxbatch = maxbatch 42 | self.samplesize = samplesize 43 | self.makebatch = transform.makebatch{merge=merge} 44 | self:_buildIndex() 45 | end 46 | } 47 | 48 | MaxBatchDataset._buildIndex = argcheck{ 49 | {name='self', type='tnt.MaxBatchDataset'}, 50 | call = function(self) 51 | self.offset = vector.tensor.new_long() 52 | self.offset[1] = 1 53 | local size = self.dataset:size() 54 | local maxssz, maxtsz = 0, 0 55 | local nstok, nttok = 0, 0 56 | 57 | for i = 1, size do 58 | local _, ssz, tsz = self.samplesize(self.dataset, i) 59 | if math.max(ssz, tsz) > self.maxbatch then 60 | print("warning: found sample that exceeds maxbatch size") 61 | end 62 | 63 | maxtsz = math.max(maxtsz, tsz) 64 | maxssz = math.max(maxssz, ssz) 65 | local nsamples = i - self.offset[#self.offset] + 1 66 | local tottsz = nsamples * maxtsz 67 | local totssz = nsamples * maxssz 68 | nstok = nstok + ssz 69 | nttok = nttok + tsz 70 | 71 | if i > 1 and math.max(tottsz, totssz) > self.maxbatch then 72 | self.offset[#self.offset + 1] = i 73 | maxssz = ssz 74 | maxtsz = tsz 75 | nstok = ssz 76 | nttok = tsz 77 | end 78 | end 79 | self.offset = self.offset:getTensor() 80 | end 81 | } 82 | 83 | MaxBatchDataset.size = argcheck{ 84 | {name='self', type='tnt.MaxBatchDataset'}, 85 | call = 86 | function(self) 87 | return self.offset:size(1) 88 | end 89 | } 90 | 91 | MaxBatchDataset.get = argcheck{ 92 | {name='self', type='tnt.MaxBatchDataset'}, 93 | {name='idx', type='number'}, 94 | call = 95 | function(self, idx) 96 | assert(idx >= 1 and idx <= self:size(), 'index out of bound') 97 | local samples = {} 98 | local first = self.offset[idx] 99 | local last = idx < self:size() and 100 | self.offset[idx + 1] - 1 or self.dataset:size() 101 | for i = first, last do 102 | local sample = self.dataset:get(i) 103 | table.insert(samples, sample) 104 | end 105 | samples = self.makebatch(samples) 106 | collectgarbage() 107 | collectgarbage() 108 | return samples 109 | end 110 | } 111 | -------------------------------------------------------------------------------- /fairseq/torchnet/ShardedDatasetIterator.lua: -------------------------------------------------------------------------------- 1 | -- Copyright (c) 2017-present, Facebook, Inc. 2 | -- All rights reserved. 3 | -- 4 | -- This source code is licensed under the license found in the LICENSE file in 5 | -- the root directory of this source tree. An additional grant of patent rights 6 | -- can be found in the PATENTS file in the same directory. 7 | -- 8 | --[[ 9 | -- 10 | -- Iterator that prepares the input data for data parallel training. 11 | -- 12 | --]] 13 | 14 | local tnt = require 'torchnet' 15 | local argcheck = require 'argcheck' 16 | 17 | local ShardedDatasetIterator, DatasetIterator = 18 | torch.class('tnt.ShardedDatasetIterator', 'tnt.DatasetIterator', tnt) 19 | 20 | ShardedDatasetIterator.__init = argcheck{ 21 | doc = [[ 22 | <a name="ShardedDatasetIterator"> 23 | #### tnt.ShardedDatasetIterator(@ARGP) 24 | @ARGT 25 | 26 | This iterator is useful when you want to split out your input batch of samples 27 | into evenly sized sub batches across a specified `dimension`, so you can later 28 | use those for your dataparallel training. 29 | 30 | You can specify the exact set of fields that you want to split out by passing 31 | the`fields` argument, the remaining fields will be duplicated. 32 | 33 | The output will be in a form of a table of size `nshards` that contains all 34 | the sub batches. 35 | 36 | ]], 37 | {name='self', type='tnt.ShardedDatasetIterator'}, 38 | {name='dataset', type='tnt.Dataset'}, 39 | {name='nshards', type='number'}, 40 | {name='dimension', type='number', default=1}, 41 | {name='fields', type='table', default={}}, 42 | call = function(self, dataset, nshards, dimension, fields) 43 | DatasetIterator.__init(self, dataset) 44 | self:_setup(nshards, dimension, fields) 45 | end 46 | } 47 | 48 | ShardedDatasetIterator.__init = argcheck{ 49 | {name='self', type='tnt.ShardedDatasetIterator'}, 50 | {name='iterator', type='tnt.DatasetIterator'}, 51 | {name='nshards', type='number'}, 52 | {name='dimension', type='number', default=1}, 53 | {name='fields', type='table', default={}}, 54 | overload = ShardedDatasetIterator.__init, 55 | call = function(self, iterator, nshards, dimension, fields) 56 | DatasetIterator.__init(self, iterator) 57 | self:_setup(nshards, dimension, fields) 58 | end 59 | } 60 | 61 | ShardedDatasetIterator._setup = argcheck{ 62 | {name='self', type='tnt.ShardedDatasetIterator'}, 63 | {name='nshards', type='number'}, 64 | {name='dimension', type='number', default=1}, 65 | {name='fields', type='table', default={}}, 66 | call = function(self, nshards, dimension, fields) 67 | self.nshards = nshards 68 | self.dimension = dimension 69 | self.fields_map = {} 70 | for _, v in ipairs(fields) do 71 | self.fields_map[v] = true 72 | end 73 | 74 | self.base_run = self.run 75 | self.run = self:_run() 76 | end 77 | } 78 | 79 | local function shouldSplit(k, v, fields) 80 | local typename = torch.typename(v) 81 | return fields[k] and typename and typename:match('Tensor') 82 | end 83 | 84 | local function inferSize(sample, dimension, fields) 85 | local size = -1 86 | for k, v in pairs(sample) do 87 | if shouldSplit(k, v, fields) then 88 | local cursize = v:size(dimension) 89 | assert(size == -1 or size == cursize) 90 | size = cursize 91 | end 92 | end 93 | assert(size ~= -1) 94 | return size 95 | end 96 | 97 | function ShardedDatasetIterator:_run() 98 | return function() 99 | local next_from_base = self.base_run() 100 | 101 | return function() 102 | local sample = next_from_base() 103 | if not sample then 104 | return sample 105 | end 106 | local size = inferSize(sample, self.dimension, self.fields_map) 107 | local shardsz = math.ceil(size / self.nshards) 108 | local offset = 1 109 | local result = {} 110 | for shardid = 1, self.nshards do 111 | if offset > size then 112 | break 113 | end 114 | local curshardsz = math.min(shardsz, size - offset + 1) 115 | result[shardid] = {} 116 | local shard = result[shardid] 117 | for k, v in pairs(sample) do 118 | if shouldSplit(k, v, self.fields_map) then 119 | shard[k] = v:narrow(self.dimension, offset, curshardsz) 120 | else 121 | shard[k] = v 122 | end 123 | end 124 | offset = offset + curshardsz 125 | end 126 | return result 127 | end 128 | end 129 | end 130 | -------------------------------------------------------------------------------- /fairseq/torchnet/SingleParallelIterator.lua: -------------------------------------------------------------------------------- 1 | -- Copyright (c) 2017-present, Facebook, Inc. 2 | -- All rights reserved. 3 | -- 4 | -- This source code is licensed under the license found in the LICENSE file in 5 | -- the root directory of this source tree. An additional grant of patent rights 6 | -- can be found in the PATENTS file in the same directory. 7 | -- 8 | --[[ 9 | -- 10 | -- A parallel dataset iterator that can wrap another iterator and is thus 11 | -- limited to using a single thread only. 12 | -- 13 | --]] 14 | 15 | local tnt = require 'torchnet' 16 | local argcheck = require 'argcheck' 17 | local Threads = require 'threads' 18 | local doc = require 'argcheck.doc' 19 | 20 | local SingleParallelIterator = torch.class('tnt.SingleParallelIterator', 'tnt.DatasetIterator', tnt) 21 | 22 | SingleParallelIterator.__init = argcheck{ 23 | {name='self', type='tnt.SingleParallelIterator'}, 24 | {name='init', type='function', default=function(idx) end}, 25 | {name='closure', type='function'}, 26 | call = function(self, init, closure) 27 | local function main(idx) 28 | giterator = closure(idx) 29 | assert(torch.isTypeOf(giterator, 'tnt.DatasetIterator'), 30 | 'closure should return a DatasetIterator class') 31 | gloop = nil 32 | end 33 | Threads.serialization('threads.sharedserialize') 34 | local threads = Threads(1, init, main) 35 | self.__threads = threads 36 | local sample -- beware: do not put this line in loop() 37 | local sampleOrigIdx 38 | function self.run() 39 | -- make sure we are not in the middle of something 40 | threads:synchronize() 41 | local function enqueue() 42 | threads:addjob( 43 | function() 44 | if not gloop then 45 | gloop = giterator:run() 46 | end 47 | local sample = gloop() 48 | collectgarbage() 49 | collectgarbage() 50 | if not sample then 51 | gloop = nil 52 | end 53 | return sample 54 | end, 55 | function(_sample_) 56 | sample = _sample_ 57 | end) 58 | end 59 | 60 | enqueue() 61 | local iterFunction = function() 62 | while threads:hasjob() do 63 | threads:dojob() 64 | if threads:haserror() then 65 | threads:synchronize() 66 | end 67 | if sample then 68 | enqueue() 69 | end 70 | return sample 71 | end 72 | end 73 | 74 | return iterFunction 75 | end 76 | end 77 | } 78 | 79 | SingleParallelIterator.exec = 80 | function(self, name, ...) 81 | assert(not self.__threads:hasjob(), 'cannot exec during loop') 82 | local args = {...} 83 | local res 84 | self.__threads:addjob( 85 | function() 86 | return giterator:exec(name, table.unpack(args)) 87 | end, 88 | function(...) 89 | res = {...} 90 | end) 91 | self.__threads:synchronize() 92 | return table.unpack(res) 93 | end 94 | -------------------------------------------------------------------------------- /fairseq/torchnet/init.lua: -------------------------------------------------------------------------------- 1 | -- Copyright (c) 2017-present, Facebook, Inc. 2 | -- All rights reserved. 3 | -- 4 | -- This source code is licensed under the license found in the LICENSE file in 5 | -- the root directory of this source tree. An additional grant of patent rights 6 | -- can be found in the PATENTS file in the same directory. 7 | -- 8 | --[[ 9 | -- 10 | -- init file for the torchnet extensions. 11 | -- 12 | --]] 13 | 14 | require 'torchnet.sequential.dataset' 15 | require 'fairseq.torchnet.MaxBatchDataset' 16 | require 'fairseq.torchnet.ResumableDPOptimEngine' 17 | require 'fairseq.torchnet.ShardedDatasetIterator' 18 | require 'fairseq.torchnet.SingleParallelIterator' 19 | -------------------------------------------------------------------------------- /fairseq/utils.lua: -------------------------------------------------------------------------------- 1 | -- Copyright (c) 2017-present, Facebook, Inc. 2 | -- All rights reserved. 3 | -- 4 | -- This source code is licensed under the license found in the LICENSE file in 5 | -- the root directory of this source tree. An additional grant of patent rights 6 | -- can be found in the PATENTS file in the same directory. 7 | -- 8 | --[[ 9 | -- 10 | -- Helper functions. 11 | -- 12 | --]] 13 | 14 | local argcheck = require 'argcheck' 15 | local stringx = require 'pl.stringx' 16 | 17 | local util = {} 18 | 19 | util.isint = function(x) 20 | return x ~= nil and type(x) == 'number' and x == math.floor(x) 21 | end 22 | 23 | util.loadCuda = argcheck{ 24 | call = function() 25 | local names = {'cutorch', 'cudnn', 'cunn', 'tbc', 'nccl'} 26 | local modules = {} 27 | for _, name in ipairs(names) do 28 | local ok, module = pcall(require, name) 29 | modules[name] = ok and module or nil 30 | end 31 | return modules 32 | end 33 | } 34 | 35 | util.parseListOrDefault = argcheck{ 36 | {name='str', type='string'}, 37 | {name='n', type='number'}, 38 | {name='val', type='number'}, 39 | {name='del', type='string', default=','}, 40 | call = function(str, n, val, del) 41 | local kv = {} 42 | if str == '' then 43 | for i = 1, n do 44 | kv[i] = val 45 | end 46 | else 47 | kv = stringx.split(str, del) 48 | for k, v in pairs(kv) do 49 | kv[k] = tonumber(v) 50 | end 51 | end 52 | return kv 53 | end 54 | } 55 | 56 | util.sendtogpu = function(data, data_gpu) 57 | data_gpu = data_gpu or torch.CudaTensor() 58 | assert(data_gpu and torch.type(data_gpu) == 'torch.CudaTensor') 59 | assert(data and torch.isTensor(data)) 60 | data_gpu:resize(data:size()):copy(data) 61 | return data_gpu 62 | end 63 | 64 | util.retry = function(n, ...) 65 | for i = 1, n do 66 | local status, err = pcall(...) 67 | if status then 68 | return true 69 | end 70 | print(err) 71 | end 72 | return false 73 | end 74 | 75 | util.RecyclableSet = function(n) 76 | local buffer = torch.IntTensor(n):zero() 77 | local t = 1 78 | return { 79 | set = function(self, idx) 80 | buffer[idx] = t 81 | end, 82 | isset = function(self, idx) 83 | return buffer[idx] == t 84 | end, 85 | clear = function(self) 86 | t = t + 1 87 | end, 88 | } 89 | end 90 | 91 | return util 92 | -------------------------------------------------------------------------------- /generate-lines.lua: -------------------------------------------------------------------------------- 1 | -- Copyright (c) 2017-present, Facebook, Inc. 2 | -- All rights reserved. 3 | -- 4 | -- This source code is licensed under the license found in the LICENSE file in 5 | -- the root directory of this source tree. An additional grant of patent rights 6 | -- can be found in the PATENTS file in the same directory. 7 | -- 8 | --[[ 9 | -- 10 | -- Hypothesis generation script with text file input, processed line-by-line. 11 | -- By default, this will run in interactive mode. 12 | -- 13 | --]] 14 | 15 | require 'fairseq' 16 | 17 | local tnt = require 'torchnet' 18 | local tds = require 'tds' 19 | local argcheck = require 'argcheck' 20 | local plstringx = require 'pl.stringx' 21 | local data = require 'fairseq.torchnet.data' 22 | local search = require 'fairseq.search' 23 | local tokenizer = require 'fairseq.text.tokenizer' 24 | local mutils = require 'fairseq.models.utils' 25 | 26 | local cmd = torch.CmdLine() 27 | cmd:option('-path', 'model1.th7,model2.th7', 'path to saved model(s)') 28 | cmd:option('-beam', 1, 'search beam width') 29 | cmd:option('-lenpen', 1, 30 | 'length penalty: <1.0 favors shorter, >1.0 favors longer sentences') 31 | cmd:option('-unkpen', 0, 32 | 'unknown word penalty: <0 produces more, >0 produces less unknown words') 33 | cmd:option('-subwordpen', 0, 34 | 'subword penalty: <0 favors longer, >0 favors shorter words') 35 | cmd:option('-covpen', 0, 36 | 'coverage penalty: favor hypotheses that cover all source tokens') 37 | cmd:option('-nbest', 1, 'number of candidate hypotheses') 38 | cmd:option('-minlen', 1, 'minimum length of generated hypotheses') 39 | cmd:option('-maxlen', 500, 'maximum length of generated hypotheses') 40 | cmd:option('-input', '-', 'source language input text file') 41 | cmd:option('-sourcedict', '', 'source language dictionary') 42 | cmd:option('-targetdict', '', 'target language dictionary') 43 | cmd:option('-vocab', '', 'restrict output to target vocab') 44 | cmd:option('-visdom', '', 'visualize with visdom: (host:port)') 45 | cmd:option('-model', '', 'model type for legacy models') 46 | cmd:option('-aligndictpath', '', 'path to an alignment dictionary (optional)') 47 | cmd:option('-nmostcommon', 500, 48 | 'the number of most common words to keep when using alignment') 49 | cmd:option('-topnalign', 100, 'the number of the most common alignments to use') 50 | cmd:option('-freqthreshold', -1, 51 | 'the minimum frequency for an alignment candidate in order' .. 52 | 'to be considered (default no limit)') 53 | cmd:option('-fconvfast', false, 'make fconv model faster') 54 | 55 | local config = cmd:parse(arg) 56 | 57 | ------------------------------------------------------------------- 58 | -- Load data 59 | ------------------------------------------------------------------- 60 | config.dict = torch.load(config.targetdict) 61 | print(string.format('| [target] Dictionary: %d types', config.dict:size())) 62 | config.srcdict = torch.load(config.sourcedict) 63 | print(string.format('| [source] Dictionary: %d types', config.srcdict:size())) 64 | 65 | if config.aligndictpath ~= '' then 66 | config.aligndict = tnt.IndexedDatasetReader{ 67 | indexfilename = config.aligndictpath .. '.idx', 68 | datafilename = config.aligndictpath .. '.bin', 69 | mmap = true, 70 | mmapidx = true, 71 | } 72 | config.nmostcommon = math.max(config.nmostcommon, config.dict.nspecial) 73 | config.nmostcommon = math.min(config.nmostcommon, config.dict:size()) 74 | end 75 | 76 | local TextFileIterator, _ = 77 | torch.class('tnt.TextFileIterator', 'tnt.DatasetIterator', tnt) 78 | 79 | TextFileIterator.__init = argcheck{ 80 | {name='self', type='tnt.TextFileIterator'}, 81 | {name='path', type='string'}, 82 | {name='transform', type='function', 83 | default=function(sample) return sample end}, 84 | call = function(self, path, transform) 85 | function self.run() 86 | local fd 87 | if path == '-' then 88 | fd = io.stdin 89 | else 90 | fd = io.open(path) 91 | end 92 | return function() 93 | if torch.isatty(fd) then 94 | io.stdout:write('> ') 95 | io.stdout:flush() 96 | end 97 | local line = fd:read() 98 | if line ~= nil then 99 | return transform(line) 100 | elseif fd ~= io.stdin then 101 | fd:close() 102 | end 103 | end 104 | end 105 | end 106 | } 107 | 108 | local dataset = tnt.DatasetIterator{ 109 | iterator = tnt.TextFileIterator{ 110 | path = config.input, 111 | transform = function(line) 112 | return { 113 | bin = tokenizer.tensorizeString(line, config.srcdict), 114 | text = line, 115 | } 116 | end 117 | }, 118 | transform = function(sample) 119 | local source = sample.bin:view(-1, 1):int() 120 | local sourcePos = data.makePositions(source, 121 | config.srcdict:getPadIndex()):view(-1, 1) 122 | local sample = { 123 | source = source, 124 | sourcePos = sourcePos, 125 | text = sample.text, 126 | target = torch.IntTensor(1, 1), -- a stub 127 | } 128 | if config.aligndict then 129 | sample.targetVocab, sample.targetVocabMap, 130 | sample.targetVocabStats 131 | = data.getTargetVocabFromAlignment{ 132 | dictsize = config.dict:size(), 133 | unk = config.dict:getUnkIndex(), 134 | aligndict = config.aligndict, 135 | set = 'test', 136 | source = sample.source, 137 | target = sample.target, 138 | nmostcommon = config.nmostcommon, 139 | topnalign = config.topnalign, 140 | freqthreshold = config.freqthreshold, 141 | } 142 | end 143 | return sample 144 | end, 145 | } 146 | 147 | local model 148 | if config.model ~= '' then 149 | model = mutils.loadLegacyModel(config.path, config.model) 150 | else 151 | model = require( 152 | 'fairseq.models.ensemble_model' 153 | ).new(config) 154 | if config.fconvfast then 155 | local nfconv = 0 156 | for _, fconv in ipairs(model.models) do 157 | if torch.typename(fconv) == 'FConvModel' then 158 | fconv:makeDecoderFast() 159 | nfconv = nfconv + 1 160 | end 161 | end 162 | assert(nfconv > 0, '-fconvfast requires an fconv model in the ensemble') 163 | end 164 | end 165 | 166 | local vocab = nil 167 | if config.vocab ~= '' then 168 | vocab = tds.Hash() 169 | local fd = io.open(config.vocab) 170 | while true do 171 | local line = fd:read() 172 | if line == nil then 173 | break 174 | end 175 | -- Add word on this line together with all prefixes 176 | for i = 1, line:len() do 177 | vocab[line:sub(1, i)] = 1 178 | end 179 | end 180 | end 181 | local searchf = search.beam{ 182 | ttype = model:type(), 183 | dict = config.dict, 184 | srcdict = config.srcdict, 185 | beam = config.beam, 186 | lenPenalty = config.lenpen, 187 | unkPenalty = config.unkpen, 188 | subwordPenalty = config.subwordpen, 189 | coveragePenalty = config.covpen, 190 | vocab = vocab, 191 | } 192 | 193 | if config.visdom ~= '' then 194 | local host, port = table.unpack(plstringx.split(config.visdom, ':')) 195 | searchf = search.visualize{ 196 | sf = searchf, 197 | dict = config.dict, 198 | sourceDict = config.srcdict, 199 | host = host, 200 | port = tonumber(port), 201 | } 202 | end 203 | 204 | local dict, srcdict = config.dict, config.srcdict 205 | local eos = dict:getSymbol(dict:getEosIndex()) 206 | local seos = srcdict:getSymbol(srcdict:getEosIndex()) 207 | local unk = dict:getSymbol(dict:getUnkIndex()) 208 | 209 | -- Select unknown token for reference that can't be produced by the model so 210 | -- that the program output can be scored correctly. 211 | local runk = unk 212 | repeat 213 | runk = string.format('<%s>', runk) 214 | until dict:getIndex(runk) == dict:getUnkIndex() 215 | 216 | for sample in dataset() do 217 | sample.bsz = 1 218 | local hypos, scores, attns = model:generate(config, sample, searchf) 219 | 220 | -- Print results 221 | local sourceString = config.srcdict:getString(sample.source:t()[1]) 222 | sourceString = sourceString:gsub(seos .. '.*', '') 223 | print('S', sourceString) 224 | print('O', sample.text) 225 | 226 | for i = 1, math.min(config.nbest, config.beam) do 227 | local hypo = config.dict:getString(hypos[i]):gsub(eos .. '.*', '') 228 | print('H', scores[i], hypo) 229 | -- NOTE: This will print #hypo + 1 attention maxima. The last one is the 230 | -- attention that was used to generate the <eos> symbol. 231 | local _, maxattns = torch.max(attns[i], 2) 232 | print('A', table.concat(maxattns:squeeze(2):totable(), ' ')) 233 | end 234 | 235 | io.stdout:flush() 236 | collectgarbage() 237 | end 238 | -------------------------------------------------------------------------------- /generate.lua: -------------------------------------------------------------------------------- 1 | -- Copyright (c) 2017-present, Facebook, Inc. 2 | -- All rights reserved. 3 | -- 4 | -- This source code is licensed under the license found in the LICENSE file in 5 | -- the root directory of this source tree. An additional grant of patent rights 6 | -- can be found in the PATENTS file in the same directory. 7 | -- 8 | --[[ 9 | -- 10 | -- Batch hypothesis generation script. 11 | -- 12 | --]] 13 | 14 | require 'nn' 15 | require 'xlua' 16 | require 'fairseq' 17 | 18 | local tnt = require 'torchnet' 19 | local tds = require 'tds' 20 | local plpath = require 'pl.path' 21 | local hooks = require 'fairseq.torchnet.hooks' 22 | local data = require 'fairseq.torchnet.data' 23 | local search = require 'fairseq.search' 24 | local clib = require 'fairseq.clib' 25 | local mutils = require 'fairseq.models.utils' 26 | local utils = require 'fairseq.utils' 27 | local pretty = require 'fairseq.text.pretty' 28 | 29 | local cmd = torch.CmdLine() 30 | cmd:option('-path', 'model1.th7,model2.th7', 'path to saved model(s)') 31 | cmd:option('-nobleu', false, 'don\'t produce final bleu score') 32 | cmd:option('-quiet', false, 'don\'t print generated text') 33 | cmd:option('-beam', 1, 'beam width') 34 | cmd:option('-lenpen', 1, 35 | 'length penalty: <1.0 favors shorter, >1.0 favors longer sentences') 36 | cmd:option('-unkpen', 0, 37 | 'unknown word penalty: <0 produces more, >0 produces less unknown words') 38 | cmd:option('-subwordpen', 0, 39 | 'subword penalty: <0 favors longer, >0 favors shorter words') 40 | cmd:option('-covpen', 0, 41 | 'coverage penalty: favor hypotheses that cover all source tokens') 42 | cmd:option('-nbest', 1, 'number of candidate hypotheses') 43 | cmd:option('-batchsize', 16, 'batch size') 44 | cmd:option('-minlen', 1, 'minimum length of generated hypotheses') 45 | cmd:option('-maxlen', 500, 'maximum length of generated hypotheses') 46 | cmd:option('-sourcelang', 'de', 'source language') 47 | cmd:option('-targetlang', 'en', 'target language') 48 | cmd:option('-datadir', 'data-bin') 49 | cmd:option('-dataset', 'test', 'data subset') 50 | cmd:option('-partial', '1/1', 51 | 'decode only part of the dataset, syntax: part_index/num_parts') 52 | cmd:option('-vocab', '', 'restrict output to target vocab') 53 | cmd:option('-seed', 1111, 'random number seed (for dataset)') 54 | cmd:option('-model', '', 'model type for legacy models') 55 | cmd:option('-ndatathreads', 0, 'number of threads for data preparation') 56 | cmd:option('-aligndictpath', '', 'path to an alignment dictionary (optional)') 57 | cmd:option('-nmostcommon', 500, 58 | 'the number of most common words to keep when using alignment') 59 | cmd:option('-topnalign', 100, 'the number of the most common alignments to use') 60 | cmd:option('-freqthreshold', -1, 61 | 'the minimum frequency for an alignment candidate in order' .. 62 | 'to be considered (default no limit)') 63 | cmd:option('-fconvfast', false, 'make fconv model faster') 64 | 65 | local cuda = utils.loadCuda() 66 | 67 | local config = cmd:parse(arg) 68 | torch.manualSeed(config.seed) 69 | if cuda.cutorch then 70 | cutorch.manualSeed(config.seed) 71 | end 72 | 73 | local function accTime() 74 | local total = {} 75 | return function(times) 76 | for k, v in pairs(times or {}) do 77 | if not total[k] then 78 | total[k] = {real = 0, sys = 0, user = 0} 79 | end 80 | for l, w in pairs(v) do 81 | total[k][l] = total[k][l] + w 82 | end 83 | end 84 | return total 85 | end 86 | end 87 | 88 | local function accBleu(beam, dict) 89 | local scorer = clib.bleu(dict:getPadIndex(), dict:getEosIndex()) 90 | local unkIndex = dict:getUnkIndex() 91 | local refBuf, hypoBuf = torch.IntTensor(), torch.IntTensor() 92 | return function(sample, hypos) 93 | if sample then 94 | local tgtT = sample.target:t() 95 | local ref = refBuf:resizeAs(tgtT):copy(tgtT) 96 | :apply(function(x) 97 | return x == unkIndex and -unkIndex or x 98 | end) 99 | for i = 1, sample.bsz do 100 | local hypoL = hypos[(i - 1) * beam + 1] 101 | local hypo = hypoBuf:resize(hypoL:size()):copy(hypoL) 102 | scorer:add(ref[i], hypo) 103 | end 104 | end 105 | return scorer 106 | end 107 | end 108 | 109 | ------------------------------------------------------------------- 110 | -- Load data 111 | ------------------------------------------------------------------- 112 | config.dict = torch.load(plpath.join(config.datadir, 113 | 'dict.' .. config.targetlang .. '.th7')) 114 | print(string.format('| [%s] Dictionary: %d types', config.targetlang, 115 | config.dict:size())) 116 | config.srcdict = torch.load(plpath.join(config.datadir, 117 | 'dict.' .. config.sourcelang .. '.th7')) 118 | print(string.format('| [%s] Dictionary: %d types', config.sourcelang, 119 | config.srcdict:size())) 120 | 121 | if config.aligndictpath ~= '' then 122 | config.aligndict = tnt.IndexedDatasetReader{ 123 | indexfilename = config.aligndictpath .. '.idx', 124 | datafilename = config.aligndictpath .. '.bin', 125 | mmap = true, 126 | mmapidx = true, 127 | } 128 | config.nmostcommon = math.max(config.nmostcommon, config.dict.nspecial) 129 | config.nmostcommon = math.min(config.nmostcommon, config.dict:size()) 130 | end 131 | 132 | local _, test = data.loadCorpus{config = config, testsets = {config.dataset}} 133 | local dataset = test[config.dataset] 134 | 135 | local model 136 | if config.model ~= '' then 137 | model = mutils.loadLegacyModel(config.path, config.model) 138 | else 139 | model = require( 140 | 'fairseq.models.ensemble_model' 141 | ).new(config) 142 | if config.fconvfast then 143 | local nfconv = 0 144 | for _, fconv in ipairs(model.models) do 145 | if torch.typename(fconv) == 'FConvModel' then 146 | fconv:makeDecoderFast() 147 | nfconv = nfconv + 1 148 | end 149 | end 150 | assert(nfconv > 0, '-fconvfast requires an fconv model in the ensemble') 151 | end 152 | end 153 | 154 | local vocab = nil 155 | if config.vocab ~= '' then 156 | vocab = tds.Hash() 157 | local fd = io.open(config.vocab) 158 | while true do 159 | local line = fd:read() 160 | if line == nil then 161 | break 162 | end 163 | -- Add word on this line together with all prefixes 164 | for i = 1, line:len() do 165 | vocab[line:sub(1, i)] = 1 166 | end 167 | end 168 | end 169 | 170 | local searchf = search.beam{ 171 | ttype = model:type(), 172 | dict = config.dict, 173 | srcdict = config.srcdict, 174 | beam = config.beam, 175 | lenPenalty = config.lenpen, 176 | unkPenalty = config.unkpen, 177 | subwordPenalty = config.subwordpen, 178 | coveragePenalty = config.covpen, 179 | vocab = vocab, 180 | } 181 | 182 | local dict, srcdict = config.dict, config.srcdict 183 | local display = pretty.displayResults(dict, srcdict, config.nbest, config.beam) 184 | local computeSampleStats = hooks.computeSampleStats(dict) 185 | 186 | -- Ensure that the model is fully unrolled for the maximum source sentence 187 | -- length in the test set. Lazy unrolling might otherwise distort the generation 188 | -- time measurements. 189 | local maxlen = 1 190 | for samples in dataset() do 191 | for _, sample in ipairs(samples) do 192 | maxlen = math.max(maxlen, sample.source:size(1)) 193 | end 194 | end 195 | model:extend(maxlen) 196 | 197 | -- allow to decode only part of the set k/N means decode part k of N 198 | local partidx, nparts = config.partial:match('(%d+)/(%d+)') 199 | partidx, nparts = tonumber(partidx), tonumber(nparts) 200 | 201 | -- let's decode 202 | local addBleu = accBleu(config.beam, dict) 203 | local addTime = accTime() 204 | local timer = torch.Timer() 205 | local nsents, ntoks, nbatch = 0, 0, 0 206 | local state = {} 207 | for samples in dataset() do 208 | if (nbatch % nparts == partidx - 1) then 209 | assert(#samples == 1, 'can\'t handle multiple samples') 210 | state.samples = samples 211 | computeSampleStats(state) 212 | local sample = state.samples[1] 213 | local hypos, scores, attns, t = model:generate(config, sample, searchf) 214 | nsents = nsents + sample.bsz 215 | ntoks = ntoks + sample.ntokens 216 | addTime(t) 217 | 218 | -- print results 219 | if not config.quiet then 220 | display(sample, hypos, scores, attns) 221 | end 222 | 223 | -- accumulate bleu 224 | if (not config.nobleu) then 225 | addBleu(sample, hypos) 226 | end 227 | end 228 | nbatch = nbatch + 1 229 | end 230 | 231 | -- report overall stats 232 | local elapsed = timer:time().real 233 | local statmsg = 234 | ('| Translated %d sentences (%d tokens) in %.1fs (%.2f tokens/s)') 235 | :format(nsents, ntoks, elapsed, ntoks / elapsed) 236 | if state.dictstats then 237 | local avg = state.dictstats.size / state.dictstats.n 238 | statmsg = ('%s with avg dict of size %.1f'):format(statmsg, avg) 239 | end 240 | print(statmsg) 241 | 242 | local timings = '| Timings:' 243 | local totalTime = addTime() 244 | for k, v in pairs(totalTime) do 245 | local percent = 100 * v.real / elapsed 246 | timings = ('%s %s %.1fs (%.1f%%),'):format(timings, k, v.real, percent) 247 | end 248 | print(timings:sub(1, -2)) 249 | 250 | if not config.nobleu then 251 | local bleu = addBleu() 252 | print(('| %s'):format(bleu:resultString())) 253 | end 254 | -------------------------------------------------------------------------------- /help.lua: -------------------------------------------------------------------------------- 1 | -- Copyright (c) 2017-present, Facebook, Inc. 2 | -- All rights reserved. 3 | -- 4 | -- This source code is licensed under the license found in the 5 | -- LICENSE file in the root directory of this source tree. 6 | -- 7 | --[[ 8 | -- 9 | -- List available scripts. 10 | -- 11 | --]] 12 | 13 | local dir = require 'pl.dir' 14 | local path = require 'pl.path' 15 | 16 | local scriptdir = path.abspath(path.dirname(debug.getinfo(1).source:sub(2))) 17 | local tools = {} 18 | local maxbase = 0 19 | for _, file in pairs(dir.getfiles(scriptdir, '*.lua')) do 20 | local base, _ = path.splitext(path.basename(file)) 21 | 22 | local f = io.open(file) 23 | local source = f:read("*all") 24 | f:close() 25 | -- First sentence of first multi-line comment block is regarded as a brief 26 | -- description 27 | local m = source:gsub('\n', ' '):match('%-%-%[%[.*%-%-%]%]') 28 | local description = m:match('(%w+[^%.]*)%.') 29 | 30 | table.insert(tools, {base = base, description = description}) 31 | if #base > maxbase then 32 | maxbase = #base 33 | end 34 | end 35 | 36 | print('Available tools:') 37 | for i, tool in ipairs(tools) do 38 | io.stdout:write(' ') 39 | io.stdout:write(tool.base) 40 | for j = #tool.base, maxbase + 2 do 41 | io.stdout:write(' ') 42 | end 43 | print(tool.description) 44 | end 45 | -------------------------------------------------------------------------------- /optimize-fconv.lua: -------------------------------------------------------------------------------- 1 | -- Copyright (c) 2017-present, Facebook, Inc. 2 | -- All rights reserved. 3 | -- 4 | -- This source code is licensed under the license found in the LICENSE file in 5 | -- the root directory of this source tree. An additional grant of patent rights 6 | -- can be found in the PATENTS file in the same directory. 7 | -- 8 | --[[ 9 | -- 10 | -- Optimize a fconv model for fast generation. 11 | -- 12 | --]] 13 | 14 | require 'fairseq' 15 | 16 | local cmd = torch.CmdLine() 17 | cmd:option('-input_model', 'fconv_model.th7', 18 | 'a th7 file that contains a fconv model') 19 | cmd:option('-output_model', 'fconv_model_opt.th7', 20 | 'an output file that will contain an optimized version') 21 | local config = cmd:parse(arg) 22 | 23 | local model = torch.load(config.input_model) 24 | if torch.typename(model) ~= 'FConvModel' then 25 | error '"FConvModel" expected' 26 | end 27 | 28 | -- Enable faster decoding 29 | model:makeDecoderFast() 30 | 31 | -- Clear output buffers and zero gradients for better compressability 32 | model.module:clearState() 33 | local _, gparams = model.module:parameters() 34 | for i = 1, #gparams do 35 | gparams[i]:zero() 36 | end 37 | 38 | torch.save(config.output_model, model) 39 | -------------------------------------------------------------------------------- /preprocess.lua: -------------------------------------------------------------------------------- 1 | -- Copyright (c) 2017-present, Facebook, Inc. 2 | -- All rights reserved. 3 | -- 4 | -- This source code is licensed under the license found in the LICENSE file in 5 | -- the root directory of this source tree. An additional grant of patent rights 6 | -- can be found in the PATENTS file in the same directory. 7 | -- 8 | --[[ 9 | -- 10 | -- Data pre-processing (binarization). Create dictionary and store parallel data 11 | -- as binary, indexed torchnet datasets, one dataset per language and subset 12 | -- (train/valid/test). 13 | -- 14 | -- The following naming scheme is assumed for the parallel corpus: 15 | -- 16 | -- $trainpref.$sourcelang - train source language file 17 | -- $trainpref.$targetlang - train target language file 18 | -- $validpref.$sourcelang - validation source language file 19 | -- $validpref.$targetlang - validation target language file 20 | -- $testpref.sourcelang - test source language file 21 | -- $testpref.targetlang - test target language file 22 | -- 23 | -- For example: 24 | -- -sourcelang de \ 25 | -- -targetlang en \ 26 | -- -trainpref ./data/iwslt14.tokenized.de-en/train 27 | -- assumes that there are two files present: 28 | -- ./data/iwslt14.tokenized.de-en/train.de 29 | -- ./data/iwslt14.tokenized.de-en/train.en 30 | -- 31 | -- If a file with alignments is given (-alignfile) this script also produces a 32 | -- list (of length -ncandidates) of most common words from the target language 33 | -- for each word from the source language. 34 | -- 35 | -- The alignemnt file uses "Pharaoh format", where a pair i-j (zero based) 36 | -- indicates that the ith word of the source language is aligned to the jth 37 | -- word of the target language. For example: 38 | -- 39 | -- 0-0 1-1 2-4 3-2 4-3 5-5 6-6 40 | -- 0-0 1-1 2-2 2-3 3-4 4-5 41 | -- 42 | --]] 43 | 44 | require 'fairseq' 45 | local tok = require 'fairseq.text.tokenizer' 46 | local lmc = require 'fairseq.text.lm_corpus' 47 | local plpath = require 'pl.path' 48 | local pldir = require 'pl.dir' 49 | 50 | local cmd = torch.CmdLine() 51 | cmd:option('-sourcelang', 'de', 'source language') 52 | cmd:option('-targetlang', 'en', 'target language') 53 | cmd:option('-trainpref', 'train', 'training file prefix') 54 | cmd:option('-validpref', 'valid', 'validation file prefix') 55 | cmd:option('-testpref', 'test', 'testing file prefix') 56 | cmd:option('-alignfile', '', 'an alignment file (optional)') 57 | cmd:option('-ncandidates', 1000, 'number of candidates per a source word') 58 | cmd:option('-thresholdtgt', 0, 59 | 'map words appearing less than threshold times to unknown') 60 | cmd:option('-thresholdsrc', 0, 61 | 'map words appearing less than threshold times to unknown') 62 | cmd:option('-nwordstgt', -1, 63 | 'number of target words to retain') 64 | cmd:option('-nwordssrc', -1, 65 | 'number of source words to retain') 66 | cmd:option('-destdir', 'data-bin') 67 | 68 | local config = cmd:parse(arg) 69 | assert(not (config.nwordstgt >= 0 and config.thresholdtgt > 0), 70 | 'Specify either a frequency threshold or a word count') 71 | assert(not (config.nwordssrc >= 0 and config.thresholdsrc > 0), 72 | 'Specify either a frequency threshold or a word count') 73 | 74 | local langcode = config.sourcelang .. '-' .. config.targetlang 75 | local srcext = string.format('.%s.%s', langcode, config.sourcelang) 76 | local tgtext = string.format('.%s.%s', langcode, config.targetlang) 77 | pldir.makepath(config.destdir) 78 | 79 | local src = { 80 | lang = config.sourcelang, 81 | threshold = config.thresholdsrc, 82 | nwords = config.nwordssrc, 83 | dictbin = plpath.join( 84 | config.destdir, 'dict.' .. config.sourcelang .. '.th7' 85 | ), 86 | traintxt = config.trainpref .. '.' .. config.sourcelang, 87 | validtxt = config.validpref .. '.' .. config.sourcelang, 88 | testtxt = config.testpref .. '.' .. config.sourcelang, 89 | trainbin = plpath.join(config.destdir, 'train' .. srcext), 90 | validbin = plpath.join(config.destdir, 'valid' .. srcext), 91 | testbin = plpath.join(config.destdir, 'test' .. srcext), 92 | } 93 | 94 | local tgt = { 95 | lang = config.targetlang, 96 | threshold = config.thresholdtgt, 97 | nwords = config.nwordstgt, 98 | dictbin = plpath.join( 99 | config.destdir, 'dict.' .. config.targetlang .. '.th7' 100 | ), 101 | traintxt = config.trainpref .. '.' .. config.targetlang, 102 | validtxt = config.validpref .. '.' .. config.targetlang, 103 | testtxt = config.testpref .. '.' .. config.targetlang, 104 | trainbin = plpath.join(config.destdir, 'train' .. tgtext), 105 | validbin = plpath.join(config.destdir, 'valid' .. tgtext), 106 | testbin = plpath.join(config.destdir, 'test' .. tgtext), 107 | } 108 | 109 | for _, lang in ipairs({src, tgt}) do 110 | lang.dict = tok.buildDictionary{ 111 | filename = lang.traintxt, 112 | threshold = lang.threshold, 113 | } 114 | if lang.nwords >= 0 then 115 | lang.dict.cutoff = lang.nwords + lang.dict.nspecial 116 | end 117 | 118 | print(string.format('| [%s] Dictionary: %d types', 119 | lang.lang, lang.dict:size())) 120 | torch.save(lang.dictbin, lang.dict, 'binary', false) 121 | collectgarbage() 122 | 123 | local res = lmc.binarizeCorpus{ 124 | files = { 125 | {dest=lang.trainbin, src=lang.traintxt}, 126 | {dest=lang.validbin, src=lang.validtxt}, 127 | {dest=lang.testbin, src=lang.testtxt}, 128 | }, 129 | dict = lang.dict, 130 | } 131 | 132 | local files = {lang.traintxt, lang.validtxt, lang.testtxt} 133 | for i = 1, #files do 134 | print(string.format( 135 | '| [%s] %s: %d sents, %d tokens, %.2f%% replaced by %s', 136 | lang.lang, files[i], res[i].nseq, res[i].ntok, 137 | 100 * res[i].nunk / res[i].ntok, lang.dict.unk)) 138 | end 139 | 140 | print(string.format('| [%s] Wrote preprocessed data to %s', 141 | lang.lang, config.destdir)) 142 | collectgarbage() 143 | end 144 | 145 | if config.alignfile ~= '' then 146 | -- Process the alignment file 147 | local alignfreqmap = tok.buildAlignFreqMap{ 148 | alignfile = config.alignfile, 149 | srcfile = src.traintxt, 150 | tgtfile = tgt.traintxt, 151 | srcdict = src.dict, 152 | tgtdict = tgt.dict, 153 | } 154 | local dest = plpath.join(config.destdir, 'alignment.' .. langcode) 155 | local stats = tok.binarizeAlignFreqMap{ 156 | freqmap = alignfreqmap, 157 | srcdict = src.dict, 158 | indexfilename = dest .. '.idx', 159 | datafilename = dest .. '.bin', 160 | ncandidates = config.ncandidates, 161 | } 162 | print(string.format( 163 | '| [%s] Alignments: %d valid pairs', langcode, stats.npairs)) 164 | print(string.format('| [%s] Wrote preprocessed data to %s', langcode, dest)) 165 | collectgarbage() 166 | end 167 | -------------------------------------------------------------------------------- /rocks/fairseq-cpu-scm-1.rockspec: -------------------------------------------------------------------------------- 1 | -- Copyright (c) 2017-present, Facebook, Inc. 2 | -- All rights reserved. 3 | -- 4 | -- This source code is licensed under the license found in the LICENSE file in 5 | -- the root directory of this source tree. An additional grant of patent rights 6 | -- can be found in the PATENTS file in the same directory. 7 | -- 8 | package = 'fairseq-cpu' 9 | version = 'scm-1' 10 | source = { 11 | url = 'git://github.com/facebookresearch/fairseq', 12 | tag = 'main', 13 | } 14 | description = { 15 | summary = 'Facebook AI Research Sequence-to-Sequence Toolkit', 16 | homepage = 'https://github.com/facebookresearch/fairseq', 17 | license = 'BSD 3-clause', 18 | } 19 | dependencies = { 20 | 'argcheck', 21 | 'lua-cjson', 22 | 'nn', 23 | 'nngraph', 24 | 'penlight', 25 | 'rnnlib', 26 | 'tbc', 27 | 'tds', 28 | 'threads', 29 | 'torch >= 7.0', 30 | 'torchnet', 31 | 'torchnet-sequential', 32 | 'visdom', 33 | } 34 | build = { 35 | type = "cmake", 36 | variables = { 37 | CMAKE_BUILD_TYPE="Release", 38 | ROCKS_PREFIX="$(PREFIX)", 39 | ROCKS_LUADIR="$(LUADIR)", 40 | ROCKS_LIBDIR="$(LIBDIR)", 41 | ROCKS_BINDIR="$(BINDIR)", 42 | CMAKE_PREFIX_PATH="$(LUA_BINDIR)/..", 43 | } 44 | } 45 | -------------------------------------------------------------------------------- /rocks/fairseq-scm-1.rockspec: -------------------------------------------------------------------------------- 1 | -- Copyright (c) 2017-present, Facebook, Inc. 2 | -- All rights reserved. 3 | -- 4 | -- This source code is licensed under the license found in the LICENSE file in 5 | -- the root directory of this source tree. An additional grant of patent rights 6 | -- can be found in the PATENTS file in the same directory. 7 | -- 8 | package = 'fairseq' 9 | version = 'scm-1' 10 | source = { 11 | url = 'git://github.com/facebookresearch/fairseq', 12 | tag = 'main', 13 | } 14 | description = { 15 | summary = 'Facebook AI Research Sequence-to-Sequence Toolkit', 16 | homepage = 'https://github.com/facebookresearch/fairseq', 17 | license = 'BSD 3-clause', 18 | } 19 | dependencies = { 20 | 'argcheck', 21 | 'cudnn', 22 | 'cunn', 23 | 'lua-cjson', 24 | 'nccl', 25 | 'nn', 26 | 'nngraph', 27 | 'penlight', 28 | 'rnnlib', 29 | 'tbc', 30 | 'tds', 31 | 'threads', 32 | 'torch >= 7.0', 33 | 'torchnet', 34 | 'torchnet-sequential', 35 | 'visdom', 36 | } 37 | build = { 38 | type = "cmake", 39 | variables = { 40 | CMAKE_BUILD_TYPE="Release", 41 | ROCKS_PREFIX="$(PREFIX)", 42 | ROCKS_LUADIR="$(LUADIR)", 43 | ROCKS_LIBDIR="$(LIBDIR)", 44 | ROCKS_BINDIR="$(BINDIR)", 45 | CMAKE_PREFIX_PATH="$(LUA_BINDIR)/..", 46 | } 47 | } 48 | -------------------------------------------------------------------------------- /run.lua: -------------------------------------------------------------------------------- 1 | -- Copyright (c) 2017-present, Facebook, Inc. 2 | -- All rights reserved. 3 | -- 4 | -- This source code is licensed under the license found in the 5 | -- LICENSE file in the root directory of this source tree. 6 | -- 7 | --[[ 8 | -- 9 | -- A "master" script that can launch a given script. 10 | -- 11 | --]] 12 | 13 | if #arg > 0 then 14 | if arg[1] == '--help' or arg[1] == '-h' or arg[2] == '-?' then 15 | arg[1] = 'help' 16 | end 17 | require('fairseq.scripts.' .. table.remove(arg, 1)) 18 | else 19 | print('Usage: fairseq <tool> [options]') 20 | end 21 | -------------------------------------------------------------------------------- /score.lua: -------------------------------------------------------------------------------- 1 | -- Copyright (c) 2017-present, Facebook, Inc. 2 | -- All rights reserved. 3 | -- 4 | -- This source code is licensed under the license found in the LICENSE file in 5 | -- the root directory of this source tree. An additional grant of patent rights 6 | -- can be found in the PATENTS file in the same directory. 7 | -- 8 | --[[ 9 | -- 10 | -- Command-line script for BLEU scoring. 11 | -- 12 | --]] 13 | 14 | local tok = require 'fairseq.text.tokenizer' 15 | local plpath = require 'pl.path' 16 | local bleu = require 'fairseq.text.bleu' 17 | 18 | local cmd = torch.CmdLine() 19 | cmd:option('-sys', '-', 'system output') 20 | cmd:option('-ref', '', 'references') 21 | cmd:option('-order', 4, 'consider ngrams up to this order') 22 | cmd:option('-ignore_case', false, 'case-insensitive scoring') 23 | 24 | local config = cmd:parse(arg) 25 | 26 | assert(config.sys == '-' or plpath.exists(config.sys)) 27 | local fdsys = config.sys == '-' and io.stdin or io.open(config.sys) 28 | assert(plpath.exists(config.ref)) 29 | local fdref = io.open(config.ref) 30 | 31 | local function readLine(fd) 32 | local s = fd:read() 33 | if s == nil then 34 | return nil 35 | end 36 | 37 | if config.ignore_case then 38 | s = string.lower(s) 39 | end 40 | return tok.tokenize(s) 41 | end 42 | 43 | local scorer = bleu.scorer(config.order) 44 | 45 | -- Process system output and reference file 46 | while true do 47 | local sysTok = readLine(fdsys) 48 | local refTok = readLine(fdref) 49 | if sysTok == nil and refTok ~= nil then 50 | error 'Insufficient number of lines in system output' 51 | elseif refTok == nil and sysTok ~= nil then 52 | error 'Insufficient number of lines in reference output' 53 | elseif sysTok == nil and refTok == nil then 54 | break 55 | end 56 | 57 | scorer.update(sysTok, refTok) 58 | end 59 | 60 | print(scorer.resultString()) 61 | -------------------------------------------------------------------------------- /scripts/binarize.lua: -------------------------------------------------------------------------------- 1 | -- Copyright (c) 2017-present, Facebook, Inc. 2 | -- All rights reserved. 3 | -- 4 | -- This source code is licensed under the license found in the LICENSE file in 5 | -- the root directory of this source tree. An additional grant of patent rights 6 | -- can be found in the PATENTS file in the same directory. 7 | -- 8 | --[[ 9 | -- 10 | -- 11 | --]] 12 | 13 | require 'fairseq.text' 14 | local tok = require 'fairseq.text.tokenizer' 15 | 16 | local dict = torch.load(arg[1]) 17 | local text = arg[2] 18 | local dest = arg[3] 19 | 20 | local stats = tok.binarize{ 21 | filename = text, 22 | dict = dict, 23 | indexfilename = dest .. '.idx', 24 | datafilename = dest .. '.bin', 25 | } 26 | print(stats) 27 | -------------------------------------------------------------------------------- /scripts/build_sym_alignment.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | # 8 | 9 | #!/usr/bin/env python 10 | 11 | """ 12 | 13 | Use this script in order to build symmetric alignments for your translation 14 | dataset. 15 | 16 | This script depends on fast_align and mosesdecoder tools. You will need to 17 | build those before running the script. 18 | 19 | fast_align: 20 | github: http://github.com/clab/fast_align 21 | instructions: follow the instructions in README.md 22 | 23 | mosesdecoder: 24 | github: http://github.com/moses-smt/mosesdecoder 25 | instructions: http://www.statmt.org/moses/?n=Development.GetStarted 26 | 27 | The script produces the following files under --output_dir: 28 | 29 | text.joined - concatenation of lines from the source_file and the 30 | target_file. 31 | 32 | align.forward - forward pass of fast_align. 33 | 34 | align.backward - backward pass of fast_align. 35 | 36 | aligned.sym_heuristic - symmetrized alignment. 37 | """ 38 | 39 | import argparse 40 | import os 41 | from itertools import izip 42 | 43 | 44 | def main(): 45 | parser = argparse.ArgumentParser(description='symmetric alignment builer') 46 | parser.add_argument('--fast_align_dir', 47 | help='path to fast_align build directory') 48 | parser.add_argument('--mosesdecoder_dir', 49 | help='path to mosesdecoder root directory') 50 | parser.add_argument('--sym_heuristic', 51 | help='heuristic to use for symmetrization', 52 | default='grow-diag-final-and') 53 | parser.add_argument('--source_file', 54 | help='path to a file with sentences ' 55 | 'in the source language') 56 | parser.add_argument('--target_file', 57 | help='path to a file with sentences ' 58 | 'in the target language') 59 | parser.add_argument('--output_dir', 60 | help='output directory') 61 | args = parser.parse_args() 62 | 63 | fast_align_bin = os.path.join(args.fast_align_dir, 'fast_align') 64 | symal_bin = os.path.join(args.mosesdecoder_dir, 'bin', 'symal') 65 | sym_fast_align_bin = os.path.join( 66 | args.mosesdecoder_dir, 'scripts', 'ems', 67 | 'support', 'symmetrize-fast-align.perl') 68 | 69 | # create joined file 70 | joined_file = os.path.join(args.output_dir, 'text.joined') 71 | with open(args.source_file, 'r') as src, open(args.target_file, 'r') as tgt: 72 | with open(joined_file, 'w') as joined: 73 | for s, t in izip(src, tgt): 74 | print >> joined, '%s ||| %s' % (s.strip(), t.strip()) 75 | 76 | bwd_align_file = os.path.join(args.output_dir, 'align.backward') 77 | 78 | # run forward alignment 79 | fwd_align_file = os.path.join(args.output_dir, 'align.forward') 80 | fwd_fast_align_cmd = '{FASTALIGN} -i {JOINED} -d -o -v > {FWD}'.format( 81 | FASTALIGN=fast_align_bin, 82 | JOINED=joined_file, 83 | FWD=fwd_align_file) 84 | assert os.system(fwd_fast_align_cmd) == 0 85 | 86 | # run backward alignment 87 | bwd_align_file = os.path.join(args.output_dir, 'align.backward') 88 | bwd_fast_align_cmd = '{FASTALIGN} -i {JOINED} -d -o -v -r > {BWD}'.format( 89 | FASTALIGN=fast_align_bin, 90 | JOINED=joined_file, 91 | BWD=bwd_align_file) 92 | assert os.system(bwd_fast_align_cmd) == 0 93 | 94 | # run symmetrization 95 | sym_out_file = os.path.join(args.output_dir, 'aligned') 96 | sym_cmd = '{SYMFASTALIGN} {FWD} {BWD} {SRC} ' \ 97 | '{TGT} {OUT} {HEURISTIC} {SYMAL}'.format( 98 | SYMFASTALIGN=sym_fast_align_bin, 99 | FWD=fwd_align_file, 100 | BWD=bwd_align_file, 101 | SRC=args.source_file, 102 | TGT=args.target_file, 103 | OUT=sym_out_file, 104 | HEURISTIC=args.sym_heuristic, 105 | SYMAL=symal_bin) 106 | assert os.system(sym_cmd) == 0 107 | 108 | 109 | if __name__ == '__main__': 110 | main() 111 | -------------------------------------------------------------------------------- /scripts/make_fconv_vocsel.lua: -------------------------------------------------------------------------------- 1 | -- Copyright (c) 2017-present, Facebook, Inc. 2 | -- All rights reserved. 3 | -- 4 | -- This source code is licensed under the license found in the LICENSE file in 5 | -- the root directory of this source tree. An additional grant of patent rights 6 | -- can be found in the PATENTS file in the same directory. 7 | -- 8 | --[[ 9 | -- 10 | -- Converts a fconv_model with a full softmax to a vocabulary selection model 11 | -- for CPU decoding. 12 | -- 13 | --]] 14 | 15 | local tnt = require 'torchnet' 16 | local plpath = require 'pl.path' 17 | local utils = require 'fairseq.utils' 18 | local mutils = require 'fairseq.models.utils' 19 | 20 | local cmd = torch.CmdLine() 21 | cmd:option('-modelin', '', 'path to model with flat softmax output layer') 22 | cmd:option('-modelout', '', 'path where vocab selection model is to be written') 23 | cmd:option('-model', '', 'model type {avgpool|blstm|conv|fconv}') 24 | cmd:option('-sourcelang', 'en', 'source language') 25 | cmd:option('-targetlang', 'de', 'target language') 26 | cmd:option('-datadir', '') 27 | cmd:option('-aligndictpath', '', 'path to an alignment dictionary (optional)') 28 | cmd:option('-nembed', 256, 'dimension of embeddings and attention') 29 | cmd:option('-noutembed', 256, 'dimension of the output embeddings') 30 | cmd:option('-nhid', 256, 'number of hidden units per layer') 31 | cmd:option('-nlayer', 1, 'number of hidden layers in decoder') 32 | cmd:option('-nenclayer', 1, 'number of hidden layers in encoder') 33 | cmd:option('-nagglayer', -1, 34 | 'number of layers for conv encoder aggregation stack (CNN-c)') 35 | cmd:option('-kwidth', 3, 'kernel width for conv encoder') 36 | cmd:option('-klmwidth', 3, 'kernel width for convolutional language models') 37 | 38 | cmd:option('-cudnnconv', false, 'use cudnn.TemporalConvolution (slower)') 39 | cmd:option('-attnlayers', '-1', 'decoder layers with attention (-1: all)') 40 | cmd:option('-bfactor', 0, 'factor to divide nhid in bottleneck structure') 41 | cmd:option('-fconv_nhids', '', 42 | 'comma-separated list of hidden units for each encoder layer') 43 | cmd:option('-fconv_nlmhids', '', 44 | 'comma-separated list of hidden units for each decoder layer') 45 | cmd:option('-fconv_kwidths', '', 46 | 'comma-separated list of kernel widths for conv encoder') 47 | cmd:option('-fconv_klmwidths', '', 48 | 'comma-separated list of kernel widths for convolutional language model') 49 | 50 | local config = cmd:parse(arg) 51 | 52 | assert(config.model == 'fconv', 'only conversion for fconv models supported') 53 | 54 | -- parse hidden sizes and kernal widths 55 | -- encoder 56 | config.nhids = utils.parseListOrDefault( 57 | config.fconv_nhids, config.nenclayer, config.nhid) 58 | config.kwidths = utils.parseListOrDefault( 59 | config.fconv_kwidths, config.nenclayer, config.kwidth) 60 | 61 | -- deconder 62 | config.nlmhids = utils.parseListOrDefault( 63 | config.fconv_nlmhids, config.nlayer, config.nhid) 64 | config.klmwidths = utils.parseListOrDefault( 65 | config.fconv_klmwidths, config.nlayer, config.klmwidth) 66 | 67 | 68 | ------------------------------------------------------------------- 69 | -- Load data 70 | ------------------------------------------------------------------- 71 | config.dict = torch.load(plpath.join(config.datadir, 72 | 'dict.' .. config.targetlang .. '.th7')) 73 | print(string.format('| [%s] Dictionary: %d types', config.targetlang, 74 | config.dict:size())) 75 | config.srcdict = torch.load(plpath.join(config.datadir, 76 | 'dict.' .. config.sourcelang .. '.th7')) 77 | print(string.format('| [%s] Dictionary: %d types', config.sourcelang, 78 | config.srcdict:size())) 79 | 80 | -- augment config with alignaligndictpath 81 | config.aligndict = tnt.IndexedDatasetReader{ 82 | indexfilename = config.aligndictpath .. '.idx', 83 | datafilename = config.aligndictpath .. '.bin', 84 | mmap = true, 85 | mmapidx = true, 86 | } 87 | 88 | -- load existing model and build vocab selection model 89 | local model = torch.load(config.modelin) 90 | local selmodel = require( 91 | string.format('fairseq.models.%s_model', 92 | config.model)).new(config) 93 | 94 | -- convert both models to CPU 95 | model:float() 96 | selmodel:float() 97 | 98 | model.module:evaluate() 99 | selmodel.module:evaluate() 100 | model.module:training() 101 | selmodel.module:training() 102 | local p, _ = model.module:parameters() 103 | local sp, _ = selmodel.module:parameters() 104 | assert(#p - 2 == #sp, 'Number of parameters do not match') 105 | 106 | -- copy parameters which should match 107 | for i = 1, #p - 3 do 108 | sp[i]:copy(p[i]:typeAs(sp[i])) 109 | end 110 | 111 | -- find Linear/WeightNorm and LookupTable in both models 112 | local lutm = mutils.findAnnotatedNode(selmodel.module, 'outmodule') 113 | :get(2):get(1):get(2) 114 | local linm = mutils.findAnnotatedNode(model.module, 'outmodule'):get(4) 115 | 116 | 117 | -- Next, copy the weights computed by WeightNorm(Linear) to LookupTable 118 | -- Note: we cannot copy the parameters as WeightNorm:parameters() does 119 | -- not return the weight of the wrapped module but only the direction (v) and 120 | -- length (g). So we find and copy the weight tensor (lutm.weight) instead. 121 | 122 | -- copy Linear.weight to LookupTable 123 | lutm.weight:narrow(2, 1, config.noutembed):copy( 124 | linm.weight:typeAs(lutm.weight)) 125 | 126 | -- copy Linear.bias to LookupTable 127 | lutm.weight:narrow(2, config.noutembed + 1, 1):copy( 128 | linm.bias:typeAs(lutm.weight)) 129 | 130 | 131 | -- check that the norms of the old and new output word embeddings match 132 | local lut = sp[#sp] 133 | print('compare norms of weight/bias of output layers in each model + params:') 134 | print(string.format('bias norms: %f, %f, %f', 135 | linm.bias:norm(), lutm.weight:narrow(2, config.noutembed + 1, 1):norm(), 136 | lut:narrow(2, config.noutembed + 1, 1):norm())) 137 | print(string.format('weight norms: %f, %f, %f', 138 | linm.weight:norm(), lutm.weight:narrow(2, 1, config.noutembed):norm(), 139 | lut:narrow(2, 1, config.noutembed):norm())) 140 | 141 | assert(linm.bias:ne(lutm.weight:narrow(2, config.noutembed + 1, 1)):sum() == 0) 142 | assert(linm.weight:ne(lutm.weight:narrow(2, 1, config.noutembed)):sum() == 0) 143 | 144 | -- save vocab selection model 145 | print(string.format('saving vocab selection model to %s', config.modelout)) 146 | torch.save(config.modelout, selmodel) 147 | -------------------------------------------------------------------------------- /scripts/makealigndict.lua: -------------------------------------------------------------------------------- 1 | -- Copyright (c) 2017-present, Facebook, Inc. 2 | -- All rights reserved. 3 | -- 4 | -- This source code is licensed under the license found in the LICENSE file in 5 | -- the root directory of this source tree. An additional grant of patent rights 6 | -- can be found in the PATENTS file in the same directory. 7 | -- 8 | --[[ 9 | -- 10 | -- Generates an alignment dictionary for usage in scripts/unkreplace.lua. 11 | -- Use build_sym_alignment.py to generate a symmetrized alignment file 12 | -- for a given corpus. 13 | -- 14 | --]] 15 | 16 | require 'fairseq' 17 | 18 | local tds = require 'tds' 19 | local stringx = require 'pl.stringx' 20 | local tablex = require 'pl.tablex' 21 | local tok = require 'fairseq.text.tokenizer' 22 | 23 | local cmd = torch.CmdLine() 24 | cmd:option('-source', '', 'source text') 25 | cmd:option('-target', '', 'target text') 26 | cmd:option('-alignment', '', 'alignment file') 27 | cmd:option('-output', 'aligndict.th7', 'destination file') 28 | 29 | local config = cmd:parse(arg) 30 | local tokenize = tok.tokenize 31 | local dict = tds.Hash() 32 | 33 | -- Count alignment frequencies 34 | local source = io.open(config.source) 35 | local target = io.open(config.target) 36 | local alignment = io.open(config.alignment) 37 | local n = 0 38 | while true do 39 | local s = source:read() 40 | if s == nil then 41 | break 42 | end 43 | local t = target:read() 44 | local a = alignment:read() 45 | 46 | local stoks = tokenize(s) 47 | local ttoks = tokenize(t) 48 | local atoks = tokenize(a) 49 | for _, atok in ipairs(atoks) do 50 | local apair = tablex.map(tonumber, stringx.split(atok, '-')) 51 | local stok = stoks[apair[1] + 1] 52 | local ttok = ttoks[apair[2] + 1] 53 | if not dict[stok] then 54 | dict[stok] = tds.Hash() 55 | end 56 | if not dict[stok][ttok] then 57 | dict[stok][ttok] = 1 58 | else 59 | dict[stok][ttok] = dict[stok][ttok] + 1 60 | end 61 | end 62 | 63 | n = n + 1 64 | if n % 25000 == 0 then 65 | print(string.format('Processed %d sentences', n)) 66 | end 67 | end 68 | print(string.format('Processed %d sentences', n)) 69 | 70 | -- Only keep the most frequently aligned words 71 | local adict = tds.Hash() 72 | for stok, v in pairs(dict) do 73 | local maxf = -1 74 | for ttok, f in pairs(v) do 75 | if f > maxf then 76 | maxf = f 77 | adict[stok] = ttok 78 | end 79 | end 80 | end 81 | 82 | torch.save(config.output, adict) 83 | -------------------------------------------------------------------------------- /scripts/makedict.lua: -------------------------------------------------------------------------------- 1 | -- Copyright (c) 2017-present, Facebook, Inc. 2 | -- All rights reserved. 3 | -- 4 | -- This source code is licensed under the license found in the LICENSE file in 5 | -- the root directory of this source tree. An additional grant of patent rights 6 | -- can be found in the PATENTS file in the same directory. 7 | -- 8 | --[[ 9 | -- 10 | -- Creates a dictionary from raw text. 11 | -- 12 | --]] 13 | 14 | require 'fairseq.text' 15 | local tok = require 'fairseq.text.tokenizer' 16 | 17 | local cmd = torch.CmdLine() 18 | cmd:option('-text', 'source text') 19 | cmd:option('-out', 'target path') 20 | cmd:option('-threshold', 0, 21 | 'map words appearing less than threshold times to unknown') 22 | cmd:option('-nwords', -1, 'number of non-control target words to retain') 23 | local config = cmd:parse(arg) 24 | 25 | assert(not (config.nwords >= 0 and config.threshold > 0), 26 | 'Specify either a frequency threshold or a word count') 27 | 28 | local dict = tok.buildDictionary{ 29 | filename = config.text, 30 | threshold = config.threshold, 31 | } 32 | if config.nwords >= 0 then 33 | dict.cutoff = config.nwords + dict.nspecial 34 | end 35 | 36 | print(string.format('| Dictionary: %d types', dict:size())) 37 | torch.save(config.out, dict, 'binary', false) 38 | -------------------------------------------------------------------------------- /scripts/unkreplace.lua: -------------------------------------------------------------------------------- 1 | -- Copyright (c) 2017-present, Facebook, Inc. 2 | -- All rights reserved. 3 | -- 4 | -- This source code is licensed under the license found in the LICENSE file in 5 | -- the root directory of this source tree. An additional grant of patent rights 6 | -- can be found in the PATENTS file in the same directory. 7 | -- 8 | --[[ 9 | -- 10 | -- Performs unknown word replacement on output from generation.lua. Use 11 | -- 'makealigndict' to generate an alignment dictionary. 12 | -- Prints post-processed hypotheses in sorted order. 13 | -- 14 | --]] 15 | 16 | require 'fairseq' 17 | 18 | local stringx = require 'pl.stringx' 19 | local tablex = require 'pl.tablex' 20 | 21 | local cmd = torch.CmdLine() 22 | cmd:option('-genout', '-', 'generation output') 23 | cmd:option('-source', '', 'source language input') 24 | cmd:option('-dict', '', 'path to alignment dictionary') 25 | cmd:option('-unk', '<unk>', 'unknown word marker') 26 | cmd:option('-offset', 0, 'apply offset to attention maxima') 27 | 28 | local config = cmd:parse(arg) 29 | local dict = torch.load(config.dict) 30 | 31 | local function readFile(path) 32 | local lines = {} 33 | local fd = io.open(path) 34 | while true do 35 | local line = fd:read() 36 | if line == nil then 37 | break 38 | end 39 | table.insert(lines, line) 40 | end 41 | return lines 42 | end 43 | local srcs = readFile(config.source) 44 | 45 | local fd 46 | if config.genout == '-' then 47 | fd = io.stdin 48 | else 49 | fd = io.open(config.genout) 50 | end 51 | 52 | local hypos = {} 53 | local attns = {} 54 | while true do 55 | local line = fd:read() 56 | if line == nil then 57 | break 58 | end 59 | local parts = stringx.split(line, '\t') 60 | 61 | local num = parts[1]:match('^H%-(%d+)') 62 | if num then 63 | num = tonumber(num) 64 | hypos[num] = parts[3] 65 | else 66 | num = parts[1]:match('^A%-(%d+)') 67 | if num then 68 | num = tonumber(num) 69 | attns[num] = tablex.map(tonumber, stringx.split(parts[2])) 70 | end 71 | end 72 | end 73 | 74 | assert(#hypos == #attns, 75 | 'Number of hypotheses and attention scores does not match') 76 | assert(#hypos == #srcs, 77 | 'Number of hypotheses and source sentences does not match') 78 | 79 | for i = 1, #hypos do 80 | local htoks = stringx.split(hypos[i]) 81 | local stoks = stringx.split(srcs[i]) 82 | for j = 1, #htoks do 83 | if htoks[j] == config.unk then 84 | local attn = attns[i][j] + config.offset 85 | if attn == #stoks + 1 then 86 | io.stderr:write(string.format( 87 | 'Sentence %d: <unk> was predicted to EOS. %d\n', 88 | i)) 89 | break 90 | elseif attn < 1 or attn > #stoks then 91 | io.stderr:write(string.format( 92 | 'Sentence %d: attention index out of bound: %d\n', 93 | i, attn)) 94 | else 95 | local stok = stoks[attn] 96 | if dict[stok] then 97 | htoks[j] = dict[stok] 98 | else 99 | htoks[j] = stok 100 | end 101 | end 102 | end 103 | end 104 | print(stringx.join(' ', htoks)) 105 | end 106 | -------------------------------------------------------------------------------- /test/test.lua: -------------------------------------------------------------------------------- 1 | -- Copyright (c) 2017-present, Facebook, Inc. 2 | -- All rights reserved. 3 | -- 4 | -- This source code is licensed under the license found in the LICENSE file in 5 | -- the root directory of this source tree. An additional grant of patent rights 6 | -- can be found in the PATENTS file in the same directory. 7 | 8 | if not package.loaded['fairseq'] then 9 | __main__ = true 10 | require 'fairseq' 11 | end 12 | local dir = require 'pl.dir' 13 | local path = require 'pl.path' 14 | 15 | local tester = torch.Tester() 16 | -- Collect all the tests 17 | local testdir = path.abspath(path.dirname(debug.getinfo(1).source:sub(2))) 18 | for _, file in pairs(dir.getfiles(testdir, 'test_*.lua')) do 19 | tester:add(paths.dofile(file)(tester)) 20 | end 21 | 22 | local function dotest(tests) 23 | tester:run(tests) 24 | end 25 | 26 | if __main__ then 27 | if #arg > 0 then 28 | dotest(arg) 29 | else 30 | dotest() 31 | end 32 | else 33 | return tester 34 | end 35 | -------------------------------------------------------------------------------- /test/test_appendbias.lua: -------------------------------------------------------------------------------- 1 | -- Copyright (c) 2017-present, Facebook, Inc. 2 | -- All rights reserved. 3 | -- 4 | -- This source code is licensed under the license found in the LICENSE file in 5 | -- the root directory of this source tree. An additional grant of patent rights 6 | -- can be found in the PATENTS file in the same directory. 7 | -- 8 | --[[ 9 | -- 10 | -- Tests for the nn.AppendBias module. 11 | -- 12 | --]] 13 | 14 | require 'fairseq.modules' 15 | 16 | local tester 17 | local test = torch.TestSuite() 18 | 19 | function test.AppendBias_Forward() 20 | local m = nn.AppendBias() 21 | local input = torch.Tensor{{2, 3}, {4, 5}} 22 | local output = torch.Tensor{{2, 3, 1}, {4, 5, 1}} 23 | tester:assert(torch.all(torch.eq(m:forward(input), output))) 24 | end 25 | 26 | function test.AppendBias_Backward() 27 | local m = nn.AppendBias() 28 | local input = torch.Tensor{{2, 3}, {4, 5}} 29 | local gradOutput = torch.Tensor{{7, 8, 2}, {9, 10, 2}} 30 | local gradInput = torch.Tensor{{7, 8}, {9, 10}} 31 | tester:assert(torch.all(torch.eq(m:backward(input, gradOutput), gradInput))) 32 | end 33 | 34 | return function(_tester_) 35 | tester = _tester_ 36 | return test 37 | end 38 | -------------------------------------------------------------------------------- /test/test_dictionary.lua: -------------------------------------------------------------------------------- 1 | -- Copyright (c) 2017-present, Facebook, Inc. 2 | -- All rights reserved. 3 | -- 4 | -- This source code is licensed under the license found in the LICENSE file in 5 | -- the root directory of this source tree. An additional grant of patent rights 6 | -- can be found in the PATENTS file in the same directory. 7 | -- 8 | --[[ 9 | -- 10 | -- Tests for the Dictionary class. 11 | -- 12 | --]] 13 | 14 | require 'fairseq.text' 15 | 16 | local tester 17 | local test = torch.TestSuite() 18 | 19 | function test.Dictionary_Simple() 20 | local dict = Dictionary{} 21 | dict:addSymbol('foo') 22 | dict:addSymbol('bar') 23 | dict:addSymbol('baz') 24 | dict:addSymbol('foo') 25 | dict:finalize() 26 | tester:assertGeneralEq(6, dict:size()) -- 3 special tokens for unk/pad/eos 27 | tester:assertGeneralNe(dict.unk_index, dict:getIndex('foo')) 28 | tester:assertGeneralEq('foo', dict:getSymbol(dict:getIndex('foo'))) 29 | tester:assertGeneralEq(dict.unk_index, dict:getIndex('???')) 30 | tester:assertGeneralEq(dict.unk, dict:getSymbol(dict:getIndex('???'))) 31 | tester:assertGeneralEq('foo bar </s>', dict:getString(torch.IntTensor{ 32 | dict:getIndex('foo'), 33 | dict:getIndex('bar'), 34 | dict.eos_index, 35 | })) 36 | -- Dictionary is sorted by frequency 37 | tester:assert(dict:getIndex('foo') < dict:getIndex('bar')) 38 | end 39 | 40 | function test.Dictionary_NoFinalize() 41 | local dict = Dictionary{} 42 | dict:addSymbol('foo') 43 | dict:addSymbol('bar') 44 | tester:assertError(function() return dict:size() end) 45 | dict:finalize() 46 | tester:assertGeneralEq(5, dict:size()) 47 | end 48 | 49 | function test.Dictionary_Thresholding() 50 | local dict = Dictionary{threshold=3} 51 | dict:addSymbol('baz') 52 | dict:addSymbol('foo') 53 | dict:addSymbol('foo') 54 | dict:addSymbol('foo') 55 | dict:addSymbol('bar') 56 | dict:addSymbol('bar') 57 | dict:finalize() 58 | tester:assertGeneralEq(dict.unk_index, dict:getIndex(dict.unk)) 59 | tester:assertGeneralEq(dict.pad_index, dict:getIndex(dict.pad)) 60 | tester:assertGeneralEq(dict.eos_index, dict:getIndex(dict.eos)) 61 | tester:assertGeneralEq(4, dict:size()) 62 | tester:assertGeneralEq(4, dict:getIndex('foo')) 63 | tester:assertGeneralEq(dict.unk_index, dict:getIndex('baz')) 64 | tester:assertGeneralEq(dict.unk_index, dict:getIndex('bar')) 65 | 66 | local dict2 = Dictionary{threshold=2} 67 | dict2:addSymbol('foo') 68 | dict2:addSymbol('bar') 69 | dict2:addSymbol('baz') 70 | dict2:finalize() 71 | tester:assertGeneralEq(3, dict2:size()) 72 | end 73 | 74 | function test.Dictionary_CustomSpecialSymbols() 75 | local dict = Dictionary{unk='UNK', pad='PAD', eos='EOS'} 76 | tester:assertGeneralEq('UNK', dict.unk) 77 | tester:assertGeneralEq('PAD', dict.pad) 78 | tester:assertGeneralEq('EOS', dict.eos) 79 | tester:assertGeneralEq(1, dict:getIndex('UNK')) 80 | tester:assertGeneralEq(2, dict:getIndex('PAD')) 81 | tester:assertGeneralEq(3, dict:getIndex('EOS')) 82 | end 83 | 84 | return function(_tester_) 85 | tester = _tester_ 86 | return test 87 | end 88 | -------------------------------------------------------------------------------- /test/test_logsoftmax.lua: -------------------------------------------------------------------------------- 1 | -- Copyright (c) 2017-present, Facebook, Inc. 2 | -- All rights reserved. 3 | -- 4 | -- This source code is licensed under the license found in the LICENSE file in 5 | -- the root directory of this source tree. An additional grant of patent rights 6 | -- can be found in the PATENTS file in the same directory. 7 | -- 8 | --[[ 9 | -- 10 | -- Tests for the log-softmax C implementation 11 | -- 12 | --]] 13 | 14 | local clib = require 'fairseq.clib' 15 | local nn = require 'nn' 16 | 17 | local tester 18 | local test = torch.TestSuite() 19 | 20 | local function timing(f, N, showtime) 21 | local timer = torch.Timer() 22 | for i = 1, N do f() end 23 | local elapsed = timer:time().real 24 | if showtime then 25 | print(('lsm %.4f msec (ran %.4f sec)'):format(elapsed*1000/N, elapsed)) 26 | end 27 | end 28 | 29 | local function dotest(N, bsz, beam, v, showtime) 30 | local x = torch.FloatTensor(bsz * beam, v):uniform() 31 | 32 | if showtime then 33 | print('nn') 34 | end 35 | local lsm = nn.LogSoftMax():float() 36 | timing(function() lsm:forward(x) end, N, showtime) 37 | 38 | if showtime then 39 | print('cpp') 40 | end 41 | local lsm2 = clib.logsoftmax() 42 | local y 43 | timing(function() y = lsm2(x) end, N, showtime) 44 | 45 | if N == 1 then 46 | local err = y:clone():add(-1, lsm.output):abs():max() 47 | tester:assert(err < 1e-6) 48 | end 49 | end 50 | 51 | function test.LogSoftmax_Accuracy() 52 | dotest(1, 32, 5, 40*1000, false) 53 | end 54 | 55 | --[[ Disable speed test since they're time-consuming 56 | function test.LogSoftmax_SingleSpeed() 57 | dotest(5000, 1, 5, 40*1000, true) 58 | end 59 | 60 | function test.LogSoftmax_Batch() 61 | dotest(5000/32, 32, 5, 40*1000, true) 62 | end 63 | --]] 64 | 65 | return function(_tester_) 66 | tester = _tester_ 67 | return test 68 | end 69 | -------------------------------------------------------------------------------- /test/test_tokenizer.lua: -------------------------------------------------------------------------------- 1 | -- Copyright (c) 2017-present, Facebook, Inc. 2 | -- All rights reserved. 3 | -- 4 | -- This source code is licensed under the license found in the LICENSE file in 5 | -- the root directory of this source tree. An additional grant of patent rights 6 | -- can be found in the PATENTS file in the same directory. 7 | -- 8 | --[[ 9 | -- 10 | -- Tests for tokenizer. 11 | -- 12 | --]] 13 | 14 | require 'fairseq' 15 | local tokenizer = require 'fairseq.text.tokenizer' 16 | local tnt = require 'torchnet' 17 | local path = require 'pl.path' 18 | local pltablex = require 'pl.tablex' 19 | local plutils = require 'pl.utils' 20 | 21 | local tester 22 | local test = torch.TestSuite() 23 | 24 | local testdir = path.abspath(path.dirname(debug.getinfo(1).source:sub(2))) 25 | local testdata = testdir .. '/tst2012.en' 26 | local testdataUrl = 'https://nlp.stanford.edu/projects/nmt/data/iwslt15.en-vi/tst2012.en' 27 | if not path.exists(testdata) then 28 | require 'os' 29 | os.execute('curl ' .. testdataUrl .. ' > ' .. testdata) 30 | if path.getsize(testdata) ~= 140250 then 31 | error('Failed to download test data from ' .. testdataUrl) 32 | end 33 | local head = io.open(testdata):read(15) 34 | if head ~= 'How can I speak' then 35 | error('Failed to download test data from ' .. testdataUrl) 36 | end 37 | end 38 | 39 | function test.Tokenizer_BuildDictionary() 40 | local dict = tokenizer.buildDictionary{ 41 | filename = testdata, 42 | threshold = 0, 43 | } 44 | tester:assertGeneralEq(3730, dict:size()) 45 | tester:assertGeneralEq(dict.unk_index, dict:getIndex('NotInCorpus')) 46 | 47 | local dict2 = tokenizer.buildDictionary{ 48 | filename = testdata, 49 | threshold = 100, 50 | } 51 | tester:assertGeneralEq(38, dict2:size()) 52 | 53 | -- Use a custom tokenizer that removes all 'the's 54 | local dict3 = tokenizer.buildDictionary{ 55 | filename = testdata, 56 | tokenize = function(line) 57 | local words = tokenizer.tokenize(line) 58 | return pltablex.filter(words, function (w) return w ~= 'the' end) 59 | end, 60 | threshold = 0, 61 | } 62 | tester:assertGeneralEq(dict3.unk_index, dict3:getIndex('the')) 63 | tester:assertGeneralEq(3729, dict3:size()) 64 | end 65 | 66 | function test.Tokenizer_BuildDictionaryMultipleFiles() 67 | local dict2 = tokenizer.buildDictionary{ 68 | filenames = {testdata, testdata, testdata, testdata}, 69 | threshold = 100 * 4, 70 | } 71 | tester:assertGeneralEq(38, dict2:size()) 72 | end 73 | 74 | function test.Tokenizer_Tensorize() 75 | local dict = tokenizer.buildDictionary{ 76 | filename = testdata, 77 | threshold = 0, 78 | } 79 | local data, stats = tokenizer.tensorize{ 80 | filename = testdata, 81 | dict = dict, 82 | } 83 | local smap, words = data.smap, data.words 84 | tester:assertGeneralEq({1553, 2}, smap:size():totable()) 85 | tester:assertGeneralEq({29536}, words:size():totable()) 86 | tester:assertGeneralEq(1553, stats.nseq) 87 | tester:assertGeneralEq(29536, stats.ntok) 88 | tester:assertGeneralEq(0, stats.nunk) 89 | 90 | tester:assertGeneralEq('He is my grandfather . </s>', dict:getString( 91 | words:narrow(1, smap[11][1], smap[11][2]) 92 | )) 93 | end 94 | 95 | function test.Tokenizer_TensorizeString() 96 | local dict = tokenizer.makeDictionary{ 97 | threshold = 0, 98 | } 99 | local tokens = plutils.split('aa bb cc', ' ') 100 | for _, token in ipairs(tokens) do 101 | dict:addSymbol(token) 102 | end 103 | local text = 'aa cc' 104 | local tensor = tokenizer.tensorizeString{ 105 | text = text, 106 | dict = dict, 107 | } 108 | for i, token in ipairs(plutils.split(text, ' ')) do 109 | tester:assertGeneralEq(dict:getIndex(token), tensor[i]) 110 | end 111 | end 112 | 113 | function test.Tokenizer_TensorizeAlignment() 114 | local alignmenttext = '0-1 1-2 2-1' 115 | local tensor = tokenizer.tensorizeAlignment{ 116 | text = alignmenttext, 117 | } 118 | tester:assertGeneralEq(tensor:size(1), 3) 119 | tester:assertGeneralEq(tensor:size(2), 2) 120 | tester:assertGeneralEq(tensor[1][1], 0 + 1) 121 | tester:assertGeneralEq(tensor[1][2], 1 + 1) 122 | tester:assertGeneralEq(tensor[2][1], 1 + 1) 123 | tester:assertGeneralEq(tensor[2][2], 2 + 1) 124 | tester:assertGeneralEq(tensor[3][1], 2 + 1) 125 | tester:assertGeneralEq(tensor[3][2], 1 + 1) 126 | end 127 | 128 | function test.Tokenizer_TensorizeThresh() 129 | local dict = tokenizer.buildDictionary{ 130 | filename = testdata, 131 | threshold = 50, 132 | } 133 | local data, stats = tokenizer.tensorize{ 134 | filename = testdata, 135 | dict = dict, 136 | } 137 | local smap, words = data.smap, data.words 138 | tester:assertGeneralEq({1553, 2}, smap:size():totable()) 139 | tester:assertGeneralEq({29536}, words:size():totable()) 140 | tester:assertGeneralEq(1553, stats.nseq) 141 | tester:assertGeneralEq(29536, stats.ntok) 142 | tester:assertGeneralEq(11485, stats.nunk) 143 | 144 | tester:assertGeneralEq('<unk> is my <unk> . </s>', dict:getString( 145 | words:narrow(1, smap[11][1], smap[11][2]) 146 | )) 147 | end 148 | 149 | function test.Tokenizer_Binarize() 150 | local dict = tokenizer.buildDictionary{ 151 | filename = testdata, 152 | threshold = 0, 153 | } 154 | 155 | -- XXX A temporary directory function would be great 156 | local dest = os.tmpname() 157 | 158 | local res = tokenizer.binarize{ 159 | filename = testdata, 160 | dict = dict, 161 | indexfilename = dest .. '.idx', 162 | datafilename = dest .. '.bin', 163 | } 164 | tester:assertGeneralEq(1553, res.nseq) 165 | tester:assertGeneralEq(29536, res.ntok) 166 | tester:assertGeneralEq(0, res.nunk) 167 | 168 | local field = path.basename(dest) 169 | local ds = tnt.IndexedDataset{ 170 | fields = {field}, 171 | path = paths.dirname(dest), 172 | } 173 | tester:assertGeneralEq(1553, ds:size()) 174 | tester:assertGeneralEq('He is my grandfather . </s>', dict:getString( 175 | ds:get(11)[field] 176 | )) 177 | end 178 | 179 | function test.Tokenizer_BinarizeThresh() 180 | local dict = tokenizer.buildDictionary{ 181 | filename = testdata, 182 | threshold = 50, 183 | } 184 | 185 | -- XXX A temporary directory function would be great 186 | local dest = os.tmpname() 187 | 188 | local res = tokenizer.binarize{ 189 | filename = testdata, 190 | dict = dict, 191 | indexfilename = dest .. '.idx', 192 | datafilename = dest .. '.bin', 193 | } 194 | tester:assertGeneralEq(1553, res.nseq) 195 | tester:assertGeneralEq(29536, res.ntok) 196 | tester:assertGeneralEq(11485, res.nunk) 197 | 198 | local field = path.basename(dest) 199 | local ds = tnt.IndexedDataset{ 200 | fields = {field}, 201 | path = paths.dirname(dest), 202 | } 203 | tester:assertGeneralEq(1553, ds:size()) 204 | tester:assertGeneralEq('<unk> is my <unk> . </s>', dict:getString( 205 | ds:get(11)[field] 206 | )) 207 | end 208 | 209 | function test.Tokenizer_BinarizeAlignment() 210 | local function makeFile(line) 211 | local filename = os.tmpname() 212 | local file = io.open(filename, 'w') 213 | file:write(line .. '\n') 214 | file:close(file) 215 | return filename 216 | end 217 | 218 | local srcfile = makeFile('a b c a') 219 | local srcdict = tokenizer.buildDictionary{ 220 | filename = srcfile, 221 | threshold = 0, 222 | } 223 | 224 | local tgtfile = makeFile('x y z w x') 225 | local tgtdict = tokenizer.buildDictionary{ 226 | filename = tgtfile, 227 | threshold = 0, 228 | } 229 | 230 | local alignfile = makeFile('0-0 0-1 1-1 2-2 2-4 3-1 3-3') 231 | local alignfreqmap = tokenizer.buildAlignFreqMap{ 232 | alignfile = alignfile, 233 | srcfile = srcfile, 234 | tgtfile = tgtfile, 235 | srcdict = srcdict, 236 | tgtdict = tgtdict, 237 | } 238 | 239 | tester:assertGeneralEq(alignfreqmap[srcdict:getEosIndex()], nil) 240 | tester:assertGeneralEq(alignfreqmap[srcdict:getPadIndex()], nil) 241 | tester:assertGeneralEq(alignfreqmap[srcdict:getUnkIndex()], nil) 242 | 243 | tester:assertGeneralEq( 244 | alignfreqmap[srcdict:getIndex('a')][tgtdict:getIndex('x')], 1) 245 | tester:assertGeneralEq( 246 | alignfreqmap[srcdict:getIndex('a')][tgtdict:getIndex('y')], 2) 247 | tester:assertGeneralEq( 248 | alignfreqmap[srcdict:getIndex('a')][tgtdict:getIndex('w')], 1) 249 | 250 | tester:assertGeneralEq( 251 | alignfreqmap[srcdict:getIndex('b')][tgtdict:getIndex('y')], 1) 252 | 253 | tester:assertGeneralEq( 254 | alignfreqmap[srcdict:getIndex('c')][tgtdict:getIndex('x')], 1) 255 | tester:assertGeneralEq( 256 | alignfreqmap[srcdict:getIndex('c')][tgtdict:getIndex('z')], 1) 257 | 258 | local dest = os.tmpname() 259 | local stats = tokenizer.binarizeAlignFreqMap{ 260 | freqmap = alignfreqmap, 261 | srcdict = srcdict, 262 | indexfilename = dest .. '.idx', 263 | datafilename = dest .. '.bin', 264 | ncandidates = 2, 265 | } 266 | 267 | tester:assertGeneralEq(stats.npairs, 5) 268 | 269 | local reader = tnt.IndexedDatasetReader{ 270 | indexfilename = dest .. '.idx', 271 | datafilename = dest .. '.bin', 272 | } 273 | 274 | tester:assertGeneralEq(reader:get(srcdict:getEosIndex()):dim(), 0) 275 | tester:assertGeneralEq(reader:get(srcdict:getPadIndex()):dim(), 0) 276 | tester:assertGeneralEq(reader:get(srcdict:getUnkIndex()):dim(), 0) 277 | tester:assert(torch.all(torch.eq( 278 | reader:get(srcdict:getIndex('a')), 279 | torch.IntTensor{ 280 | {tgtdict:getIndex('y'), 2}, 281 | {tgtdict:getIndex('x'), 1}}))) 282 | tester:assert(torch.all(torch.eq( 283 | reader:get(srcdict:getIndex('b')), 284 | torch.IntTensor{{tgtdict:getIndex('y'), 1}}))) 285 | tester:assert(torch.all(torch.eq( 286 | reader:get(srcdict:getIndex('c')), 287 | torch.IntTensor{ 288 | {tgtdict:getIndex('z'), 1}, 289 | {tgtdict:getIndex('x'), 1}}))) 290 | end 291 | 292 | return function(_tester_) 293 | tester = _tester_ 294 | return test 295 | end 296 | -------------------------------------------------------------------------------- /test/test_topk.lua: -------------------------------------------------------------------------------- 1 | -- Copyright (c) 2017-present, Facebook, Inc. 2 | -- All rights reserved. 3 | -- 4 | -- This source code is licensed under the license found in the LICENSE file in 5 | -- the root directory of this source tree. An additional grant of patent rights 6 | -- can be found in the PATENTS file in the same directory. 7 | -- 8 | --[[ 9 | -- 10 | -- Tests for the topk C implementation 11 | -- 12 | --]] 13 | 14 | local clib = require 'fairseq.clib' 15 | 16 | local tester 17 | local test = torch.TestSuite() 18 | 19 | local function timing(f, N, showtime) 20 | local timer = torch.Timer() 21 | for i = 1, N do f() end 22 | local elapsed = timer:time().real 23 | if showtime then 24 | print(('topk %.4f msec (ran %.4f sec)'):format(elapsed*1000/N, elapsed)) 25 | end 26 | end 27 | 28 | local function dotest(N, bsz, beam, n, v, showtime) 29 | local t = torch.FloatTensor(bsz, beam * v):uniform() 30 | 31 | if showtime then 32 | print('torch') 33 | end 34 | local top, ind = torch.FloatTensor(), torch.LongTensor() 35 | timing(function() torch.topk(top, ind, t, n, 2, true, true) end, N, showtime) 36 | 37 | if showtime then 38 | print('cpp') 39 | end 40 | local top2, ind2 = torch.FloatTensor(), torch.LongTensor() 41 | timing(function() clib.topk(top2, ind2, t, n) end, N, showtime) 42 | 43 | if (v <= 100) then 44 | -- equality happens if too many samples accuracy is only tested with 45 | -- short lists. 46 | tester:assert(ind2:clone():add(-1, ind):abs():sum() == 0) 47 | end 48 | end 49 | 50 | function test.TopK_Accuracy() 51 | dotest(1, 32, 5, 10, 100, false) 52 | end 53 | 54 | --[[ Disable speed test since they're time-consuming 55 | function test.TopK_SingleSpeed() 56 | dotest(10000/32, 32, 5, 10, 40*1000, true) 57 | end 58 | 59 | function test.TopK_Batch() 60 | dotest(10000, 1, 5, 10, 40*1000, true) 61 | end 62 | --]] 63 | 64 | return function(_tester_) 65 | tester = _tester_ 66 | return test 67 | end 68 | -------------------------------------------------------------------------------- /test/test_zipalong.lua: -------------------------------------------------------------------------------- 1 | -- Copyright (c) 2017-present, Facebook, Inc. 2 | -- All rights reserved. 3 | -- 4 | -- This source code is licensed under the license found in the LICENSE file in 5 | -- the root directory of this source tree. An additional grant of patent rights 6 | -- can be found in the PATENTS file in the same directory. 7 | -- 8 | --[[ 9 | -- 10 | -- Tests for the nn.ZipAlong module. 11 | -- 12 | --]] 13 | 14 | require 'fairseq.modules' 15 | 16 | local tester 17 | local test = torch.TestSuite() 18 | 19 | function test.ZipAlong_Simple() 20 | local m = nn.ZipAlong() 21 | tester:assertGeneralEq( 22 | {{1, 4}, {2, 4}, {3, 4}}, 23 | m:forward{{1, 2, 3}, 4} 24 | ) 25 | tester:assertGeneralEq( 26 | {{1, {4, 4}}, {2, {4, 4}}, {3, {4, 4}}}, 27 | m:forward{{1, 2, 3}, {4, 4}} 28 | ) 29 | end 30 | 31 | local function toTable(t) 32 | if type(t) == 'table' then 33 | local tb = {} 34 | for k,v in pairs(t) do 35 | tb[k] = toTable(v) 36 | end 37 | return tb 38 | end 39 | return t:totable() 40 | end 41 | 42 | function test.ZipAlong_Tensor() 43 | local t = torch.Tensor({0, 1, 2}) 44 | local m = nn.ZipAlong() 45 | tester:assertGeneralEq( 46 | toTable({{t, t}, {t*2, t}, {t*4, t}}), 47 | toTable(m:forward{{t, t*2, t*4}, t}) 48 | ) 49 | tester:assertGeneralEq( 50 | toTable({{t, t*2, t*4}, t*3}), 51 | toTable(m:backward({{t, t*2, t*4}, t}, {{t, t}, {t*2, t}, {t*4, t}})) 52 | ) 53 | 54 | -- Add table along 55 | tester:assertGeneralEq( 56 | toTable({{t, {t, t*2}}, {t*2, {t, t*2}}, {t*4, {t, t*2}}}), 57 | toTable(m:forward{{t, t*2, t*4}, {t, t*2}}) 58 | ) 59 | tester:assertGeneralEq( 60 | toTable({{t, t*2, t*4}, {t*3, t*6}}), 61 | toTable(m:backward({{t, t*2, t*4}, {t, t*2}}, 62 | {{t, {t, t*2}}, {t*2, {t, t*2}}, {t*4, {t, t*2}}})) 63 | ) 64 | end 65 | 66 | -- Test in combination with map and add 67 | function test.ZipAlong_UpdateOutputAdd() 68 | local tensor = torch.Tensor({1, 2, 3, 4}) 69 | local add = torch.Tensor({3, 2, 2, 2}) 70 | local m = nn.Sequential() 71 | m:add(nn.ZipAlong()) 72 | m:add(nn.MapTable(nn.CAddTable())) 73 | 74 | local input = {{tensor, tensor * 2, tensor * 4}, add} 75 | local result = {tensor + add, tensor * 2 + add, tensor * 4 + add} 76 | local output = m:forward(input) 77 | for i = 1, #input[1] do 78 | tester:assertGeneralEq(result[i]:totable(), output[i]:totable()) 79 | end 80 | end 81 | 82 | -- Test in combination with map and add 83 | function test.ZipAlong_UpdateGradInputAdd() 84 | local tensor = torch.Tensor({1, 2, 3, 4}) 85 | local add = torch.Tensor({3, 2, 2, 2}) 86 | local m = nn.Sequential() 87 | m:add(nn.ZipAlong()) 88 | m:add(nn.MapTable(nn.CAddTable())) 89 | 90 | local input = {{tensor, tensor * 2, tensor * 4}, add} 91 | local gradients = {torch.Tensor({1, 0, 0, 0}), torch.Tensor({0, 1, 0, 0}), 92 | torch.Tensor({0, 0, 1, 0})} 93 | 94 | m:forward(input) 95 | local gradInput = m:backward(input, gradients) 96 | for i = 1, #input[1] do 97 | tester:assertGeneralEq( 98 | gradients[i]:totable(), 99 | gradInput[1][i]:totable() 100 | ) 101 | tester:assertGeneralEq( 102 | torch.add(gradients[1], gradients[2]):add(gradients[3]):totable(), 103 | gradInput[2]:totable() 104 | ) 105 | end 106 | end 107 | 108 | return function(_tester_) 109 | tester = _tester_ 110 | return test 111 | end 112 | -------------------------------------------------------------------------------- /tofloat.lua: -------------------------------------------------------------------------------- 1 | -- Copyright (c) 2017-present, Facebook, Inc. 2 | -- All rights reserved. 3 | -- 4 | -- This source code is licensed under the license found in the LICENSE file in 5 | -- the root directory of this source tree. An additional grant of patent rights 6 | -- can be found in the PATENTS file in the same directory. 7 | -- 8 | --[[ 9 | -- 10 | -- A helper script to convert a CUDA model into a CPU variant. 11 | -- 12 | --]] 13 | 14 | require 'fairseq' 15 | local utils = require 'fairseq.utils' 16 | 17 | local cuda = utils.loadCuda() 18 | assert(cuda.cutorch) 19 | 20 | local cmd = torch.CmdLine() 21 | cmd:option('-input_model', 'cuda_model.th7', 22 | 'a th7 file that contains a CUDA model') 23 | cmd:option('-output_model', 'float_model.th7', 24 | 'an output file that will contain the CPU verion of the model') 25 | local config = cmd:parse(arg) 26 | 27 | local model = torch.load(config.input_model) 28 | model:float() 29 | model.module:getParameters() 30 | torch.save(config.output_model, model) 31 | --------------------------------------------------------------------------------