├── .DS_Store ├── README.md ├── get-data-back-translate.sh ├── get-data-nmt-local.sh ├── preprocess.py ├── requirements.txt ├── src ├── .DS_Store ├── __init__.py ├── __init__.pyc ├── __pycache__ │ ├── __init__.cpython-36.pyc │ └── logger.cpython-36.pyc ├── data │ ├── __init__.py │ ├── __init__.pyc │ ├── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ └── dictionary.cpython-36.pyc │ ├── dataset.py │ ├── dictionary.py │ ├── dictionary.pyc │ └── loader.py ├── evaluation │ ├── __init__.py │ ├── evaluator.py │ ├── glue.py │ ├── multi-bleu.perl │ └── xnli.py ├── logger.py ├── logger.pyc ├── model │ ├── __init__.py │ ├── embedder.py │ ├── memory │ │ ├── __init__.py │ │ ├── memory.py │ │ ├── query.py │ │ └── utils.py │ ├── pretrain.py │ └── transformer.py ├── optim.py ├── slurm.py ├── trainer.py └── utils.py ├── tools ├── .DS_Store ├── README.md ├── lowercase_and_remove_accent.py ├── segment_th.py └── tokenize.sh ├── train.py ├── train_IBT.sh ├── train_IBT_plus_BACK.sh ├── train_IBT_plus_SRC.sh ├── train_sup.sh ├── translate.py └── translate_exe.sh /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jind11/DAMT/3caa22822b936137e5da3e827d7a5a2078c3115e/.DS_Store -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DAMT 2 | A new method for semi-supervised domain adaptation of Neural Machine Translation (NMT) 3 | 4 | This is the source code for the paper: [Jin, D., Jin, Z., Zhou, J.T., & Szolovits, P. (2020). Unsupervised Domain Adaptation for Neural Machine Translation with Iterative Back Translation. ArXiv, abs/2001.08140.](https://arxiv.org/abs/2001.08140). If you use the code, please cite the paper: 5 | 6 | ``` 7 | @article{Jin2020UnsupervisedDA, 8 | title={Unsupervised Domain Adaptation for Neural Machine Translation with Iterative Back Translation}, 9 | author={Di Jin and Zhijing Jin and Joey Tianyi Zhou and Peter Szolovits}, 10 | journal={ArXiv}, 11 | year={2020}, 12 | volume={abs/2001.08140} 13 | } 14 | ``` 15 | 16 | ## Prerequisites: 17 | Run the following command to install the prerequisite packages: 18 | ``` 19 | pip install -r requirements.txt 20 | ``` 21 | You should also install Moses tokenizer and fastBPE tool in the folder of "tools" by running the following commands: 22 | ``` 23 | cd tools 24 | git clone https://github.com/moses-smt/mosesdecoder 25 | git clone https://github.com/glample/fastBPE 26 | cd fastBPE 27 | g++ -std=c++11 -pthread -O3 fastBPE/main.cc -IfastBPE -o fast 28 | cd ../.. 29 | ``` 30 | 31 | ## Data: 32 | Please download the data from the [Google Drive](https://drive.google.com/file/d/1aQOXfcGpPbQemG4mQQuiy6ZrCRn6WiDj/view?usp=sharing) and unzip it to the main directory of this repository. The data downloaded include the domains of MED (EMEA), IT, LAW (ACQUIS), and TED for DE-EN language pair and MED, LAW, and TED for EN-RO language pair. WMT14 DE-EN data can be downloaded [here](https://nlp.stanford.edu/projects/nmt/) and WMT16 EN-RO data is downloaded from [here](https://www.statmt.org/wmt16/translation-task.html). 33 | 34 | ## How to use 35 | 1. First we need to download the pretrained model parameter files from the [XLM repository](https://github.com/facebookresearch/XLM#pretrained-xlmmlm-models). 36 | 37 | 2. Then we need to process the data. Suppose we want to train the NMT model from German (de) to English (en), and the source domain is Law (dataset name is acquis) and the target domain is IT, then run the following command: 38 | ``` 39 | ./get-data-nmt-local.sh --src de --tgt en --data_name it --data_path ./data/de-en/it --reload_codes PATH_TO_PRETRAINED_MODEL_CODES --reload_vocab PATH_TO_PRETRAINED_MODEL_VOCAB 40 | ./get-data-nmt-local.sh --src de --tgt en --data_name acquis --data_path ./data/de-en/acquis --reload_codes PATH_TO_PRETRAINED_MODEL_CODES --reload_vocab PATH_TO_PRETRAINED_MODEL_VOCAB 41 | ``` 42 | 43 | 3. After data processing, to reproduce the "IBT" setting as mentioned in the paper, run the following command: 44 | ``` 45 | ./train_IBT.sh --src de --tgt en --data_name it --pretrained_model_dir DIR_TO_PRETRAINED_MODEL 46 | ``` 47 | 48 | 4. To reproduce the "IBT+SRC" setting, Recall that we want to adapt from the Law domain to IT domain, where the source domain is Law (dataset name is acquis) and the target domain is IT, then run the following command: 49 | ``` 50 | ./train_IBT_plus_SRC.sh --src de --tgt en --src_data_name acquis --tgt_data_name it --pretrained_model_dir DIR_TO_PRETRAINED_MODEL 51 | ``` 52 | 53 | 5. In order to reproduce the "IBT+Back" setting, we need to go through several steps. 54 | 55 | * First of all, we need to train a NMT model to translate from en to de using the source domain data (acquis) by running the following command: 56 | ``` 57 | ./train_sup.sh --src en --tgt de --data_name acquis --pretrained_model_dir DIR_TO_PRETRAINED_MODEL 58 | ``` 59 | 60 | * After training this model, we get the translation results by using thie model to translate the English sentences in the target domain (it) to German, which are used as the back-translated data: 61 | ``` 62 | ./translate_exe.sh --src en --tgt de --data_name it --model_name acquis --model_dir DIR_TO_TRAINED_MODEL 63 | ./get-data-back-translate.sh --src en --tgt de --data_name it --model_name acquis 64 | ``` 65 | 66 | * When the back-translated data is ready, we can finally run this command: 67 | ``` 68 | ./train_IBT_plus_BACK.sh --src de --tgt en --src_data_name acquis --tgt_data_name it --pretrained_model_dir DIR_TO_PRETRAINED_MODEL 69 | ``` 70 | -------------------------------------------------------------------------------- /get-data-back-translate.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | set -e 9 | 10 | # 11 | # Read arguments 12 | # 13 | POSITIONAL=() 14 | while [[ $# -gt 0 ]] 15 | do 16 | key="$1" 17 | case $key in 18 | --src) 19 | SRC="$2"; shift 2;; 20 | --tgt) 21 | TGT="$2"; shift 2;; 22 | --data_name) 23 | DATA_NAME="$2"; shift 2;; 24 | --model_name) 25 | MODEL_NAME="$2"; shift 2;; 26 | *) 27 | POSITIONAL+=("$1") 28 | shift 29 | ;; 30 | esac 31 | done 32 | set -- "${POSITIONAL[@]}" 33 | 34 | if [ "$SRC" \> "$TGT" ]; then echo "please ensure SRC < TGT"; exit; fi 35 | 36 | MAIN_PATH=$PWD 37 | DATA_PATH=data/$SRC-$TGT/$DATA_NAME 38 | PROC_PATH=$DATA_PATH/processed/$SRC-$TGT 39 | BACK_DATA_DIR=$DATA_PATH/back_translate/$MODEL_NAME 40 | FULL_VOCAB=$PROC_PATH/vocab.$SRC-$TGT 41 | 42 | $MAIN_PATH/preprocess.py $FULL_VOCAB $BACK_DATA_DIR/train.$SRC-$TGT.$SRC 43 | $MAIN_PATH/preprocess.py $FULL_VOCAB $BACK_DATA_DIR/train.$SRC-$TGT.$TGT 44 | -------------------------------------------------------------------------------- /get-data-nmt-local.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | set -e 9 | 10 | 11 | # 12 | # Data preprocessing configuration 13 | # 14 | CODES=60000 # number of BPE codes 15 | N_THREADS=16 # number of threads in data preprocessing 16 | 17 | 18 | # 19 | # Read arguments 20 | # 21 | POSITIONAL=() 22 | while [[ $# -gt 0 ]] 23 | do 24 | key="$1" 25 | case $key in 26 | --src) 27 | SRC="$2"; shift 2;; 28 | --tgt) 29 | TGT="$2"; shift 2;; 30 | --data_name) 31 | DATA_NAME="$2"; shift 2;; 32 | --data_path) 33 | DATA_PATH="$2"; shift 2;; 34 | --reload_codes) 35 | RELOAD_CODES="$2"; shift 2;; 36 | --reload_vocab) 37 | RELOAD_VOCAB="$2"; shift 2;; 38 | *) 39 | POSITIONAL+=("$1") 40 | shift 41 | ;; 42 | esac 43 | done 44 | set -- "${POSITIONAL[@]}" 45 | 46 | 47 | # 48 | # Check parameters 49 | # 50 | if [ "$SRC" == "" ]; then echo "--src not provided"; exit; fi 51 | if [ "$TGT" == "" ]; then echo "--tgt not provided"; exit; fi 52 | if [ "$SRC" != "de" -a "$SRC" != "en" -a "$SRC" != "fr" -a "$SRC" != "ro" ]; then echo "unknown source language"; exit; fi 53 | if [ "$TGT" != "de" -a "$TGT" != "en" -a "$TGT" != "fr" -a "$TGT" != "ro" ]; then echo "unknown target language"; exit; fi 54 | if [ "$SRC" == "$TGT" ]; then echo "source and target cannot be identical"; exit; fi 55 | if [ "$SRC" \> "$TGT" ]; then echo "please ensure SRC < TGT"; exit; fi 56 | if [ "$RELOAD_CODES" != "" ] && [ ! -f "$RELOAD_CODES" ]; then echo "cannot locate BPE codes"; exit; fi 57 | if [ "$RELOAD_VOCAB" != "" ] && [ ! -f "$RELOAD_VOCAB" ]; then echo "cannot locate vocabulary"; exit; fi 58 | if [ "$RELOAD_CODES" == "" -a "$RELOAD_VOCAB" != "" -o "$RELOAD_CODES" != "" -a "$RELOAD_VOCAB" == "" ]; then echo "BPE codes should be provided if and only if vocabulary is also provided"; exit; fi 59 | 60 | 61 | # 62 | # Initialize tools and data paths 63 | # 64 | 65 | # main paths 66 | MAIN_PATH=$PWD 67 | TOOLS_PATH=$PWD/tools 68 | PROC_PATH=$DATA_PATH/processed/$SRC-$TGT 69 | 70 | # create paths 71 | mkdir -p $PROC_PATH 72 | 73 | # moses 74 | MOSES=$TOOLS_PATH/mosesdecoder 75 | REPLACE_UNICODE_PUNCT=$MOSES/scripts/tokenizer/replace-unicode-punctuation.perl 76 | NORM_PUNC=$MOSES/scripts/tokenizer/normalize-punctuation.perl 77 | REM_NON_PRINT_CHAR=$MOSES/scripts/tokenizer/remove-non-printing-char.perl 78 | TOKENIZER=$MOSES/scripts/tokenizer/tokenizer.perl 79 | INPUT_FROM_SGM=$MOSES/scripts/ems/support/input-from-sgm.perl 80 | 81 | # fastBPE 82 | FASTBPE_DIR=$TOOLS_PATH/fastBPE 83 | FASTBPE=$TOOLS_PATH/fastBPE/fast 84 | 85 | # Sennrich's WMT16 scripts for Romanian preprocessing 86 | WMT16_SCRIPTS=$TOOLS_PATH/wmt16-scripts 87 | NORMALIZE_ROMANIAN=$WMT16_SCRIPTS/preprocess/normalise-romanian.py 88 | REMOVE_DIACRITICS=$WMT16_SCRIPTS/preprocess/remove-diacritics.py 89 | 90 | # raw and tokenized files 91 | SRC_RAW=$DATA_PATH/train.$SRC.mono 92 | TGT_RAW=$DATA_PATH/train.$TGT.mono 93 | SRC_TOK=$SRC_RAW.tok 94 | TGT_TOK=$TGT_RAW.tok 95 | 96 | # BPE / vocab files 97 | BPE_CODES=$PROC_PATH/codes 98 | SRC_VOCAB=$PROC_PATH/vocab.$SRC 99 | TGT_VOCAB=$PROC_PATH/vocab.$TGT 100 | FULL_VOCAB=$PROC_PATH/vocab.$SRC-$TGT 101 | 102 | # train / valid / test monolingual BPE data 103 | SRC_TRAIN_BPE=$PROC_PATH/train.$SRC 104 | TGT_TRAIN_BPE=$PROC_PATH/train.$TGT 105 | SRC_VALID_BPE=$PROC_PATH/valid.$SRC 106 | TGT_VALID_BPE=$PROC_PATH/valid.$TGT 107 | SRC_TEST_BPE=$PROC_PATH/test.$SRC 108 | TGT_TEST_BPE=$PROC_PATH/test.$TGT 109 | 110 | # valid / test parallel BPE data 111 | PARA_SRC_TRAIN_BPE=$PROC_PATH/train.$SRC-$TGT.$SRC 112 | PARA_TGT_TRAIN_BPE=$PROC_PATH/train.$SRC-$TGT.$TGT 113 | PARA_SRC_VALID_BPE=$PROC_PATH/valid.$SRC-$TGT.$SRC 114 | PARA_TGT_VALID_BPE=$PROC_PATH/valid.$SRC-$TGT.$TGT 115 | PARA_SRC_TEST_BPE=$PROC_PATH/test.$SRC-$TGT.$SRC 116 | PARA_TGT_TEST_BPE=$PROC_PATH/test.$SRC-$TGT.$TGT 117 | 118 | # valid / test file raw data 119 | unset PARA_SRC_VALID PARA_TGT_VALID PARA_SRC_TEST PARA_TGT_TEST 120 | PARA_SRC_TRAIN=$DATA_PATH/train.$SRC 121 | PARA_TGT_TRAIN=$DATA_PATH/train.$TGT 122 | PARA_SRC_VALID=$DATA_PATH/dev.$SRC 123 | PARA_TGT_VALID=$DATA_PATH/dev.$TGT 124 | PARA_SRC_TEST=$DATA_PATH/test.$SRC 125 | PARA_TGT_TEST=$DATA_PATH/test.$TGT 126 | 127 | #cd $DATA_PATH 128 | 129 | # preprocessing commands - special case for Romanian 130 | if [ "$SRC" == "ro" ]; then 131 | SRC_PREPROCESSING="$REPLACE_UNICODE_PUNCT | $NORM_PUNC -l $SRC | $REM_NON_PRINT_CHAR | $NORMALIZE_ROMANIAN | $REMOVE_DIACRITICS | $TOKENIZER -l $SRC -no-escape -threads $N_THREADS" 132 | else 133 | SRC_PREPROCESSING="$REPLACE_UNICODE_PUNCT | $NORM_PUNC -l $SRC | $REM_NON_PRINT_CHAR | $TOKENIZER -l $SRC -no-escape -threads $N_THREADS" 134 | fi 135 | if [ "$TGT" == "ro" ]; then 136 | TGT_PREPROCESSING="$REPLACE_UNICODE_PUNCT | $NORM_PUNC -l $TGT | $REM_NON_PRINT_CHAR | $NORMALIZE_ROMANIAN | $REMOVE_DIACRITICS | $TOKENIZER -l $TGT -no-escape -threads $N_THREADS" 137 | else 138 | TGT_PREPROCESSING="$REPLACE_UNICODE_PUNCT | $NORM_PUNC -l $TGT | $REM_NON_PRINT_CHAR | $TOKENIZER -l $TGT -no-escape -threads $N_THREADS" 139 | fi 140 | 141 | # tokenize data 142 | if ! [[ -f "$SRC_TOK" ]]; then 143 | echo "Tokenize $SRC monolingual data..." 144 | eval "cat $SRC_RAW | $SRC_PREPROCESSING > $SRC_TOK" 145 | fi 146 | 147 | if ! [[ -f "$TGT_TOK" ]]; then 148 | echo "Tokenize $TGT monolingual data..." 149 | eval "cat $TGT_RAW | $TGT_PREPROCESSING > $TGT_TOK" 150 | fi 151 | echo "$SRC monolingual data tokenized in: $SRC_TOK" 152 | echo "$TGT monolingual data tokenized in: $TGT_TOK" 153 | 154 | # reload BPE codes 155 | cd $MAIN_PATH 156 | if [ ! -f "$BPE_CODES" ] && [ -f "$RELOAD_CODES" ]; then 157 | echo "Reloading BPE codes from $RELOAD_CODES ..." 158 | cp $RELOAD_CODES $BPE_CODES 159 | fi 160 | 161 | # learn BPE codes 162 | if [ ! -f "$BPE_CODES" ]; then 163 | echo "Learning BPE codes..." 164 | $FASTBPE learnbpe $CODES $SRC_TOK $TGT_TOK > $BPE_CODES 165 | fi 166 | echo "BPE learned in $BPE_CODES" 167 | 168 | # apply BPE codes 169 | if ! [[ -f "$SRC_TRAIN_BPE" ]]; then 170 | echo "Applying $SRC BPE codes..." 171 | $FASTBPE applybpe $SRC_TRAIN_BPE $SRC_TOK $BPE_CODES 172 | fi 173 | if ! [[ -f "$TGT_TRAIN_BPE" ]]; then 174 | echo "Applying $TGT BPE codes..." 175 | $FASTBPE applybpe $TGT_TRAIN_BPE $TGT_TOK $BPE_CODES 176 | fi 177 | echo "BPE codes applied to $SRC in: $SRC_TRAIN_BPE" 178 | echo "BPE codes applied to $TGT in: $TGT_TRAIN_BPE" 179 | 180 | # extract source and target vocabulary 181 | if ! [[ -f "$SRC_VOCAB" && -f "$TGT_VOCAB" ]]; then 182 | echo "Extracting vocabulary..." 183 | $FASTBPE getvocab $SRC_TRAIN_BPE > $SRC_VOCAB 184 | $FASTBPE getvocab $TGT_TRAIN_BPE > $TGT_VOCAB 185 | fi 186 | echo "$SRC vocab in: $SRC_VOCAB" 187 | echo "$TGT vocab in: $TGT_VOCAB" 188 | 189 | # reload full vocabulary 190 | cd $MAIN_PATH 191 | if [ ! -f "$FULL_VOCAB" ] && [ -f "$RELOAD_VOCAB" ]; then 192 | echo "Reloading vocabulary from $RELOAD_VOCAB ..." 193 | cp $RELOAD_VOCAB $FULL_VOCAB 194 | fi 195 | 196 | # extract full vocabulary 197 | if ! [[ -f "$FULL_VOCAB" ]]; then 198 | echo "Extracting vocabulary..." 199 | $FASTBPE getvocab $SRC_TRAIN_BPE $TGT_TRAIN_BPE > $FULL_VOCAB 200 | fi 201 | echo "Full vocab in: $FULL_VOCAB" 202 | 203 | # binarize data 204 | if ! [[ -f "$SRC_TRAIN_BPE.pth" ]]; then 205 | echo "Binarizing $SRC data..." 206 | $MAIN_PATH/preprocess.py $FULL_VOCAB $SRC_TRAIN_BPE 207 | fi 208 | if ! [[ -f "$TGT_TRAIN_BPE.pth" ]]; then 209 | echo "Binarizing $TGT data..." 210 | $MAIN_PATH/preprocess.py $FULL_VOCAB $TGT_TRAIN_BPE 211 | fi 212 | echo "$SRC binarized data in: $SRC_TRAIN_BPE.pth" 213 | echo "$TGT binarized data in: $TGT_TRAIN_BPE.pth" 214 | 215 | # 216 | # Download parallel data (for evaluation only) 217 | # 218 | 219 | echo "Tokenizing parallel train, valid and test data..." 220 | eval "cat $PARA_SRC_TRAIN | $SRC_PREPROCESSING > $PARA_SRC_TRAIN.tok" 221 | eval "cat $PARA_TGT_TRAIN | $TGT_PREPROCESSING > $PARA_TGT_TRAIN.tok" 222 | eval "cat $PARA_SRC_VALID | $SRC_PREPROCESSING > $PARA_SRC_VALID.tok" 223 | eval "cat $PARA_TGT_VALID | $TGT_PREPROCESSING > $PARA_TGT_VALID.tok" 224 | eval "cat $PARA_SRC_TEST | $SRC_PREPROCESSING > $PARA_SRC_TEST.tok" 225 | eval "cat $PARA_TGT_TEST | $TGT_PREPROCESSING > $PARA_TGT_TEST.tok" 226 | 227 | echo "Applying BPE to train, valid and test files..." 228 | $FASTBPE applybpe $PARA_SRC_TRAIN_BPE $PARA_SRC_TRAIN.tok $BPE_CODES $SRC_VOCAB 229 | $FASTBPE applybpe $PARA_TGT_TRAIN_BPE $PARA_TGT_TRAIN.tok $BPE_CODES $TGT_VOCAB 230 | $FASTBPE applybpe $PARA_SRC_VALID_BPE $PARA_SRC_VALID.tok $BPE_CODES $SRC_VOCAB 231 | $FASTBPE applybpe $PARA_TGT_VALID_BPE $PARA_TGT_VALID.tok $BPE_CODES $TGT_VOCAB 232 | $FASTBPE applybpe $PARA_SRC_TEST_BPE $PARA_SRC_TEST.tok $BPE_CODES $SRC_VOCAB 233 | $FASTBPE applybpe $PARA_TGT_TEST_BPE $PARA_TGT_TEST.tok $BPE_CODES $TGT_VOCAB 234 | 235 | echo "Binarizing data..." 236 | rm -f $PARA_SRC_TRAIN_BPE.pth $PARA_TGT_TRAIN_BPE.pth $PARA_SRC_VALID_BPE.pth $PARA_TGT_VALID_BPE.pth $PARA_SRC_TEST_BPE.pth $PARA_TGT_TEST_BPE.pth 237 | $MAIN_PATH/preprocess.py $FULL_VOCAB $PARA_SRC_TRAIN_BPE 238 | $MAIN_PATH/preprocess.py $FULL_VOCAB $PARA_TGT_TRAIN_BPE 239 | $MAIN_PATH/preprocess.py $FULL_VOCAB $PARA_SRC_VALID_BPE 240 | $MAIN_PATH/preprocess.py $FULL_VOCAB $PARA_TGT_VALID_BPE 241 | $MAIN_PATH/preprocess.py $FULL_VOCAB $PARA_SRC_TEST_BPE 242 | $MAIN_PATH/preprocess.py $FULL_VOCAB $PARA_TGT_TEST_BPE 243 | 244 | 245 | # 246 | # Link monolingual validation and test data to parallel data 247 | # 248 | cd $PROC_PATH 249 | ln -sf valid.$SRC-$TGT.$SRC.pth valid.$SRC.pth 250 | ln -sf valid.$SRC-$TGT.$TGT.pth valid.$TGT.pth 251 | ln -sf test.$SRC-$TGT.$SRC.pth test.$SRC.pth 252 | ln -sf test.$SRC-$TGT.$TGT.pth test.$TGT.pth 253 | 254 | # 255 | # Summary 256 | # 257 | echo "" 258 | echo "===== Data summary" 259 | echo "Monolingual training data:" 260 | echo " $SRC: $SRC_TRAIN_BPE.pth" 261 | echo " $TGT: $TGT_TRAIN_BPE.pth" 262 | echo "Monolingual validation data:" 263 | echo " $SRC: $SRC_VALID_BPE.pth" 264 | echo " $TGT: $TGT_VALID_BPE.pth" 265 | echo "Monolingual test data:" 266 | echo " $SRC: $SRC_TEST_BPE.pth" 267 | echo " $TGT: $TGT_TEST_BPE.pth" 268 | echo "Parallel training data:" 269 | echo " $SRC: $PARA_SRC_TRAIN_BPE.pth" 270 | echo " $TGT: $PARA_TGT_TRAIN_BPE.pth" 271 | echo "Parallel validation data:" 272 | echo " $SRC: $PARA_SRC_VALID_BPE.pth" 273 | echo " $TGT: $PARA_TGT_VALID_BPE.pth" 274 | echo "Parallel test data:" 275 | echo " $SRC: $PARA_SRC_TEST_BPE.pth" 276 | echo " $TGT: $PARA_TGT_TEST_BPE.pth" 277 | echo "" 278 | -------------------------------------------------------------------------------- /preprocess.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | # Copyright (c) 2019-present, Facebook, Inc. 5 | # All rights reserved. 6 | # 7 | # This source code is licensed under the license found in the 8 | # LICENSE file in the root directory of this source tree. 9 | # 10 | 11 | 12 | """ 13 | Example: python data/vocab.txt data/train.txt 14 | vocab.txt: 1stline=word, 2ndline=count 15 | """ 16 | 17 | import os 18 | import sys 19 | 20 | from src.logger import create_logger 21 | from src.data.dictionary import Dictionary 22 | 23 | 24 | if __name__ == '__main__': 25 | 26 | logger = create_logger(None, 0) 27 | 28 | voc_path = sys.argv[1] 29 | txt_path = sys.argv[2] 30 | bin_path = sys.argv[2] + '.pth' 31 | assert os.path.isfile(voc_path) 32 | assert os.path.isfile(txt_path) 33 | 34 | dico = Dictionary.read_vocab(voc_path) 35 | logger.info("") 36 | 37 | data = Dictionary.index_data(txt_path, bin_path, dico) 38 | logger.info("%i words (%i unique) in %i sentences." % ( 39 | len(data['sentences']) - len(data['positions']), 40 | len(data['dico']), 41 | len(data['positions']) 42 | )) 43 | if len(data['unk_words']) > 0: 44 | logger.info("%i unknown words (%i unique), covering %.2f%% of the data." % ( 45 | sum(data['unk_words'].values()), 46 | len(data['unk_words']), 47 | sum(data['unk_words'].values()) * 100. / (len(data['sentences']) - len(data['positions'])) 48 | )) 49 | if len(data['unk_words']) < 30: 50 | for w, c in sorted(data['unk_words'].items(), key=lambda x: x[1])[::-1]: 51 | logger.info("%s: %i" % (w, c)) 52 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | certifi==2020.4.5.1 2 | future==0.18.2 3 | numpy==1.18.5 4 | torch==1.2.0 5 | -------------------------------------------------------------------------------- /src/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jind11/DAMT/3caa22822b936137e5da3e827d7a5a2078c3115e/src/.DS_Store -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jind11/DAMT/3caa22822b936137e5da3e827d7a5a2078c3115e/src/__init__.py -------------------------------------------------------------------------------- /src/__init__.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jind11/DAMT/3caa22822b936137e5da3e827d7a5a2078c3115e/src/__init__.pyc -------------------------------------------------------------------------------- /src/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jind11/DAMT/3caa22822b936137e5da3e827d7a5a2078c3115e/src/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /src/__pycache__/logger.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jind11/DAMT/3caa22822b936137e5da3e827d7a5a2078c3115e/src/__pycache__/logger.cpython-36.pyc -------------------------------------------------------------------------------- /src/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jind11/DAMT/3caa22822b936137e5da3e827d7a5a2078c3115e/src/data/__init__.py -------------------------------------------------------------------------------- /src/data/__init__.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jind11/DAMT/3caa22822b936137e5da3e827d7a5a2078c3115e/src/data/__init__.pyc -------------------------------------------------------------------------------- /src/data/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jind11/DAMT/3caa22822b936137e5da3e827d7a5a2078c3115e/src/data/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /src/data/__pycache__/dictionary.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jind11/DAMT/3caa22822b936137e5da3e827d7a5a2078c3115e/src/data/__pycache__/dictionary.cpython-36.pyc -------------------------------------------------------------------------------- /src/data/dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | from logging import getLogger 9 | import math 10 | import numpy as np 11 | import torch 12 | 13 | 14 | logger = getLogger() 15 | 16 | 17 | class StreamDataset(object): 18 | 19 | def __init__(self, sent, pos, bs, params): 20 | """ 21 | Prepare batches for data iterator. 22 | """ 23 | bptt = params.bptt 24 | self.eos = params.eos_index 25 | 26 | # checks 27 | assert len(pos) == (sent == self.eos).sum() 28 | assert len(pos) == (sent[pos[:, 1]] == self.eos).sum() 29 | 30 | n_tokens = len(sent) 31 | n_batches = math.ceil(n_tokens / (bs * bptt)) 32 | t_size = n_batches * bptt * bs 33 | 34 | buffer = np.zeros(t_size, dtype=sent.dtype) + self.eos 35 | buffer[t_size - n_tokens:] = sent 36 | buffer = buffer.reshape((bs, n_batches * bptt)).T 37 | self.data = np.zeros((n_batches * bptt + 1, bs), dtype=sent.dtype) + self.eos 38 | self.data[1:] = buffer 39 | 40 | self.bptt = bptt 41 | self.n_tokens = n_tokens 42 | self.n_batches = n_batches 43 | self.n_sentences = len(pos) 44 | self.lengths = torch.LongTensor(bs).fill_(bptt) 45 | 46 | def __len__(self): 47 | """ 48 | Number of sentences in the dataset. 49 | """ 50 | return self.n_sentences 51 | 52 | def select_data(self, a, b): 53 | """ 54 | Only select a subset of the dataset. 55 | """ 56 | if not (0 <= a < b <= self.n_batches): 57 | logger.warning("Invalid split values: %i %i - %i" % (a, b, self.n_batches)) 58 | return 59 | assert 0 <= a < b <= self.n_batches 60 | logger.info("Selecting batches from %i to %i ..." % (a, b)) 61 | 62 | # sub-select 63 | self.data = self.data[a * self.bptt:b * self.bptt] 64 | self.n_batches = b - a 65 | self.n_sentences = (self.data == self.eos).sum().item() 66 | 67 | def get_iterator(self, shuffle, subsample=1): 68 | """ 69 | Return a sentences iterator. 70 | """ 71 | indexes = (np.random.permutation if shuffle else range)(self.n_batches // subsample) 72 | for i in indexes: 73 | a = self.bptt * i 74 | b = self.bptt * (i + 1) 75 | yield torch.from_numpy(self.data[a:b].astype(np.int64)), self.lengths 76 | 77 | 78 | class Dataset(object): 79 | 80 | def __init__(self, sent, pos, params): 81 | 82 | self.eos_index = params.eos_index 83 | self.pad_index = params.pad_index 84 | self.batch_size = params.batch_size 85 | self.tokens_per_batch = params.tokens_per_batch 86 | self.max_batch_size = params.max_batch_size 87 | 88 | self.sent = sent 89 | self.pos = pos 90 | self.lengths = self.pos[:, 1] - self.pos[:, 0] 91 | 92 | # check number of sentences 93 | assert len(self.pos) == (self.sent == self.eos_index).sum() 94 | 95 | # # remove empty sentences 96 | # self.remove_empty_sentences() 97 | 98 | # sanity checks 99 | self.check() 100 | 101 | def __len__(self): 102 | """ 103 | Number of sentences in the dataset. 104 | """ 105 | return len(self.pos) 106 | 107 | def check(self): 108 | """ 109 | Sanity checks. 110 | """ 111 | eos = self.eos_index 112 | assert len(self.pos) == (self.sent[self.pos[:, 1]] == eos).sum() # check sentences indices 113 | # assert self.lengths.min() > 0 # check empty sentences 114 | 115 | def batch_sentences(self, sentences): 116 | """ 117 | Take as input a list of n sentences (torch.LongTensor vectors) and return 118 | a tensor of size (slen, n) where slen is the length of the longest 119 | sentence, and a vector lengths containing the length of each sentence. 120 | """ 121 | # sentences = sorted(sentences, key=lambda x: len(x), reverse=True) 122 | lengths = torch.LongTensor([len(s) + 2 for s in sentences]) 123 | sent = torch.LongTensor(lengths.max().item(), lengths.size(0)).fill_(self.pad_index) 124 | 125 | sent[0] = self.eos_index 126 | for i, s in enumerate(sentences): 127 | if lengths[i] > 2: # if sentence not empty 128 | sent[1:lengths[i] - 1, i].copy_(torch.from_numpy(s.astype(np.int64))) 129 | sent[lengths[i] - 1, i] = self.eos_index 130 | 131 | return sent, lengths 132 | 133 | def remove_empty_sentences(self): 134 | """ 135 | Remove empty sentences. 136 | """ 137 | init_size = len(self.pos) 138 | indices = np.arange(len(self.pos)) 139 | indices = indices[self.lengths[indices] > 0] 140 | self.pos = self.pos[indices] 141 | self.lengths = self.pos[:, 1] - self.pos[:, 0] 142 | logger.info("Removed %i empty sentences." % (init_size - len(indices))) 143 | self.check() 144 | 145 | def remove_long_sentences(self, max_len): 146 | """ 147 | Remove sentences exceeding a certain length. 148 | """ 149 | assert max_len >= 0 150 | if max_len == 0: 151 | return 152 | init_size = len(self.pos) 153 | indices = np.arange(len(self.pos)) 154 | indices = indices[self.lengths[indices] <= max_len] 155 | self.pos = self.pos[indices] 156 | self.lengths = self.pos[:, 1] - self.pos[:, 0] 157 | logger.info("Removed %i too long sentences." % (init_size - len(indices))) 158 | self.check() 159 | 160 | def select_data(self, a, b): 161 | """ 162 | Only select a subset of the dataset. 163 | """ 164 | assert 0 <= a < b <= len(self.pos) 165 | logger.info("Selecting sentences from %i to %i ..." % (a, b)) 166 | 167 | # sub-select 168 | self.pos = self.pos[a:b] 169 | self.lengths = self.pos[:, 1] - self.pos[:, 0] 170 | 171 | # re-index 172 | min_pos = self.pos.min() 173 | max_pos = self.pos.max() 174 | self.pos -= min_pos 175 | self.sent = self.sent[min_pos:max_pos + 1] 176 | 177 | # sanity checks 178 | self.check() 179 | 180 | def get_batches_iterator(self, batches, return_indices): 181 | """ 182 | Return a sentences iterator, given the associated sentence batches. 183 | """ 184 | assert type(return_indices) is bool 185 | 186 | for sentence_ids in batches: 187 | if 0 < self.max_batch_size < len(sentence_ids): 188 | np.random.shuffle(sentence_ids) 189 | sentence_ids = sentence_ids[:self.max_batch_size] 190 | pos = self.pos[sentence_ids] 191 | sent = [self.sent[a:b] for a, b in pos] 192 | sent = self.batch_sentences(sent) 193 | yield (sent, sentence_ids) if return_indices else sent 194 | 195 | def get_iterator(self, shuffle, group_by_size=False, n_sentences=-1, seed=None, return_indices=False): 196 | """ 197 | Return a sentences iterator. 198 | """ 199 | assert seed is None or shuffle is True and type(seed) is int 200 | rng = np.random.RandomState(seed) 201 | n_sentences = len(self.pos) if n_sentences == -1 else n_sentences 202 | assert 0 < n_sentences <= len(self.pos) 203 | assert type(shuffle) is bool and type(group_by_size) is bool 204 | assert group_by_size is False or shuffle is True 205 | 206 | # sentence lengths 207 | lengths = self.lengths + 2 208 | 209 | # select sentences to iterate over 210 | if shuffle: 211 | indices = rng.permutation(len(self.pos))[:n_sentences] 212 | else: 213 | indices = np.arange(n_sentences) 214 | 215 | # group sentences by lengths 216 | if group_by_size: 217 | indices = indices[np.argsort(lengths[indices], kind='mergesort')] 218 | 219 | # create batches - either have a fixed number of sentences, or a similar number of tokens 220 | if self.tokens_per_batch == -1: 221 | batches = np.array_split(indices, math.ceil(len(indices) * 1. / self.batch_size)) 222 | else: 223 | batch_ids = np.cumsum(lengths[indices]) // self.tokens_per_batch 224 | _, bounds = np.unique(batch_ids, return_index=True) 225 | batches = [indices[bounds[i]:bounds[i + 1]] for i in range(len(bounds) - 1)] 226 | if bounds[-1] < len(indices): 227 | batches.append(indices[bounds[-1]:]) 228 | 229 | # optionally shuffle batches 230 | if shuffle: 231 | rng.shuffle(batches) 232 | 233 | # sanity checks 234 | assert n_sentences == sum([len(x) for x in batches]) 235 | assert lengths[indices].sum() == sum([lengths[x].sum() for x in batches]) 236 | # assert set.union(*[set(x.tolist()) for x in batches]) == set(range(n_sentences)) # slow 237 | 238 | # return the iterator 239 | return self.get_batches_iterator(batches, return_indices) 240 | 241 | 242 | class ParallelDataset(Dataset): 243 | 244 | def __init__(self, sent1, pos1, sent2, pos2, params): 245 | 246 | self.eos_index = params.eos_index 247 | self.pad_index = params.pad_index 248 | self.batch_size = params.batch_size 249 | self.tokens_per_batch = params.tokens_per_batch 250 | self.max_batch_size = params.max_batch_size 251 | 252 | self.sent1 = sent1 253 | self.sent2 = sent2 254 | self.pos1 = pos1 255 | self.pos2 = pos2 256 | self.lengths1 = self.pos1[:, 1] - self.pos1[:, 0] 257 | self.lengths2 = self.pos2[:, 1] - self.pos2[:, 0] 258 | 259 | # check number of sentences 260 | assert len(self.pos1) == (self.sent1 == self.eos_index).sum() 261 | assert len(self.pos2) == (self.sent2 == self.eos_index).sum() 262 | 263 | # remove empty sentences 264 | self.remove_empty_sentences() 265 | 266 | # sanity checks 267 | self.check() 268 | 269 | def __len__(self): 270 | """ 271 | Number of sentences in the dataset. 272 | """ 273 | return len(self.pos1) 274 | 275 | def check(self): 276 | """ 277 | Sanity checks. 278 | """ 279 | eos = self.eos_index 280 | assert len(self.pos1) == len(self.pos2) > 0 # check number of sentences 281 | assert len(self.pos1) == (self.sent1[self.pos1[:, 1]] == eos).sum() # check sentences indices 282 | assert len(self.pos2) == (self.sent2[self.pos2[:, 1]] == eos).sum() # check sentences indices 283 | assert eos <= self.sent1.min() < self.sent1.max() # check dictionary indices 284 | assert eos <= self.sent2.min() < self.sent2.max() # check dictionary indices 285 | assert self.lengths1.min() > 0 # check empty sentences 286 | assert self.lengths2.min() > 0 # check empty sentences 287 | 288 | def remove_empty_sentences(self): 289 | """ 290 | Remove empty sentences. 291 | """ 292 | init_size = len(self.pos1) 293 | indices = np.arange(len(self.pos1)) 294 | indices = indices[self.lengths1[indices] > 0] 295 | indices = indices[self.lengths2[indices] > 0] 296 | self.pos1 = self.pos1[indices] 297 | self.pos2 = self.pos2[indices] 298 | self.lengths1 = self.pos1[:, 1] - self.pos1[:, 0] 299 | self.lengths2 = self.pos2[:, 1] - self.pos2[:, 0] 300 | logger.info("Removed %i empty sentences." % (init_size - len(indices))) 301 | self.check() 302 | 303 | def remove_long_sentences(self, max_len): 304 | """ 305 | Remove sentences exceeding a certain length. 306 | """ 307 | assert max_len >= 0 308 | if max_len == 0: 309 | return 310 | init_size = len(self.pos1) 311 | indices = np.arange(len(self.pos1)) 312 | indices = indices[self.lengths1[indices] <= max_len] 313 | indices = indices[self.lengths2[indices] <= max_len] 314 | self.pos1 = self.pos1[indices] 315 | self.pos2 = self.pos2[indices] 316 | self.lengths1 = self.pos1[:, 1] - self.pos1[:, 0] 317 | self.lengths2 = self.pos2[:, 1] - self.pos2[:, 0] 318 | logger.info("Removed %i too long sentences." % (init_size - len(indices))) 319 | self.check() 320 | 321 | def select_data(self, a, b): 322 | """ 323 | Only select a subset of the dataset. 324 | """ 325 | assert 0 <= a < b <= len(self.pos1) 326 | logger.info("Selecting sentences from %i to %i ..." % (a, b)) 327 | 328 | # sub-select 329 | self.pos1 = self.pos1[a:b] 330 | self.pos2 = self.pos2[a:b] 331 | self.lengths1 = self.pos1[:, 1] - self.pos1[:, 0] 332 | self.lengths2 = self.pos2[:, 1] - self.pos2[:, 0] 333 | 334 | # re-index 335 | min_pos1 = self.pos1.min() 336 | max_pos1 = self.pos1.max() 337 | min_pos2 = self.pos2.min() 338 | max_pos2 = self.pos2.max() 339 | self.pos1 -= min_pos1 340 | self.pos2 -= min_pos2 341 | self.sent1 = self.sent1[min_pos1:max_pos1 + 1] 342 | self.sent2 = self.sent2[min_pos2:max_pos2 + 1] 343 | 344 | # sanity checks 345 | self.check() 346 | 347 | def get_batches_iterator(self, batches, return_indices): 348 | """ 349 | Return a sentences iterator, given the associated sentence batches. 350 | """ 351 | assert type(return_indices) is bool 352 | 353 | for sentence_ids in batches: 354 | if 0 < self.max_batch_size < len(sentence_ids): 355 | np.random.shuffle(sentence_ids) 356 | sentence_ids = sentence_ids[:self.max_batch_size] 357 | pos1 = self.pos1[sentence_ids] 358 | pos2 = self.pos2[sentence_ids] 359 | sent1 = self.batch_sentences([self.sent1[a:b] for a, b in pos1]) 360 | sent2 = self.batch_sentences([self.sent2[a:b] for a, b in pos2]) 361 | yield (sent1, sent2, sentence_ids) if return_indices else (sent1, sent2) 362 | 363 | def get_iterator(self, shuffle, group_by_size=False, n_sentences=-1, return_indices=False): 364 | """ 365 | Return a sentences iterator. 366 | """ 367 | n_sentences = len(self.pos1) if n_sentences == -1 else n_sentences 368 | assert 0 < n_sentences <= len(self.pos1) 369 | assert type(shuffle) is bool and type(group_by_size) is bool 370 | 371 | # sentence lengths 372 | lengths = self.lengths1 + self.lengths2 + 4 373 | 374 | # select sentences to iterate over 375 | if shuffle: 376 | indices = np.random.permutation(len(self.pos1))[:n_sentences] 377 | else: 378 | indices = np.arange(n_sentences) 379 | 380 | # group sentences by lengths 381 | if group_by_size: 382 | indices = indices[np.argsort(lengths[indices], kind='mergesort')] 383 | 384 | # create batches - either have a fixed number of sentences, or a similar number of tokens 385 | if self.tokens_per_batch == -1: 386 | batches = np.array_split(indices, math.ceil(len(indices) * 1. / self.batch_size)) 387 | else: 388 | batch_ids = np.cumsum(lengths[indices]) // self.tokens_per_batch 389 | _, bounds = np.unique(batch_ids, return_index=True) 390 | batches = [indices[bounds[i]:bounds[i + 1]] for i in range(len(bounds) - 1)] 391 | if bounds[-1] < len(indices): 392 | batches.append(indices[bounds[-1]:]) 393 | 394 | # optionally shuffle batches 395 | if shuffle: 396 | np.random.shuffle(batches) 397 | 398 | # sanity checks 399 | assert n_sentences == sum([len(x) for x in batches]) 400 | assert lengths[indices].sum() == sum([lengths[x].sum() for x in batches]) 401 | # assert set.union(*[set(x.tolist()) for x in batches]) == set(range(n_sentences)) # slow 402 | 403 | # return the iterator 404 | return self.get_batches_iterator(batches, return_indices) 405 | -------------------------------------------------------------------------------- /src/data/dictionary.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import os 9 | import numpy as np 10 | import torch 11 | from logging import getLogger 12 | 13 | 14 | logger = getLogger() 15 | 16 | 17 | BOS_WORD = '' 18 | EOS_WORD = '' 19 | PAD_WORD = '' 20 | UNK_WORD = '' 21 | 22 | SPECIAL_WORD = '' 23 | SPECIAL_WORDS = 10 24 | 25 | SEP_WORD = SPECIAL_WORD % 0 26 | MASK_WORD = SPECIAL_WORD % 1 27 | 28 | 29 | class Dictionary(object): 30 | 31 | def __init__(self, id2word, word2id, counts): 32 | assert len(id2word) == len(word2id) == len(counts) 33 | self.id2word = id2word 34 | self.word2id = word2id 35 | self.counts = counts 36 | self.bos_index = word2id[BOS_WORD] 37 | self.eos_index = word2id[EOS_WORD] 38 | self.pad_index = word2id[PAD_WORD] 39 | self.unk_index = word2id[UNK_WORD] 40 | self.check_valid() 41 | 42 | def __len__(self): 43 | """ 44 | Returns the number of words in the dictionary. 45 | """ 46 | return len(self.id2word) 47 | 48 | def __getitem__(self, i): 49 | """ 50 | Returns the word of the specified index. 51 | """ 52 | return self.id2word[i] 53 | 54 | def __contains__(self, w): 55 | """ 56 | Returns whether a word is in the dictionary. 57 | """ 58 | return w in self.word2id 59 | 60 | def __eq__(self, y): 61 | """ 62 | Compare this dictionary with another one. 63 | """ 64 | self.check_valid() 65 | y.check_valid() 66 | if len(self.id2word) != len(y): 67 | return False 68 | return all(self.id2word[i] == y[i] for i in range(len(y))) 69 | 70 | def check_valid(self): 71 | """ 72 | Check that the dictionary is valid. 73 | """ 74 | assert self.bos_index == 0 75 | assert self.eos_index == 1 76 | assert self.pad_index == 2 77 | assert self.unk_index == 3 78 | assert all(self.id2word[4 + i] == SPECIAL_WORD % i for i in range(SPECIAL_WORDS)) 79 | assert len(self.id2word) == len(self.word2id) == len(self.counts) 80 | assert set(self.word2id.keys()) == set(self.counts.keys()) 81 | for i in range(len(self.id2word)): 82 | assert self.word2id[self.id2word[i]] == i 83 | last_count = 1e18 84 | for i in range(4 + SPECIAL_WORDS, len(self.id2word) - 1): 85 | count = self.counts[self.id2word[i]] 86 | assert count <= last_count 87 | last_count = count 88 | 89 | def index(self, word, no_unk=False): 90 | """ 91 | Returns the index of the specified word. 92 | """ 93 | if no_unk: 94 | return self.word2id[word] 95 | else: 96 | return self.word2id.get(word, self.unk_index) 97 | 98 | def max_vocab(self, max_vocab): 99 | """ 100 | Limit the vocabulary size. 101 | """ 102 | assert max_vocab >= 1 103 | init_size = len(self) 104 | self.id2word = {k: v for k, v in self.id2word.items() if k < max_vocab} 105 | self.word2id = {v: k for k, v in self.id2word.items()} 106 | self.counts = {k: v for k, v in self.counts.items() if k in self.word2id} 107 | self.check_valid() 108 | logger.info("Maximum vocabulary size: %i. Dictionary size: %i -> %i (removed %i words)." 109 | % (max_vocab, init_size, len(self), init_size - len(self))) 110 | 111 | def min_count(self, min_count): 112 | """ 113 | Threshold on the word frequency counts. 114 | """ 115 | assert min_count >= 0 116 | init_size = len(self) 117 | self.id2word = {k: v for k, v in self.id2word.items() if self.counts[self.id2word[k]] >= min_count or k < 4 + SPECIAL_WORDS} 118 | self.word2id = {v: k for k, v in self.id2word.items()} 119 | self.counts = {k: v for k, v in self.counts.items() if k in self.word2id} 120 | self.check_valid() 121 | logger.info("Minimum frequency count: %i. Dictionary size: %i -> %i (removed %i words)." 122 | % (min_count, init_size, len(self), init_size - len(self))) 123 | 124 | @staticmethod 125 | def read_vocab(vocab_path): 126 | """ 127 | Create a dictionary from a vocabulary file. 128 | """ 129 | skipped = 0 130 | assert os.path.isfile(vocab_path), vocab_path 131 | word2id = {BOS_WORD: 0, EOS_WORD: 1, PAD_WORD: 2, UNK_WORD: 3} 132 | for i in range(SPECIAL_WORDS): 133 | word2id[SPECIAL_WORD % i] = 4 + i 134 | counts = {k: 0 for k in word2id.keys()} 135 | f = open(vocab_path, 'r', encoding='utf-8') 136 | for i, line in enumerate(f): 137 | if '\u2028' in line: 138 | skipped += 1 139 | continue 140 | line = line.rstrip().split() 141 | if len(line) != 2: 142 | skipped += 1 143 | continue 144 | assert len(line) == 2, (i, line) 145 | # assert line[0] not in word2id and line[1].isdigit(), (i, line) 146 | assert line[1].isdigit(), (i, line) 147 | if line[0] in word2id: 148 | skipped += 1 149 | print('%s already in vocab' % line[0]) 150 | continue 151 | if not line[1].isdigit(): 152 | skipped += 1 153 | print('Empty word at line %s with count %s' % (i, line)) 154 | continue 155 | word2id[line[0]] = 4 + SPECIAL_WORDS + i - skipped # shift because of extra words 156 | counts[line[0]] = int(line[1]) 157 | f.close() 158 | id2word = {v: k for k, v in word2id.items()} 159 | dico = Dictionary(id2word, word2id, counts) 160 | logger.info("Read %i words from the vocabulary file." % len(dico)) 161 | if skipped > 0: 162 | logger.warning("Skipped %i empty lines!" % skipped) 163 | return dico 164 | 165 | @staticmethod 166 | def index_data(path, bin_path, dico): 167 | """ 168 | Index sentences with a dictionary. 169 | """ 170 | if bin_path is not None and os.path.isfile(bin_path): 171 | print("Loading data from %s ..." % bin_path) 172 | data = torch.load(bin_path) 173 | assert dico == data['dico'] 174 | return data 175 | 176 | positions = [] 177 | sentences = [] 178 | unk_words = {} 179 | 180 | # index sentences 181 | f = open(path, 'r', encoding='utf-8') 182 | for i, line in enumerate(f): 183 | if i % 1000000 == 0 and i > 0: 184 | print(i) 185 | s = line.rstrip().split() 186 | # skip empty sentences 187 | if len(s) == 0: 188 | print("Empty sentence in line %i." % i) 189 | # index sentence words 190 | count_unk = 0 191 | indexed = [] 192 | for w in s: 193 | word_id = dico.index(w, no_unk=False) 194 | # if we find a special word which is not an unknown word, skip the sentence 195 | if 0 <= word_id < 4 + SPECIAL_WORDS and word_id != 3: 196 | logger.warning('Found unexpected special word "%s" (%i)!!' % (w, word_id)) 197 | continue 198 | assert word_id >= 0 199 | indexed.append(word_id) 200 | if word_id == dico.unk_index: 201 | unk_words[w] = unk_words.get(w, 0) + 1 202 | count_unk += 1 203 | # add sentence 204 | positions.append([len(sentences), len(sentences) + len(indexed)]) 205 | sentences.extend(indexed) 206 | sentences.append(1) # EOS index 207 | f.close() 208 | 209 | # tensorize data 210 | positions = np.int64(positions) 211 | if len(dico) < 1 << 16: 212 | sentences = np.uint16(sentences) 213 | elif len(dico) < 1 << 31: 214 | sentences = np.int32(sentences) 215 | else: 216 | raise Exception("Dictionary is too big.") 217 | assert sentences.min() >= 0 218 | data = { 219 | 'dico': dico, 220 | 'positions': positions, 221 | 'sentences': sentences, 222 | 'unk_words': unk_words, 223 | } 224 | if bin_path is not None: 225 | print("Saving the data to %s ..." % bin_path) 226 | torch.save(data, bin_path, pickle_protocol=4) 227 | 228 | return data 229 | -------------------------------------------------------------------------------- /src/data/dictionary.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jind11/DAMT/3caa22822b936137e5da3e827d7a5a2078c3115e/src/data/dictionary.pyc -------------------------------------------------------------------------------- /src/data/loader.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | from logging import getLogger 9 | import os 10 | import numpy as np 11 | import torch 12 | 13 | from .dataset import StreamDataset, Dataset, ParallelDataset 14 | from .dictionary import BOS_WORD, EOS_WORD, PAD_WORD, UNK_WORD, MASK_WORD 15 | 16 | 17 | logger = getLogger() 18 | 19 | 20 | def process_binarized(data, params): 21 | """ 22 | Process a binarized dataset and log main statistics. 23 | """ 24 | dico = data['dico'] 25 | assert ((data['sentences'].dtype == np.uint16) and (len(dico) < 1 << 16) or 26 | (data['sentences'].dtype == np.int32) and (1 << 16 <= len(dico) < 1 << 31)) 27 | logger.info("%i words (%i unique) in %i sentences. %i unknown words (%i unique) covering %.2f%% of the data." % ( 28 | len(data['sentences']) - len(data['positions']), 29 | len(dico), len(data['positions']), 30 | sum(data['unk_words'].values()), len(data['unk_words']), 31 | 100. * sum(data['unk_words'].values()) / (len(data['sentences']) - len(data['positions'])) 32 | )) 33 | if params.max_vocab != -1: 34 | assert params.max_vocab > 0 35 | logger.info("Selecting %i most frequent words ..." % params.max_vocab) 36 | dico.max_vocab(params.max_vocab) 37 | data['sentences'][data['sentences'] >= params.max_vocab] = dico.index(UNK_WORD) 38 | unk_count = (data['sentences'] == dico.index(UNK_WORD)).sum() 39 | logger.info("Now %i unknown words covering %.2f%% of the data." 40 | % (unk_count, 100. * unk_count / (len(data['sentences']) - len(data['positions'])))) 41 | if params.min_count > 0: 42 | logger.info("Selecting words with >= %i occurrences ..." % params.min_count) 43 | dico.min_count(params.min_count) 44 | data['sentences'][data['sentences'] >= len(dico)] = dico.index(UNK_WORD) 45 | unk_count = (data['sentences'] == dico.index(UNK_WORD)).sum() 46 | logger.info("Now %i unknown words covering %.2f%% of the data." 47 | % (unk_count, 100. * unk_count / (len(data['sentences']) - len(data['positions'])))) 48 | if (data['sentences'].dtype == np.int32) and (len(dico) < 1 << 16): 49 | logger.info("Less than 65536 words. Moving data from int32 to uint16 ...") 50 | data['sentences'] = data['sentences'].astype(np.uint16) 51 | return data 52 | 53 | 54 | def load_binarized(path, params): 55 | """ 56 | Load a binarized dataset. 57 | """ 58 | assert path.endswith('.pth') 59 | if params.debug_train: 60 | path = path.replace('train', 'valid') 61 | if getattr(params, 'multi_gpu', False): 62 | split_path = '%s.%i.pth' % (path[:-4], params.local_rank) 63 | if os.path.isfile(split_path): 64 | assert params.split_data is False 65 | path = split_path 66 | assert os.path.isfile(path), path 67 | logger.info("Loading data from %s ..." % path) 68 | data = torch.load(path) 69 | data = process_binarized(data, params) 70 | return data 71 | 72 | 73 | def set_dico_parameters(params, data, dico): 74 | """ 75 | Update dictionary parameters. 76 | """ 77 | if 'dico' in data: 78 | assert data['dico'] == dico 79 | else: 80 | data['dico'] = dico 81 | 82 | n_words = len(dico) 83 | bos_index = dico.index(BOS_WORD) 84 | eos_index = dico.index(EOS_WORD) 85 | pad_index = dico.index(PAD_WORD) 86 | unk_index = dico.index(UNK_WORD) 87 | mask_index = dico.index(MASK_WORD) 88 | if hasattr(params, 'bos_index'): 89 | assert params.n_words == n_words 90 | assert params.bos_index == bos_index 91 | assert params.eos_index == eos_index 92 | assert params.pad_index == pad_index 93 | assert params.unk_index == unk_index 94 | assert params.mask_index == mask_index 95 | else: 96 | params.n_words = n_words 97 | params.bos_index = bos_index 98 | params.eos_index = eos_index 99 | params.pad_index = pad_index 100 | params.unk_index = unk_index 101 | params.mask_index = mask_index 102 | 103 | 104 | def load_mono_data(params, data): 105 | """ 106 | Load monolingual data. 107 | """ 108 | data['mono'] = {} 109 | data['mono_stream'] = {} 110 | 111 | for lang in params.mono_dataset.keys(): 112 | 113 | logger.info('============ Monolingual data (%s)' % lang) 114 | 115 | assert lang in params.langs and lang not in data['mono'] 116 | data['mono'][lang] = {} 117 | data['mono_stream'][lang] = {} 118 | 119 | for splt in ['train', 'valid', 'test']: 120 | 121 | # no need to load training data for evaluation 122 | if splt == 'train' and params.eval_only: 123 | continue 124 | 125 | # load data / update dictionary parameters / update data 126 | mono_data = load_binarized(params.mono_dataset[lang][splt], params) 127 | set_dico_parameters(params, data, mono_data['dico']) 128 | 129 | # create stream dataset 130 | bs = params.batch_size if splt == 'train' else 1 131 | data['mono_stream'][lang][splt] = StreamDataset(mono_data['sentences'], mono_data['positions'], bs, params) 132 | 133 | # if there are several processes on the same machine, we can split the dataset 134 | if splt == 'train' and params.split_data and 1 < params.n_gpu_per_node <= data['mono_stream'][lang][splt].n_batches: 135 | n_batches = data['mono_stream'][lang][splt].n_batches // params.n_gpu_per_node 136 | a = n_batches * params.local_rank 137 | b = n_batches * params.local_rank + n_batches 138 | data['mono_stream'][lang][splt].select_data(a, b) 139 | 140 | # for denoising auto-encoding and online back-translation, we need a non-stream (batched) dataset 141 | if lang in params.ae_steps or lang in params.bt_src_langs: 142 | 143 | # create batched dataset 144 | dataset = Dataset(mono_data['sentences'], mono_data['positions'], params) 145 | 146 | # remove empty and too long sentences 147 | if splt == 'train': 148 | dataset.remove_empty_sentences() 149 | dataset.remove_long_sentences(params.max_len) 150 | 151 | # if there are several processes on the same machine, we can split the dataset 152 | if splt == 'train' and params.n_gpu_per_node > 1 and params.split_data: 153 | n_sent = len(dataset) // params.n_gpu_per_node 154 | a = n_sent * params.local_rank 155 | b = n_sent * params.local_rank + n_sent 156 | dataset.select_data(a, b) 157 | 158 | data['mono'][lang][splt] = dataset 159 | 160 | logger.info("") 161 | 162 | logger.info("") 163 | 164 | 165 | def load_para_data(params, data): 166 | """ 167 | Load parallel data. 168 | """ 169 | data['para'] = {} 170 | 171 | required_para_train = set(params.clm_steps + params.mlm_steps + params.pc_steps + params.mt_steps) 172 | 173 | for src, tgt in params.para_dataset.keys(): 174 | 175 | logger.info('============ Parallel data (%s-%s)' % (src, tgt)) 176 | 177 | assert (src, tgt) not in data['para'] 178 | data['para'][(src, tgt)] = {} 179 | 180 | for splt in ['train', 'valid', 'test']: 181 | 182 | # no need to load training data for evaluation 183 | if splt == 'train' and params.eval_only: 184 | continue 185 | 186 | # for back-translation, we can't load training data 187 | if splt == 'train' and (src, tgt) not in required_para_train and (tgt, src) not in required_para_train: 188 | continue 189 | 190 | # load binarized datasets 191 | src_path, tgt_path = params.para_dataset[(src, tgt)][splt] 192 | src_data = load_binarized(src_path, params) 193 | tgt_data = load_binarized(tgt_path, params) 194 | 195 | # update dictionary parameters 196 | set_dico_parameters(params, data, src_data['dico']) 197 | set_dico_parameters(params, data, tgt_data['dico']) 198 | 199 | # create ParallelDataset 200 | dataset = ParallelDataset( 201 | src_data['sentences'], src_data['positions'], 202 | tgt_data['sentences'], tgt_data['positions'], 203 | params 204 | ) 205 | 206 | # remove empty and too long sentences 207 | if splt == 'train': 208 | dataset.remove_empty_sentences() 209 | dataset.remove_long_sentences(params.max_len) 210 | 211 | # for validation and test set, enumerate sentence per sentence 212 | if splt != 'train': 213 | dataset.tokens_per_batch = -1 214 | 215 | # if there are several processes on the same machine, we can split the dataset 216 | if splt == 'train' and params.n_gpu_per_node > 1 and params.split_data: 217 | n_sent = len(dataset) // params.n_gpu_per_node 218 | a = n_sent * params.local_rank 219 | b = n_sent * params.local_rank + n_sent 220 | dataset.select_data(a, b) 221 | 222 | data['para'][(src, tgt)][splt] = dataset 223 | logger.info("") 224 | 225 | logger.info("") 226 | 227 | 228 | def check_data_params(params): 229 | """ 230 | Check datasets parameters. 231 | """ 232 | # data path 233 | assert os.path.isdir(params.data_path), params.data_path 234 | 235 | # check languages 236 | params.langs = params.lgs.split('-') if params.lgs != 'debug' else ['en'] 237 | assert len(params.langs) == len(set(params.langs)) >= 1 238 | # assert sorted(params.langs) == params.langs 239 | params.id2lang = {k: v for k, v in enumerate(sorted(params.langs))} 240 | params.lang2id = {k: v for v, k in params.id2lang.items()} 241 | params.n_langs = len(params.langs) 242 | 243 | # CLM steps 244 | clm_steps = [s.split('-') for s in params.clm_steps.split(',') if len(s) > 0] 245 | params.clm_steps = [(s[0], None) if len(s) == 1 else tuple(s) for s in clm_steps] 246 | assert all([(l1 in params.langs) and (l2 in params.langs or l2 is None) for l1, l2 in params.clm_steps]) 247 | assert len(params.clm_steps) == len(set(params.clm_steps)) 248 | 249 | # MLM / TLM steps 250 | mlm_steps = [s.split('-') for s in params.mlm_steps.split(',') if len(s) > 0] 251 | params.mlm_steps = [(s[0], None) if len(s) == 1 else tuple(s) for s in mlm_steps] 252 | assert all([(l1 in params.langs) and (l2 in params.langs or l2 is None) for l1, l2 in params.mlm_steps]) 253 | assert len(params.mlm_steps) == len(set(params.mlm_steps)) 254 | 255 | # parallel classification steps 256 | params.pc_steps = [tuple(s.split('-')) for s in params.pc_steps.split(',') if len(s) > 0] 257 | assert all([len(x) == 2 for x in params.pc_steps]) 258 | assert all([l1 in params.langs and l2 in params.langs for l1, l2 in params.pc_steps]) 259 | assert all([l1 != l2 for l1, l2 in params.pc_steps]) 260 | assert len(params.pc_steps) == len(set(params.pc_steps)) 261 | 262 | # machine translation steps 263 | params.mt_steps = [tuple(s.split('-')) for s in params.mt_steps.split(',') if len(s) > 0] 264 | assert all([len(x) == 2 for x in params.mt_steps]) 265 | assert all([l1 in params.langs and l2 in params.langs for l1, l2 in params.mt_steps]) 266 | assert all([l1 != l2 for l1, l2 in params.mt_steps]) 267 | assert len(params.mt_steps) == len(set(params.mt_steps)) 268 | assert len(params.mt_steps) == 0 or not params.encoder_only 269 | 270 | # denoising auto-encoder steps 271 | params.ae_steps = [s for s in params.ae_steps.split(',') if len(s) > 0] 272 | assert all([lang in params.langs for lang in params.ae_steps]) 273 | assert len(params.ae_steps) == len(set(params.ae_steps)) 274 | assert len(params.ae_steps) == 0 or not params.encoder_only 275 | 276 | # back-translation steps 277 | params.bt_steps = [tuple(s.split('-')) for s in params.bt_steps.split(',') if len(s) > 0] 278 | assert all([len(x) == 3 for x in params.bt_steps]) 279 | assert all([l1 in params.langs and l2 in params.langs and l3 in params.langs for l1, l2, l3 in params.bt_steps]) 280 | assert all([l1 == l3 and l1 != l2 for l1, l2, l3 in params.bt_steps]) 281 | assert len(params.bt_steps) == len(set(params.bt_steps)) 282 | assert len(params.bt_steps) == 0 or not params.encoder_only 283 | params.bt_src_langs = [l1 for l1, _, _ in params.bt_steps] 284 | 285 | # check monolingual datasets 286 | required_mono = set([l1 for l1, l2 in (params.mlm_steps + params.clm_steps) if l2 is None] + params.ae_steps + params.bt_src_langs) 287 | params.mono_dataset = { 288 | lang: { 289 | splt: os.path.join(params.data_path, '%s.%s.pth' % (splt, lang)) 290 | for splt in ['train', 'valid', 'test'] 291 | } for lang in params.langs if lang in required_mono 292 | } 293 | for paths in params.mono_dataset.values(): 294 | for p in paths.values(): 295 | if not os.path.isfile(p): 296 | logger.error(f"{p} not found") 297 | assert all([all([os.path.isfile(p) for p in paths.values()]) for paths in params.mono_dataset.values()]) 298 | 299 | # check parallel datasets 300 | if not params.para_data_path: 301 | params.para_data_path = params.data_path 302 | required_para_train = set(params.clm_steps + params.mlm_steps + params.pc_steps + params.mt_steps) 303 | required_para = required_para_train | set([(l2, l3) for _, l2, l3 in params.bt_steps]) 304 | params.para_dataset = {} 305 | for src in params.langs: 306 | for tgt in params.langs: 307 | if src < tgt and ((src, tgt) in required_para or (tgt, src) in required_para): 308 | params.para_dataset[(src, tgt)] = {} 309 | for splt in ['train', 'valid', 'test']: 310 | if splt != 'train': 311 | params.para_dataset[(src, tgt)][splt] = \ 312 | (os.path.join(params.data_path, '%s.%s-%s.%s.pth' % (splt, src, tgt, src)), 313 | os.path.join(params.data_path, '%s.%s-%s.%s.pth' % (splt, src, tgt, tgt))) 314 | else: 315 | if (src, tgt) in required_para_train or (tgt, src) in required_para_train: 316 | params.para_dataset[(src, tgt)][splt] = \ 317 | (os.path.join(params.para_data_path, '%s.%s-%s.%s.pth' % (splt, src, tgt, src)), 318 | os.path.join(params.para_data_path, '%s.%s-%s.%s.pth' % (splt, src, tgt, tgt))) 319 | 320 | for paths in params.para_dataset.values(): 321 | for p1, p2 in paths.values(): 322 | if not os.path.isfile(p1): 323 | logger.error(f"{p1} not found") 324 | if not os.path.isfile(p2): 325 | logger.error(f"{p2} not found") 326 | assert all([all([os.path.isfile(p1) and os.path.isfile(p2) for p1, p2 in paths.values()]) for paths in params.para_dataset.values()]) 327 | 328 | # check that we can evaluate on BLEU 329 | assert params.eval_bleu is False or len(params.mt_steps + params.bt_steps) > 0 330 | 331 | 332 | def load_data(params): 333 | """ 334 | Load monolingual data. 335 | The returned dictionary contains: 336 | - dico (dictionary) 337 | - vocab (FloatTensor) 338 | - train / valid / test (monolingual datasets) 339 | """ 340 | data = {} 341 | 342 | # monolingual datasets 343 | load_mono_data(params, data) 344 | 345 | # parallel datasets 346 | load_para_data(params, data) 347 | 348 | # monolingual data summary 349 | logger.info('============ Data summary') 350 | for lang, v in data['mono_stream'].items(): 351 | for data_set in v.keys(): 352 | logger.info('{: <18} - {: >5} - {: >12}:{: >10}'.format('Monolingual data', data_set, lang, len(v[data_set]))) 353 | 354 | # parallel data summary 355 | for (src, tgt), v in data['para'].items(): 356 | for data_set in v.keys(): 357 | logger.info('{: <18} - {: >5} - {: >12}:{: >10}'.format('Parallel data', data_set, '%s-%s' % (src, tgt), len(v[data_set]))) 358 | 359 | logger.info("") 360 | return data 361 | -------------------------------------------------------------------------------- /src/evaluation/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jind11/DAMT/3caa22822b936137e5da3e827d7a5a2078c3115e/src/evaluation/__init__.py -------------------------------------------------------------------------------- /src/evaluation/evaluator.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | from logging import getLogger 9 | import os 10 | import re 11 | import subprocess 12 | from collections import OrderedDict 13 | import numpy as np 14 | import torch 15 | 16 | from ..utils import to_cuda, restore_segmentation, concat_batches 17 | from ..model.memory import HashingMemory 18 | 19 | 20 | BLEU_SCRIPT_PATH = os.path.join(os.path.abspath(os.path.dirname(__file__)), 'multi-bleu.perl') 21 | assert os.path.isfile(BLEU_SCRIPT_PATH) 22 | 23 | 24 | logger = getLogger() 25 | 26 | 27 | def kl_score(x): 28 | # assert np.abs(np.sum(x) - 1) < 1e-5 29 | _x = x.copy() 30 | _x[x == 0] = 1 31 | return np.log(len(x)) + (x * np.log(_x)).sum() 32 | 33 | 34 | def gini_score(x): 35 | # assert np.abs(np.sum(x) - 1) < 1e-5 36 | B = np.cumsum(np.sort(x)).mean() 37 | return 1 - 2 * B 38 | 39 | 40 | def tops(x): 41 | # assert np.abs(np.sum(x) - 1) < 1e-5 42 | y = np.cumsum(np.sort(x)) 43 | top50, top90, top99 = y.shape[0] - np.searchsorted(y, [0.5, 0.1, 0.01]) 44 | return top50, top90, top99 45 | 46 | 47 | def eval_memory_usage(scores, name, mem_att, mem_size): 48 | """ 49 | Evaluate memory usage (HashingMemory / FFN). 50 | """ 51 | # memory slot scores 52 | assert mem_size > 0 53 | mem_scores_w = np.zeros(mem_size, dtype=np.float32) # weighted scores 54 | mem_scores_u = np.zeros(mem_size, dtype=np.float32) # unweighted scores 55 | 56 | # sum each slot usage 57 | for indices, weights in mem_att: 58 | np.add.at(mem_scores_w, indices, weights) 59 | np.add.at(mem_scores_u, indices, 1) 60 | 61 | # compute the KL distance to the uniform distribution 62 | mem_scores_w = mem_scores_w / mem_scores_w.sum() 63 | mem_scores_u = mem_scores_u / mem_scores_u.sum() 64 | 65 | # store stats 66 | scores['%s_mem_used' % name] = float(100 * (mem_scores_w != 0).sum() / len(mem_scores_w)) 67 | 68 | scores['%s_mem_kl_w' % name] = float(kl_score(mem_scores_w)) 69 | scores['%s_mem_kl_u' % name] = float(kl_score(mem_scores_u)) 70 | 71 | scores['%s_mem_gini_w' % name] = float(gini_score(mem_scores_w)) 72 | scores['%s_mem_gini_u' % name] = float(gini_score(mem_scores_u)) 73 | 74 | top50, top90, top99 = tops(mem_scores_w) 75 | scores['%s_mem_top50_w' % name] = float(top50) 76 | scores['%s_mem_top90_w' % name] = float(top90) 77 | scores['%s_mem_top99_w' % name] = float(top99) 78 | 79 | top50, top90, top99 = tops(mem_scores_u) 80 | scores['%s_mem_top50_u' % name] = float(top50) 81 | scores['%s_mem_top90_u' % name] = float(top90) 82 | scores['%s_mem_top99_u' % name] = float(top99) 83 | 84 | 85 | class Evaluator(object): 86 | 87 | def __init__(self, trainer, data, params): 88 | """ 89 | Initialize evaluator. 90 | """ 91 | self.trainer = trainer 92 | self.data = data 93 | self.dico = data['dico'] 94 | self.params = params 95 | self.memory_list = trainer.memory_list 96 | 97 | # create directory to store hypotheses, and reference files for BLEU evaluation 98 | if self.params.is_master: 99 | params.hyp_path = os.path.join(params.dump_path, 'hypotheses') 100 | subprocess.Popen('mkdir -p %s' % params.hyp_path, shell=True).wait() 101 | self.create_reference_files() 102 | 103 | def get_iterator(self, data_set, lang1, lang2=None, stream=False): 104 | """ 105 | Create a new iterator for a dataset. 106 | """ 107 | assert data_set in ['valid', 'test'] 108 | assert lang1 in self.params.langs 109 | assert lang2 is None or lang2 in self.params.langs 110 | assert stream is False or lang2 is None 111 | 112 | # hacks to reduce evaluation time when using many languages 113 | if len(self.params.langs) > 30: 114 | eval_lgs = set(["ar", "bg", "de", "el", "en", "es", "fr", "hi", "ru", "sw", "th", "tr", "ur", "vi", "zh", "ab", "ay", "bug", "ha", "ko", "ln", "min", "nds", "pap", "pt", "tg", "to", "udm", "uk", "zh_classical"]) 115 | eval_lgs = set(["ar", "bg", "de", "el", "en", "es", "fr", "hi", "ru", "sw", "th", "tr", "ur", "vi", "zh"]) 116 | subsample = 10 if (data_set == 'test' or lang1 not in eval_lgs) else 5 117 | n_sentences = 600 if (data_set == 'test' or lang1 not in eval_lgs) else 1500 118 | elif len(self.params.langs) > 5: 119 | subsample = 10 if data_set == 'test' else 5 120 | n_sentences = 300 if data_set == 'test' else 1500 121 | else: 122 | # n_sentences = -1 if data_set == 'valid' else 100 123 | n_sentences = -1 124 | subsample = 1 125 | 126 | if lang2 is None: 127 | if stream: 128 | iterator = self.data['mono_stream'][lang1][data_set].get_iterator(shuffle=False, subsample=subsample) 129 | else: 130 | iterator = self.data['mono'][lang1][data_set].get_iterator( 131 | shuffle=False, 132 | group_by_size=True, 133 | n_sentences=n_sentences, 134 | ) 135 | else: 136 | assert stream is False 137 | _lang1, _lang2 = (lang1, lang2) if lang1 < lang2 else (lang2, lang1) 138 | iterator = self.data['para'][(_lang1, _lang2)][data_set].get_iterator( 139 | shuffle=False, 140 | group_by_size=True, 141 | n_sentences=n_sentences 142 | ) 143 | 144 | for batch in iterator: 145 | yield batch if lang2 is None or lang1 < lang2 else batch[::-1] 146 | 147 | def create_reference_files(self): 148 | """ 149 | Create reference files for BLEU evaluation. 150 | """ 151 | params = self.params 152 | params.ref_paths = {} 153 | 154 | for (lang1, lang2), v in self.data['para'].items(): 155 | 156 | assert lang1 < lang2 157 | 158 | for data_set in ['valid', 'test']: 159 | 160 | # define data paths 161 | lang1_path = os.path.join(params.hyp_path, 'ref.{0}-{1}.{2}.txt'.format(lang2, lang1, data_set)) 162 | lang2_path = os.path.join(params.hyp_path, 'ref.{0}-{1}.{2}.txt'.format(lang1, lang2, data_set)) 163 | 164 | # store data paths 165 | params.ref_paths[(lang2, lang1, data_set)] = lang1_path 166 | params.ref_paths[(lang1, lang2, data_set)] = lang2_path 167 | 168 | # text sentences 169 | lang1_txt = [] 170 | lang2_txt = [] 171 | 172 | # convert to text 173 | for (sent1, len1), (sent2, len2) in self.get_iterator(data_set, lang1, lang2): 174 | lang1_txt.extend(convert_to_text(sent1, len1, self.dico, params)) 175 | lang2_txt.extend(convert_to_text(sent2, len2, self.dico, params)) 176 | 177 | # replace by <> as these tokens cannot be counted in BLEU 178 | lang1_txt = [x.replace('', '<>') for x in lang1_txt] 179 | lang2_txt = [x.replace('', '<>') for x in lang2_txt] 180 | 181 | # export hypothesis 182 | with open(lang1_path, 'w', encoding='utf-8') as f: 183 | f.write('\n'.join(lang1_txt) + '\n') 184 | with open(lang2_path, 'w', encoding='utf-8') as f: 185 | f.write('\n'.join(lang2_txt) + '\n') 186 | 187 | # restore original segmentation 188 | restore_segmentation(lang1_path, bpe_type=params.bpe_type) 189 | restore_segmentation(lang2_path, bpe_type=params.bpe_type) 190 | 191 | def mask_out(self, x, lengths, rng): 192 | """ 193 | Decide of random words to mask out. 194 | We specify the random generator to ensure that the test is the same at each epoch. 195 | """ 196 | params = self.params 197 | slen, bs = x.size() 198 | 199 | # words to predict - be sure there is at least one word per sentence 200 | to_predict = rng.rand(slen, bs) <= params.word_pred 201 | to_predict[0] = 0 202 | for i in range(bs): 203 | to_predict[lengths[i] - 1:, i] = 0 204 | if not np.any(to_predict[:lengths[i] - 1, i]): 205 | v = rng.randint(1, lengths[i] - 1) 206 | to_predict[v, i] = 1 207 | pred_mask = torch.from_numpy(to_predict.astype(np.uint8)) 208 | 209 | # generate possible targets / update x input 210 | _x_real = x[pred_mask] 211 | _x_mask = _x_real.clone().fill_(params.mask_index) 212 | x = x.masked_scatter(pred_mask, _x_mask) 213 | 214 | assert 0 <= x.min() <= x.max() < params.n_words 215 | assert x.size() == (slen, bs) 216 | assert pred_mask.size() == (slen, bs) 217 | 218 | return x, _x_real, pred_mask 219 | 220 | def run_all_evals(self, trainer): 221 | """ 222 | Run all evaluations. 223 | """ 224 | params = self.params 225 | scores = OrderedDict({'epoch': trainer.epoch}) 226 | 227 | with torch.no_grad(): 228 | 229 | for data_set in ['valid', 'test']: 230 | 231 | # causal prediction task (evaluate perplexity and accuracy) 232 | for lang1, lang2 in params.clm_steps: 233 | self.evaluate_clm(scores, data_set, lang1, lang2) 234 | 235 | # prediction task (evaluate perplexity and accuracy) 236 | for lang1, lang2 in params.mlm_steps: 237 | self.evaluate_mlm(scores, data_set, lang1, lang2) 238 | 239 | # machine translation task (evaluate perplexity and accuracy) 240 | for lang1, lang2 in set(params.mt_steps + [(l2, l3) for _, l2, l3 in params.bt_steps]): 241 | eval_bleu = params.eval_bleu and params.is_master 242 | self.evaluate_mt(scores, data_set, lang1, lang2, eval_bleu) 243 | 244 | # report average metrics per language 245 | _clm_mono = [l1 for (l1, l2) in params.clm_steps if l2 is None] 246 | if len(_clm_mono) > 0: 247 | scores['%s_clm_ppl' % data_set] = np.mean([scores['%s_%s_clm_ppl' % (data_set, lang)] for lang in _clm_mono]) 248 | scores['%s_clm_acc' % data_set] = np.mean([scores['%s_%s_clm_acc' % (data_set, lang)] for lang in _clm_mono]) 249 | _mlm_mono = [l1 for (l1, l2) in params.mlm_steps if l2 is None] 250 | if len(_mlm_mono) > 0: 251 | scores['%s_mlm_ppl' % data_set] = np.mean([scores['%s_%s_mlm_ppl' % (data_set, lang)] for lang in _mlm_mono]) 252 | scores['%s_mlm_acc' % data_set] = np.mean([scores['%s_%s_mlm_acc' % (data_set, lang)] for lang in _mlm_mono]) 253 | 254 | return scores 255 | 256 | def evaluate_clm(self, scores, data_set, lang1, lang2): 257 | """ 258 | Evaluate perplexity and next word prediction accuracy. 259 | """ 260 | params = self.params 261 | assert data_set in ['valid', 'test'] 262 | assert lang1 in params.langs 263 | assert lang2 in params.langs or lang2 is None 264 | 265 | model = self.model if params.encoder_only else self.decoder 266 | model.eval() 267 | model = model.module if params.multi_gpu else model 268 | 269 | lang1_id = params.lang2id[lang1] 270 | lang2_id = params.lang2id[lang2] if lang2 is not None else None 271 | l1l2 = lang1 if lang2 is None else f"{lang1}-{lang2}" 272 | 273 | n_words = 0 274 | xe_loss = 0 275 | n_valid = 0 276 | 277 | # only save states / evaluate usage on the validation set 278 | eval_memory = params.use_memory and data_set == 'valid' and self.params.is_master 279 | HashingMemory.EVAL_MEMORY = eval_memory 280 | if eval_memory: 281 | all_mem_att = {k: [] for k, _ in self.memory_list} 282 | 283 | for batch in self.get_iterator(data_set, lang1, lang2, stream=(lang2 is None)): 284 | 285 | # batch 286 | if lang2 is None: 287 | x, lengths = batch 288 | positions = None 289 | langs = x.clone().fill_(lang1_id) if params.n_langs > 1 else None 290 | else: 291 | (sent1, len1), (sent2, len2) = batch 292 | x, lengths, positions, langs = concat_batches(sent1, len1, lang1_id, sent2, len2, lang2_id, params.pad_index, params.eos_index, reset_positions=True) 293 | 294 | # words to predict 295 | alen = torch.arange(lengths.max(), dtype=torch.long, device=lengths.device) 296 | pred_mask = alen[:, None] < lengths[None] - 1 297 | y = x[1:].masked_select(pred_mask[:-1]) 298 | assert pred_mask.sum().item() == y.size(0) 299 | 300 | # cuda 301 | x, lengths, positions, langs, pred_mask, y = to_cuda(x, lengths, positions, langs, pred_mask, y) 302 | 303 | # forward / loss 304 | tensor = model('fwd', x=x, lengths=lengths, positions=positions, langs=langs, causal=True) 305 | word_scores, loss = model('predict', tensor=tensor, pred_mask=pred_mask, y=y, get_scores=True) 306 | 307 | # update stats 308 | n_words += y.size(0) 309 | xe_loss += loss.item() * len(y) 310 | n_valid += (word_scores.max(1)[1] == y).sum().item() 311 | if eval_memory: 312 | for k, v in self.memory_list: 313 | all_mem_att[k].append((v.last_indices, v.last_scores)) 314 | 315 | # log 316 | logger.info("Found %i words in %s. %i were predicted correctly." % (n_words, data_set, n_valid)) 317 | 318 | # compute perplexity and prediction accuracy 319 | ppl_name = '%s_%s_clm_ppl' % (data_set, l1l2) 320 | acc_name = '%s_%s_clm_acc' % (data_set, l1l2) 321 | scores[ppl_name] = np.exp(xe_loss / n_words) 322 | scores[acc_name] = 100. * n_valid / n_words 323 | 324 | # compute memory usage 325 | if eval_memory: 326 | for mem_name, mem_att in all_mem_att.items(): 327 | eval_memory_usage(scores, '%s_%s_%s' % (data_set, l1l2, mem_name), mem_att, params.mem_size) 328 | 329 | def evaluate_mlm(self, scores, data_set, lang1, lang2): 330 | """ 331 | Evaluate perplexity and next word prediction accuracy. 332 | """ 333 | params = self.params 334 | assert data_set in ['valid', 'test'] 335 | assert lang1 in params.langs 336 | assert lang2 in params.langs or lang2 is None 337 | 338 | model = self.model if params.encoder_only else self.encoder 339 | model.eval() 340 | model = model.module if params.multi_gpu else model 341 | 342 | rng = np.random.RandomState(0) 343 | 344 | lang1_id = params.lang2id[lang1] 345 | lang2_id = params.lang2id[lang2] if lang2 is not None else None 346 | l1l2 = lang1 if lang2 is None else f"{lang1}_{lang2}" 347 | 348 | n_words = 0 349 | xe_loss = 0 350 | n_valid = 0 351 | 352 | # only save states / evaluate usage on the validation set 353 | eval_memory = params.use_memory and data_set == 'valid' and self.params.is_master 354 | HashingMemory.EVAL_MEMORY = eval_memory 355 | if eval_memory: 356 | all_mem_att = {k: [] for k, _ in self.memory_list} 357 | 358 | for batch in self.get_iterator(data_set, lang1, lang2, stream=(lang2 is None)): 359 | 360 | # batch 361 | if lang2 is None: 362 | x, lengths = batch 363 | positions = None 364 | langs = x.clone().fill_(lang1_id) if params.n_langs > 1 else None 365 | else: 366 | (sent1, len1), (sent2, len2) = batch 367 | x, lengths, positions, langs = concat_batches(sent1, len1, lang1_id, sent2, len2, lang2_id, params.pad_index, params.eos_index, reset_positions=True) 368 | 369 | # words to predict 370 | x, y, pred_mask = self.mask_out(x, lengths, rng) 371 | 372 | # cuda 373 | x, y, pred_mask, lengths, positions, langs = to_cuda(x, y, pred_mask, lengths, positions, langs) 374 | 375 | # forward / loss 376 | tensor = model('fwd', x=x, lengths=lengths, positions=positions, langs=langs, causal=False) 377 | word_scores, loss = model('predict', tensor=tensor, pred_mask=pred_mask, y=y, get_scores=True) 378 | 379 | # update stats 380 | n_words += len(y) 381 | xe_loss += loss.item() * len(y) 382 | n_valid += (word_scores.max(1)[1] == y).sum().item() 383 | if eval_memory: 384 | for k, v in self.memory_list: 385 | all_mem_att[k].append((v.last_indices, v.last_scores)) 386 | 387 | # compute perplexity and prediction accuracy 388 | ppl_name = '%s_%s_mlm_ppl' % (data_set, l1l2) 389 | acc_name = '%s_%s_mlm_acc' % (data_set, l1l2) 390 | scores[ppl_name] = np.exp(xe_loss / n_words) if n_words > 0 else 1e9 391 | scores[acc_name] = 100. * n_valid / n_words if n_words > 0 else 0. 392 | 393 | # compute memory usage 394 | if eval_memory: 395 | for mem_name, mem_att in all_mem_att.items(): 396 | eval_memory_usage(scores, '%s_%s_%s' % (data_set, l1l2, mem_name), mem_att, params.mem_size) 397 | 398 | 399 | class SingleEvaluator(Evaluator): 400 | 401 | def __init__(self, trainer, data, params): 402 | """ 403 | Build language model evaluator. 404 | """ 405 | super().__init__(trainer, data, params) 406 | self.model = trainer.model 407 | 408 | 409 | class EncDecEvaluator(Evaluator): 410 | 411 | def __init__(self, trainer, data, params): 412 | """ 413 | Build encoder / decoder evaluator. 414 | """ 415 | super().__init__(trainer, data, params) 416 | self.encoder = trainer.encoder 417 | self.decoder = trainer.decoder 418 | 419 | def evaluate_mt(self, scores, data_set, lang1, lang2, eval_bleu): 420 | """ 421 | Evaluate perplexity and next word prediction accuracy. 422 | """ 423 | params = self.params 424 | assert data_set in ['valid', 'test'] 425 | assert lang1 in params.langs 426 | assert lang2 in params.langs 427 | 428 | self.encoder.eval() 429 | self.decoder.eval() 430 | encoder = self.encoder.module if params.multi_gpu else self.encoder 431 | decoder = self.decoder.module if params.multi_gpu else self.decoder 432 | 433 | params = params 434 | lang1_id = params.lang2id[lang1] 435 | lang2_id = params.lang2id[lang2] 436 | 437 | n_words = 0 438 | xe_loss = 0 439 | n_valid = 0 440 | 441 | # only save states / evaluate usage on the validation set 442 | eval_memory = params.use_memory and data_set == 'valid' and self.params.is_master 443 | HashingMemory.EVAL_MEMORY = eval_memory 444 | if eval_memory: 445 | all_mem_att = {k: [] for k, _ in self.memory_list} 446 | 447 | # store hypothesis to compute BLEU score 448 | if eval_bleu: 449 | hypothesis = [] 450 | 451 | for batch in self.get_iterator(data_set, lang1, lang2): 452 | 453 | # generate batch 454 | (x1, len1), (x2, len2) = batch 455 | langs1 = x1.clone().fill_(lang1_id) 456 | langs2 = x2.clone().fill_(lang2_id) 457 | 458 | # target words to predict 459 | alen = torch.arange(len2.max(), dtype=torch.long, device=len2.device) 460 | pred_mask = alen[:, None] < len2[None] - 1 # do not predict anything given the last target word 461 | y = x2[1:].masked_select(pred_mask[:-1]) 462 | assert len(y) == (len2 - 1).sum().item() 463 | 464 | # cuda 465 | x1, len1, langs1, x2, len2, langs2, y = to_cuda(x1, len1, langs1, x2, len2, langs2, y) 466 | 467 | # encode source sentence 468 | enc1 = encoder('fwd', x=x1, lengths=len1, langs=langs1, causal=False) 469 | enc1 = enc1.transpose(0, 1) 470 | enc1 = enc1.half() if params.fp16 else enc1 471 | 472 | # decode target sentence 473 | dec2 = decoder('fwd', x=x2, lengths=len2, langs=langs2, causal=True, src_enc=enc1, src_len=len1) 474 | 475 | # loss 476 | word_scores, loss = decoder('predict', tensor=dec2, pred_mask=pred_mask, y=y, get_scores=True) 477 | 478 | # update stats 479 | n_words += y.size(0) 480 | xe_loss += loss.item() * len(y) 481 | n_valid += (word_scores.max(1)[1] == y).sum().item() 482 | if eval_memory: 483 | for k, v in self.memory_list: 484 | all_mem_att[k].append((v.last_indices, v.last_scores)) 485 | 486 | # generate translation - translate / convert to text 487 | if eval_bleu: 488 | max_len = int(1.5 * len1.max().item() + 10) 489 | if params.beam_size == 1: 490 | generated, lengths = decoder.generate(enc1, len1, lang2_id, max_len=max_len) 491 | else: 492 | generated, lengths = decoder.generate_beam( 493 | enc1, len1, lang2_id, beam_size=params.beam_size, 494 | length_penalty=params.length_penalty, 495 | early_stopping=params.early_stopping, 496 | max_len=max_len 497 | ) 498 | hypothesis.extend(convert_to_text(generated, lengths, self.dico, params)) 499 | 500 | # compute perplexity and prediction accuracy 501 | scores['%s_%s-%s_mt_ppl' % (data_set, lang1, lang2)] = np.exp(xe_loss / n_words) 502 | scores['%s_%s-%s_mt_acc' % (data_set, lang1, lang2)] = 100. * n_valid / n_words 503 | 504 | # compute memory usage 505 | if eval_memory: 506 | for mem_name, mem_att in all_mem_att.items(): 507 | eval_memory_usage(scores, '%s_%s-%s_%s' % (data_set, lang1, lang2, mem_name), mem_att, params.mem_size) 508 | 509 | # compute BLEU 510 | if eval_bleu: 511 | 512 | # hypothesis / reference paths 513 | hyp_name = 'hyp{0}.{1}-{2}.{3}.txt'.format(scores['epoch'], lang1, lang2, data_set) 514 | hyp_path = os.path.join(params.hyp_path, hyp_name) 515 | ref_path = params.ref_paths[(lang1, lang2, data_set)] 516 | 517 | # export sentences to hypothesis file / restore BPE segmentation 518 | with open(hyp_path, 'w', encoding='utf-8') as f: 519 | f.write('\n'.join(hypothesis) + '\n') 520 | restore_segmentation(hyp_path, bpe_type=params.bpe_type) 521 | 522 | # evaluate BLEU score 523 | bleu = eval_moses_bleu(ref_path, hyp_path) 524 | sacrebleu = eval_sacrebleu(ref_path, hyp_path) 525 | logger.info("BLEU %s %s : %f" % (hyp_path, ref_path, bleu)) 526 | logger.info("SacreBLEU %s %s : %f" % (hyp_path, ref_path, sacrebleu)) 527 | scores['%s_%s-%s_mt_bleu' % (data_set, lang1, lang2)] = bleu 528 | scores['%s_%s-%s_mt_sacrebleu' % (data_set, lang1, lang2)] = sacrebleu 529 | 530 | 531 | def convert_to_text(batch, lengths, dico, params): 532 | """ 533 | Convert a batch of sentences to a list of text sentences. 534 | """ 535 | batch = batch.cpu().numpy() 536 | lengths = lengths.cpu().numpy() 537 | 538 | slen, bs = batch.shape 539 | assert lengths.max() == slen and lengths.shape[0] == bs 540 | assert (batch[0] == params.eos_index).sum() == bs 541 | assert (batch == params.eos_index).sum() == 2 * bs 542 | sentences = [] 543 | 544 | for j in range(bs): 545 | words = [] 546 | for k in range(1, lengths[j]): 547 | if batch[k, j] == params.eos_index: 548 | break 549 | words.append(dico[batch[k, j]]) 550 | sentences.append(" ".join(words)) 551 | return sentences 552 | 553 | 554 | def eval_moses_bleu(ref, hyp): 555 | """ 556 | Given a file of hypothesis and reference files, 557 | evaluate the BLEU score using Moses scripts. 558 | """ 559 | assert os.path.isfile(hyp) 560 | assert os.path.isfile(ref) or os.path.isfile(ref + '0') 561 | assert os.path.isfile(BLEU_SCRIPT_PATH) 562 | command = BLEU_SCRIPT_PATH + ' %s < %s' 563 | p = subprocess.Popen(command % (ref, hyp), stdout=subprocess.PIPE, shell=True) 564 | result = p.communicate()[0].decode("utf-8") 565 | if result.startswith('BLEU'): 566 | return float(result[7:result.index(',')]) 567 | else: 568 | logger.warning('Impossible to parse BLEU score! "%s"' % result) 569 | return -1 570 | 571 | 572 | def eval_sacrebleu(ref, hyp): 573 | ref_lines = open(ref).readlines() 574 | hyp_lines = open(hyp).readlines() 575 | scorer = SacrebleuScorer() 576 | for ref_line, hyp_line in zip(ref_lines, hyp_lines): 577 | scorer.add_string(ref_line, hyp_line) 578 | return float(re.findall("\d+\.\d+", str(scorer.result_string()))[0]) 579 | 580 | class SacrebleuScorer(object): 581 | def __init__(self): 582 | import sacrebleu 583 | self.sacrebleu = sacrebleu 584 | self.reset() 585 | 586 | def reset(self, one_init=False): 587 | if one_init: 588 | raise NotImplementedError 589 | self.ref = [] 590 | self.sys = [] 591 | 592 | def add_string(self, ref, pred): 593 | self.ref.append(ref) 594 | self.sys.append(pred) 595 | 596 | def score(self, order=4): 597 | return self.result_string(order).score 598 | 599 | def result_string(self, order=4): 600 | if order != 4: 601 | raise NotImplementedError 602 | return self.sacrebleu.corpus_bleu(self.sys, [self.ref]) 603 | -------------------------------------------------------------------------------- /src/evaluation/glue.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | from logging import getLogger 9 | import os 10 | import copy 11 | import time 12 | import json 13 | from collections import OrderedDict 14 | 15 | import numpy as np 16 | import torch 17 | from torch import nn 18 | import torch.nn.functional as F 19 | 20 | from scipy.stats import spearmanr, pearsonr 21 | from sklearn.metrics import f1_score, matthews_corrcoef 22 | 23 | from ..optim import get_optimizer 24 | from ..utils import concat_batches, truncate, to_cuda 25 | from ..data.dataset import Dataset, ParallelDataset 26 | from ..data.loader import load_binarized, set_dico_parameters 27 | 28 | 29 | N_CLASSES = { 30 | 'MNLI-m': 3, 31 | 'MNLI-mm': 3, 32 | 'QQP': 2, 33 | 'QNLI': 2, 34 | 'SST-2': 2, 35 | 'CoLA': 2, 36 | 'MRPC': 2, 37 | 'RTE': 2, 38 | 'STS-B': 1, 39 | 'WNLI': 2, 40 | 'AX_MNLI-m': 3, 41 | } 42 | 43 | 44 | logger = getLogger() 45 | 46 | 47 | class GLUE: 48 | 49 | def __init__(self, embedder, scores, params): 50 | """ 51 | Initialize GLUE trainer / evaluator. 52 | Initial `embedder` should be on CPU to save memory. 53 | """ 54 | self._embedder = embedder 55 | self.params = params 56 | self.scores = scores 57 | 58 | def get_iterator(self, splt): 59 | """ 60 | Build data iterator. 61 | """ 62 | return self.data[splt]['x'].get_iterator( 63 | shuffle=(splt == 'train'), 64 | return_indices=True, 65 | group_by_size=self.params.group_by_size 66 | ) 67 | 68 | def run(self, task): 69 | """ 70 | Run GLUE training / evaluation. 71 | """ 72 | params = self.params 73 | 74 | # task parameters 75 | self.task = task 76 | params.out_features = N_CLASSES[task] 77 | self.is_classif = task != 'STS-B' 78 | 79 | # load data 80 | self.data = self.load_data(task) 81 | if not self.data['dico'] == self._embedder.dico: 82 | raise Exception(("Dictionary in evaluation data (%i words) seems different than the one " + 83 | "in the pretrained model (%i words). Please verify you used the same dictionary, " + 84 | "and the same values for max_vocab and min_count.") % (len(self.data['dico']), len(self._embedder.dico))) 85 | 86 | # embedder 87 | self.embedder = copy.deepcopy(self._embedder) 88 | self.embedder.cuda() 89 | 90 | # projection layer 91 | self.proj = nn.Sequential(*[ 92 | nn.Dropout(params.dropout), 93 | nn.Linear(self.embedder.out_dim, params.out_features) 94 | ]).cuda() 95 | 96 | # optimizers 97 | self.optimizer_e = get_optimizer(list(self.embedder.get_parameters(params.finetune_layers)), params.optimizer_e) 98 | self.optimizer_p = get_optimizer(self.proj.parameters(), params.optimizer_p) 99 | 100 | # train and evaluate the model 101 | for epoch in range(params.n_epochs): 102 | 103 | # update epoch 104 | self.epoch = epoch 105 | 106 | # training 107 | logger.info("GLUE - %s - Training epoch %i ..." % (task, epoch)) 108 | self.train() 109 | 110 | # evaluation 111 | logger.info("GLUE - %s - Evaluating epoch %i ..." % (task, epoch)) 112 | with torch.no_grad(): 113 | scores = self.eval('valid') 114 | self.scores.update(scores) 115 | self.eval('test') 116 | 117 | def train(self): 118 | """ 119 | Finetune for one epoch on the training set. 120 | """ 121 | params = self.params 122 | self.embedder.train() 123 | self.proj.train() 124 | 125 | # training variables 126 | losses = [] 127 | ns = 0 # number of sentences 128 | nw = 0 # number of words 129 | t = time.time() 130 | 131 | iterator = self.get_iterator('train') 132 | lang_id = params.lang2id['en'] 133 | 134 | while True: 135 | 136 | # batch 137 | try: 138 | batch = next(iterator) 139 | except StopIteration: 140 | break 141 | if self.n_sent == 1: 142 | (x, lengths), idx = batch 143 | x, lengths = truncate(x, lengths, params.max_len, params.eos_index) 144 | else: 145 | (sent1, len1), (sent2, len2), idx = batch 146 | sent1, len1 = truncate(sent1, len1, params.max_len, params.eos_index) 147 | sent2, len2 = truncate(sent2, len2, params.max_len, params.eos_index) 148 | x, lengths, _, _ = concat_batches(sent1, len1, lang_id, sent2, len2, lang_id, params.pad_index, params.eos_index, reset_positions=False) 149 | y = self.data['train']['y'][idx] 150 | bs = len(lengths) 151 | 152 | # cuda 153 | x, y, lengths = to_cuda(x, y, lengths) 154 | 155 | # loss 156 | output = self.proj(self.embedder.get_embeddings(x, lengths, positions=None, langs=None)) 157 | if self.is_classif: 158 | loss = F.cross_entropy(output, y, weight=self.weights) 159 | else: 160 | loss = F.mse_loss(output.squeeze(1), y.float()) 161 | 162 | # backward / optimization 163 | self.optimizer_e.zero_grad() 164 | self.optimizer_p.zero_grad() 165 | loss.backward() 166 | self.optimizer_e.step() 167 | self.optimizer_p.step() 168 | 169 | # update statistics 170 | ns += bs 171 | nw += lengths.sum().item() 172 | losses.append(loss.item()) 173 | 174 | # log 175 | if ns != 0 and ns % (10 * bs) < bs: 176 | logger.info( 177 | "GLUE - %s - Epoch %s - Train iter %7i - %.1f words/s - %s Loss: %.4f" 178 | % (self.task, self.epoch, ns, nw / (time.time() - t), 'XE' if self.is_classif else 'MSE', sum(losses) / len(losses)) 179 | ) 180 | nw, t = 0, time.time() 181 | losses = [] 182 | 183 | # epoch size 184 | if params.epoch_size != -1 and ns >= params.epoch_size: 185 | break 186 | 187 | def eval(self, splt): 188 | """ 189 | Evaluate on XNLI validation and test sets, for all languages. 190 | """ 191 | params = self.params 192 | self.embedder.eval() 193 | self.proj.eval() 194 | 195 | assert splt in ['valid', 'test'] 196 | has_labels = 'y' in self.data[splt] 197 | 198 | scores = OrderedDict({'epoch': self.epoch}) 199 | task = self.task.lower() 200 | 201 | idxs = [] # sentence indices 202 | prob = [] # probabilities 203 | pred = [] # predicted values 204 | gold = [] # real values 205 | 206 | lang_id = params.lang2id['en'] 207 | 208 | for batch in self.get_iterator(splt): 209 | 210 | # batch 211 | if self.n_sent == 1: 212 | (x, lengths), idx = batch 213 | # x, lengths = truncate(x, lengths, params.max_len, params.eos_index) 214 | else: 215 | (sent1, len1), (sent2, len2), idx = batch 216 | # sent1, len1 = truncate(sent1, len1, params.max_len, params.eos_index) 217 | # sent2, len2 = truncate(sent2, len2, params.max_len, params.eos_index) 218 | x, lengths, _, _ = concat_batches(sent1, len1, lang_id, sent2, len2, lang_id, params.pad_index, params.eos_index, reset_positions=False) 219 | y = self.data[splt]['y'][idx] if has_labels else None 220 | 221 | # cuda 222 | x, y, lengths = to_cuda(x, y, lengths) 223 | 224 | # prediction 225 | output = self.proj(self.embedder.get_embeddings(x, lengths, positions=None, langs=None)) 226 | p = output.data.max(1)[1] if self.is_classif else output.squeeze(1) 227 | idxs.append(idx) 228 | prob.append(output.cpu().numpy()) 229 | pred.append(p.cpu().numpy()) 230 | if has_labels: 231 | gold.append(y.cpu().numpy()) 232 | 233 | # indices / predictions 234 | idxs = np.concatenate(idxs) 235 | prob = np.concatenate(prob) 236 | pred = np.concatenate(pred) 237 | assert len(idxs) == len(pred), (len(idxs), len(pred)) 238 | assert idxs[-1] == len(idxs) - 1, (idxs[-1], len(idxs) - 1) 239 | 240 | # score the predictions if we have labels 241 | if has_labels: 242 | gold = np.concatenate(gold) 243 | prefix = f'{splt}_{task}' 244 | if self.is_classif: 245 | scores['%s_acc' % prefix] = 100. * (pred == gold).sum() / len(pred) 246 | scores['%s_f1' % prefix] = 100. * f1_score(gold, pred, average='binary' if params.out_features == 2 else 'micro') 247 | scores['%s_mc' % prefix] = 100. * matthews_corrcoef(gold, pred) 248 | else: 249 | scores['%s_prs' % prefix] = 100. * pearsonr(pred, gold)[0] 250 | scores['%s_spr' % prefix] = 100. * spearmanr(pred, gold)[0] 251 | logger.info("__log__:%s" % json.dumps(scores)) 252 | 253 | # output predictions 254 | pred_path = os.path.join(params.dump_path, f'{splt}.pred.{self.epoch}') 255 | with open(pred_path, 'w') as f: 256 | for i, p in zip(idxs, prob): 257 | f.write('%i\t%s\n' % (i, ','.join([str(x) for x in p]))) 258 | logger.info(f"Wrote {len(idxs)} {splt} predictions to {pred_path}") 259 | 260 | return scores 261 | 262 | def load_data(self, task): 263 | """ 264 | Load pair regression/classification bi-sentence tasks 265 | """ 266 | params = self.params 267 | data = {splt: {} for splt in ['train', 'valid', 'test']} 268 | dpath = os.path.join(params.data_path, 'eval', task) 269 | 270 | self.n_sent = 1 if task in ['SST-2', 'CoLA'] else 2 271 | 272 | for splt in ['train', 'valid', 'test']: 273 | 274 | # load data and dictionary 275 | data1 = load_binarized(os.path.join(dpath, '%s.s1.pth' % splt), params) 276 | data2 = load_binarized(os.path.join(dpath, '%s.s2.pth' % splt), params) if self.n_sent == 2 else None 277 | data['dico'] = data.get('dico', data1['dico']) 278 | 279 | # set dictionary parameters 280 | set_dico_parameters(params, data, data1['dico']) 281 | if self.n_sent == 2: 282 | set_dico_parameters(params, data, data2['dico']) 283 | 284 | # create dataset 285 | if self.n_sent == 1: 286 | data[splt]['x'] = Dataset(data1['sentences'], data1['positions'], params) 287 | else: 288 | data[splt]['x'] = ParallelDataset( 289 | data1['sentences'], data1['positions'], 290 | data2['sentences'], data2['positions'], 291 | params 292 | ) 293 | 294 | # load labels 295 | if splt != 'test' or task in ['MRPC']: 296 | # read labels from file 297 | with open(os.path.join(dpath, '%s.label' % splt), 'r') as f: 298 | lines = [l.rstrip() for l in f] 299 | # STS-B task 300 | if task == 'STS-B': 301 | assert all(0 <= float(x) <= 5 for x in lines) 302 | y = [float(l) for l in lines] 303 | # QQP 304 | elif task == 'QQP': 305 | UNK_LABEL = 0 306 | lab2id = {x: i for i, x in enumerate(sorted(set(lines) - set([''])))} 307 | y = [lab2id.get(x, UNK_LABEL) for x in lines] 308 | # other tasks 309 | else: 310 | lab2id = {x: i for i, x in enumerate(sorted(set(lines)))} 311 | y = [lab2id[x] for x in lines] 312 | data[splt]['y'] = torch.LongTensor(y) 313 | assert len(data[splt]['x']) == len(data[splt]['y']) 314 | 315 | # compute weights for weighted training 316 | if task != 'STS-B' and params.weighted_training: 317 | weights = torch.FloatTensor([ 318 | 1.0 / (data['train']['y'] == i).sum().item() 319 | for i in range(len(lab2id)) 320 | ]).cuda() 321 | self.weights = weights / weights.sum() 322 | else: 323 | self.weights = None 324 | 325 | return data 326 | -------------------------------------------------------------------------------- /src/evaluation/multi-bleu.perl: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env perl 2 | # 3 | # This file is part of moses. Its use is licensed under the GNU Lesser General 4 | # Public License version 2.1 or, at your option, any later version. 5 | 6 | # $Id$ 7 | use warnings; 8 | use strict; 9 | 10 | my $lowercase = 0; 11 | if ($ARGV[0] eq "-lc") { 12 | $lowercase = 1; 13 | shift; 14 | } 15 | 16 | my $stem = $ARGV[0]; 17 | if (!defined $stem) { 18 | print STDERR "usage: multi-bleu.pl [-lc] reference < hypothesis\n"; 19 | print STDERR "Reads the references from reference or reference0, reference1, ...\n"; 20 | exit(1); 21 | } 22 | 23 | $stem .= ".ref" if !-e $stem && !-e $stem."0" && -e $stem.".ref0"; 24 | 25 | my @REF; 26 | my $ref=0; 27 | while(-e "$stem$ref") { 28 | &add_to_ref("$stem$ref",\@REF); 29 | $ref++; 30 | } 31 | &add_to_ref($stem,\@REF) if -e $stem; 32 | die("ERROR: could not find reference file $stem") unless scalar @REF; 33 | 34 | # add additional references explicitly specified on the command line 35 | shift; 36 | foreach my $stem (@ARGV) { 37 | &add_to_ref($stem,\@REF) if -e $stem; 38 | } 39 | 40 | 41 | 42 | sub add_to_ref { 43 | my ($file,$REF) = @_; 44 | my $s=0; 45 | if ($file =~ /.gz$/) { 46 | open(REF,"gzip -dc $file|") or die "Can't read $file"; 47 | } else { 48 | open(REF,$file) or die "Can't read $file"; 49 | } 50 | while() { 51 | chop; 52 | push @{$$REF[$s++]}, $_; 53 | } 54 | close(REF); 55 | } 56 | 57 | my(@CORRECT,@TOTAL,$length_translation,$length_reference); 58 | my $s=0; 59 | while() { 60 | chop; 61 | $_ = lc if $lowercase; 62 | my @WORD = split; 63 | my %REF_NGRAM = (); 64 | my $length_translation_this_sentence = scalar(@WORD); 65 | my ($closest_diff,$closest_length) = (9999,9999); 66 | foreach my $reference (@{$REF[$s]}) { 67 | # print "$s $_ <=> $reference\n"; 68 | $reference = lc($reference) if $lowercase; 69 | my @WORD = split(' ',$reference); 70 | my $length = scalar(@WORD); 71 | my $diff = abs($length_translation_this_sentence-$length); 72 | if ($diff < $closest_diff) { 73 | $closest_diff = $diff; 74 | $closest_length = $length; 75 | # print STDERR "$s: closest diff ".abs($length_translation_this_sentence-$length)." = abs($length_translation_this_sentence-$length), setting len: $closest_length\n"; 76 | } elsif ($diff == $closest_diff) { 77 | $closest_length = $length if $length < $closest_length; 78 | # from two references with the same closeness to me 79 | # take the *shorter* into account, not the "first" one. 80 | } 81 | for(my $n=1;$n<=4;$n++) { 82 | my %REF_NGRAM_N = (); 83 | for(my $start=0;$start<=$#WORD-($n-1);$start++) { 84 | my $ngram = "$n"; 85 | for(my $w=0;$w<$n;$w++) { 86 | $ngram .= " ".$WORD[$start+$w]; 87 | } 88 | $REF_NGRAM_N{$ngram}++; 89 | } 90 | foreach my $ngram (keys %REF_NGRAM_N) { 91 | if (!defined($REF_NGRAM{$ngram}) || 92 | $REF_NGRAM{$ngram} < $REF_NGRAM_N{$ngram}) { 93 | $REF_NGRAM{$ngram} = $REF_NGRAM_N{$ngram}; 94 | # print "$i: REF_NGRAM{$ngram} = $REF_NGRAM{$ngram}
\n"; 95 | } 96 | } 97 | } 98 | } 99 | $length_translation += $length_translation_this_sentence; 100 | $length_reference += $closest_length; 101 | for(my $n=1;$n<=4;$n++) { 102 | my %T_NGRAM = (); 103 | for(my $start=0;$start<=$#WORD-($n-1);$start++) { 104 | my $ngram = "$n"; 105 | for(my $w=0;$w<$n;$w++) { 106 | $ngram .= " ".$WORD[$start+$w]; 107 | } 108 | $T_NGRAM{$ngram}++; 109 | } 110 | foreach my $ngram (keys %T_NGRAM) { 111 | $ngram =~ /^(\d+) /; 112 | my $n = $1; 113 | # my $corr = 0; 114 | # print "$i e $ngram $T_NGRAM{$ngram}
\n"; 115 | $TOTAL[$n] += $T_NGRAM{$ngram}; 116 | if (defined($REF_NGRAM{$ngram})) { 117 | if ($REF_NGRAM{$ngram} >= $T_NGRAM{$ngram}) { 118 | $CORRECT[$n] += $T_NGRAM{$ngram}; 119 | # $corr = $T_NGRAM{$ngram}; 120 | # print "$i e correct1 $T_NGRAM{$ngram}
\n"; 121 | } 122 | else { 123 | $CORRECT[$n] += $REF_NGRAM{$ngram}; 124 | # $corr = $REF_NGRAM{$ngram}; 125 | # print "$i e correct2 $REF_NGRAM{$ngram}
\n"; 126 | } 127 | } 128 | # $REF_NGRAM{$ngram} = 0 if !defined $REF_NGRAM{$ngram}; 129 | # print STDERR "$ngram: {$s, $REF_NGRAM{$ngram}, $T_NGRAM{$ngram}, $corr}\n" 130 | } 131 | } 132 | $s++; 133 | } 134 | my $brevity_penalty = 1; 135 | my $bleu = 0; 136 | 137 | my @bleu=(); 138 | 139 | for(my $n=1;$n<=4;$n++) { 140 | if (defined ($TOTAL[$n])){ 141 | $bleu[$n]=($TOTAL[$n])?$CORRECT[$n]/$TOTAL[$n]:0; 142 | # print STDERR "CORRECT[$n]:$CORRECT[$n] TOTAL[$n]:$TOTAL[$n]\n"; 143 | }else{ 144 | $bleu[$n]=0; 145 | } 146 | } 147 | 148 | if ($length_reference==0){ 149 | printf "BLEU = 0, 0/0/0/0 (BP=0, ratio=0, hyp_len=0, ref_len=0)\n"; 150 | exit(1); 151 | } 152 | 153 | if ($length_translation<$length_reference) { 154 | $brevity_penalty = exp(1-$length_reference/$length_translation); 155 | } 156 | $bleu = $brevity_penalty * exp((my_log( $bleu[1] ) + 157 | my_log( $bleu[2] ) + 158 | my_log( $bleu[3] ) + 159 | my_log( $bleu[4] ) ) / 4) ; 160 | printf "BLEU = %.2f, %.1f/%.1f/%.1f/%.1f (BP=%.3f, ratio=%.3f, hyp_len=%d, ref_len=%d)\n", 161 | 100*$bleu, 162 | 100*$bleu[1], 163 | 100*$bleu[2], 164 | 100*$bleu[3], 165 | 100*$bleu[4], 166 | $brevity_penalty, 167 | $length_translation / $length_reference, 168 | $length_translation, 169 | $length_reference; 170 | 171 | 172 | # print STDERR "It is in-advisable to publish scores from multi-bleu.perl. The scores depend on your tokenizer, which is unlikely to be reproducible from your paper or consistent across research groups. Instead you should detokenize then use mteval-v14.pl, which has a standard tokenization. Scores from multi-bleu.perl can still be used for internal purposes when you have a consistent tokenizer.\n"; 173 | 174 | sub my_log { 175 | return -9999999999 unless $_[0]; 176 | return log($_[0]); 177 | } 178 | -------------------------------------------------------------------------------- /src/evaluation/xnli.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | from logging import getLogger 9 | import os 10 | import copy 11 | import time 12 | import json 13 | from collections import OrderedDict 14 | 15 | import torch 16 | from torch import nn 17 | import torch.nn.functional as F 18 | 19 | from ..optim import get_optimizer 20 | from ..utils import concat_batches, truncate, to_cuda 21 | from ..data.dataset import ParallelDataset 22 | from ..data.loader import load_binarized, set_dico_parameters 23 | 24 | 25 | XNLI_LANGS = ['ar', 'bg', 'de', 'el', 'en', 'es', 'fr', 'hi', 'ru', 'sw', 'th', 'tr', 'ur', 'vi', 'zh'] 26 | 27 | 28 | logger = getLogger() 29 | 30 | 31 | class XNLI: 32 | 33 | def __init__(self, embedder, scores, params): 34 | """ 35 | Initialize XNLI trainer / evaluator. 36 | Initial `embedder` should be on CPU to save memory. 37 | """ 38 | self._embedder = embedder 39 | self.params = params 40 | self.scores = scores 41 | 42 | def get_iterator(self, splt, lang): 43 | """ 44 | Get a monolingual data iterator. 45 | """ 46 | assert splt in ['valid', 'test'] or splt == 'train' and lang == 'en' 47 | return self.data[lang][splt]['x'].get_iterator( 48 | shuffle=(splt == 'train'), 49 | group_by_size=self.params.group_by_size, 50 | return_indices=True 51 | ) 52 | 53 | def run(self): 54 | """ 55 | Run XNLI training / evaluation. 56 | """ 57 | params = self.params 58 | 59 | # load data 60 | self.data = self.load_data() 61 | if not self.data['dico'] == self._embedder.dico: 62 | raise Exception(("Dictionary in evaluation data (%i words) seems different than the one " + 63 | "in the pretrained model (%i words). Please verify you used the same dictionary, " + 64 | "and the same values for max_vocab and min_count.") % (len(self.data['dico']), len(self._embedder.dico))) 65 | 66 | # embedder 67 | self.embedder = copy.deepcopy(self._embedder) 68 | self.embedder.cuda() 69 | 70 | # projection layer 71 | self.proj = nn.Sequential(*[ 72 | nn.Dropout(params.dropout), 73 | nn.Linear(self.embedder.out_dim, 3) 74 | ]).cuda() 75 | 76 | # optimizers 77 | self.optimizer_e = get_optimizer(list(self.embedder.get_parameters(params.finetune_layers)), params.optimizer_e) 78 | self.optimizer_p = get_optimizer(self.proj.parameters(), params.optimizer_p) 79 | 80 | # train and evaluate the model 81 | for epoch in range(params.n_epochs): 82 | 83 | # update epoch 84 | self.epoch = epoch 85 | 86 | # training 87 | logger.info("XNLI - Training epoch %i ..." % epoch) 88 | self.train() 89 | 90 | # evaluation 91 | logger.info("XNLI - Evaluating epoch %i ..." % epoch) 92 | with torch.no_grad(): 93 | scores = self.eval() 94 | self.scores.update(scores) 95 | 96 | def train(self): 97 | """ 98 | Finetune for one epoch on the XNLI English training set. 99 | """ 100 | params = self.params 101 | self.embedder.train() 102 | self.proj.train() 103 | 104 | # training variables 105 | losses = [] 106 | ns = 0 # number of sentences 107 | nw = 0 # number of words 108 | t = time.time() 109 | 110 | iterator = self.get_iterator('train', 'en') 111 | lang_id = params.lang2id['en'] 112 | 113 | while True: 114 | 115 | # batch 116 | try: 117 | batch = next(iterator) 118 | except StopIteration: 119 | break 120 | (sent1, len1), (sent2, len2), idx = batch 121 | sent1, len1 = truncate(sent1, len1, params.max_len, params.eos_index) 122 | sent2, len2 = truncate(sent2, len2, params.max_len, params.eos_index) 123 | x, lengths, positions, langs = concat_batches( 124 | sent1, len1, lang_id, 125 | sent2, len2, lang_id, 126 | params.pad_index, 127 | params.eos_index, 128 | reset_positions=False 129 | ) 130 | y = self.data['en']['train']['y'][idx] 131 | bs = len(len1) 132 | 133 | # cuda 134 | x, y, lengths, positions, langs = to_cuda(x, y, lengths, positions, langs) 135 | 136 | # loss 137 | output = self.proj(self.embedder.get_embeddings(x, lengths, positions, langs)) 138 | loss = F.cross_entropy(output, y) 139 | 140 | # backward / optimization 141 | self.optimizer_e.zero_grad() 142 | self.optimizer_p.zero_grad() 143 | loss.backward() 144 | self.optimizer_e.step() 145 | self.optimizer_p.step() 146 | 147 | # update statistics 148 | ns += bs 149 | nw += lengths.sum().item() 150 | losses.append(loss.item()) 151 | 152 | # log 153 | if ns % (100 * bs) < bs: 154 | logger.info("XNLI - Epoch %i - Train iter %7i - %.1f words/s - Loss: %.4f" % (self.epoch, ns, nw / (time.time() - t), sum(losses) / len(losses))) 155 | nw, t = 0, time.time() 156 | losses = [] 157 | 158 | # epoch size 159 | if params.epoch_size != -1 and ns >= params.epoch_size: 160 | break 161 | 162 | def eval(self): 163 | """ 164 | Evaluate on XNLI validation and test sets, for all languages. 165 | """ 166 | params = self.params 167 | self.embedder.eval() 168 | self.proj.eval() 169 | 170 | scores = OrderedDict({'epoch': self.epoch}) 171 | 172 | for splt in ['valid', 'test']: 173 | 174 | for lang in XNLI_LANGS: 175 | if lang not in params.lang2id: 176 | continue 177 | 178 | lang_id = params.lang2id[lang] 179 | valid = 0 180 | total = 0 181 | 182 | for batch in self.get_iterator(splt, lang): 183 | 184 | # batch 185 | (sent1, len1), (sent2, len2), idx = batch 186 | x, lengths, positions, langs = concat_batches( 187 | sent1, len1, lang_id, 188 | sent2, len2, lang_id, 189 | params.pad_index, 190 | params.eos_index, 191 | reset_positions=False 192 | ) 193 | y = self.data[lang][splt]['y'][idx] 194 | 195 | # cuda 196 | x, y, lengths, positions, langs = to_cuda(x, y, lengths, positions, langs) 197 | 198 | # forward 199 | output = self.proj(self.embedder.get_embeddings(x, lengths, positions, langs)) 200 | predictions = output.data.max(1)[1] 201 | 202 | # update statistics 203 | valid += predictions.eq(y).sum().item() 204 | total += len(len1) 205 | 206 | # compute accuracy 207 | acc = 100.0 * valid / total 208 | scores['xnli_%s_%s_acc' % (splt, lang)] = acc 209 | logger.info("XNLI - %s - %s - Epoch %i - Acc: %.1f%%" % (splt, lang, self.epoch, acc)) 210 | 211 | logger.info("__log__:%s" % json.dumps(scores)) 212 | return scores 213 | 214 | def load_data(self): 215 | """ 216 | Load XNLI cross-lingual classification data. 217 | """ 218 | params = self.params 219 | data = {lang: {splt: {} for splt in ['train', 'valid', 'test']} for lang in XNLI_LANGS} 220 | label2id = {'contradiction': 0, 'neutral': 1, 'entailment': 2} 221 | dpath = os.path.join(params.data_path, 'eval', 'XNLI') 222 | 223 | for splt in ['train', 'valid', 'test']: 224 | 225 | for lang in XNLI_LANGS: 226 | 227 | # only English has a training set 228 | if splt == 'train' and lang != 'en': 229 | del data[lang]['train'] 230 | continue 231 | 232 | # load data and dictionary 233 | data1 = load_binarized(os.path.join(dpath, '%s.s1.%s.pth' % (splt, lang)), params) 234 | data2 = load_binarized(os.path.join(dpath, '%s.s2.%s.pth' % (splt, lang)), params) 235 | data['dico'] = data.get('dico', data1['dico']) 236 | 237 | # set dictionary parameters 238 | set_dico_parameters(params, data, data1['dico']) 239 | set_dico_parameters(params, data, data2['dico']) 240 | 241 | # create dataset 242 | data[lang][splt]['x'] = ParallelDataset( 243 | data1['sentences'], data1['positions'], 244 | data2['sentences'], data2['positions'], 245 | params 246 | ) 247 | 248 | # load labels 249 | with open(os.path.join(dpath, '%s.label.%s' % (splt, lang)), 'r') as f: 250 | labels = [label2id[l.rstrip()] for l in f] 251 | data[lang][splt]['y'] = torch.LongTensor(labels) 252 | assert len(data[lang][splt]['x']) == len(data[lang][splt]['y']) 253 | 254 | return data 255 | -------------------------------------------------------------------------------- /src/logger.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import logging 9 | import time 10 | from datetime import timedelta 11 | 12 | 13 | class LogFormatter(): 14 | 15 | def __init__(self): 16 | self.start_time = time.time() 17 | 18 | def format(self, record): 19 | elapsed_seconds = round(record.created - self.start_time) 20 | 21 | prefix = "%s - %s - %s" % ( 22 | record.levelname, 23 | time.strftime('%x %X'), 24 | timedelta(seconds=elapsed_seconds) 25 | ) 26 | message = record.getMessage() 27 | message = message.replace('\n', '\n' + ' ' * (len(prefix) + 3)) 28 | return "%s - %s" % (prefix, message) if message else '' 29 | 30 | 31 | def create_logger(filepath, rank): 32 | """ 33 | Create a logger. 34 | Use a different log file for each process. 35 | """ 36 | # create log formatter 37 | log_formatter = LogFormatter() 38 | 39 | # create file handler and set level to debug 40 | if filepath is not None: 41 | if rank > 0: 42 | filepath = '%s-%i' % (filepath, rank) 43 | file_handler = logging.FileHandler(filepath, "a") 44 | file_handler.setLevel(logging.DEBUG) 45 | file_handler.setFormatter(log_formatter) 46 | 47 | # create console handler and set level to info 48 | console_handler = logging.StreamHandler() 49 | console_handler.setLevel(logging.INFO) 50 | console_handler.setFormatter(log_formatter) 51 | 52 | # create logger and set level to debug 53 | logger = logging.getLogger() 54 | logger.handlers = [] 55 | logger.setLevel(logging.DEBUG) 56 | logger.propagate = False 57 | if filepath is not None: 58 | logger.addHandler(file_handler) 59 | logger.addHandler(console_handler) 60 | 61 | # reset logger elapsed time 62 | def reset_time(): 63 | log_formatter.start_time = time.time() 64 | logger.reset_time = reset_time 65 | 66 | return logger 67 | -------------------------------------------------------------------------------- /src/logger.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jind11/DAMT/3caa22822b936137e5da3e827d7a5a2078c3115e/src/logger.pyc -------------------------------------------------------------------------------- /src/model/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | from logging import getLogger 9 | import os 10 | import torch 11 | 12 | from .pretrain import load_embeddings 13 | from .transformer import DECODER_ONLY_PARAMS, TransformerModel # , TRANSFORMER_LAYER_PARAMS 14 | from .memory import HashingMemory 15 | 16 | 17 | logger = getLogger() 18 | 19 | 20 | def check_model_params(params): 21 | """ 22 | Check models parameters. 23 | """ 24 | # masked language modeling task parameters 25 | assert params.bptt >= 1 26 | assert 0 <= params.word_pred < 1 27 | assert 0 <= params.sample_alpha < 1 28 | s = params.word_mask_keep_rand.split(',') 29 | assert len(s) == 3 30 | s = [float(x) for x in s] 31 | assert all([0 <= x <= 1 for x in s]) and sum(s) == 1 32 | params.word_mask = s[0] 33 | params.word_keep = s[1] 34 | params.word_rand = s[2] 35 | 36 | # input sentence noise for DAE 37 | if len(params.ae_steps) == 0: 38 | assert params.word_shuffle == 0 39 | assert params.word_dropout == 0 40 | assert params.word_blank == 0 41 | else: 42 | assert params.word_shuffle == 0 or params.word_shuffle > 1 43 | assert 0 <= params.word_dropout < 1 44 | assert 0 <= params.word_blank < 1 45 | 46 | # model dimensions 47 | assert params.emb_dim % params.n_heads == 0 48 | 49 | # share input and output embeddings 50 | assert params.share_inout_emb is False or params.asm is False 51 | 52 | # adaptive softmax 53 | if params.asm: 54 | assert params.asm_div_value > 1 55 | s = params.asm_cutoffs.split(',') 56 | assert all([x.isdigit() for x in s]) 57 | params.asm_cutoffs = [int(x) for x in s] 58 | assert params.max_vocab == -1 or params.asm_cutoffs[-1] < params.max_vocab 59 | 60 | # memory 61 | if params.use_memory: 62 | HashingMemory.check_params(params) 63 | s_enc = [x for x in params.mem_enc_positions.split(',') if x != ''] 64 | s_dec = [x for x in params.mem_dec_positions.split(',') if x != ''] 65 | assert len(s_enc) == len(set(s_enc)) 66 | assert len(s_dec) == len(set(s_dec)) 67 | assert all(x.isdigit() or x[-1] == '+' and x[:-1].isdigit() for x in s_enc) 68 | assert all(x.isdigit() or x[-1] == '+' and x[:-1].isdigit() for x in s_dec) 69 | params.mem_enc_positions = [(int(x[:-1]), 'after') if x[-1] == '+' else (int(x), 'in') for x in s_enc] 70 | params.mem_dec_positions = [(int(x[:-1]), 'after') if x[-1] == '+' else (int(x), 'in') for x in s_dec] 71 | assert len(params.mem_enc_positions) + len(params.mem_dec_positions) > 0 72 | assert len(params.mem_enc_positions) == 0 or 0 <= min([x[0] for x in params.mem_enc_positions]) <= max([x[0] for x in params.mem_enc_positions]) <= params.n_layers - 1 73 | assert len(params.mem_dec_positions) == 0 or 0 <= min([x[0] for x in params.mem_dec_positions]) <= max([x[0] for x in params.mem_dec_positions]) <= params.n_layers - 1 74 | 75 | # reload pretrained word embeddings 76 | if params.reload_emb != '': 77 | assert os.path.isfile(params.reload_emb) 78 | 79 | # reload a pretrained model 80 | if params.reload_model != '': 81 | if params.encoder_only: 82 | assert os.path.isfile(params.reload_model) 83 | else: 84 | s = params.reload_model.split(',') 85 | assert len(s) == 2 86 | assert all([x == '' or os.path.isfile(x) for x in s]) 87 | 88 | 89 | def set_pretrain_emb(model, dico, word2id, embeddings): 90 | """ 91 | Pretrain word embeddings. 92 | """ 93 | n_found = 0 94 | with torch.no_grad(): 95 | for i in range(len(dico)): 96 | idx = word2id.get(dico[i], None) 97 | if idx is None: 98 | continue 99 | n_found += 1 100 | model.embeddings.weight[i] = embeddings[idx].cuda() 101 | model.pred_layer.proj.weight[i] = embeddings[idx].cuda() 102 | logger.info("Pretrained %i/%i words (%.3f%%)." 103 | % (n_found, len(dico), 100. * n_found / len(dico))) 104 | 105 | 106 | def build_model(params, dico): 107 | """ 108 | Build model. 109 | """ 110 | if params.encoder_only: 111 | # build 112 | model = TransformerModel(params, dico, is_encoder=True, with_output=True) 113 | 114 | # reload pretrained word embeddings 115 | if params.reload_emb != '': 116 | word2id, embeddings = load_embeddings(params.reload_emb, params) 117 | set_pretrain_emb(model, dico, word2id, embeddings) 118 | 119 | # reload a pretrained model 120 | if params.reload_model != '': 121 | logger.info("Reloading model from %s ..." % params.reload_model) 122 | reloaded = torch.load(params.reload_model, map_location=lambda storage, loc: storage.cuda(params.local_rank))['model'] 123 | if all([k.startswith('module.') for k in reloaded.keys()]): 124 | reloaded = {k[len('module.'):]: v for k, v in reloaded.items()} 125 | 126 | # # HACK to reload models with less layers 127 | # for i in range(12, 24): 128 | # for k in TRANSFORMER_LAYER_PARAMS: 129 | # k = k % i 130 | # if k in model.state_dict() and k not in reloaded: 131 | # logger.warning("Parameter %s not found. Ignoring ..." % k) 132 | # reloaded[k] = model.state_dict()[k] 133 | 134 | model.load_state_dict(reloaded) 135 | 136 | logger.info("Model: {}".format(model)) 137 | logger.info("Number of parameters (model): %i" % sum([p.numel() for p in model.parameters() if p.requires_grad])) 138 | 139 | return model.cuda() 140 | 141 | else: 142 | # build 143 | encoder = TransformerModel(params, dico, is_encoder=True, with_output=True) # TODO: only output when necessary - len(params.clm_steps + params.mlm_steps) > 0 144 | decoder = TransformerModel(params, dico, is_encoder=False, with_output=True) 145 | 146 | # reload pretrained word embeddings 147 | if params.reload_emb != '': 148 | word2id, embeddings = load_embeddings(params.reload_emb, params) 149 | set_pretrain_emb(encoder, dico, word2id, embeddings) 150 | set_pretrain_emb(decoder, dico, word2id, embeddings) 151 | 152 | # reload a pretrained model 153 | if params.reload_model != '': 154 | enc_path, dec_path = params.reload_model.split(',') 155 | assert not (enc_path == '' and dec_path == '') 156 | 157 | # reload encoder 158 | if enc_path != '': 159 | logger.info("Reloading encoder from %s ..." % enc_path) 160 | enc_reload = torch.load(enc_path, map_location=lambda storage, loc: storage.cuda(params.local_rank)) 161 | enc_reload = enc_reload['model' if 'model' in enc_reload else 'encoder'] 162 | if all([k.startswith('module.') for k in enc_reload.keys()]): 163 | enc_reload = {k[len('module.'):]: v for k, v in enc_reload.items()} 164 | encoder.load_state_dict(enc_reload) 165 | 166 | # reload decoder 167 | if dec_path != '': 168 | logger.info("Reloading decoder from %s ..." % dec_path) 169 | dec_reload = torch.load(dec_path, map_location=lambda storage, loc: storage.cuda(params.local_rank)) 170 | dec_reload = dec_reload['model' if 'model' in dec_reload else 'decoder'] 171 | if all([k.startswith('module.') for k in dec_reload.keys()]): 172 | dec_reload = {k[len('module.'):]: v for k, v in dec_reload.items()} 173 | for i in range(params.n_layers): 174 | for name in DECODER_ONLY_PARAMS: 175 | if name % i not in dec_reload: 176 | logger.warning("Parameter %s not found." % (name % i)) 177 | dec_reload[name % i] = decoder.state_dict()[name % i] 178 | decoder.load_state_dict(dec_reload) 179 | 180 | logger.debug("Encoder: {}".format(encoder)) 181 | logger.debug("Decoder: {}".format(decoder)) 182 | logger.info("Number of parameters (encoder): %i" % sum([p.numel() for p in encoder.parameters() if p.requires_grad])) 183 | logger.info("Number of parameters (decoder): %i" % sum([p.numel() for p in decoder.parameters() if p.requires_grad])) 184 | 185 | return encoder.cuda(), decoder.cuda() 186 | -------------------------------------------------------------------------------- /src/model/embedder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | from logging import getLogger 9 | import torch 10 | 11 | from .transformer import TransformerModel 12 | from ..data.dictionary import Dictionary, BOS_WORD, EOS_WORD, PAD_WORD, UNK_WORD, MASK_WORD 13 | from ..utils import AttrDict 14 | 15 | 16 | logger = getLogger() 17 | 18 | 19 | class SentenceEmbedder(object): 20 | 21 | @staticmethod 22 | def reload(path, params): 23 | """ 24 | Create a sentence embedder from a pretrained model. 25 | """ 26 | # reload model 27 | reloaded = torch.load(path) 28 | state_dict = reloaded['model'] 29 | 30 | # handle models from multi-GPU checkpoints 31 | if 'checkpoint' in path: 32 | state_dict = {(k[7:] if k.startswith('module.') else k): v for k, v in state_dict.items()} 33 | 34 | # reload dictionary and model parameters 35 | dico = Dictionary(reloaded['dico_id2word'], reloaded['dico_word2id'], reloaded['dico_counts']) 36 | pretrain_params = AttrDict(reloaded['params']) 37 | pretrain_params.n_words = len(dico) 38 | pretrain_params.bos_index = dico.index(BOS_WORD) 39 | pretrain_params.eos_index = dico.index(EOS_WORD) 40 | pretrain_params.pad_index = dico.index(PAD_WORD) 41 | pretrain_params.unk_index = dico.index(UNK_WORD) 42 | pretrain_params.mask_index = dico.index(MASK_WORD) 43 | 44 | # build model and reload weights 45 | model = TransformerModel(pretrain_params, dico, True, True) 46 | model.load_state_dict(state_dict) 47 | model.eval() 48 | 49 | # adding missing parameters 50 | params.max_batch_size = 0 51 | 52 | return SentenceEmbedder(model, dico, pretrain_params) 53 | 54 | def __init__(self, model, dico, pretrain_params): 55 | """ 56 | Wrapper on top of the different sentence embedders. 57 | Returns sequence-wise or single-vector sentence representations. 58 | """ 59 | self.pretrain_params = {k: v for k, v in pretrain_params.__dict__.items()} 60 | self.model = model 61 | self.dico = dico 62 | self.n_layers = model.n_layers 63 | self.out_dim = model.dim 64 | self.n_words = model.n_words 65 | 66 | def train(self): 67 | self.model.train() 68 | 69 | def eval(self): 70 | self.model.eval() 71 | 72 | def cuda(self): 73 | self.model.cuda() 74 | 75 | def get_parameters(self, layer_range): 76 | 77 | s = layer_range.split(':') 78 | assert len(s) == 2 79 | i, j = int(s[0].replace('_', '-')), int(s[1].replace('_', '-')) 80 | 81 | # negative indexing 82 | i = self.n_layers + i + 1 if i < 0 else i 83 | j = self.n_layers + j + 1 if j < 0 else j 84 | 85 | # sanity check 86 | assert 0 <= i <= self.n_layers 87 | assert 0 <= j <= self.n_layers 88 | 89 | if i > j: 90 | return [] 91 | 92 | parameters = [] 93 | 94 | # embeddings 95 | if i == 0: 96 | # embeddings 97 | parameters += self.model.embeddings.parameters() 98 | logger.info("Adding embedding parameters to optimizer") 99 | # positional embeddings 100 | if self.pretrain_params['sinusoidal_embeddings'] is False: 101 | parameters += self.model.position_embeddings.parameters() 102 | logger.info("Adding positional embedding parameters to optimizer") 103 | # language embeddings 104 | if hasattr(self.model, 'lang_embeddings'): 105 | parameters += self.model.lang_embeddings.parameters() 106 | logger.info("Adding language embedding parameters to optimizer") 107 | parameters += self.model.layer_norm_emb.parameters() 108 | # layers 109 | for l in range(max(i - 1, 0), j): 110 | parameters += self.model.attentions[l].parameters() 111 | parameters += self.model.layer_norm1[l].parameters() 112 | parameters += self.model.ffns[l].parameters() 113 | parameters += self.model.layer_norm2[l].parameters() 114 | logger.info("Adding layer-%s parameters to optimizer" % (l + 1)) 115 | 116 | logger.info("Optimizing on %i Transformer elements." % sum([p.nelement() for p in parameters])) 117 | 118 | return parameters 119 | 120 | def get_embeddings(self, x, lengths, positions=None, langs=None): 121 | """ 122 | Inputs: 123 | `x` : LongTensor of shape (slen, bs) 124 | `lengths` : LongTensor of shape (bs,) 125 | Outputs: 126 | `sent_emb` : FloatTensor of shape (bs, out_dim) 127 | With out_dim == emb_dim 128 | """ 129 | slen, bs = x.size() 130 | assert lengths.size(0) == bs and lengths.max().item() == slen 131 | 132 | # get transformer last hidden layer 133 | tensor = self.model('fwd', x=x, lengths=lengths, positions=positions, langs=langs, causal=False) 134 | assert tensor.size() == (slen, bs, self.out_dim) 135 | 136 | # single-vector sentence representation (first column of last layer) 137 | return tensor[0] 138 | -------------------------------------------------------------------------------- /src/model/memory/__init__.py: -------------------------------------------------------------------------------- 1 | from .memory import HashingMemory 2 | -------------------------------------------------------------------------------- /src/model/memory/query.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | from .utils import get_slices 5 | 6 | 7 | def mlp(sizes, bias=True, batchnorm=True, groups=1): 8 | """ 9 | Generate a feedforward neural network. 10 | """ 11 | assert len(sizes) >= 2 12 | pairs = [(sizes[i], sizes[i + 1]) for i in range(len(sizes) - 1)] 13 | layers = [] 14 | 15 | for i, (dim_in, dim_out) in enumerate(pairs): 16 | if groups == 1 or i == 0: 17 | layers.append(nn.Linear(dim_in, groups * dim_out, bias=bias)) 18 | else: 19 | layers.append(GroupedLinear(groups * dim_in, groups * dim_out, bias=bias, groups=groups)) 20 | if batchnorm: 21 | layers.append(nn.BatchNorm1d(groups * dim_out)) 22 | if i < len(pairs) - 1: 23 | layers.append(nn.ReLU()) 24 | 25 | return nn.Sequential(*layers) 26 | 27 | 28 | def convs(channel_sizes, kernel_sizes, bias=True, batchnorm=True, residual=False, groups=1): 29 | """ 30 | Generate a convolutional neural network. 31 | """ 32 | assert len(channel_sizes) >= 2 33 | assert len(channel_sizes) == len(kernel_sizes) + 1 34 | pairs = [(channel_sizes[i], channel_sizes[i + 1]) for i in range(len(channel_sizes) - 1)] 35 | layers = [] 36 | 37 | for i, (dim_in, dim_out) in enumerate(pairs): 38 | ks = (kernel_sizes[i], kernel_sizes[i]) 39 | in_group = 1 if i == 0 else groups 40 | _dim_in = dim_in * in_group 41 | _dim_out = dim_out * groups 42 | if not residual: 43 | layers.append(nn.Conv2d(_dim_in, _dim_out, ks, padding=[k // 2 for k in ks], bias=bias, groups=in_group)) 44 | if batchnorm: 45 | layers.append(nn.BatchNorm2d(_dim_out)) 46 | if i < len(pairs) - 1: 47 | layers.append(nn.ReLU()) 48 | else: 49 | layers.append(BottleneckResidualConv2d( 50 | _dim_in, _dim_out, ks, bias=bias, 51 | batchnorm=batchnorm, groups=in_group 52 | )) 53 | if i == len(pairs) - 1: 54 | layers.append(nn.Conv2d(_dim_out, _dim_out, (1, 1), bias=bias)) 55 | 56 | return nn.Sequential(*layers) 57 | 58 | 59 | class GroupedLinear(nn.Module): 60 | 61 | def __init__(self, in_features, out_features, bias=True, groups=1): 62 | 63 | super().__init__() 64 | self.in_features = in_features 65 | self.out_features = out_features 66 | self.groups = groups 67 | self.bias = bias 68 | assert groups > 1 69 | 70 | self.layer = nn.Conv1d(in_features, out_features, bias=bias, kernel_size=1, groups=groups) 71 | 72 | def forward(self, input): 73 | assert input.dim() == 2 and input.size(1) == self.in_features 74 | return self.layer(input.unsqueeze(2)).squeeze(2) 75 | 76 | def extra_repr(self): 77 | return 'in_features={}, out_features={}, groups={}, bias={}'.format( 78 | self.in_features, self.out_features, self.groups, self.bias is not None 79 | ) 80 | 81 | 82 | class BottleneckResidualConv2d(nn.Module): 83 | 84 | def __init__(self, input_channels, output_channels, kernel_size, bias=True, batchnorm=True, groups=1): 85 | 86 | super().__init__() 87 | hidden_channels = min(input_channels, output_channels) 88 | assert all(k % 2 == 1 for k in kernel_size) 89 | 90 | self.conv1 = nn.Conv2d(input_channels, hidden_channels, kernel_size, padding=[k // 2 for k in kernel_size], bias=bias, groups=groups) 91 | self.conv2 = nn.Conv2d(hidden_channels, output_channels, kernel_size, padding=[k // 2 for k in kernel_size], bias=bias, groups=groups) 92 | self.act = nn.ReLU() 93 | 94 | self.batchnorm = batchnorm 95 | if self.batchnorm: 96 | self.bn1 = nn.BatchNorm2d(hidden_channels) 97 | self.bn2 = nn.BatchNorm2d(output_channels) 98 | 99 | if input_channels == output_channels: 100 | self.residual = nn.Sequential() 101 | else: 102 | self.residual = nn.Conv2d(input_channels, output_channels, (1, 1), bias=False, groups=groups) 103 | 104 | def forward(self, input): 105 | x = self.conv1(input) 106 | x = self.bn1(x) if self.batchnorm else x 107 | x = self.act(x) 108 | x = self.conv2(x) 109 | x = self.bn2(x) if self.batchnorm else x 110 | x = self.act(x + self.residual(input)) 111 | return x 112 | 113 | 114 | class QueryIdentity(nn.Module): 115 | 116 | def __init__(self, input_dim, heads, shuffle_hidden): 117 | super().__init__() 118 | self.input_dim = input_dim 119 | self.heads = heads 120 | self.shuffle_query = shuffle_hidden 121 | assert shuffle_hidden is False or heads > 1 122 | assert shuffle_hidden is False or self.input_dim % (2 ** self.heads) == 0 123 | if shuffle_hidden: 124 | self.slices = {head_id: get_slices(input_dim, head_id) for head_id in range(heads)} 125 | 126 | def forward(self, input): 127 | """ 128 | Generate queries from hidden states by either 129 | repeating them or creating some shuffled version. 130 | """ 131 | assert input.shape[-1] == self.input_dim 132 | input = input.contiguous().view(-1, self.input_dim) if input.dim() > 2 else input 133 | bs = len(input) 134 | 135 | if self.heads == 1: 136 | query = input 137 | 138 | elif not self.shuffle_query: 139 | query = input.unsqueeze(1).repeat(1, self.heads, 1) 140 | query = query.view(bs * self.heads, self.input_dim) 141 | 142 | else: 143 | query = torch.cat([ 144 | input[:, a:b] 145 | for head_id in range(self.heads) 146 | for a, b in self.slices[head_id] 147 | ], 1).view(bs * self.heads, self.input_dim) 148 | 149 | assert query.shape == (bs * self.heads, self.input_dim) 150 | return query 151 | 152 | 153 | class QueryMLP(nn.Module): 154 | 155 | def __init__( 156 | self, input_dim, heads, k_dim, product_quantization, multi_query_net, 157 | sizes, bias=True, batchnorm=True, grouped_conv=False 158 | ): 159 | super().__init__() 160 | self.input_dim = input_dim 161 | self.heads = heads 162 | self.k_dim = k_dim 163 | self.sizes = sizes 164 | self.grouped_conv = grouped_conv 165 | assert not multi_query_net or product_quantization or heads >= 2 166 | assert sizes[0] == input_dim 167 | assert sizes[-1] == (k_dim // 2) if multi_query_net else (heads * k_dim) 168 | assert self.grouped_conv is False or len(sizes) > 2 169 | 170 | # number of required MLPs 171 | self.groups = (2 * heads) if multi_query_net else 1 172 | 173 | # MLPs 174 | if self.grouped_conv: 175 | self.query_mlps = mlp(sizes, bias=bias, batchnorm=batchnorm, groups=self.groups) 176 | elif len(self.sizes) == 2: 177 | sizes_ = list(sizes) 178 | sizes_[-1] = sizes_[-1] * self.groups 179 | self.query_mlps = mlp(sizes_, bias=bias, batchnorm=batchnorm, groups=1) 180 | else: 181 | self.query_mlps = nn.ModuleList([ 182 | mlp(sizes, bias=bias, batchnorm=batchnorm, groups=1) 183 | for _ in range(self.groups) 184 | ]) 185 | 186 | def forward(self, input): 187 | """ 188 | Compute queries using either grouped 1D convolutions or ModuleList + concat. 189 | """ 190 | assert input.shape[-1] == self.input_dim 191 | input = input.contiguous().view(-1, self.input_dim) if input.dim() > 2 else input 192 | bs = len(input) 193 | 194 | if self.grouped_conv or len(self.sizes) == 2: 195 | query = self.query_mlps(input) 196 | else: 197 | outputs = [m(input) for m in self.query_mlps] 198 | query = torch.cat(outputs, 1) if len(outputs) > 1 else outputs[0] 199 | 200 | assert query.shape == (bs, self.heads * self.k_dim) 201 | return query.view(bs * self.heads, self.k_dim) 202 | 203 | 204 | class QueryConv(nn.Module): 205 | 206 | def __init__( 207 | self, input_dim, heads, k_dim, product_quantization, multi_query_net, 208 | sizes, kernel_sizes, bias=True, batchnorm=True, 209 | residual=False, grouped_conv=False 210 | ): 211 | super().__init__() 212 | self.input_dim = input_dim 213 | self.heads = heads 214 | self.k_dim = k_dim 215 | self.sizes = sizes 216 | self.grouped_conv = grouped_conv 217 | assert not multi_query_net or product_quantization or heads >= 2 218 | assert sizes[0] == input_dim 219 | assert sizes[-1] == (k_dim // 2) if multi_query_net else (heads * k_dim) 220 | assert self.grouped_conv is False or len(sizes) > 2 221 | assert len(sizes) == len(kernel_sizes) + 1 >= 2 and all(ks % 2 == 1 for ks in kernel_sizes) 222 | 223 | # number of required CNNs 224 | self.groups = (2 * heads) if multi_query_net else 1 225 | 226 | # CNNs 227 | if self.grouped_conv: 228 | self.query_convs = convs(sizes, kernel_sizes, bias=bias, batchnorm=batchnorm, residual=residual, groups=self.groups) 229 | elif len(self.sizes) == 2: 230 | sizes_ = list(sizes) 231 | sizes_[-1] = sizes_[-1] * self.groups 232 | self.query_convs = convs(sizes_, kernel_sizes, bias=bias, batchnorm=batchnorm, residual=residual, groups=1) 233 | else: 234 | self.query_convs = nn.ModuleList([ 235 | convs(sizes, kernel_sizes, bias=bias, batchnorm=batchnorm, residual=residual, groups=1) 236 | for _ in range(self.groups) 237 | ]) 238 | 239 | def forward(self, input): 240 | 241 | bs, nf, h, w = input.shape 242 | assert nf == self.input_dim 243 | 244 | if self.grouped_conv or len(self.sizes) == 2: 245 | query = self.query_convs(input) 246 | else: 247 | outputs = [m(input) for m in self.query_convs] 248 | query = torch.cat(outputs, 1) if len(outputs) > 1 else outputs[0] 249 | 250 | assert query.shape == (bs, self.heads * self.k_dim, h, w) 251 | query = query.transpose(1, 3).contiguous().view(bs * w * h * self.heads, self.k_dim) 252 | return query 253 | -------------------------------------------------------------------------------- /src/model/memory/utils.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import math 3 | import numpy as np 4 | import torch 5 | 6 | 7 | # load FAISS GPU library if available (dramatically accelerates the nearest neighbor search) 8 | try: 9 | import faiss 10 | FAISS_AVAILABLE = hasattr(faiss, 'StandardGpuResources') 11 | except ImportError: 12 | FAISS_AVAILABLE = False 13 | sys.stderr.write("FAISS library was not found.\n") 14 | 15 | 16 | def get_gaussian_keys(n_keys, dim, normalized, seed): 17 | """ 18 | Generate random Gaussian keys. 19 | """ 20 | rng = np.random.RandomState(seed) 21 | X = rng.randn(n_keys, dim) 22 | if normalized: 23 | X /= np.linalg.norm(X, axis=1, keepdims=True) 24 | return X.astype(np.float32) 25 | 26 | 27 | def get_uniform_keys(n_keys, dim, normalized, seed): 28 | """ 29 | Generate random uniform keys (same initialization as nn.Linear). 30 | """ 31 | rng = np.random.RandomState(seed) 32 | bound = 1 / math.sqrt(dim) 33 | X = rng.uniform(-bound, bound, (n_keys, dim)) 34 | if normalized: 35 | X /= np.linalg.norm(X, axis=1, keepdims=True) 36 | return X.astype(np.float32) 37 | 38 | 39 | def get_slices(dim, head_id): 40 | """ 41 | Generate slices of hidden dimensions. 42 | Used when there are multiple heads and/or different set of keys, 43 | and that there is no query network. 44 | """ 45 | if head_id == 0: 46 | return [(0, dim)] 47 | offset = dim // (2 ** (head_id + 1)) 48 | starts = np.arange(0, dim, offset) 49 | slices1 = [(x, x + offset) for i, x in enumerate(starts) if i % 2 == 0] 50 | slices2 = [(x, x + offset) for i, x in enumerate(starts) if i % 2 == 1] 51 | return slices1 + slices2 52 | 53 | 54 | def cartesian_product(a, b): 55 | """ 56 | Compute the batched cartesian product between two matrices. 57 | Input: 58 | a: Tensor(n, d1) 59 | b: Tensor(n, d2) 60 | Output: 61 | output: Tensor(n, d1 * d2, 2) 62 | """ 63 | n1, d1 = a.shape 64 | n2, d2 = b.shape 65 | assert n1 == n2 66 | return torch.cat([ 67 | a.unsqueeze(-1).repeat(1, 1, d2).unsqueeze(-1), 68 | b.repeat(1, d1).view(n2, d1, d2).unsqueeze(-1) 69 | ], 3).view(n1, d1 * d2, 2) 70 | 71 | 72 | def swig_ptr_from_FloatTensor(x): 73 | assert x.is_contiguous() 74 | assert x.dtype == torch.float32 75 | return faiss.cast_integer_to_float_ptr(x.storage().data_ptr() + x.storage_offset() * 4) 76 | 77 | 78 | def swig_ptr_from_LongTensor(x): 79 | assert x.is_contiguous() 80 | assert x.dtype == torch.int64, 'dtype=%s' % x.dtype 81 | return faiss.cast_integer_to_long_ptr(x.storage().data_ptr() + x.storage_offset() * 8) 82 | 83 | 84 | def get_knn_pytorch(a, b, k, distance='dot_product'): 85 | """ 86 | Input: 87 | - matrix of size (m, d) (keys) 88 | - matrix of size (n, d) (queries) 89 | - number of nearest neighbors 90 | - distance metric 91 | Output: 92 | - `scores` matrix of size (n, k) with nearest neighors scores 93 | - `indices` matrix of size (n, k) with nearest neighors indices 94 | """ 95 | m, d = a.size() 96 | n, _ = b.size() 97 | assert b.size(1) == d 98 | assert k > 0 99 | assert distance in ['dot_product', 'cosine', 'l2'] 100 | 101 | with torch.no_grad(): 102 | 103 | if distance == 'dot_product': 104 | scores = a.mm(b.t()) # (m, n) 105 | 106 | elif distance == 'cosine': 107 | scores = a.mm(b.t()) # (m, n) 108 | scores /= (a.norm(2, 1)[:, None] + 1e-9) # (m, n) 109 | scores /= (b.norm(2, 1)[None, :] + 1e-9) # (m, n) 110 | 111 | elif distance == 'l2': 112 | scores = a.mm(b.t()) # (m, n) 113 | scores *= 2 # (m, n) 114 | scores -= (a ** 2).sum(1)[:, None] # (m, n) 115 | scores -= (b ** 2).sum(1)[None, :] # (m, n) 116 | 117 | scores, indices = scores.topk(k=k, dim=0, largest=True) # (k, n) 118 | scores = scores.t() # (n, k) 119 | indices = indices.t() # (n, k) 120 | 121 | return scores, indices 122 | 123 | 124 | def get_knn_faiss(xb, xq, k, distance='dot_product'): 125 | """ 126 | `metric` can be faiss.METRIC_INNER_PRODUCT or faiss.METRIC_L2 127 | https://github.com/facebookresearch/faiss/blob/master/gpu/test/test_pytorch_faiss.py 128 | """ 129 | assert xb.device == xq.device 130 | assert distance in ['dot_product', 'l2'] 131 | metric = faiss.METRIC_INNER_PRODUCT if distance == 'dot_product' else faiss.METRIC_L2 132 | 133 | xq_ptr = swig_ptr_from_FloatTensor(xq) 134 | xb_ptr = swig_ptr_from_FloatTensor(xb) 135 | 136 | nq, d1 = xq.size() 137 | nb, d2 = xb.size() 138 | assert d1 == d2 139 | 140 | D = torch.empty(nq, k, device=xb.device, dtype=torch.float32) 141 | I = torch.empty(nq, k, device=xb.device, dtype=torch.int64) 142 | 143 | D_ptr = swig_ptr_from_FloatTensor(D) 144 | I_ptr = swig_ptr_from_LongTensor(I) 145 | 146 | faiss.bruteForceKnn( 147 | FAISS_RES, metric, 148 | xb_ptr, nb, 149 | xq_ptr, nq, 150 | d1, k, D_ptr, I_ptr 151 | ) 152 | 153 | return D, I 154 | 155 | 156 | if FAISS_AVAILABLE: 157 | FAISS_RES = faiss.StandardGpuResources() 158 | FAISS_RES.setDefaultNullStreamAllDevices() 159 | FAISS_RES.setTempMemory(1200 * 1024 * 1024) 160 | get_knn = get_knn_faiss 161 | else: 162 | sys.stderr.write("FAISS not available. Switching to standard nearest neighbors search implementation.\n") 163 | get_knn = get_knn_pytorch 164 | -------------------------------------------------------------------------------- /src/model/pretrain.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | from logging import getLogger 9 | import io 10 | import numpy as np 11 | import torch 12 | 13 | 14 | logger = getLogger() 15 | 16 | 17 | def load_fasttext_model(path): 18 | """ 19 | Load a binarized fastText model. 20 | """ 21 | try: 22 | import fastText 23 | except ImportError: 24 | raise Exception("Unable to import fastText. Please install fastText for Python: " 25 | "https://github.com/facebookresearch/fastText") 26 | return fastText.load_model(path) 27 | 28 | 29 | def read_txt_embeddings(path, params): 30 | """ 31 | Reload pretrained embeddings from a text file. 32 | """ 33 | word2id = {} 34 | vectors = [] 35 | 36 | # load pretrained embeddings 37 | _emb_dim_file = params.emb_dim 38 | with io.open(path, 'r', encoding='utf-8', newline='\n', errors='ignore') as f: 39 | for i, line in enumerate(f): 40 | if i == 0: 41 | split = line.split() 42 | assert len(split) == 2 43 | assert _emb_dim_file == int(split[1]) 44 | continue 45 | word, vect = line.rstrip().split(' ', 1) 46 | vect = np.fromstring(vect, sep=' ') 47 | if word in word2id: 48 | logger.warning("Word \"%s\" found twice!" % word) 49 | continue 50 | if not vect.shape == (_emb_dim_file,): 51 | logger.warning("Invalid dimension (%i) for word \"%s\" in line %i." 52 | % (vect.shape[0], word, i)) 53 | continue 54 | assert vect.shape == (_emb_dim_file,) 55 | word2id[word] = len(word2id) 56 | vectors.append(vect[None]) 57 | 58 | assert len(word2id) == len(vectors) 59 | logger.info("Loaded %i pretrained word embeddings from %s" % (len(vectors), path)) 60 | 61 | # compute new vocabulary / embeddings 62 | embeddings = np.concatenate(vectors, 0) 63 | embeddings = torch.from_numpy(embeddings).float() 64 | 65 | assert embeddings.size() == (len(word2id), params.emb_dim) 66 | return word2id, embeddings 67 | 68 | 69 | def load_bin_embeddings(path, params): 70 | """ 71 | Reload pretrained embeddings from a fastText binary file. 72 | """ 73 | model = load_fasttext_model(path) 74 | assert model.get_dimension() == params.emb_dim 75 | words = model.get_labels() 76 | logger.info("Loaded binary model from %s" % path) 77 | 78 | # compute new vocabulary / embeddings 79 | embeddings = np.concatenate([model.get_word_vector(w)[None] for w in words], 0) 80 | embeddings = torch.from_numpy(embeddings).float() 81 | word2id = {w: i for i, w in enumerate(words)} 82 | logger.info("Generated embeddings for %i words." % len(words)) 83 | 84 | assert embeddings.size() == (len(word2id), params.emb_dim) 85 | return word2id, embeddings 86 | 87 | 88 | def load_embeddings(path, params): 89 | """ 90 | Reload pretrained embeddings. 91 | """ 92 | if path.endswith('.bin'): 93 | return load_bin_embeddings(path, params) 94 | else: 95 | return read_txt_embeddings(path, params) 96 | -------------------------------------------------------------------------------- /src/optim.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import re 9 | import math 10 | import inspect 11 | 12 | import torch 13 | from torch import optim 14 | 15 | 16 | class Adam(optim.Optimizer): 17 | """ 18 | Same as https://github.com/pytorch/pytorch/blob/master/torch/optim/adam.py, 19 | without amsgrad, with step in a tensor, and states initialization in __init__. 20 | It was important to add `.item()` in `state['step'].item()`. 21 | """ 22 | 23 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0): 24 | if not 0.0 <= lr: 25 | raise ValueError("Invalid learning rate: {}".format(lr)) 26 | if not 0.0 <= eps: 27 | raise ValueError("Invalid epsilon value: {}".format(eps)) 28 | if not 0.0 <= betas[0] < 1.0: 29 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 30 | if not 0.0 <= betas[1] < 1.0: 31 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 32 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) 33 | super().__init__(params, defaults) 34 | 35 | for group in self.param_groups: 36 | for p in group['params']: 37 | state = self.state[p] 38 | state['step'] = 0 # torch.zeros(1) 39 | state['exp_avg'] = torch.zeros_like(p.data) 40 | state['exp_avg_sq'] = torch.zeros_like(p.data) 41 | 42 | def __setstate__(self, state): 43 | super().__setstate__(state) 44 | 45 | def step(self, closure=None): 46 | """ 47 | Step. 48 | """ 49 | loss = None 50 | if closure is not None: 51 | loss = closure() 52 | 53 | for group in self.param_groups: 54 | for p in group['params']: 55 | if p.grad is None: 56 | continue 57 | grad = p.grad.data 58 | if grad.is_sparse: 59 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') 60 | 61 | state = self.state[p] 62 | 63 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 64 | beta1, beta2 = group['betas'] 65 | 66 | state['step'] += 1 67 | 68 | # if group['weight_decay'] != 0: 69 | # grad.add_(group['weight_decay'], p.data) 70 | 71 | # Decay the first and second moment running average coefficient 72 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 73 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 74 | denom = exp_avg_sq.sqrt().add_(group['eps']) 75 | # denom = exp_avg_sq.sqrt().clamp_(min=group['eps']) 76 | 77 | bias_correction1 = 1 - beta1 ** state['step'] # .item() 78 | bias_correction2 = 1 - beta2 ** state['step'] # .item() 79 | step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1 80 | 81 | if group['weight_decay'] != 0: 82 | p.data.add_(-group['weight_decay'] * group['lr'], p.data) 83 | 84 | p.data.addcdiv_(-step_size, exp_avg, denom) 85 | 86 | return loss 87 | 88 | 89 | class AdamInverseSqrtWithWarmup(Adam): 90 | """ 91 | Decay the LR based on the inverse square root of the update number. 92 | We also support a warmup phase where we linearly increase the learning rate 93 | from some initial learning rate (`warmup-init-lr`) until the configured 94 | learning rate (`lr`). Thereafter we decay proportional to the number of 95 | updates, with a decay factor set to align with the configured learning rate. 96 | During warmup: 97 | lrs = torch.linspace(warmup_init_lr, lr, warmup_updates) 98 | lr = lrs[update_num] 99 | After warmup: 100 | lr = decay_factor / sqrt(update_num) 101 | where 102 | decay_factor = lr * sqrt(warmup_updates) 103 | """ 104 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, 105 | weight_decay=0, warmup_updates=4000, warmup_init_lr=1e-7, 106 | exp_factor=0.5): 107 | super().__init__( 108 | params, 109 | lr=warmup_init_lr, 110 | betas=betas, 111 | eps=eps, 112 | weight_decay=weight_decay, 113 | ) 114 | 115 | # linearly warmup for the first warmup_updates 116 | self.warmup_updates = warmup_updates 117 | self.warmup_init_lr = warmup_init_lr 118 | warmup_end_lr = lr 119 | self.lr_step = (warmup_end_lr - warmup_init_lr) / warmup_updates 120 | 121 | # then, decay prop. to the inverse square root of the update number 122 | self.exp_factor = exp_factor 123 | self.decay_factor = warmup_end_lr * warmup_updates ** self.exp_factor 124 | 125 | # total number of updates 126 | for param_group in self.param_groups: 127 | param_group['num_updates'] = 0 128 | 129 | def get_lr_for_step(self, num_updates): 130 | if num_updates < self.warmup_updates: 131 | return self.warmup_init_lr + num_updates * self.lr_step 132 | else: 133 | return self.decay_factor * (num_updates ** -self.exp_factor) 134 | 135 | def step(self, closure=None): 136 | super().step(closure) 137 | for param_group in self.param_groups: 138 | param_group['num_updates'] += 1 139 | param_group['lr'] = self.get_lr_for_step(param_group['num_updates']) 140 | 141 | 142 | class AdamCosineWithWarmup(Adam): 143 | """ 144 | Assign LR based on a cyclical schedule that follows the cosine function. 145 | See https://arxiv.org/pdf/1608.03983.pdf for details. 146 | We also support a warmup phase where we linearly increase the learning rate 147 | from some initial learning rate (``--warmup-init-lr``) until the configured 148 | learning rate (``--lr``). 149 | During warmup:: 150 | lrs = torch.linspace(args.warmup_init_lr, args.lr, args.warmup_updates) 151 | lr = lrs[update_num] 152 | After warmup:: 153 | lr = lr_min + 0.5*(lr_max - lr_min)*(1 + cos(t_curr / t_i)) 154 | where ``t_curr`` is current percentage of updates within the current period 155 | range and ``t_i`` is the current period range, which is scaled by ``t_mul`` 156 | after every iteration. 157 | """ 158 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, 159 | weight_decay=0, warmup_updates=4000, warmup_init_lr=1e-7, 160 | min_lr=1e-9, init_period=1000000, period_mult=1, lr_shrink=0.75): 161 | super().__init__( 162 | params, 163 | lr=warmup_init_lr, 164 | betas=betas, 165 | eps=eps, 166 | weight_decay=weight_decay, 167 | ) 168 | 169 | # linearly warmup for the first warmup_updates 170 | self.warmup_updates = warmup_updates 171 | self.warmup_init_lr = warmup_init_lr 172 | warmup_end_lr = lr 173 | self.lr_step = (warmup_end_lr - warmup_init_lr) / warmup_updates 174 | 175 | # then, apply cosine scheduler 176 | self.min_lr = min_lr 177 | self.max_lr = lr 178 | self.period = init_period 179 | self.period_mult = period_mult 180 | self.lr_shrink = lr_shrink 181 | 182 | # total number of updates 183 | for param_group in self.param_groups: 184 | param_group['num_updates'] = 0 185 | 186 | def get_lr_for_step(self, num_updates): 187 | if num_updates < self.warmup_updates: 188 | return self.warmup_init_lr + num_updates * self.lr_step 189 | else: 190 | t = num_updates - self.warmup_updates 191 | if self.period_mult == 1: 192 | pid = math.floor(t / self.period) 193 | t_i = self.period 194 | t_curr = t - (self.period * pid) 195 | else: 196 | pid = math.floor(math.log(1 - t / self.period * (1 - self.period_mult), self.period_mult)) 197 | t_i = self.period * (self.period_mult ** pid) 198 | t_curr = t - (1 - self.period_mult ** pid) / (1 - self.period_mult) * self.period 199 | lr_shrink = self.lr_shrink ** pid 200 | min_lr = self.min_lr * lr_shrink 201 | max_lr = self.max_lr * lr_shrink 202 | return min_lr + 0.5 * (max_lr - min_lr) * (1 + math.cos(math.pi * t_curr / t_i)) 203 | 204 | def step(self, closure=None): 205 | super().step(closure) 206 | for param_group in self.param_groups: 207 | param_group['num_updates'] += 1 208 | param_group['lr'] = self.get_lr_for_step(param_group['num_updates']) 209 | 210 | 211 | def get_optimizer(parameters, s): 212 | """ 213 | Parse optimizer parameters. 214 | Input should be of the form: 215 | - "sgd,lr=0.01" 216 | - "adagrad,lr=0.1,lr_decay=0.05" 217 | """ 218 | if "," in s: 219 | method = s[:s.find(',')] 220 | optim_params = {} 221 | for x in s[s.find(',') + 1:].split(','): 222 | split = x.split('=') 223 | assert len(split) == 2 224 | assert re.match("^[+-]?(\d+(\.\d*)?|\.\d+)$", split[1]) is not None 225 | optim_params[split[0]] = float(split[1]) 226 | else: 227 | method = s 228 | optim_params = {} 229 | 230 | if method == 'adadelta': 231 | optim_fn = optim.Adadelta 232 | elif method == 'adagrad': 233 | optim_fn = optim.Adagrad 234 | elif method == 'adam': 235 | optim_fn = Adam 236 | optim_params['betas'] = (optim_params.get('beta1', 0.9), optim_params.get('beta2', 0.999)) 237 | optim_params.pop('beta1', None) 238 | optim_params.pop('beta2', None) 239 | elif method == 'adam_inverse_sqrt': 240 | optim_fn = AdamInverseSqrtWithWarmup 241 | optim_params['betas'] = (optim_params.get('beta1', 0.9), optim_params.get('beta2', 0.999)) 242 | optim_params.pop('beta1', None) 243 | optim_params.pop('beta2', None) 244 | elif method == 'adam_cosine': 245 | optim_fn = AdamCosineWithWarmup 246 | optim_params['betas'] = (optim_params.get('beta1', 0.9), optim_params.get('beta2', 0.999)) 247 | optim_params.pop('beta1', None) 248 | optim_params.pop('beta2', None) 249 | elif method == 'adamax': 250 | optim_fn = optim.Adamax 251 | elif method == 'asgd': 252 | optim_fn = optim.ASGD 253 | elif method == 'rmsprop': 254 | optim_fn = optim.RMSprop 255 | elif method == 'rprop': 256 | optim_fn = optim.Rprop 257 | elif method == 'sgd': 258 | optim_fn = optim.SGD 259 | assert 'lr' in optim_params 260 | else: 261 | raise Exception('Unknown optimization method: "%s"' % method) 262 | 263 | # check that we give good parameters to the optimizer 264 | expected_args = inspect.getargspec(optim_fn.__init__)[0] 265 | assert expected_args[:2] == ['self', 'params'] 266 | if not all(k in expected_args[2:] for k in optim_params.keys()): 267 | raise Exception('Unexpected parameters: expected "%s", got "%s"' % ( 268 | str(expected_args[2:]), str(optim_params.keys()))) 269 | 270 | return optim_fn(parameters, **optim_params) 271 | -------------------------------------------------------------------------------- /src/slurm.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | from logging import getLogger 9 | import os 10 | import sys 11 | import torch 12 | import socket 13 | import signal 14 | import subprocess 15 | 16 | 17 | logger = getLogger() 18 | 19 | 20 | def sig_handler(signum, frame): 21 | logger.warning("Signal handler called with signal " + str(signum)) 22 | prod_id = int(os.environ['SLURM_PROCID']) 23 | logger.warning("Host: %s - Global rank: %i" % (socket.gethostname(), prod_id)) 24 | if prod_id == 0: 25 | logger.warning("Requeuing job " + os.environ['SLURM_JOB_ID']) 26 | os.system('scontrol requeue ' + os.environ['SLURM_JOB_ID']) 27 | else: 28 | logger.warning("Not the master process, no need to requeue.") 29 | sys.exit(-1) 30 | 31 | 32 | def term_handler(signum, frame): 33 | logger.warning("Signal handler called with signal " + str(signum)) 34 | logger.warning("Bypassing SIGTERM.") 35 | 36 | 37 | def init_signal_handler(): 38 | """ 39 | Handle signals sent by SLURM for time limit / pre-emption. 40 | """ 41 | signal.signal(signal.SIGUSR1, sig_handler) 42 | signal.signal(signal.SIGTERM, term_handler) 43 | logger.warning("Signal handler installed.") 44 | 45 | 46 | def init_distributed_mode(params): 47 | """ 48 | Handle single and multi-GPU / multi-node / SLURM jobs. 49 | Initialize the following variables: 50 | - n_nodes 51 | - node_id 52 | - local_rank 53 | - global_rank 54 | - world_size 55 | """ 56 | params.is_slurm_job = 'SLURM_JOB_ID' in os.environ and not params.debug_slurm 57 | print("SLURM job: %s" % str(params.is_slurm_job)) 58 | 59 | # SLURM job 60 | if params.is_slurm_job: 61 | 62 | assert params.local_rank == -1 # on the cluster, this is handled by SLURM 63 | 64 | SLURM_VARIABLES = [ 65 | 'SLURM_JOB_ID', 66 | 'SLURM_JOB_NODELIST', 'SLURM_JOB_NUM_NODES', 'SLURM_NTASKS', 'SLURM_TASKS_PER_NODE', 67 | 'SLURM_MEM_PER_NODE', 'SLURM_MEM_PER_CPU', 68 | 'SLURM_NODEID', 'SLURM_PROCID', 'SLURM_LOCALID', 'SLURM_TASK_PID' 69 | ] 70 | 71 | PREFIX = "%i - " % int(os.environ['SLURM_PROCID']) 72 | for name in SLURM_VARIABLES: 73 | value = os.environ.get(name, None) 74 | print(PREFIX + "%s: %s" % (name, str(value))) 75 | 76 | # # job ID 77 | # params.job_id = os.environ['SLURM_JOB_ID'] 78 | 79 | # number of nodes / node ID 80 | params.n_nodes = int(os.environ['SLURM_JOB_NUM_NODES']) 81 | params.node_id = int(os.environ['SLURM_NODEID']) 82 | 83 | # local rank on the current node / global rank 84 | params.local_rank = int(os.environ['SLURM_LOCALID']) 85 | params.global_rank = int(os.environ['SLURM_PROCID']) 86 | 87 | # number of processes / GPUs per node 88 | params.world_size = int(os.environ['SLURM_NTASKS']) 89 | params.n_gpu_per_node = params.world_size // params.n_nodes 90 | 91 | # define master address and master port 92 | hostnames = subprocess.check_output(['scontrol', 'show', 'hostnames', os.environ['SLURM_JOB_NODELIST']]) 93 | params.master_addr = hostnames.split()[0].decode('utf-8') 94 | assert 10001 <= params.master_port <= 20000 or params.world_size == 1 95 | print(PREFIX + "Master address: %s" % params.master_addr) 96 | print(PREFIX + "Master port : %i" % params.master_port) 97 | 98 | # set environment variables for 'env://' 99 | os.environ['MASTER_ADDR'] = params.master_addr 100 | os.environ['MASTER_PORT'] = str(params.master_port) 101 | os.environ['WORLD_SIZE'] = str(params.world_size) 102 | os.environ['RANK'] = str(params.global_rank) 103 | 104 | # multi-GPU job (local or multi-node) - jobs started with torch.distributed.launch 105 | elif params.local_rank != -1: 106 | 107 | assert params.master_port == -1 108 | 109 | # read environment variables 110 | params.global_rank = int(os.environ['RANK']) 111 | params.world_size = int(os.environ['WORLD_SIZE']) 112 | params.n_gpu_per_node = int(os.environ['NGPU']) 113 | 114 | # number of nodes / node ID 115 | params.n_nodes = params.world_size // params.n_gpu_per_node 116 | params.node_id = params.global_rank // params.n_gpu_per_node 117 | 118 | # local job (single GPU) 119 | else: 120 | assert params.local_rank == -1 121 | assert params.master_port == -1 122 | params.n_nodes = 1 123 | params.node_id = 0 124 | params.local_rank = 0 125 | params.global_rank = 0 126 | params.world_size = 1 127 | params.n_gpu_per_node = 1 128 | 129 | # sanity checks 130 | assert params.n_nodes >= 1 131 | assert 0 <= params.node_id < params.n_nodes 132 | assert 0 <= params.local_rank <= params.global_rank < params.world_size 133 | assert params.world_size == params.n_nodes * params.n_gpu_per_node 134 | 135 | # define whether this is the master process / if we are in distributed mode 136 | params.is_master = params.node_id == 0 and params.local_rank == 0 137 | params.multi_node = params.n_nodes > 1 138 | params.multi_gpu = params.world_size > 1 139 | 140 | # summary 141 | PREFIX = "%i - " % params.global_rank 142 | print(PREFIX + "Number of nodes: %i" % params.n_nodes) 143 | print(PREFIX + "Node ID : %i" % params.node_id) 144 | print(PREFIX + "Local rank : %i" % params.local_rank) 145 | print(PREFIX + "Global rank : %i" % params.global_rank) 146 | print(PREFIX + "World size : %i" % params.world_size) 147 | print(PREFIX + "GPUs per node : %i" % params.n_gpu_per_node) 148 | print(PREFIX + "Master : %s" % str(params.is_master)) 149 | print(PREFIX + "Multi-node : %s" % str(params.multi_node)) 150 | print(PREFIX + "Multi-GPU : %s" % str(params.multi_gpu)) 151 | print(PREFIX + "Hostname : %s" % socket.gethostname()) 152 | 153 | # set GPU device 154 | torch.cuda.set_device(params.local_rank) 155 | 156 | # initialize multi-GPU 157 | if params.multi_gpu: 158 | 159 | # http://pytorch.apachecn.org/en/0.3.0/distributed.html#environment-variable-initialization 160 | # 'env://' will read these environment variables: 161 | # MASTER_PORT - required; has to be a free port on machine with rank 0 162 | # MASTER_ADDR - required (except for rank 0); address of rank 0 node 163 | # WORLD_SIZE - required; can be set either here, or in a call to init function 164 | # RANK - required; can be set either here, or in a call to init function 165 | 166 | print("Initializing PyTorch distributed ...") 167 | torch.distributed.init_process_group( 168 | init_method='env://', 169 | backend='nccl', 170 | ) 171 | -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import os 9 | import re 10 | import sys 11 | import pickle 12 | import random 13 | import getpass 14 | import argparse 15 | import subprocess 16 | import numpy as np 17 | import torch 18 | 19 | from .logger import create_logger 20 | 21 | 22 | FALSY_STRINGS = {'off', 'false', '0'} 23 | TRUTHY_STRINGS = {'on', 'true', '1'} 24 | 25 | DUMP_PATH = '/checkpoint/%s/dumped' % getpass.getuser() 26 | DYNAMIC_COEFF = ['lambda_clm', 'lambda_mlm', 'lambda_pc', 'lambda_ae', 'lambda_mt', 'lambda_bt'] 27 | 28 | 29 | class AttrDict(dict): 30 | def __init__(self, *args, **kwargs): 31 | super(AttrDict, self).__init__(*args, **kwargs) 32 | self.__dict__ = self 33 | 34 | 35 | def bool_flag(s): 36 | """ 37 | Parse boolean arguments from the command line. 38 | """ 39 | if s.lower() in FALSY_STRINGS: 40 | return False 41 | elif s.lower() in TRUTHY_STRINGS: 42 | return True 43 | else: 44 | raise argparse.ArgumentTypeError("Invalid value for a boolean flag!") 45 | 46 | 47 | def initialize_exp(params): 48 | """ 49 | Initialize the experience: 50 | - dump parameters 51 | - create a logger 52 | """ 53 | # dump parameters 54 | get_dump_path(params) 55 | pickle.dump(params, open(os.path.join(params.dump_path, 'params.pkl'), 'wb')) 56 | 57 | # get running command 58 | command = ["python", sys.argv[0]] 59 | for x in sys.argv[1:]: 60 | if x.startswith('--'): 61 | assert '"' not in x and "'" not in x 62 | command.append(x) 63 | else: 64 | assert "'" not in x 65 | if re.match('^[a-zA-Z0-9_]+$', x): 66 | command.append("%s" % x) 67 | else: 68 | command.append("'%s'" % x) 69 | command = ' '.join(command) 70 | params.command = command + ' --exp_id "%s"' % params.exp_id 71 | 72 | # check experiment name 73 | assert len(params.exp_name.strip()) > 0 74 | 75 | # create a logger 76 | logger = create_logger(os.path.join(params.dump_path, 'train.log'), rank=getattr(params, 'global_rank', 0)) 77 | logger.info("============ Initialized logger ============") 78 | logger.info("\n".join("%s: %s" % (k, str(v)) 79 | for k, v in sorted(dict(vars(params)).items()))) 80 | logger.info("The experiment will be stored in %s\n" % params.dump_path) 81 | logger.info("Running command: %s" % command) 82 | logger.info("") 83 | return logger 84 | 85 | 86 | def get_dump_path(params): 87 | """ 88 | Create a directory to store the experiment. 89 | """ 90 | dump_path = DUMP_PATH if params.dump_path == '' else params.dump_path 91 | assert len(params.exp_name) > 0 92 | 93 | # create the sweep path if it does not exist 94 | sweep_path = os.path.join(dump_path, params.exp_name) 95 | if not os.path.exists(sweep_path): 96 | subprocess.Popen("mkdir -p %s" % sweep_path, shell=True).wait() 97 | 98 | # create an ID for the job if it is not given in the parameters. 99 | # if we run on the cluster, the job ID is the one of Chronos. 100 | # otherwise, it is randomly generated 101 | if params.exp_id == '': 102 | chronos_job_id = os.environ.get('CHRONOS_JOB_ID') 103 | slurm_job_id = os.environ.get('SLURM_JOB_ID') 104 | assert chronos_job_id is None or slurm_job_id is None 105 | exp_id = chronos_job_id if chronos_job_id is not None else slurm_job_id 106 | if exp_id is None: 107 | chars = 'abcdefghijklmnopqrstuvwxyz0123456789' 108 | while True: 109 | exp_id = ''.join(random.choice(chars) for _ in range(10)) 110 | if not os.path.isdir(os.path.join(sweep_path, exp_id)): 111 | break 112 | else: 113 | assert exp_id.isdigit() 114 | params.exp_id = exp_id 115 | 116 | # create the dump folder / update parameters 117 | params.dump_path = os.path.join(sweep_path, params.exp_id) 118 | if not os.path.isdir(params.dump_path): 119 | subprocess.Popen("mkdir -p %s" % params.dump_path, shell=True).wait() 120 | 121 | 122 | def to_cuda(*args): 123 | """ 124 | Move tensors to CUDA. 125 | """ 126 | return [None if x is None else x.cuda() for x in args] 127 | 128 | 129 | def restore_segmentation(path, bpe_type='fastBPE'): 130 | """ 131 | Take a file segmented with BPE and restore it to its original segmentation. 132 | """ 133 | assert os.path.isfile(path) 134 | if bpe_type == 'fastBPE': 135 | restore_cmd = "sed -i -r 's/(@@ )|(@@ ?$)//g' %s" 136 | elif bpe_type == 'sentencepiece': 137 | restore_cmd = u"sed -i -e 's/ //g' -e 's/^\u2581//g' -e 's/\u2581/ /g' %s" 138 | else: 139 | raise NotImplementedError 140 | subprocess.Popen(restore_cmd % path, shell=True).wait() 141 | 142 | 143 | def parse_lambda_config(params): 144 | """ 145 | Parse the configuration of lambda coefficient (for scheduling). 146 | x = "3" # lambda will be a constant equal to x 147 | x = "0:1,1000:0" # lambda will start from 1 and linearly decrease to 0 during the first 1000 iterations 148 | x = "0:0,1000:0,2000:1" # lambda will be equal to 0 for the first 1000 iterations, then will linearly increase to 1 until iteration 2000 149 | """ 150 | for name in DYNAMIC_COEFF: 151 | x = getattr(params, name) 152 | split = x.split(',') 153 | if len(split) == 1: 154 | setattr(params, name, float(x)) 155 | setattr(params, name + '_config', None) 156 | else: 157 | split = [s.split(':') for s in split] 158 | assert all(len(s) == 2 for s in split) 159 | assert all(k.isdigit() for k, _ in split) 160 | assert all(int(split[i][0]) < int(split[i + 1][0]) for i in range(len(split) - 1)) 161 | setattr(params, name, float(split[0][1])) 162 | setattr(params, name + '_config', [(int(k), float(v)) for k, v in split]) 163 | 164 | 165 | def get_lambda_value(config, n_iter): 166 | """ 167 | Compute a lambda value according to its schedule configuration. 168 | """ 169 | ranges = [i for i in range(len(config) - 1) if config[i][0] <= n_iter < config[i + 1][0]] 170 | if len(ranges) == 0: 171 | assert n_iter >= config[-1][0] 172 | return config[-1][1] 173 | assert len(ranges) == 1 174 | i = ranges[0] 175 | x_a, y_a = config[i] 176 | x_b, y_b = config[i + 1] 177 | return y_a + (n_iter - x_a) * float(y_b - y_a) / float(x_b - x_a) 178 | 179 | 180 | def update_lambdas(params, n_iter): 181 | """ 182 | Update all lambda coefficients. 183 | """ 184 | for name in DYNAMIC_COEFF: 185 | config = getattr(params, name + '_config') 186 | if config is not None: 187 | setattr(params, name, get_lambda_value(config, n_iter)) 188 | 189 | 190 | def set_sampling_probs(data, params): 191 | """ 192 | Set the probability of sampling specific languages / language pairs during training. 193 | """ 194 | coeff = params.lg_sampling_factor 195 | if coeff == -1: 196 | return 197 | assert coeff > 0 198 | 199 | # monolingual data 200 | params.mono_list = [k for k, v in data['mono_stream'].items() if 'train' in v] 201 | if len(params.mono_list) > 0: 202 | probs = np.array([1.0 * len(data['mono_stream'][lang]['train']) for lang in params.mono_list]) 203 | probs /= probs.sum() 204 | probs = np.array([p ** coeff for p in probs]) 205 | probs /= probs.sum() 206 | params.mono_probs = probs 207 | 208 | # parallel data 209 | params.para_list = [k for k, v in data['para'].items() if 'train' in v] 210 | if len(params.para_list) > 0: 211 | probs = np.array([1.0 * len(data['para'][(l1, l2)]['train']) for (l1, l2) in params.para_list]) 212 | probs /= probs.sum() 213 | probs = np.array([p ** coeff for p in probs]) 214 | probs /= probs.sum() 215 | params.para_probs = probs 216 | 217 | 218 | def concat_batches(x1, len1, lang1_id, x2, len2, lang2_id, pad_idx, eos_idx, reset_positions): 219 | """ 220 | Concat batches with different languages. 221 | """ 222 | assert reset_positions is False or lang1_id != lang2_id 223 | lengths = len1 + len2 224 | if not reset_positions: 225 | lengths -= 1 226 | slen, bs = lengths.max().item(), lengths.size(0) 227 | 228 | x = x1.new(slen, bs).fill_(pad_idx) 229 | x[:len1.max().item()].copy_(x1) 230 | positions = torch.arange(slen)[:, None].repeat(1, bs).to(x1.device) 231 | langs = x1.new(slen, bs).fill_(lang1_id) 232 | 233 | for i in range(bs): 234 | l1 = len1[i] if reset_positions else len1[i] - 1 235 | x[l1:l1 + len2[i], i].copy_(x2[:len2[i], i]) 236 | if reset_positions: 237 | positions[l1:, i] -= len1[i] 238 | langs[l1:, i] = lang2_id 239 | 240 | assert (x == eos_idx).long().sum().item() == (4 if reset_positions else 3) * bs 241 | 242 | return x, lengths, positions, langs 243 | 244 | 245 | def truncate(x, lengths, max_len, eos_index): 246 | """ 247 | Truncate long sentences. 248 | """ 249 | if lengths.max().item() > max_len: 250 | x = x[:max_len].clone() 251 | lengths = lengths.clone() 252 | for i in range(len(lengths)): 253 | if lengths[i] > max_len: 254 | lengths[i] = max_len 255 | x[max_len - 1, i] = eos_index 256 | return x, lengths 257 | 258 | 259 | def shuf_order(langs, params=None, n=5): 260 | """ 261 | Randomize training order. 262 | """ 263 | if len(langs) == 0: 264 | return [] 265 | 266 | if params is None: 267 | return [langs[i] for i in np.random.permutation(len(langs))] 268 | 269 | # sample monolingual and parallel languages separately 270 | mono = [l1 for l1, l2 in langs if l2 is None] 271 | para = [(l1, l2) for l1, l2 in langs if l2 is not None] 272 | 273 | # uniform / weighted sampling 274 | if params.lg_sampling_factor == -1: 275 | p_mono = None 276 | p_para = None 277 | else: 278 | p_mono = np.array([params.mono_probs[params.mono_list.index(k)] for k in mono]) 279 | p_para = np.array([params.para_probs[params.para_list.index(tuple(sorted(k)))] for k in para]) 280 | p_mono = p_mono / p_mono.sum() 281 | p_para = p_para / p_para.sum() 282 | 283 | s_mono = [mono[i] for i in np.random.choice(len(mono), size=min(n, len(mono)), p=p_mono, replace=True)] if len(mono) > 0 else [] 284 | s_para = [para[i] for i in np.random.choice(len(para), size=min(n, len(para)), p=p_para, replace=True)] if len(para) > 0 else [] 285 | 286 | assert len(s_mono) + len(s_para) > 0 287 | return [(lang, None) for lang in s_mono] + s_para 288 | 289 | 290 | def find_modules(module, module_name, module_instance, found): 291 | """ 292 | Recursively find all instances of a specific module inside a module. 293 | """ 294 | if isinstance(module, module_instance): 295 | found.append((module_name, module)) 296 | else: 297 | for name, child in module.named_children(): 298 | name = ('%s[%s]' if name.isdigit() else '%s.%s') % (module_name, name) 299 | find_modules(child, name, module_instance, found) 300 | -------------------------------------------------------------------------------- /tools/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jind11/DAMT/3caa22822b936137e5da3e827d7a5a2078c3115e/tools/.DS_Store -------------------------------------------------------------------------------- /tools/README.md: -------------------------------------------------------------------------------- 1 | # Tools 2 | 3 | In `XLM/tools/`, you will need to install the following tools: 4 | 5 | ## Tokenizers 6 | 7 | [Moses](https://github.com/moses-smt/mosesdecoder/tree/master/scripts/tokenizer) tokenizer: 8 | ``` 9 | git clone https://github.com/moses-smt/mosesdecoder 10 | ``` 11 | 12 | Thai [PythaiNLP](https://github.com/PyThaiNLP/pythainlp) tokenizer: 13 | ``` 14 | pip install pythainlp 15 | ``` 16 | 17 | Japanese [KyTea](http://www.phontron.com/kytea) tokenizer: 18 | ``` 19 | wget http://www.phontron.com/kytea/download/kytea-0.4.7.tar.gz 20 | tar -xzf kytea-0.4.7.tar.gz 21 | cd kytea-0.4.7 22 | ./configure 23 | make 24 | make install 25 | kytea --help 26 | ``` 27 | 28 | Chinese Stanford segmenter: 29 | ``` 30 | wget https://nlp.stanford.edu/software/stanford-segmenter-2018-10-16.zip 31 | unzip stanford-segmenter-2018-10-16.zip 32 | ``` 33 | 34 | ## fastBPE 35 | 36 | ``` 37 | git clone https://github.com/glample/fastBPE 38 | cd fastBPE 39 | g++ -std=c++11 -pthread -O3 fastBPE/main.cc -IfastBPE -o fast 40 | ``` 41 | -------------------------------------------------------------------------------- /tools/lowercase_and_remove_accent.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import sys 9 | import unicodedata 10 | import six 11 | 12 | 13 | def convert_to_unicode(text): 14 | """ 15 | Converts `text` to Unicode (if it's not already), assuming UTF-8 input. 16 | """ 17 | # six_ensure_text is copied from https://github.com/benjaminp/six 18 | def six_ensure_text(s, encoding='utf-8', errors='strict'): 19 | if isinstance(s, six.binary_type): 20 | return s.decode(encoding, errors) 21 | elif isinstance(s, six.text_type): 22 | return s 23 | else: 24 | raise TypeError("not expecting type '%s'" % type(s)) 25 | 26 | return six_ensure_text(text, encoding="utf-8", errors="ignore") 27 | 28 | 29 | def run_strip_accents(text): 30 | """ 31 | Strips accents from a piece of text. 32 | """ 33 | text = unicodedata.normalize("NFD", text) 34 | output = [] 35 | for char in text: 36 | cat = unicodedata.category(char) 37 | if cat == "Mn": 38 | continue 39 | output.append(char) 40 | return "".join(output) 41 | 42 | 43 | for line in sys.stdin: 44 | line = convert_to_unicode(line.rstrip().lower()) 45 | line = run_strip_accents(line) 46 | print(u'%s' % line.lower()) 47 | -------------------------------------------------------------------------------- /tools/segment_th.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import sys 9 | from pythainlp.tokenize import word_tokenize 10 | 11 | for line in sys.stdin.readlines(): 12 | line = line.rstrip('\n') 13 | print(' '.join(word_tokenize(line))) 14 | -------------------------------------------------------------------------------- /tools/tokenize.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | # Tokenize text data in various languages 9 | # Usage: e.g. cat wiki.ar | tokenize.sh ar 10 | 11 | set -e 12 | 13 | N_THREADS=8 14 | 15 | lg=$1 16 | TOOLS_PATH=$PWD/tools 17 | 18 | # moses 19 | MOSES=$TOOLS_PATH/mosesdecoder 20 | REPLACE_UNICODE_PUNCT=$MOSES/scripts/tokenizer/replace-unicode-punctuation.perl 21 | NORM_PUNC=$MOSES/scripts/tokenizer/normalize-punctuation.perl 22 | REM_NON_PRINT_CHAR=$MOSES/scripts/tokenizer/remove-non-printing-char.perl 23 | TOKENIZER=$MOSES/scripts/tokenizer/tokenizer.perl 24 | 25 | # Chinese 26 | if [ "$lg" = "zh" ]; then 27 | $TOOLS_PATH/stanford-segmenter-*/segment.sh pku /dev/stdin UTF-8 0 | $REPLACE_UNICODE_PUNCT | $NORM_PUNC -l $lg | $REM_NON_PRINT_CHAR 28 | # Thai 29 | elif [ "$lg" = "th" ]; then 30 | cat - | $REPLACE_UNICODE_PUNCT | $NORM_PUNC -l $lg | $REM_NON_PRINT_CHAR | python $TOOLS_PATH/segment_th.py 31 | # Japanese 32 | elif [ "$lg" = "ja" ]; then 33 | cat - | $REPLACE_UNICODE_PUNCT | $NORM_PUNC -l $lg | $REM_NON_PRINT_CHAR | kytea -notags 34 | # other languages 35 | else 36 | cat - | $REPLACE_UNICODE_PUNCT | $NORM_PUNC -l $lg | $REM_NON_PRINT_CHAR | $TOKENIZER -no-escape -threads $N_THREADS -l $lg 37 | fi 38 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import json 9 | import random 10 | import argparse 11 | 12 | from src.slurm import init_signal_handler, init_distributed_mode 13 | from src.data.loader import check_data_params, load_data 14 | from src.utils import bool_flag, initialize_exp, set_sampling_probs, shuf_order 15 | from src.model import check_model_params, build_model 16 | from src.model.memory import HashingMemory 17 | from src.trainer import SingleTrainer, EncDecTrainer 18 | from src.evaluation.evaluator import SingleEvaluator, EncDecEvaluator 19 | 20 | 21 | def get_parser(): 22 | """ 23 | Generate a parameters parser. 24 | """ 25 | # parse parameters 26 | parser = argparse.ArgumentParser(description="Language transfer") 27 | 28 | # main parameters 29 | parser.add_argument("--dump_path", type=str, default="./dumped/", 30 | help="Experiment dump path") 31 | parser.add_argument("--exp_name", type=str, default="", 32 | help="Experiment name") 33 | parser.add_argument("--save_periodic", type=int, default=0, 34 | help="Save the model periodically (0 to disable)") 35 | parser.add_argument("--exp_id", type=str, default="", 36 | help="Experiment ID") 37 | 38 | # float16 / AMP API 39 | parser.add_argument("--fp16", type=bool_flag, default=False, 40 | help="Run model with float16") 41 | parser.add_argument("--amp", type=int, default=-1, 42 | help="Use AMP wrapper for float16 / distributed / gradient accumulation. Level of optimization. -1 to disable.") 43 | 44 | # only use an encoder (use a specific decoder for machine translation) 45 | parser.add_argument("--encoder_only", type=bool_flag, default=True, 46 | help="Only use an encoder") 47 | 48 | # model parameters 49 | parser.add_argument("--emb_dim", type=int, default=512, 50 | help="Embedding layer size") 51 | parser.add_argument("--n_layers", type=int, default=4, 52 | help="Number of Transformer layers") 53 | parser.add_argument("--n_heads", type=int, default=8, 54 | help="Number of Transformer heads") 55 | parser.add_argument("--dropout", type=float, default=0, 56 | help="Dropout") 57 | parser.add_argument("--attention_dropout", type=float, default=0, 58 | help="Dropout in the attention layer") 59 | parser.add_argument("--gelu_activation", type=bool_flag, default=False, 60 | help="Use a GELU activation instead of ReLU") 61 | parser.add_argument("--share_inout_emb", type=bool_flag, default=True, 62 | help="Share input and output embeddings") 63 | parser.add_argument("--sinusoidal_embeddings", type=bool_flag, default=False, 64 | help="Use sinusoidal embeddings") 65 | parser.add_argument("--use_lang_emb", type=bool_flag, default=True, 66 | help="Use language embedding") 67 | 68 | # memory parameters 69 | parser.add_argument("--use_memory", type=bool_flag, default=False, 70 | help="Use an external memory") 71 | if parser.parse_known_args()[0].use_memory: 72 | HashingMemory.register_args(parser) 73 | parser.add_argument("--mem_enc_positions", type=str, default="", 74 | help="Memory positions in the encoder ('4' for inside layer 4, '7,10+' for inside layer 7 and after layer 10)") 75 | parser.add_argument("--mem_dec_positions", type=str, default="", 76 | help="Memory positions in the decoder. Same syntax as `mem_enc_positions`.") 77 | 78 | # adaptive softmax 79 | parser.add_argument("--asm", type=bool_flag, default=False, 80 | help="Use adaptive softmax") 81 | if parser.parse_known_args()[0].asm: 82 | parser.add_argument("--asm_cutoffs", type=str, default="8000,20000", 83 | help="Adaptive softmax cutoffs") 84 | parser.add_argument("--asm_div_value", type=float, default=4, 85 | help="Adaptive softmax cluster sizes ratio") 86 | 87 | # causal language modeling task parameters 88 | parser.add_argument("--context_size", type=int, default=0, 89 | help="Context size (0 means that the first elements in sequences won't have any context)") 90 | 91 | # masked language modeling task parameters 92 | parser.add_argument("--word_pred", type=float, default=0.15, 93 | help="Fraction of words for which we need to make a prediction") 94 | parser.add_argument("--sample_alpha", type=float, default=0, 95 | help="Exponent for transforming word counts to probabilities (~word2vec sampling)") 96 | parser.add_argument("--word_mask_keep_rand", type=str, default="0.8,0.1,0.1", 97 | help="Fraction of words to mask out / keep / randomize, among the words to predict") 98 | 99 | # input sentence noise 100 | parser.add_argument("--word_shuffle", type=float, default=0, 101 | help="Randomly shuffle input words (0 to disable)") 102 | parser.add_argument("--word_dropout", type=float, default=0, 103 | help="Randomly dropout input words (0 to disable)") 104 | parser.add_argument("--word_blank", type=float, default=0, 105 | help="Randomly blank input words (0 to disable)") 106 | 107 | # data 108 | parser.add_argument("--data_path", type=str, default="", 109 | help="Data path") 110 | parser.add_argument("--para_data_path", type=str, default="", 111 | help="Parallel Data path") 112 | parser.add_argument("--lgs", type=str, default="", 113 | help="Languages (lg1-lg2-lg3 .. ex: en-fr-es-de)") 114 | parser.add_argument("--max_vocab", type=int, default=-1, 115 | help="Maximum vocabulary size (-1 to disable)") 116 | parser.add_argument("--min_count", type=int, default=0, 117 | help="Minimum vocabulary count") 118 | parser.add_argument("--lg_sampling_factor", type=float, default=-1, 119 | help="Language sampling factor") 120 | 121 | # batch parameters 122 | parser.add_argument("--bptt", type=int, default=256, 123 | help="Sequence length") 124 | parser.add_argument("--max_len", type=int, default=100, 125 | help="Maximum length of sentences (after BPE)") 126 | parser.add_argument("--group_by_size", type=bool_flag, default=True, 127 | help="Sort sentences by size during the training") 128 | parser.add_argument("--batch_size", type=int, default=32, 129 | help="Number of sentences per batch") 130 | parser.add_argument("--max_batch_size", type=int, default=0, 131 | help="Maximum number of sentences per batch (used in combination with tokens_per_batch, 0 to disable)") 132 | parser.add_argument("--tokens_per_batch", type=int, default=-1, 133 | help="Number of tokens per batch") 134 | 135 | # training parameters 136 | parser.add_argument("--split_data", type=bool_flag, default=False, 137 | help="Split data across workers of a same node") 138 | parser.add_argument("--optimizer", type=str, default="adam,lr=0.0001", 139 | help="Optimizer (SGD / RMSprop / Adam, etc.)") 140 | parser.add_argument("--clip_grad_norm", type=float, default=5, 141 | help="Clip gradients norm (0 to disable)") 142 | parser.add_argument("--epoch_size", type=int, default=100000, 143 | help="Epoch size / evaluation frequency (-1 for parallel data size)") 144 | parser.add_argument("--max_epoch", type=int, default=100000, 145 | help="Maximum epoch size") 146 | parser.add_argument("--stopping_criterion", type=str, default="", 147 | help="Stopping criterion, and number of non-increase before stopping the experiment") 148 | parser.add_argument("--validation_metrics", type=str, default="", 149 | help="Validation metrics") 150 | parser.add_argument("--accumulate_gradients", type=int, default=1, 151 | help="Accumulate model gradients over N iterations (N times larger batch sizes)") 152 | 153 | # training coefficients 154 | parser.add_argument("--lambda_mlm", type=str, default="1", 155 | help="Prediction coefficient (MLM)") 156 | parser.add_argument("--lambda_clm", type=str, default="1", 157 | help="Causal coefficient (LM)") 158 | parser.add_argument("--lambda_pc", type=str, default="1", 159 | help="PC coefficient") 160 | parser.add_argument("--lambda_ae", type=str, default="1", 161 | help="AE coefficient") 162 | parser.add_argument("--lambda_mt", type=str, default="1", 163 | help="MT coefficient") 164 | parser.add_argument("--lambda_bt", type=str, default="1", 165 | help="BT coefficient") 166 | 167 | # training steps 168 | parser.add_argument("--clm_steps", type=str, default="", 169 | help="Causal prediction steps (CLM)") 170 | parser.add_argument("--mlm_steps", type=str, default="", 171 | help="Masked prediction steps (MLM / TLM)") 172 | parser.add_argument("--mt_steps", type=str, default="", 173 | help="Machine translation steps") 174 | parser.add_argument("--ae_steps", type=str, default="", 175 | help="Denoising auto-encoder steps") 176 | parser.add_argument("--bt_steps", type=str, default="", 177 | help="Back-translation steps") 178 | parser.add_argument("--pc_steps", type=str, default="", 179 | help="Parallel classification steps") 180 | parser.add_argument("--delay_umt_epoch_num", type=int, default=0, 181 | help="The number of epochs to delay the umt steps") 182 | 183 | # reload pretrained embeddings / pretrained model / checkpoint 184 | parser.add_argument("--reload_emb", type=str, default="", 185 | help="Reload pretrained word embeddings") 186 | parser.add_argument("--reload_model", type=str, default="", 187 | help="Reload a pretrained model") 188 | parser.add_argument("--reload_checkpoint", type=str, default="", 189 | help="Reload a checkpoint") 190 | 191 | # beam search (for MT only) 192 | parser.add_argument("--beam_size", type=int, default=1, 193 | help="Beam size, default = 1 (greedy decoding)") 194 | parser.add_argument("--length_penalty", type=float, default=1, 195 | help="Length penalty, values < 1.0 favor shorter sentences, while values > 1.0 favor longer ones.") 196 | parser.add_argument("--early_stopping", type=bool_flag, default=False, 197 | help="Early stopping, stop as soon as we have `beam_size` hypotheses, although longer ones may have better scores.") 198 | 199 | # evaluation 200 | parser.add_argument("--eval_bleu", type=bool_flag, default=False, 201 | help="Evaluate BLEU score during MT training") 202 | parser.add_argument("--eval_only", type=bool_flag, default=False, 203 | help="Only run evaluations") 204 | parser.add_argument("--bpe_type", type=str, default='fastBPE', 205 | help="Approach to implement BPE such as: fastBPE, sentencepiece") 206 | 207 | # debug 208 | parser.add_argument("--debug_train", type=bool_flag, default=False, 209 | help="Use valid sets for train sets (faster loading)") 210 | parser.add_argument("--debug_slurm", type=bool_flag, default=False, 211 | help="Debug multi-GPU / multi-node within a SLURM job") 212 | parser.add_argument("--debug", help="Enable all debug flags", 213 | action="store_true") 214 | 215 | # multi-gpu / multi-node 216 | parser.add_argument("--local_rank", type=int, default=-1, 217 | help="Multi-GPU - Local rank") 218 | parser.add_argument("--master_port", type=int, default=-1, 219 | help="Master port (for multi-node SLURM jobs)") 220 | 221 | return parser 222 | 223 | 224 | def main(params): 225 | 226 | # initialize the multi-GPU / multi-node training 227 | init_distributed_mode(params) 228 | 229 | # initialize the experiment 230 | logger = initialize_exp(params) 231 | 232 | # initialize SLURM signal handler for time limit / pre-emption 233 | init_signal_handler() 234 | 235 | # load data 236 | data = load_data(params) 237 | 238 | # build model 239 | if params.encoder_only: 240 | model = build_model(params, data['dico']) 241 | else: 242 | encoder, decoder = build_model(params, data['dico']) 243 | 244 | # build trainer, reload potential checkpoints / build evaluator 245 | if params.encoder_only: 246 | trainer = SingleTrainer(model, data, params) 247 | evaluator = SingleEvaluator(trainer, data, params) 248 | else: 249 | trainer = EncDecTrainer(encoder, decoder, data, params) 250 | evaluator = EncDecEvaluator(trainer, data, params) 251 | 252 | # evaluation 253 | if params.eval_only: 254 | scores = evaluator.run_all_evals(trainer) 255 | for k, v in scores.items(): 256 | logger.info("%s -> %.6f" % (k, v)) 257 | logger.info("__log__:%s" % json.dumps(scores)) 258 | exit() 259 | 260 | # set sampling probabilities for training 261 | set_sampling_probs(data, params) 262 | 263 | # language model training 264 | for epoch in range(params.max_epoch): 265 | 266 | logger.info("============ Starting epoch %i ... ============" % trainer.epoch) 267 | 268 | trainer.n_sentences = 0 269 | 270 | while trainer.n_sentences < trainer.epoch_size: 271 | 272 | # CLM steps 273 | if epoch >= params.delay_umt_epoch_num: 274 | for lang1, lang2 in shuf_order(params.clm_steps, params): 275 | trainer.clm_step(lang1, lang2, params.lambda_clm) 276 | 277 | # MLM steps (also includes TLM if lang2 is not None) 278 | if epoch >= params.delay_umt_epoch_num: 279 | for lang1, lang2 in shuf_order(params.mlm_steps, params): 280 | trainer.mlm_step(lang1, lang2, params.lambda_mlm) 281 | 282 | # parallel classification steps 283 | for lang1, lang2 in shuf_order(params.pc_steps, params): 284 | trainer.pc_step(lang1, lang2, params.lambda_pc) 285 | 286 | # denoising auto-encoder steps 287 | if epoch >= params.delay_umt_epoch_num: 288 | for lang in shuf_order(params.ae_steps): 289 | trainer.mt_step(lang, lang, params.lambda_ae) 290 | 291 | # machine translation steps 292 | for lang1, lang2 in shuf_order(params.mt_steps, params): 293 | trainer.mt_step(lang1, lang2, params.lambda_mt) 294 | 295 | # back-translation steps 296 | if epoch >= params.delay_umt_epoch_num: 297 | for lang1, lang2, lang3 in shuf_order(params.bt_steps): 298 | trainer.bt_step(lang1, lang2, lang3, params.lambda_bt) 299 | 300 | trainer.iter() 301 | 302 | logger.info("============ End of epoch %i ============" % trainer.epoch) 303 | 304 | # evaluate perplexity 305 | scores = evaluator.run_all_evals(trainer) 306 | 307 | # print / JSON log 308 | for k, v in scores.items(): 309 | logger.info("%s -> %.6f" % (k, v)) 310 | if params.is_master: 311 | logger.info("__log__:%s" % json.dumps(scores)) 312 | 313 | # end of epoch 314 | trainer.save_best_model(scores) 315 | trainer.save_periodic() 316 | trainer.end_epoch(scores) 317 | 318 | 319 | if __name__ == '__main__': 320 | 321 | # generate parser / parse parameters 322 | parser = get_parser() 323 | params = parser.parse_args() 324 | 325 | # debug mode 326 | if params.debug: 327 | params.exp_name = 'debug' 328 | params.exp_id = 'debug_%08i' % random.randint(0, 100000000) 329 | params.debug_slurm = True 330 | params.debug_train = True 331 | 332 | # check parameters 333 | check_data_params(params) 334 | check_model_params(params) 335 | 336 | # run experiment 337 | main(params) 338 | -------------------------------------------------------------------------------- /train_IBT.sh: -------------------------------------------------------------------------------- 1 | # 2 | # Read arguments 3 | # 4 | POSITIONAL=() 5 | while [[ $# -gt 0 ]] 6 | do 7 | key="$1" 8 | case $key in 9 | --src) 10 | SRC="$2"; shift 2;; 11 | --tgt) 12 | TGT="$2"; shift 2;; 13 | --data_name) 14 | DATA_NAME="$2"; shift 2;; 15 | --pretrained_model_dir) 16 | PRETRAINED_MODEL_DIR="$2"; shift 2;; 17 | *) 18 | POSITIONAL+=("$1") 19 | shift 20 | ;; 21 | esac 22 | done 23 | set -- "${POSITIONAL[@]}" 24 | 25 | if [ "$SRC" != 'en' ]; then 26 | OTHER_LANG=$SRC 27 | else 28 | OTHER_LANG=$TGT 29 | fi 30 | echo $OTHER_LANG 31 | 32 | if [ "$SRC" \< "$TGT" ]; then 33 | ORDERED_SRC=$SRC 34 | ORDERED_TGT=$TGT 35 | else 36 | ORDERED_SRC=$TGT 37 | ORDERED_TGT=$SRC 38 | fi 39 | 40 | 41 | epoch_size=$(cat data/$ORDERED_SRC-$ORDERED_TGT/$DATA_NAME/processed/$ORDERED_SRC-$ORDERED_TGT/train.$SRC | wc -l) 42 | max_epoch_size=300000 43 | epoch_size=$((epoch_size>max_epoch_size ? max_epoch_size : epoch_size)) 44 | echo $epoch_size 45 | 46 | python -W ignore train.py \ 47 | --exp_name ibt_$DATA_NAME\_$SRC\_$TGT \ 48 | --dump_path ./tmp/ \ 49 | --reload_model ${PRETRAINED_MODEL_DIR}/mlm_en${OTHER_LANG}_1024.pth,${PRETRAINED_MODEL_DIR}/mlm_en${OTHER_LANG}_1024.pth \ 50 | --data_path data/$ORDERED_SRC-$ORDERED_TGT/$DATA_NAME/processed/$ORDERED_SRC-$ORDERED_TGT \ 51 | --lgs $SRC-$TGT \ 52 | --ae_steps $SRC,$TGT \ 53 | --bt_steps $SRC-$TGT-$SRC,$TGT-$SRC-$TGT \ 54 | --word_shuffle 3 \ 55 | --word_dropout 0.1 \ 56 | --word_blank 0.1 \ 57 | --lambda_ae '0:1,100000:0.1,300000:0' \ 58 | --encoder_only false \ 59 | --emb_dim 1024 \ 60 | --n_layers 6 \ 61 | --n_heads 8 \ 62 | --dropout 0.1 \ 63 | --attention_dropout 0.1 \ 64 | --gelu_activation true \ 65 | --tokens_per_batch 1000 \ 66 | --batch_size 32 \ 67 | --bptt 256 \ 68 | --optimizer adam_inverse_sqrt,beta1=0.9,beta2=0.98,lr=0.0001 \ 69 | --epoch_size $epoch_size \ 70 | --eval_bleu true \ 71 | --stopping_criterion valid_$SRC-$TGT\_mt_bleu,3 \ 72 | --validation_metrics valid_$SRC-$TGT\_mt_bleu \ 73 | --max_epoch 100 \ 74 | --max_len 150 \ -------------------------------------------------------------------------------- /train_IBT_plus_BACK.sh: -------------------------------------------------------------------------------- 1 | # 2 | # Read arguments 3 | # 4 | POSITIONAL=() 5 | while [[ $# -gt 0 ]] 6 | do 7 | key="$1" 8 | case $key in 9 | --src) 10 | SRC="$2"; shift 2;; 11 | --tgt) 12 | TGT="$2"; shift 2;; 13 | --src_data_name) 14 | SRC_DATA_NAME="$2"; shift 2;; 15 | --tgt_data_name) 16 | TGT_DATA_NAME="$2"; shift 2;; 17 | --pretrained_model_dir) 18 | PRETRAINED_MODEL_DIR="$2"; shift 2;; 19 | *) 20 | POSITIONAL+=("$1") 21 | shift 22 | ;; 23 | esac 24 | done 25 | set -- "${POSITIONAL[@]}" 26 | 27 | if [ "$SRC" != 'en' ]; then 28 | OTHER_LANG=$SRC 29 | else 30 | OTHER_LANG=$TGT 31 | fi 32 | echo $OTHER_LANG 33 | 34 | if [ "$SRC" \< "$TGT" ]; then 35 | ORDERED_SRC=$SRC 36 | ORDERED_TGT=$TGT 37 | else 38 | ORDERED_SRC=$TGT 39 | ORDERED_TGT=$SRC 40 | fi 41 | 42 | epoch_size=$(cat data/$ORDERED_SRC-$ORDERED_TGT/$TGT_DATA_NAME/processed/$ORDERED_SRC-$ORDERED_TGT/train.$SRC | wc -l) 43 | max_epoch_size=300000 44 | epoch_size=$((epoch_size>max_epoch_size ? max_epoch_size : epoch_size)) 45 | echo $epoch_size 46 | 47 | 48 | python -W ignore train.py \ 49 | --exp_name IBT_BACK_src_$SRC_DATA_NAME\_tgt_$TGT_DATA_NAME\_$SRC\_$TGT \ 50 | --dump_path ./tmp/ \ 51 | --reload_model ${PRETRAINED_MODEL_DIR/mlm_en${OTHER_LANG}_1024.pth,${PRETRAINED_MODEL_DIR/mlm_en${OTHER_LANG}_1024.pth \ 52 | --data_path data/$ORDERED_SRC-$ORDERED_TGT/$TGT_DATA_NAME/processed/$ORDERED_SRC-$ORDERED_TGT \ 53 | --para_data_path data/$ORDERED_SRC-$ORDERED_TGT/$TGT_DATA_NAME/back_translate/$SRC_DATA_NAME \ 54 | --lgs $SRC-$TGT \ 55 | --ae_steps $SRC,$TGT \ 56 | --bt_steps $SRC-$TGT-$SRC,$TGT-$SRC-$TGT \ 57 | --mt_steps $SRC-$TGT \ 58 | --word_shuffle 3 \ 59 | --word_dropout 0.1 \ 60 | --word_blank 0.1 \ 61 | --lambda_ae '0:1,100000:0.1,300000:0' \ 62 | --encoder_only false \ 63 | --emb_dim 1024 \ 64 | --n_layers 6 \ 65 | --n_heads 8 \ 66 | --dropout 0.1 \ 67 | --attention_dropout 0.1 \ 68 | --gelu_activation true \ 69 | --tokens_per_batch 1500 \ 70 | --batch_size 32 \ 71 | --bptt 256 \ 72 | --optimizer adam_inverse_sqrt,beta1=0.9,beta2=0.98,lr=0.0001 \ 73 | --epoch_size $epoch_size \ 74 | --eval_bleu true \ 75 | --stopping_criterion valid_$SRC-$TGT\_mt_bleu,3 \ 76 | --validation_metrics valid_$SRC-$TGT\_mt_bleu \ 77 | --max_epoch 50 \ 78 | --max_len 150 \ -------------------------------------------------------------------------------- /train_IBT_plus_SRC.sh: -------------------------------------------------------------------------------- 1 | # 2 | # Read arguments 3 | # 4 | POSITIONAL=() 5 | while [[ $# -gt 0 ]] 6 | do 7 | key="$1" 8 | case $key in 9 | --src) 10 | SRC="$2"; shift 2;; 11 | --tgt) 12 | TGT="$2"; shift 2;; 13 | --src_data_name) 14 | SRC_DATA_NAME="$2"; shift 2;; 15 | --tgt_data_name) 16 | TGT_DATA_NAME="$2"; shift 2;; 17 | --pretrained_model_dir) 18 | PRETRAINED_MODEL_DIR="$2"; shift 2;; 19 | *) 20 | POSITIONAL+=("$1") 21 | shift 22 | ;; 23 | esac 24 | done 25 | set -- "${POSITIONAL[@]}" 26 | 27 | if [ "$SRC" != 'en' ]; then 28 | OTHER_LANG=$SRC 29 | else 30 | OTHER_LANG=$TGT 31 | fi 32 | echo $OTHER_LANG 33 | 34 | if [ "$SRC" \< "$TGT" ]; then 35 | ORDERED_SRC=$SRC 36 | ORDERED_TGT=$TGT 37 | else 38 | ORDERED_SRC=$TGT 39 | ORDERED_TGT=$SRC 40 | fi 41 | 42 | epoch_size=$(cat data/$ORDERED_SRC-$ORDERED_TGT/$SRC_DATA_NAME/processed/$ORDERED_SRC-$ORDERED_TGT/train.$ORDERED_SRC-$ORDERED_TGT.$SRC | wc -l) 43 | max_epoch_size=500000 44 | epoch_size=$((epoch_size>max_epoch_size ? max_epoch_size : epoch_size)) 45 | echo $epoch_size 46 | 47 | python -W ignore train.py \ 48 | --exp_name semi_sup_bidir_src_$SRC_DATA_NAME\_tgt_$TGT_DATA_NAME\_$SRC\_$TGT \ 49 | --dump_path ./tmp/ \ 50 | --reload_model ${PRETRAINED_MODEL_DIR}/mlm_en${OTHER_LANG}_1024.pth,${PRETRAINED_MODEL_DIR}/mlm_en${OTHER_LANG}_1024.pth \ 51 | --data_path data/$ORDERED_SRC-$ORDERED_TGT/$TGT_DATA_NAME/processed/$ORDERED_SRC-$ORDERED_TGT \ 52 | --para_data_path data/$ORDERED_SRC-$ORDERED_TGT/$SRC_DATA_NAME/processed/$ORDERED_SRC-$ORDERED_TGT \ 53 | --lgs $SRC-$TGT \ 54 | --ae_steps $SRC,$TGT \ 55 | --bt_steps $SRC-$TGT-$SRC,$TGT-$SRC-$TGT \ 56 | --mt_steps $SRC-$TGT \ 57 | --word_shuffle 3 \ 58 | --word_dropout 0.1 \ 59 | --word_blank 0.1 \ 60 | --lambda_ae '0:1,100000:0.1,300000:0' \ 61 | --encoder_only false \ 62 | --emb_dim 1024 \ 63 | --n_layers 6 \ 64 | --n_heads 8 \ 65 | --dropout 0.1 \ 66 | --attention_dropout 0.1 \ 67 | --gelu_activation true \ 68 | --tokens_per_batch 1200 \ 69 | --batch_size 32 \ 70 | --bptt 256 \ 71 | --optimizer adam_inverse_sqrt,beta1=0.9,beta2=0.98,lr=0.0001 \ 72 | --epoch_size $epoch_size \ 73 | --eval_bleu true \ 74 | --stopping_criterion valid_$SRC-$TGT\_mt_bleu,3 \ 75 | --validation_metrics valid_$SRC-$TGT\_mt_bleu \ 76 | --max_epoch 50 \ 77 | --max_len 150 \ -------------------------------------------------------------------------------- /train_sup.sh: -------------------------------------------------------------------------------- 1 | # 2 | # Read arguments 3 | # 4 | POSITIONAL=() 5 | while [[ $# -gt 0 ]] 6 | do 7 | key="$1" 8 | case $key in 9 | --src) 10 | SRC="$2"; shift 2;; 11 | --tgt) 12 | TGT="$2"; shift 2;; 13 | --data_name) 14 | DATA_NAME="$2"; shift 2;; 15 | --pretrained_model_dir) 16 | PRETRAINED_MODEL_DIR="$2"; shift 2;; 17 | *) 18 | POSITIONAL+=("$1") 19 | shift 20 | ;; 21 | esac 22 | done 23 | set -- "${POSITIONAL[@]}" 24 | 25 | if [ "$SRC" != 'en' ]; then 26 | OTHER_LANG=$SRC 27 | else 28 | OTHER_LANG=$TGT 29 | fi 30 | echo $OTHER_LANG 31 | 32 | if [ "$SRC" \< "$TGT" ]; then 33 | ORDERED_SRC=$SRC 34 | ORDERED_TGT=$TGT 35 | else 36 | ORDERED_SRC=$TGT 37 | ORDERED_TGT=$SRC 38 | fi 39 | 40 | epoch_size=$(cat data/$ORDERED_SRC-$ORDERED_TGT/$DATA_NAME/processed/$ORDERED_SRC-$ORDERED_TGT/train.$ORDERED_SRC-$ORDERED_TGT.$SRC | wc -l) 41 | max_epoch_size=300000 42 | epoch_size=$((epoch_size>max_epoch_size ? max_epoch_size : epoch_size)) 43 | echo $epoch_size 44 | 45 | python -W ignore train.py \ 46 | --exp_name sup_$DATA_NAME\_$SRC\_$TGT \ 47 | --dump_path ./tmp/ \ 48 | --reload_model ${PRETRAINED_MODEL_DIR}/mlm_en${OTHER_LANG}_1024.pth,${PRETRAINED_MODEL_DIR}/mlm_en${OTHER_LANG}_1024.pth \ 49 | --data_path data/$ORDERED_SRC-$ORDERED_TGT/$DATA_NAME/processed/$ORDERED_SRC-$ORDERED_TGT \ 50 | --lgs $SRC-$TGT \ 51 | --mt_steps $SRC-$TGT \ 52 | --encoder_only false \ 53 | --emb_dim 1024 \ 54 | --n_layers 6 \ 55 | --n_heads 8 \ 56 | --dropout 0.1 \ 57 | --attention_dropout 0.1 \ 58 | --gelu_activation true \ 59 | --tokens_per_batch 2500 \ 60 | --batch_size 32 \ 61 | --bptt 256 \ 62 | --optimizer adam_inverse_sqrt,beta1=0.9,beta2=0.98,lr=0.0001 \ 63 | --epoch_size $epoch_size \ 64 | --eval_bleu true \ 65 | --stopping_criterion valid_$SRC-$TGT\_mt_bleu,2 \ 66 | --validation_metrics valid_$SRC-$TGT\_mt_bleu \ 67 | --max_epoch 50 \ 68 | --max_len 150 \ -------------------------------------------------------------------------------- /translate.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | # Translate sentences from the input stream. 8 | # The model will be faster is sentences are sorted by length. 9 | # Input sentences must have the same tokenization and BPE codes than the ones used in the model. 10 | # 11 | # Usage: 12 | # cat source_sentences.bpe | \ 13 | # python translate.py --exp_name translate \ 14 | # --src_lang en --tgt_lang fr \ 15 | # --model_path trained_model.pth --output_path output 16 | # 17 | 18 | import os 19 | import io 20 | import sys 21 | import argparse 22 | import torch 23 | 24 | from src.utils import AttrDict 25 | from src.utils import bool_flag, initialize_exp 26 | from src.data.dictionary import Dictionary 27 | from src.model.transformer import TransformerModel 28 | 29 | 30 | def get_parser(): 31 | """ 32 | Generate a parameters parser. 33 | """ 34 | # parse parameters 35 | parser = argparse.ArgumentParser(description="Translate sentences") 36 | 37 | # main parameters 38 | parser.add_argument("--dump_path", type=str, default="./dumped/", help="Experiment dump path") 39 | parser.add_argument("--exp_name", type=str, default="", help="Experiment name") 40 | parser.add_argument("--exp_id", type=str, default="", help="Experiment ID") 41 | parser.add_argument("--batch_size", type=int, default=32, help="Number of sentences per batch") 42 | 43 | # model / output paths 44 | parser.add_argument("--model_path", type=str, default="", help="Model path") 45 | parser.add_argument("--output_path_source", type=str, default="", help="Output path for source") 46 | parser.add_argument("--output_path_target", type=str, default="", help="Output path for target") 47 | 48 | # parser.add_argument("--max_vocab", type=int, default=-1, help="Maximum vocabulary size (-1 to disable)") 49 | # parser.add_argument("--min_count", type=int, default=0, help="Minimum vocabulary count") 50 | 51 | # source language / target language 52 | parser.add_argument("--src_lang", type=str, default="", help="Source language") 53 | parser.add_argument("--tgt_lang", type=str, default="", help="Target language") 54 | parser.add_argument("--src_data_path", type=str, default="", help="Input data path") 55 | 56 | return parser 57 | 58 | 59 | def main(params): 60 | 61 | # initialize the experiment 62 | logger = initialize_exp(params) 63 | 64 | # generate parser / parse parameters 65 | parser = get_parser() 66 | params = parser.parse_args() 67 | reloaded = torch.load(params.model_path) 68 | model_params = AttrDict(reloaded['params']) 69 | logger.info("Supported languages: %s" % ", ".join(model_params.lang2id.keys())) 70 | 71 | # update dictionary parameters 72 | for name in ['n_words', 'bos_index', 'eos_index', 'pad_index', 'unk_index', 'mask_index']: 73 | setattr(params, name, getattr(model_params, name)) 74 | 75 | # build dictionary / build encoder / build decoder / reload weights 76 | dico = Dictionary(reloaded['dico_id2word'], reloaded['dico_word2id'], reloaded['dico_counts']) 77 | encoder = TransformerModel(model_params, dico, is_encoder=True, with_output=True).cuda().eval() 78 | decoder = TransformerModel(model_params, dico, is_encoder=False, with_output=True).cuda().eval() 79 | encoder.load_state_dict(reloaded['encoder']) 80 | decoder.load_state_dict(reloaded['decoder']) 81 | params.src_id = model_params.lang2id[params.src_lang] 82 | params.tgt_id = model_params.lang2id[params.tgt_lang] 83 | 84 | # read sentences from stdin 85 | src_sent = [] 86 | for line in open(params.src_data_path, 'r').readlines(): 87 | if line.strip() and len(line.split()) <= 130: 88 | src_sent.append(line) 89 | logger.info("Read %i sentences from stdin. Translating ..." % len(src_sent)) 90 | 91 | f_src = io.open(params.output_path_source, 'w', encoding='utf-8') 92 | f_tgt = io.open(params.output_path_target, 'w', encoding='utf-8') 93 | 94 | for i in range(0, len(src_sent), params.batch_size): 95 | 96 | # prepare batch 97 | word_ids = [torch.LongTensor([dico.index(w) for w in s.strip().split()]) 98 | for s in src_sent[i:i + params.batch_size]] 99 | lengths = torch.LongTensor([len(s) + 2 for s in word_ids]) 100 | batch = torch.LongTensor(lengths.max().item(), lengths.size(0)).fill_(params.pad_index) 101 | batch[0] = params.eos_index 102 | for j, s in enumerate(word_ids): 103 | if lengths[j] > 2: # if sentence not empty 104 | batch[1:lengths[j] - 1, j].copy_(s) 105 | batch[lengths[j] - 1, j] = params.eos_index 106 | langs = batch.clone().fill_(params.src_id) 107 | 108 | # encode source batch and translate it 109 | encoded = encoder('fwd', x=batch.cuda(), lengths=lengths.cuda(), langs=langs.cuda(), causal=False) 110 | encoded = encoded.transpose(0, 1) 111 | try: 112 | decoded, dec_lengths = decoder.generate(encoded, lengths.cuda(), params.tgt_id, max_len=int(1.5 * lengths.max().item() + 10)) 113 | except: 114 | print(max([len(line.split()) for line in src_sent[i:i + params.batch_size]])) 115 | else: 116 | # convert sentences to words 117 | for j in range(decoded.size(1)): 118 | 119 | # remove delimiters 120 | sent = decoded[:, j] 121 | delimiters = (sent == params.eos_index).nonzero().view(-1) 122 | assert len(delimiters) >= 1 and delimiters[0].item() == 0 123 | sent = sent[1:] if len(delimiters) == 1 else sent[1:delimiters[1]] 124 | 125 | # output translation 126 | source = src_sent[i + j].strip() 127 | target = " ".join([dico[sent[k].item()] for k in range(len(sent))]) 128 | sys.stderr.write("%i / %i: %s -> %s\n" % (i + j, len(src_sent), source, target)) 129 | f_src.write(source + "\n") 130 | f_tgt.write(target + "\n") 131 | 132 | f_src.close() 133 | f_tgt.close() 134 | 135 | 136 | if __name__ == '__main__': 137 | 138 | # generate parser / parse parameters 139 | parser = get_parser() 140 | params = parser.parse_args() 141 | 142 | # check parameters 143 | assert os.path.isfile(params.model_path) 144 | assert params.src_lang != '' and params.tgt_lang != '' and params.src_lang != params.tgt_lang 145 | # assert params.output_path and not os.path.isfile(params.output_path) 146 | 147 | # translate 148 | with torch.no_grad(): 149 | main(params) 150 | -------------------------------------------------------------------------------- /translate_exe.sh: -------------------------------------------------------------------------------- 1 | # 2 | # Read arguments 3 | # 4 | POSITIONAL=() 5 | while [[ $# -gt 0 ]] 6 | do 7 | key="$1" 8 | case $key in 9 | --src) 10 | SRC="$2"; shift 2;; 11 | --tgt) 12 | TGT="$2"; shift 2;; 13 | --data_name) 14 | DATA_NAME="$2"; shift 2;; 15 | --model_name) 16 | MODEL_NAME="$2"; shift 2;; 17 | --model_dir) 18 | MODEL_DIR="$2"; shift 2;; 19 | *) 20 | POSITIONAL+=("$1") 21 | shift 22 | ;; 23 | esac 24 | done 25 | set -- "${POSITIONAL[@]}" 26 | 27 | if [ "$SRC" \< "$TGT" ]; then 28 | ORDERED_SRC=$SRC 29 | ORDERED_TGT=$TGT 30 | else 31 | ORDERED_SRC=$TGT 32 | ORDERED_TGT=$SRC 33 | fi 34 | 35 | OUT_DIR=data/$ORDERED_SRC-$ORDERED_TGT/$DATA_NAME/back_translate/$MODEL_NAME 36 | mkdir -p $OUT_DIR 37 | 38 | python -W ignore translate.py \ 39 | --exp_name $MODEL_NAME\_$SRC\_to_$TGT \ 40 | --dump_path ./back_translate/ \ 41 | --model_path $MODEL_DIR/best-valid_$SRC-$TGT\_mt_bleu.pth \ 42 | --src_data_path data/$ORDERED_SRC-$ORDERED_TGT/$DATA_NAME/processed/$ORDERED_SRC-$ORDERED_TGT/train.$ORDERED_SRC-$ORDERED_TGT.$SRC \ 43 | --output_path_source $OUT_DIR/train.$ORDERED_SRC-$ORDERED_TGT.$SRC \ 44 | --output_path_target $OUT_DIR/train.$ORDERED_SRC-$ORDERED_TGT.$TGT \ 45 | --src_lang $SRC \ 46 | --tgt_lang $TGT \ 47 | --batch_size 128 \ --------------------------------------------------------------------------------