├── .DS_Store
├── README.md
├── get-data-back-translate.sh
├── get-data-nmt-local.sh
├── preprocess.py
├── requirements.txt
├── src
├── .DS_Store
├── __init__.py
├── __init__.pyc
├── __pycache__
│ ├── __init__.cpython-36.pyc
│ └── logger.cpython-36.pyc
├── data
│ ├── __init__.py
│ ├── __init__.pyc
│ ├── __pycache__
│ │ ├── __init__.cpython-36.pyc
│ │ └── dictionary.cpython-36.pyc
│ ├── dataset.py
│ ├── dictionary.py
│ ├── dictionary.pyc
│ └── loader.py
├── evaluation
│ ├── __init__.py
│ ├── evaluator.py
│ ├── glue.py
│ ├── multi-bleu.perl
│ └── xnli.py
├── logger.py
├── logger.pyc
├── model
│ ├── __init__.py
│ ├── embedder.py
│ ├── memory
│ │ ├── __init__.py
│ │ ├── memory.py
│ │ ├── query.py
│ │ └── utils.py
│ ├── pretrain.py
│ └── transformer.py
├── optim.py
├── slurm.py
├── trainer.py
└── utils.py
├── tools
├── .DS_Store
├── README.md
├── lowercase_and_remove_accent.py
├── segment_th.py
└── tokenize.sh
├── train.py
├── train_IBT.sh
├── train_IBT_plus_BACK.sh
├── train_IBT_plus_SRC.sh
├── train_sup.sh
├── translate.py
└── translate_exe.sh
/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jind11/DAMT/3caa22822b936137e5da3e827d7a5a2078c3115e/.DS_Store
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # DAMT
2 | A new method for semi-supervised domain adaptation of Neural Machine Translation (NMT)
3 |
4 | This is the source code for the paper: [Jin, D., Jin, Z., Zhou, J.T., & Szolovits, P. (2020). Unsupervised Domain Adaptation for Neural Machine Translation with Iterative Back Translation. ArXiv, abs/2001.08140.](https://arxiv.org/abs/2001.08140). If you use the code, please cite the paper:
5 |
6 | ```
7 | @article{Jin2020UnsupervisedDA,
8 | title={Unsupervised Domain Adaptation for Neural Machine Translation with Iterative Back Translation},
9 | author={Di Jin and Zhijing Jin and Joey Tianyi Zhou and Peter Szolovits},
10 | journal={ArXiv},
11 | year={2020},
12 | volume={abs/2001.08140}
13 | }
14 | ```
15 |
16 | ## Prerequisites:
17 | Run the following command to install the prerequisite packages:
18 | ```
19 | pip install -r requirements.txt
20 | ```
21 | You should also install Moses tokenizer and fastBPE tool in the folder of "tools" by running the following commands:
22 | ```
23 | cd tools
24 | git clone https://github.com/moses-smt/mosesdecoder
25 | git clone https://github.com/glample/fastBPE
26 | cd fastBPE
27 | g++ -std=c++11 -pthread -O3 fastBPE/main.cc -IfastBPE -o fast
28 | cd ../..
29 | ```
30 |
31 | ## Data:
32 | Please download the data from the [Google Drive](https://drive.google.com/file/d/1aQOXfcGpPbQemG4mQQuiy6ZrCRn6WiDj/view?usp=sharing) and unzip it to the main directory of this repository. The data downloaded include the domains of MED (EMEA), IT, LAW (ACQUIS), and TED for DE-EN language pair and MED, LAW, and TED for EN-RO language pair. WMT14 DE-EN data can be downloaded [here](https://nlp.stanford.edu/projects/nmt/) and WMT16 EN-RO data is downloaded from [here](https://www.statmt.org/wmt16/translation-task.html).
33 |
34 | ## How to use
35 | 1. First we need to download the pretrained model parameter files from the [XLM repository](https://github.com/facebookresearch/XLM#pretrained-xlmmlm-models).
36 |
37 | 2. Then we need to process the data. Suppose we want to train the NMT model from German (de) to English (en), and the source domain is Law (dataset name is acquis) and the target domain is IT, then run the following command:
38 | ```
39 | ./get-data-nmt-local.sh --src de --tgt en --data_name it --data_path ./data/de-en/it --reload_codes PATH_TO_PRETRAINED_MODEL_CODES --reload_vocab PATH_TO_PRETRAINED_MODEL_VOCAB
40 | ./get-data-nmt-local.sh --src de --tgt en --data_name acquis --data_path ./data/de-en/acquis --reload_codes PATH_TO_PRETRAINED_MODEL_CODES --reload_vocab PATH_TO_PRETRAINED_MODEL_VOCAB
41 | ```
42 |
43 | 3. After data processing, to reproduce the "IBT" setting as mentioned in the paper, run the following command:
44 | ```
45 | ./train_IBT.sh --src de --tgt en --data_name it --pretrained_model_dir DIR_TO_PRETRAINED_MODEL
46 | ```
47 |
48 | 4. To reproduce the "IBT+SRC" setting, Recall that we want to adapt from the Law domain to IT domain, where the source domain is Law (dataset name is acquis) and the target domain is IT, then run the following command:
49 | ```
50 | ./train_IBT_plus_SRC.sh --src de --tgt en --src_data_name acquis --tgt_data_name it --pretrained_model_dir DIR_TO_PRETRAINED_MODEL
51 | ```
52 |
53 | 5. In order to reproduce the "IBT+Back" setting, we need to go through several steps.
54 |
55 | * First of all, we need to train a NMT model to translate from en to de using the source domain data (acquis) by running the following command:
56 | ```
57 | ./train_sup.sh --src en --tgt de --data_name acquis --pretrained_model_dir DIR_TO_PRETRAINED_MODEL
58 | ```
59 |
60 | * After training this model, we get the translation results by using thie model to translate the English sentences in the target domain (it) to German, which are used as the back-translated data:
61 | ```
62 | ./translate_exe.sh --src en --tgt de --data_name it --model_name acquis --model_dir DIR_TO_TRAINED_MODEL
63 | ./get-data-back-translate.sh --src en --tgt de --data_name it --model_name acquis
64 | ```
65 |
66 | * When the back-translated data is ready, we can finally run this command:
67 | ```
68 | ./train_IBT_plus_BACK.sh --src de --tgt en --src_data_name acquis --tgt_data_name it --pretrained_model_dir DIR_TO_PRETRAINED_MODEL
69 | ```
70 |
--------------------------------------------------------------------------------
/get-data-back-translate.sh:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2019-present, Facebook, Inc.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | #
7 |
8 | set -e
9 |
10 | #
11 | # Read arguments
12 | #
13 | POSITIONAL=()
14 | while [[ $# -gt 0 ]]
15 | do
16 | key="$1"
17 | case $key in
18 | --src)
19 | SRC="$2"; shift 2;;
20 | --tgt)
21 | TGT="$2"; shift 2;;
22 | --data_name)
23 | DATA_NAME="$2"; shift 2;;
24 | --model_name)
25 | MODEL_NAME="$2"; shift 2;;
26 | *)
27 | POSITIONAL+=("$1")
28 | shift
29 | ;;
30 | esac
31 | done
32 | set -- "${POSITIONAL[@]}"
33 |
34 | if [ "$SRC" \> "$TGT" ]; then echo "please ensure SRC < TGT"; exit; fi
35 |
36 | MAIN_PATH=$PWD
37 | DATA_PATH=data/$SRC-$TGT/$DATA_NAME
38 | PROC_PATH=$DATA_PATH/processed/$SRC-$TGT
39 | BACK_DATA_DIR=$DATA_PATH/back_translate/$MODEL_NAME
40 | FULL_VOCAB=$PROC_PATH/vocab.$SRC-$TGT
41 |
42 | $MAIN_PATH/preprocess.py $FULL_VOCAB $BACK_DATA_DIR/train.$SRC-$TGT.$SRC
43 | $MAIN_PATH/preprocess.py $FULL_VOCAB $BACK_DATA_DIR/train.$SRC-$TGT.$TGT
44 |
--------------------------------------------------------------------------------
/get-data-nmt-local.sh:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2019-present, Facebook, Inc.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | #
7 |
8 | set -e
9 |
10 |
11 | #
12 | # Data preprocessing configuration
13 | #
14 | CODES=60000 # number of BPE codes
15 | N_THREADS=16 # number of threads in data preprocessing
16 |
17 |
18 | #
19 | # Read arguments
20 | #
21 | POSITIONAL=()
22 | while [[ $# -gt 0 ]]
23 | do
24 | key="$1"
25 | case $key in
26 | --src)
27 | SRC="$2"; shift 2;;
28 | --tgt)
29 | TGT="$2"; shift 2;;
30 | --data_name)
31 | DATA_NAME="$2"; shift 2;;
32 | --data_path)
33 | DATA_PATH="$2"; shift 2;;
34 | --reload_codes)
35 | RELOAD_CODES="$2"; shift 2;;
36 | --reload_vocab)
37 | RELOAD_VOCAB="$2"; shift 2;;
38 | *)
39 | POSITIONAL+=("$1")
40 | shift
41 | ;;
42 | esac
43 | done
44 | set -- "${POSITIONAL[@]}"
45 |
46 |
47 | #
48 | # Check parameters
49 | #
50 | if [ "$SRC" == "" ]; then echo "--src not provided"; exit; fi
51 | if [ "$TGT" == "" ]; then echo "--tgt not provided"; exit; fi
52 | if [ "$SRC" != "de" -a "$SRC" != "en" -a "$SRC" != "fr" -a "$SRC" != "ro" ]; then echo "unknown source language"; exit; fi
53 | if [ "$TGT" != "de" -a "$TGT" != "en" -a "$TGT" != "fr" -a "$TGT" != "ro" ]; then echo "unknown target language"; exit; fi
54 | if [ "$SRC" == "$TGT" ]; then echo "source and target cannot be identical"; exit; fi
55 | if [ "$SRC" \> "$TGT" ]; then echo "please ensure SRC < TGT"; exit; fi
56 | if [ "$RELOAD_CODES" != "" ] && [ ! -f "$RELOAD_CODES" ]; then echo "cannot locate BPE codes"; exit; fi
57 | if [ "$RELOAD_VOCAB" != "" ] && [ ! -f "$RELOAD_VOCAB" ]; then echo "cannot locate vocabulary"; exit; fi
58 | if [ "$RELOAD_CODES" == "" -a "$RELOAD_VOCAB" != "" -o "$RELOAD_CODES" != "" -a "$RELOAD_VOCAB" == "" ]; then echo "BPE codes should be provided if and only if vocabulary is also provided"; exit; fi
59 |
60 |
61 | #
62 | # Initialize tools and data paths
63 | #
64 |
65 | # main paths
66 | MAIN_PATH=$PWD
67 | TOOLS_PATH=$PWD/tools
68 | PROC_PATH=$DATA_PATH/processed/$SRC-$TGT
69 |
70 | # create paths
71 | mkdir -p $PROC_PATH
72 |
73 | # moses
74 | MOSES=$TOOLS_PATH/mosesdecoder
75 | REPLACE_UNICODE_PUNCT=$MOSES/scripts/tokenizer/replace-unicode-punctuation.perl
76 | NORM_PUNC=$MOSES/scripts/tokenizer/normalize-punctuation.perl
77 | REM_NON_PRINT_CHAR=$MOSES/scripts/tokenizer/remove-non-printing-char.perl
78 | TOKENIZER=$MOSES/scripts/tokenizer/tokenizer.perl
79 | INPUT_FROM_SGM=$MOSES/scripts/ems/support/input-from-sgm.perl
80 |
81 | # fastBPE
82 | FASTBPE_DIR=$TOOLS_PATH/fastBPE
83 | FASTBPE=$TOOLS_PATH/fastBPE/fast
84 |
85 | # Sennrich's WMT16 scripts for Romanian preprocessing
86 | WMT16_SCRIPTS=$TOOLS_PATH/wmt16-scripts
87 | NORMALIZE_ROMANIAN=$WMT16_SCRIPTS/preprocess/normalise-romanian.py
88 | REMOVE_DIACRITICS=$WMT16_SCRIPTS/preprocess/remove-diacritics.py
89 |
90 | # raw and tokenized files
91 | SRC_RAW=$DATA_PATH/train.$SRC.mono
92 | TGT_RAW=$DATA_PATH/train.$TGT.mono
93 | SRC_TOK=$SRC_RAW.tok
94 | TGT_TOK=$TGT_RAW.tok
95 |
96 | # BPE / vocab files
97 | BPE_CODES=$PROC_PATH/codes
98 | SRC_VOCAB=$PROC_PATH/vocab.$SRC
99 | TGT_VOCAB=$PROC_PATH/vocab.$TGT
100 | FULL_VOCAB=$PROC_PATH/vocab.$SRC-$TGT
101 |
102 | # train / valid / test monolingual BPE data
103 | SRC_TRAIN_BPE=$PROC_PATH/train.$SRC
104 | TGT_TRAIN_BPE=$PROC_PATH/train.$TGT
105 | SRC_VALID_BPE=$PROC_PATH/valid.$SRC
106 | TGT_VALID_BPE=$PROC_PATH/valid.$TGT
107 | SRC_TEST_BPE=$PROC_PATH/test.$SRC
108 | TGT_TEST_BPE=$PROC_PATH/test.$TGT
109 |
110 | # valid / test parallel BPE data
111 | PARA_SRC_TRAIN_BPE=$PROC_PATH/train.$SRC-$TGT.$SRC
112 | PARA_TGT_TRAIN_BPE=$PROC_PATH/train.$SRC-$TGT.$TGT
113 | PARA_SRC_VALID_BPE=$PROC_PATH/valid.$SRC-$TGT.$SRC
114 | PARA_TGT_VALID_BPE=$PROC_PATH/valid.$SRC-$TGT.$TGT
115 | PARA_SRC_TEST_BPE=$PROC_PATH/test.$SRC-$TGT.$SRC
116 | PARA_TGT_TEST_BPE=$PROC_PATH/test.$SRC-$TGT.$TGT
117 |
118 | # valid / test file raw data
119 | unset PARA_SRC_VALID PARA_TGT_VALID PARA_SRC_TEST PARA_TGT_TEST
120 | PARA_SRC_TRAIN=$DATA_PATH/train.$SRC
121 | PARA_TGT_TRAIN=$DATA_PATH/train.$TGT
122 | PARA_SRC_VALID=$DATA_PATH/dev.$SRC
123 | PARA_TGT_VALID=$DATA_PATH/dev.$TGT
124 | PARA_SRC_TEST=$DATA_PATH/test.$SRC
125 | PARA_TGT_TEST=$DATA_PATH/test.$TGT
126 |
127 | #cd $DATA_PATH
128 |
129 | # preprocessing commands - special case for Romanian
130 | if [ "$SRC" == "ro" ]; then
131 | SRC_PREPROCESSING="$REPLACE_UNICODE_PUNCT | $NORM_PUNC -l $SRC | $REM_NON_PRINT_CHAR | $NORMALIZE_ROMANIAN | $REMOVE_DIACRITICS | $TOKENIZER -l $SRC -no-escape -threads $N_THREADS"
132 | else
133 | SRC_PREPROCESSING="$REPLACE_UNICODE_PUNCT | $NORM_PUNC -l $SRC | $REM_NON_PRINT_CHAR | $TOKENIZER -l $SRC -no-escape -threads $N_THREADS"
134 | fi
135 | if [ "$TGT" == "ro" ]; then
136 | TGT_PREPROCESSING="$REPLACE_UNICODE_PUNCT | $NORM_PUNC -l $TGT | $REM_NON_PRINT_CHAR | $NORMALIZE_ROMANIAN | $REMOVE_DIACRITICS | $TOKENIZER -l $TGT -no-escape -threads $N_THREADS"
137 | else
138 | TGT_PREPROCESSING="$REPLACE_UNICODE_PUNCT | $NORM_PUNC -l $TGT | $REM_NON_PRINT_CHAR | $TOKENIZER -l $TGT -no-escape -threads $N_THREADS"
139 | fi
140 |
141 | # tokenize data
142 | if ! [[ -f "$SRC_TOK" ]]; then
143 | echo "Tokenize $SRC monolingual data..."
144 | eval "cat $SRC_RAW | $SRC_PREPROCESSING > $SRC_TOK"
145 | fi
146 |
147 | if ! [[ -f "$TGT_TOK" ]]; then
148 | echo "Tokenize $TGT monolingual data..."
149 | eval "cat $TGT_RAW | $TGT_PREPROCESSING > $TGT_TOK"
150 | fi
151 | echo "$SRC monolingual data tokenized in: $SRC_TOK"
152 | echo "$TGT monolingual data tokenized in: $TGT_TOK"
153 |
154 | # reload BPE codes
155 | cd $MAIN_PATH
156 | if [ ! -f "$BPE_CODES" ] && [ -f "$RELOAD_CODES" ]; then
157 | echo "Reloading BPE codes from $RELOAD_CODES ..."
158 | cp $RELOAD_CODES $BPE_CODES
159 | fi
160 |
161 | # learn BPE codes
162 | if [ ! -f "$BPE_CODES" ]; then
163 | echo "Learning BPE codes..."
164 | $FASTBPE learnbpe $CODES $SRC_TOK $TGT_TOK > $BPE_CODES
165 | fi
166 | echo "BPE learned in $BPE_CODES"
167 |
168 | # apply BPE codes
169 | if ! [[ -f "$SRC_TRAIN_BPE" ]]; then
170 | echo "Applying $SRC BPE codes..."
171 | $FASTBPE applybpe $SRC_TRAIN_BPE $SRC_TOK $BPE_CODES
172 | fi
173 | if ! [[ -f "$TGT_TRAIN_BPE" ]]; then
174 | echo "Applying $TGT BPE codes..."
175 | $FASTBPE applybpe $TGT_TRAIN_BPE $TGT_TOK $BPE_CODES
176 | fi
177 | echo "BPE codes applied to $SRC in: $SRC_TRAIN_BPE"
178 | echo "BPE codes applied to $TGT in: $TGT_TRAIN_BPE"
179 |
180 | # extract source and target vocabulary
181 | if ! [[ -f "$SRC_VOCAB" && -f "$TGT_VOCAB" ]]; then
182 | echo "Extracting vocabulary..."
183 | $FASTBPE getvocab $SRC_TRAIN_BPE > $SRC_VOCAB
184 | $FASTBPE getvocab $TGT_TRAIN_BPE > $TGT_VOCAB
185 | fi
186 | echo "$SRC vocab in: $SRC_VOCAB"
187 | echo "$TGT vocab in: $TGT_VOCAB"
188 |
189 | # reload full vocabulary
190 | cd $MAIN_PATH
191 | if [ ! -f "$FULL_VOCAB" ] && [ -f "$RELOAD_VOCAB" ]; then
192 | echo "Reloading vocabulary from $RELOAD_VOCAB ..."
193 | cp $RELOAD_VOCAB $FULL_VOCAB
194 | fi
195 |
196 | # extract full vocabulary
197 | if ! [[ -f "$FULL_VOCAB" ]]; then
198 | echo "Extracting vocabulary..."
199 | $FASTBPE getvocab $SRC_TRAIN_BPE $TGT_TRAIN_BPE > $FULL_VOCAB
200 | fi
201 | echo "Full vocab in: $FULL_VOCAB"
202 |
203 | # binarize data
204 | if ! [[ -f "$SRC_TRAIN_BPE.pth" ]]; then
205 | echo "Binarizing $SRC data..."
206 | $MAIN_PATH/preprocess.py $FULL_VOCAB $SRC_TRAIN_BPE
207 | fi
208 | if ! [[ -f "$TGT_TRAIN_BPE.pth" ]]; then
209 | echo "Binarizing $TGT data..."
210 | $MAIN_PATH/preprocess.py $FULL_VOCAB $TGT_TRAIN_BPE
211 | fi
212 | echo "$SRC binarized data in: $SRC_TRAIN_BPE.pth"
213 | echo "$TGT binarized data in: $TGT_TRAIN_BPE.pth"
214 |
215 | #
216 | # Download parallel data (for evaluation only)
217 | #
218 |
219 | echo "Tokenizing parallel train, valid and test data..."
220 | eval "cat $PARA_SRC_TRAIN | $SRC_PREPROCESSING > $PARA_SRC_TRAIN.tok"
221 | eval "cat $PARA_TGT_TRAIN | $TGT_PREPROCESSING > $PARA_TGT_TRAIN.tok"
222 | eval "cat $PARA_SRC_VALID | $SRC_PREPROCESSING > $PARA_SRC_VALID.tok"
223 | eval "cat $PARA_TGT_VALID | $TGT_PREPROCESSING > $PARA_TGT_VALID.tok"
224 | eval "cat $PARA_SRC_TEST | $SRC_PREPROCESSING > $PARA_SRC_TEST.tok"
225 | eval "cat $PARA_TGT_TEST | $TGT_PREPROCESSING > $PARA_TGT_TEST.tok"
226 |
227 | echo "Applying BPE to train, valid and test files..."
228 | $FASTBPE applybpe $PARA_SRC_TRAIN_BPE $PARA_SRC_TRAIN.tok $BPE_CODES $SRC_VOCAB
229 | $FASTBPE applybpe $PARA_TGT_TRAIN_BPE $PARA_TGT_TRAIN.tok $BPE_CODES $TGT_VOCAB
230 | $FASTBPE applybpe $PARA_SRC_VALID_BPE $PARA_SRC_VALID.tok $BPE_CODES $SRC_VOCAB
231 | $FASTBPE applybpe $PARA_TGT_VALID_BPE $PARA_TGT_VALID.tok $BPE_CODES $TGT_VOCAB
232 | $FASTBPE applybpe $PARA_SRC_TEST_BPE $PARA_SRC_TEST.tok $BPE_CODES $SRC_VOCAB
233 | $FASTBPE applybpe $PARA_TGT_TEST_BPE $PARA_TGT_TEST.tok $BPE_CODES $TGT_VOCAB
234 |
235 | echo "Binarizing data..."
236 | rm -f $PARA_SRC_TRAIN_BPE.pth $PARA_TGT_TRAIN_BPE.pth $PARA_SRC_VALID_BPE.pth $PARA_TGT_VALID_BPE.pth $PARA_SRC_TEST_BPE.pth $PARA_TGT_TEST_BPE.pth
237 | $MAIN_PATH/preprocess.py $FULL_VOCAB $PARA_SRC_TRAIN_BPE
238 | $MAIN_PATH/preprocess.py $FULL_VOCAB $PARA_TGT_TRAIN_BPE
239 | $MAIN_PATH/preprocess.py $FULL_VOCAB $PARA_SRC_VALID_BPE
240 | $MAIN_PATH/preprocess.py $FULL_VOCAB $PARA_TGT_VALID_BPE
241 | $MAIN_PATH/preprocess.py $FULL_VOCAB $PARA_SRC_TEST_BPE
242 | $MAIN_PATH/preprocess.py $FULL_VOCAB $PARA_TGT_TEST_BPE
243 |
244 |
245 | #
246 | # Link monolingual validation and test data to parallel data
247 | #
248 | cd $PROC_PATH
249 | ln -sf valid.$SRC-$TGT.$SRC.pth valid.$SRC.pth
250 | ln -sf valid.$SRC-$TGT.$TGT.pth valid.$TGT.pth
251 | ln -sf test.$SRC-$TGT.$SRC.pth test.$SRC.pth
252 | ln -sf test.$SRC-$TGT.$TGT.pth test.$TGT.pth
253 |
254 | #
255 | # Summary
256 | #
257 | echo ""
258 | echo "===== Data summary"
259 | echo "Monolingual training data:"
260 | echo " $SRC: $SRC_TRAIN_BPE.pth"
261 | echo " $TGT: $TGT_TRAIN_BPE.pth"
262 | echo "Monolingual validation data:"
263 | echo " $SRC: $SRC_VALID_BPE.pth"
264 | echo " $TGT: $TGT_VALID_BPE.pth"
265 | echo "Monolingual test data:"
266 | echo " $SRC: $SRC_TEST_BPE.pth"
267 | echo " $TGT: $TGT_TEST_BPE.pth"
268 | echo "Parallel training data:"
269 | echo " $SRC: $PARA_SRC_TRAIN_BPE.pth"
270 | echo " $TGT: $PARA_TGT_TRAIN_BPE.pth"
271 | echo "Parallel validation data:"
272 | echo " $SRC: $PARA_SRC_VALID_BPE.pth"
273 | echo " $TGT: $PARA_TGT_VALID_BPE.pth"
274 | echo "Parallel test data:"
275 | echo " $SRC: $PARA_SRC_TEST_BPE.pth"
276 | echo " $TGT: $PARA_TGT_TEST_BPE.pth"
277 | echo ""
278 |
--------------------------------------------------------------------------------
/preprocess.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 |
4 | # Copyright (c) 2019-present, Facebook, Inc.
5 | # All rights reserved.
6 | #
7 | # This source code is licensed under the license found in the
8 | # LICENSE file in the root directory of this source tree.
9 | #
10 |
11 |
12 | """
13 | Example: python data/vocab.txt data/train.txt
14 | vocab.txt: 1stline=word, 2ndline=count
15 | """
16 |
17 | import os
18 | import sys
19 |
20 | from src.logger import create_logger
21 | from src.data.dictionary import Dictionary
22 |
23 |
24 | if __name__ == '__main__':
25 |
26 | logger = create_logger(None, 0)
27 |
28 | voc_path = sys.argv[1]
29 | txt_path = sys.argv[2]
30 | bin_path = sys.argv[2] + '.pth'
31 | assert os.path.isfile(voc_path)
32 | assert os.path.isfile(txt_path)
33 |
34 | dico = Dictionary.read_vocab(voc_path)
35 | logger.info("")
36 |
37 | data = Dictionary.index_data(txt_path, bin_path, dico)
38 | logger.info("%i words (%i unique) in %i sentences." % (
39 | len(data['sentences']) - len(data['positions']),
40 | len(data['dico']),
41 | len(data['positions'])
42 | ))
43 | if len(data['unk_words']) > 0:
44 | logger.info("%i unknown words (%i unique), covering %.2f%% of the data." % (
45 | sum(data['unk_words'].values()),
46 | len(data['unk_words']),
47 | sum(data['unk_words'].values()) * 100. / (len(data['sentences']) - len(data['positions']))
48 | ))
49 | if len(data['unk_words']) < 30:
50 | for w, c in sorted(data['unk_words'].items(), key=lambda x: x[1])[::-1]:
51 | logger.info("%s: %i" % (w, c))
52 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | certifi==2020.4.5.1
2 | future==0.18.2
3 | numpy==1.18.5
4 | torch==1.2.0
5 |
--------------------------------------------------------------------------------
/src/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jind11/DAMT/3caa22822b936137e5da3e827d7a5a2078c3115e/src/.DS_Store
--------------------------------------------------------------------------------
/src/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jind11/DAMT/3caa22822b936137e5da3e827d7a5a2078c3115e/src/__init__.py
--------------------------------------------------------------------------------
/src/__init__.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jind11/DAMT/3caa22822b936137e5da3e827d7a5a2078c3115e/src/__init__.pyc
--------------------------------------------------------------------------------
/src/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jind11/DAMT/3caa22822b936137e5da3e827d7a5a2078c3115e/src/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/src/__pycache__/logger.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jind11/DAMT/3caa22822b936137e5da3e827d7a5a2078c3115e/src/__pycache__/logger.cpython-36.pyc
--------------------------------------------------------------------------------
/src/data/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jind11/DAMT/3caa22822b936137e5da3e827d7a5a2078c3115e/src/data/__init__.py
--------------------------------------------------------------------------------
/src/data/__init__.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jind11/DAMT/3caa22822b936137e5da3e827d7a5a2078c3115e/src/data/__init__.pyc
--------------------------------------------------------------------------------
/src/data/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jind11/DAMT/3caa22822b936137e5da3e827d7a5a2078c3115e/src/data/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/src/data/__pycache__/dictionary.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jind11/DAMT/3caa22822b936137e5da3e827d7a5a2078c3115e/src/data/__pycache__/dictionary.cpython-36.pyc
--------------------------------------------------------------------------------
/src/data/dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2019-present, Facebook, Inc.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | #
7 |
8 | from logging import getLogger
9 | import math
10 | import numpy as np
11 | import torch
12 |
13 |
14 | logger = getLogger()
15 |
16 |
17 | class StreamDataset(object):
18 |
19 | def __init__(self, sent, pos, bs, params):
20 | """
21 | Prepare batches for data iterator.
22 | """
23 | bptt = params.bptt
24 | self.eos = params.eos_index
25 |
26 | # checks
27 | assert len(pos) == (sent == self.eos).sum()
28 | assert len(pos) == (sent[pos[:, 1]] == self.eos).sum()
29 |
30 | n_tokens = len(sent)
31 | n_batches = math.ceil(n_tokens / (bs * bptt))
32 | t_size = n_batches * bptt * bs
33 |
34 | buffer = np.zeros(t_size, dtype=sent.dtype) + self.eos
35 | buffer[t_size - n_tokens:] = sent
36 | buffer = buffer.reshape((bs, n_batches * bptt)).T
37 | self.data = np.zeros((n_batches * bptt + 1, bs), dtype=sent.dtype) + self.eos
38 | self.data[1:] = buffer
39 |
40 | self.bptt = bptt
41 | self.n_tokens = n_tokens
42 | self.n_batches = n_batches
43 | self.n_sentences = len(pos)
44 | self.lengths = torch.LongTensor(bs).fill_(bptt)
45 |
46 | def __len__(self):
47 | """
48 | Number of sentences in the dataset.
49 | """
50 | return self.n_sentences
51 |
52 | def select_data(self, a, b):
53 | """
54 | Only select a subset of the dataset.
55 | """
56 | if not (0 <= a < b <= self.n_batches):
57 | logger.warning("Invalid split values: %i %i - %i" % (a, b, self.n_batches))
58 | return
59 | assert 0 <= a < b <= self.n_batches
60 | logger.info("Selecting batches from %i to %i ..." % (a, b))
61 |
62 | # sub-select
63 | self.data = self.data[a * self.bptt:b * self.bptt]
64 | self.n_batches = b - a
65 | self.n_sentences = (self.data == self.eos).sum().item()
66 |
67 | def get_iterator(self, shuffle, subsample=1):
68 | """
69 | Return a sentences iterator.
70 | """
71 | indexes = (np.random.permutation if shuffle else range)(self.n_batches // subsample)
72 | for i in indexes:
73 | a = self.bptt * i
74 | b = self.bptt * (i + 1)
75 | yield torch.from_numpy(self.data[a:b].astype(np.int64)), self.lengths
76 |
77 |
78 | class Dataset(object):
79 |
80 | def __init__(self, sent, pos, params):
81 |
82 | self.eos_index = params.eos_index
83 | self.pad_index = params.pad_index
84 | self.batch_size = params.batch_size
85 | self.tokens_per_batch = params.tokens_per_batch
86 | self.max_batch_size = params.max_batch_size
87 |
88 | self.sent = sent
89 | self.pos = pos
90 | self.lengths = self.pos[:, 1] - self.pos[:, 0]
91 |
92 | # check number of sentences
93 | assert len(self.pos) == (self.sent == self.eos_index).sum()
94 |
95 | # # remove empty sentences
96 | # self.remove_empty_sentences()
97 |
98 | # sanity checks
99 | self.check()
100 |
101 | def __len__(self):
102 | """
103 | Number of sentences in the dataset.
104 | """
105 | return len(self.pos)
106 |
107 | def check(self):
108 | """
109 | Sanity checks.
110 | """
111 | eos = self.eos_index
112 | assert len(self.pos) == (self.sent[self.pos[:, 1]] == eos).sum() # check sentences indices
113 | # assert self.lengths.min() > 0 # check empty sentences
114 |
115 | def batch_sentences(self, sentences):
116 | """
117 | Take as input a list of n sentences (torch.LongTensor vectors) and return
118 | a tensor of size (slen, n) where slen is the length of the longest
119 | sentence, and a vector lengths containing the length of each sentence.
120 | """
121 | # sentences = sorted(sentences, key=lambda x: len(x), reverse=True)
122 | lengths = torch.LongTensor([len(s) + 2 for s in sentences])
123 | sent = torch.LongTensor(lengths.max().item(), lengths.size(0)).fill_(self.pad_index)
124 |
125 | sent[0] = self.eos_index
126 | for i, s in enumerate(sentences):
127 | if lengths[i] > 2: # if sentence not empty
128 | sent[1:lengths[i] - 1, i].copy_(torch.from_numpy(s.astype(np.int64)))
129 | sent[lengths[i] - 1, i] = self.eos_index
130 |
131 | return sent, lengths
132 |
133 | def remove_empty_sentences(self):
134 | """
135 | Remove empty sentences.
136 | """
137 | init_size = len(self.pos)
138 | indices = np.arange(len(self.pos))
139 | indices = indices[self.lengths[indices] > 0]
140 | self.pos = self.pos[indices]
141 | self.lengths = self.pos[:, 1] - self.pos[:, 0]
142 | logger.info("Removed %i empty sentences." % (init_size - len(indices)))
143 | self.check()
144 |
145 | def remove_long_sentences(self, max_len):
146 | """
147 | Remove sentences exceeding a certain length.
148 | """
149 | assert max_len >= 0
150 | if max_len == 0:
151 | return
152 | init_size = len(self.pos)
153 | indices = np.arange(len(self.pos))
154 | indices = indices[self.lengths[indices] <= max_len]
155 | self.pos = self.pos[indices]
156 | self.lengths = self.pos[:, 1] - self.pos[:, 0]
157 | logger.info("Removed %i too long sentences." % (init_size - len(indices)))
158 | self.check()
159 |
160 | def select_data(self, a, b):
161 | """
162 | Only select a subset of the dataset.
163 | """
164 | assert 0 <= a < b <= len(self.pos)
165 | logger.info("Selecting sentences from %i to %i ..." % (a, b))
166 |
167 | # sub-select
168 | self.pos = self.pos[a:b]
169 | self.lengths = self.pos[:, 1] - self.pos[:, 0]
170 |
171 | # re-index
172 | min_pos = self.pos.min()
173 | max_pos = self.pos.max()
174 | self.pos -= min_pos
175 | self.sent = self.sent[min_pos:max_pos + 1]
176 |
177 | # sanity checks
178 | self.check()
179 |
180 | def get_batches_iterator(self, batches, return_indices):
181 | """
182 | Return a sentences iterator, given the associated sentence batches.
183 | """
184 | assert type(return_indices) is bool
185 |
186 | for sentence_ids in batches:
187 | if 0 < self.max_batch_size < len(sentence_ids):
188 | np.random.shuffle(sentence_ids)
189 | sentence_ids = sentence_ids[:self.max_batch_size]
190 | pos = self.pos[sentence_ids]
191 | sent = [self.sent[a:b] for a, b in pos]
192 | sent = self.batch_sentences(sent)
193 | yield (sent, sentence_ids) if return_indices else sent
194 |
195 | def get_iterator(self, shuffle, group_by_size=False, n_sentences=-1, seed=None, return_indices=False):
196 | """
197 | Return a sentences iterator.
198 | """
199 | assert seed is None or shuffle is True and type(seed) is int
200 | rng = np.random.RandomState(seed)
201 | n_sentences = len(self.pos) if n_sentences == -1 else n_sentences
202 | assert 0 < n_sentences <= len(self.pos)
203 | assert type(shuffle) is bool and type(group_by_size) is bool
204 | assert group_by_size is False or shuffle is True
205 |
206 | # sentence lengths
207 | lengths = self.lengths + 2
208 |
209 | # select sentences to iterate over
210 | if shuffle:
211 | indices = rng.permutation(len(self.pos))[:n_sentences]
212 | else:
213 | indices = np.arange(n_sentences)
214 |
215 | # group sentences by lengths
216 | if group_by_size:
217 | indices = indices[np.argsort(lengths[indices], kind='mergesort')]
218 |
219 | # create batches - either have a fixed number of sentences, or a similar number of tokens
220 | if self.tokens_per_batch == -1:
221 | batches = np.array_split(indices, math.ceil(len(indices) * 1. / self.batch_size))
222 | else:
223 | batch_ids = np.cumsum(lengths[indices]) // self.tokens_per_batch
224 | _, bounds = np.unique(batch_ids, return_index=True)
225 | batches = [indices[bounds[i]:bounds[i + 1]] for i in range(len(bounds) - 1)]
226 | if bounds[-1] < len(indices):
227 | batches.append(indices[bounds[-1]:])
228 |
229 | # optionally shuffle batches
230 | if shuffle:
231 | rng.shuffle(batches)
232 |
233 | # sanity checks
234 | assert n_sentences == sum([len(x) for x in batches])
235 | assert lengths[indices].sum() == sum([lengths[x].sum() for x in batches])
236 | # assert set.union(*[set(x.tolist()) for x in batches]) == set(range(n_sentences)) # slow
237 |
238 | # return the iterator
239 | return self.get_batches_iterator(batches, return_indices)
240 |
241 |
242 | class ParallelDataset(Dataset):
243 |
244 | def __init__(self, sent1, pos1, sent2, pos2, params):
245 |
246 | self.eos_index = params.eos_index
247 | self.pad_index = params.pad_index
248 | self.batch_size = params.batch_size
249 | self.tokens_per_batch = params.tokens_per_batch
250 | self.max_batch_size = params.max_batch_size
251 |
252 | self.sent1 = sent1
253 | self.sent2 = sent2
254 | self.pos1 = pos1
255 | self.pos2 = pos2
256 | self.lengths1 = self.pos1[:, 1] - self.pos1[:, 0]
257 | self.lengths2 = self.pos2[:, 1] - self.pos2[:, 0]
258 |
259 | # check number of sentences
260 | assert len(self.pos1) == (self.sent1 == self.eos_index).sum()
261 | assert len(self.pos2) == (self.sent2 == self.eos_index).sum()
262 |
263 | # remove empty sentences
264 | self.remove_empty_sentences()
265 |
266 | # sanity checks
267 | self.check()
268 |
269 | def __len__(self):
270 | """
271 | Number of sentences in the dataset.
272 | """
273 | return len(self.pos1)
274 |
275 | def check(self):
276 | """
277 | Sanity checks.
278 | """
279 | eos = self.eos_index
280 | assert len(self.pos1) == len(self.pos2) > 0 # check number of sentences
281 | assert len(self.pos1) == (self.sent1[self.pos1[:, 1]] == eos).sum() # check sentences indices
282 | assert len(self.pos2) == (self.sent2[self.pos2[:, 1]] == eos).sum() # check sentences indices
283 | assert eos <= self.sent1.min() < self.sent1.max() # check dictionary indices
284 | assert eos <= self.sent2.min() < self.sent2.max() # check dictionary indices
285 | assert self.lengths1.min() > 0 # check empty sentences
286 | assert self.lengths2.min() > 0 # check empty sentences
287 |
288 | def remove_empty_sentences(self):
289 | """
290 | Remove empty sentences.
291 | """
292 | init_size = len(self.pos1)
293 | indices = np.arange(len(self.pos1))
294 | indices = indices[self.lengths1[indices] > 0]
295 | indices = indices[self.lengths2[indices] > 0]
296 | self.pos1 = self.pos1[indices]
297 | self.pos2 = self.pos2[indices]
298 | self.lengths1 = self.pos1[:, 1] - self.pos1[:, 0]
299 | self.lengths2 = self.pos2[:, 1] - self.pos2[:, 0]
300 | logger.info("Removed %i empty sentences." % (init_size - len(indices)))
301 | self.check()
302 |
303 | def remove_long_sentences(self, max_len):
304 | """
305 | Remove sentences exceeding a certain length.
306 | """
307 | assert max_len >= 0
308 | if max_len == 0:
309 | return
310 | init_size = len(self.pos1)
311 | indices = np.arange(len(self.pos1))
312 | indices = indices[self.lengths1[indices] <= max_len]
313 | indices = indices[self.lengths2[indices] <= max_len]
314 | self.pos1 = self.pos1[indices]
315 | self.pos2 = self.pos2[indices]
316 | self.lengths1 = self.pos1[:, 1] - self.pos1[:, 0]
317 | self.lengths2 = self.pos2[:, 1] - self.pos2[:, 0]
318 | logger.info("Removed %i too long sentences." % (init_size - len(indices)))
319 | self.check()
320 |
321 | def select_data(self, a, b):
322 | """
323 | Only select a subset of the dataset.
324 | """
325 | assert 0 <= a < b <= len(self.pos1)
326 | logger.info("Selecting sentences from %i to %i ..." % (a, b))
327 |
328 | # sub-select
329 | self.pos1 = self.pos1[a:b]
330 | self.pos2 = self.pos2[a:b]
331 | self.lengths1 = self.pos1[:, 1] - self.pos1[:, 0]
332 | self.lengths2 = self.pos2[:, 1] - self.pos2[:, 0]
333 |
334 | # re-index
335 | min_pos1 = self.pos1.min()
336 | max_pos1 = self.pos1.max()
337 | min_pos2 = self.pos2.min()
338 | max_pos2 = self.pos2.max()
339 | self.pos1 -= min_pos1
340 | self.pos2 -= min_pos2
341 | self.sent1 = self.sent1[min_pos1:max_pos1 + 1]
342 | self.sent2 = self.sent2[min_pos2:max_pos2 + 1]
343 |
344 | # sanity checks
345 | self.check()
346 |
347 | def get_batches_iterator(self, batches, return_indices):
348 | """
349 | Return a sentences iterator, given the associated sentence batches.
350 | """
351 | assert type(return_indices) is bool
352 |
353 | for sentence_ids in batches:
354 | if 0 < self.max_batch_size < len(sentence_ids):
355 | np.random.shuffle(sentence_ids)
356 | sentence_ids = sentence_ids[:self.max_batch_size]
357 | pos1 = self.pos1[sentence_ids]
358 | pos2 = self.pos2[sentence_ids]
359 | sent1 = self.batch_sentences([self.sent1[a:b] for a, b in pos1])
360 | sent2 = self.batch_sentences([self.sent2[a:b] for a, b in pos2])
361 | yield (sent1, sent2, sentence_ids) if return_indices else (sent1, sent2)
362 |
363 | def get_iterator(self, shuffle, group_by_size=False, n_sentences=-1, return_indices=False):
364 | """
365 | Return a sentences iterator.
366 | """
367 | n_sentences = len(self.pos1) if n_sentences == -1 else n_sentences
368 | assert 0 < n_sentences <= len(self.pos1)
369 | assert type(shuffle) is bool and type(group_by_size) is bool
370 |
371 | # sentence lengths
372 | lengths = self.lengths1 + self.lengths2 + 4
373 |
374 | # select sentences to iterate over
375 | if shuffle:
376 | indices = np.random.permutation(len(self.pos1))[:n_sentences]
377 | else:
378 | indices = np.arange(n_sentences)
379 |
380 | # group sentences by lengths
381 | if group_by_size:
382 | indices = indices[np.argsort(lengths[indices], kind='mergesort')]
383 |
384 | # create batches - either have a fixed number of sentences, or a similar number of tokens
385 | if self.tokens_per_batch == -1:
386 | batches = np.array_split(indices, math.ceil(len(indices) * 1. / self.batch_size))
387 | else:
388 | batch_ids = np.cumsum(lengths[indices]) // self.tokens_per_batch
389 | _, bounds = np.unique(batch_ids, return_index=True)
390 | batches = [indices[bounds[i]:bounds[i + 1]] for i in range(len(bounds) - 1)]
391 | if bounds[-1] < len(indices):
392 | batches.append(indices[bounds[-1]:])
393 |
394 | # optionally shuffle batches
395 | if shuffle:
396 | np.random.shuffle(batches)
397 |
398 | # sanity checks
399 | assert n_sentences == sum([len(x) for x in batches])
400 | assert lengths[indices].sum() == sum([lengths[x].sum() for x in batches])
401 | # assert set.union(*[set(x.tolist()) for x in batches]) == set(range(n_sentences)) # slow
402 |
403 | # return the iterator
404 | return self.get_batches_iterator(batches, return_indices)
405 |
--------------------------------------------------------------------------------
/src/data/dictionary.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2019-present, Facebook, Inc.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | #
7 |
8 | import os
9 | import numpy as np
10 | import torch
11 | from logging import getLogger
12 |
13 |
14 | logger = getLogger()
15 |
16 |
17 | BOS_WORD = ''
18 | EOS_WORD = ''
19 | PAD_WORD = ''
20 | UNK_WORD = ''
21 |
22 | SPECIAL_WORD = ''
23 | SPECIAL_WORDS = 10
24 |
25 | SEP_WORD = SPECIAL_WORD % 0
26 | MASK_WORD = SPECIAL_WORD % 1
27 |
28 |
29 | class Dictionary(object):
30 |
31 | def __init__(self, id2word, word2id, counts):
32 | assert len(id2word) == len(word2id) == len(counts)
33 | self.id2word = id2word
34 | self.word2id = word2id
35 | self.counts = counts
36 | self.bos_index = word2id[BOS_WORD]
37 | self.eos_index = word2id[EOS_WORD]
38 | self.pad_index = word2id[PAD_WORD]
39 | self.unk_index = word2id[UNK_WORD]
40 | self.check_valid()
41 |
42 | def __len__(self):
43 | """
44 | Returns the number of words in the dictionary.
45 | """
46 | return len(self.id2word)
47 |
48 | def __getitem__(self, i):
49 | """
50 | Returns the word of the specified index.
51 | """
52 | return self.id2word[i]
53 |
54 | def __contains__(self, w):
55 | """
56 | Returns whether a word is in the dictionary.
57 | """
58 | return w in self.word2id
59 |
60 | def __eq__(self, y):
61 | """
62 | Compare this dictionary with another one.
63 | """
64 | self.check_valid()
65 | y.check_valid()
66 | if len(self.id2word) != len(y):
67 | return False
68 | return all(self.id2word[i] == y[i] for i in range(len(y)))
69 |
70 | def check_valid(self):
71 | """
72 | Check that the dictionary is valid.
73 | """
74 | assert self.bos_index == 0
75 | assert self.eos_index == 1
76 | assert self.pad_index == 2
77 | assert self.unk_index == 3
78 | assert all(self.id2word[4 + i] == SPECIAL_WORD % i for i in range(SPECIAL_WORDS))
79 | assert len(self.id2word) == len(self.word2id) == len(self.counts)
80 | assert set(self.word2id.keys()) == set(self.counts.keys())
81 | for i in range(len(self.id2word)):
82 | assert self.word2id[self.id2word[i]] == i
83 | last_count = 1e18
84 | for i in range(4 + SPECIAL_WORDS, len(self.id2word) - 1):
85 | count = self.counts[self.id2word[i]]
86 | assert count <= last_count
87 | last_count = count
88 |
89 | def index(self, word, no_unk=False):
90 | """
91 | Returns the index of the specified word.
92 | """
93 | if no_unk:
94 | return self.word2id[word]
95 | else:
96 | return self.word2id.get(word, self.unk_index)
97 |
98 | def max_vocab(self, max_vocab):
99 | """
100 | Limit the vocabulary size.
101 | """
102 | assert max_vocab >= 1
103 | init_size = len(self)
104 | self.id2word = {k: v for k, v in self.id2word.items() if k < max_vocab}
105 | self.word2id = {v: k for k, v in self.id2word.items()}
106 | self.counts = {k: v for k, v in self.counts.items() if k in self.word2id}
107 | self.check_valid()
108 | logger.info("Maximum vocabulary size: %i. Dictionary size: %i -> %i (removed %i words)."
109 | % (max_vocab, init_size, len(self), init_size - len(self)))
110 |
111 | def min_count(self, min_count):
112 | """
113 | Threshold on the word frequency counts.
114 | """
115 | assert min_count >= 0
116 | init_size = len(self)
117 | self.id2word = {k: v for k, v in self.id2word.items() if self.counts[self.id2word[k]] >= min_count or k < 4 + SPECIAL_WORDS}
118 | self.word2id = {v: k for k, v in self.id2word.items()}
119 | self.counts = {k: v for k, v in self.counts.items() if k in self.word2id}
120 | self.check_valid()
121 | logger.info("Minimum frequency count: %i. Dictionary size: %i -> %i (removed %i words)."
122 | % (min_count, init_size, len(self), init_size - len(self)))
123 |
124 | @staticmethod
125 | def read_vocab(vocab_path):
126 | """
127 | Create a dictionary from a vocabulary file.
128 | """
129 | skipped = 0
130 | assert os.path.isfile(vocab_path), vocab_path
131 | word2id = {BOS_WORD: 0, EOS_WORD: 1, PAD_WORD: 2, UNK_WORD: 3}
132 | for i in range(SPECIAL_WORDS):
133 | word2id[SPECIAL_WORD % i] = 4 + i
134 | counts = {k: 0 for k in word2id.keys()}
135 | f = open(vocab_path, 'r', encoding='utf-8')
136 | for i, line in enumerate(f):
137 | if '\u2028' in line:
138 | skipped += 1
139 | continue
140 | line = line.rstrip().split()
141 | if len(line) != 2:
142 | skipped += 1
143 | continue
144 | assert len(line) == 2, (i, line)
145 | # assert line[0] not in word2id and line[1].isdigit(), (i, line)
146 | assert line[1].isdigit(), (i, line)
147 | if line[0] in word2id:
148 | skipped += 1
149 | print('%s already in vocab' % line[0])
150 | continue
151 | if not line[1].isdigit():
152 | skipped += 1
153 | print('Empty word at line %s with count %s' % (i, line))
154 | continue
155 | word2id[line[0]] = 4 + SPECIAL_WORDS + i - skipped # shift because of extra words
156 | counts[line[0]] = int(line[1])
157 | f.close()
158 | id2word = {v: k for k, v in word2id.items()}
159 | dico = Dictionary(id2word, word2id, counts)
160 | logger.info("Read %i words from the vocabulary file." % len(dico))
161 | if skipped > 0:
162 | logger.warning("Skipped %i empty lines!" % skipped)
163 | return dico
164 |
165 | @staticmethod
166 | def index_data(path, bin_path, dico):
167 | """
168 | Index sentences with a dictionary.
169 | """
170 | if bin_path is not None and os.path.isfile(bin_path):
171 | print("Loading data from %s ..." % bin_path)
172 | data = torch.load(bin_path)
173 | assert dico == data['dico']
174 | return data
175 |
176 | positions = []
177 | sentences = []
178 | unk_words = {}
179 |
180 | # index sentences
181 | f = open(path, 'r', encoding='utf-8')
182 | for i, line in enumerate(f):
183 | if i % 1000000 == 0 and i > 0:
184 | print(i)
185 | s = line.rstrip().split()
186 | # skip empty sentences
187 | if len(s) == 0:
188 | print("Empty sentence in line %i." % i)
189 | # index sentence words
190 | count_unk = 0
191 | indexed = []
192 | for w in s:
193 | word_id = dico.index(w, no_unk=False)
194 | # if we find a special word which is not an unknown word, skip the sentence
195 | if 0 <= word_id < 4 + SPECIAL_WORDS and word_id != 3:
196 | logger.warning('Found unexpected special word "%s" (%i)!!' % (w, word_id))
197 | continue
198 | assert word_id >= 0
199 | indexed.append(word_id)
200 | if word_id == dico.unk_index:
201 | unk_words[w] = unk_words.get(w, 0) + 1
202 | count_unk += 1
203 | # add sentence
204 | positions.append([len(sentences), len(sentences) + len(indexed)])
205 | sentences.extend(indexed)
206 | sentences.append(1) # EOS index
207 | f.close()
208 |
209 | # tensorize data
210 | positions = np.int64(positions)
211 | if len(dico) < 1 << 16:
212 | sentences = np.uint16(sentences)
213 | elif len(dico) < 1 << 31:
214 | sentences = np.int32(sentences)
215 | else:
216 | raise Exception("Dictionary is too big.")
217 | assert sentences.min() >= 0
218 | data = {
219 | 'dico': dico,
220 | 'positions': positions,
221 | 'sentences': sentences,
222 | 'unk_words': unk_words,
223 | }
224 | if bin_path is not None:
225 | print("Saving the data to %s ..." % bin_path)
226 | torch.save(data, bin_path, pickle_protocol=4)
227 |
228 | return data
229 |
--------------------------------------------------------------------------------
/src/data/dictionary.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jind11/DAMT/3caa22822b936137e5da3e827d7a5a2078c3115e/src/data/dictionary.pyc
--------------------------------------------------------------------------------
/src/data/loader.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2019-present, Facebook, Inc.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | #
7 |
8 | from logging import getLogger
9 | import os
10 | import numpy as np
11 | import torch
12 |
13 | from .dataset import StreamDataset, Dataset, ParallelDataset
14 | from .dictionary import BOS_WORD, EOS_WORD, PAD_WORD, UNK_WORD, MASK_WORD
15 |
16 |
17 | logger = getLogger()
18 |
19 |
20 | def process_binarized(data, params):
21 | """
22 | Process a binarized dataset and log main statistics.
23 | """
24 | dico = data['dico']
25 | assert ((data['sentences'].dtype == np.uint16) and (len(dico) < 1 << 16) or
26 | (data['sentences'].dtype == np.int32) and (1 << 16 <= len(dico) < 1 << 31))
27 | logger.info("%i words (%i unique) in %i sentences. %i unknown words (%i unique) covering %.2f%% of the data." % (
28 | len(data['sentences']) - len(data['positions']),
29 | len(dico), len(data['positions']),
30 | sum(data['unk_words'].values()), len(data['unk_words']),
31 | 100. * sum(data['unk_words'].values()) / (len(data['sentences']) - len(data['positions']))
32 | ))
33 | if params.max_vocab != -1:
34 | assert params.max_vocab > 0
35 | logger.info("Selecting %i most frequent words ..." % params.max_vocab)
36 | dico.max_vocab(params.max_vocab)
37 | data['sentences'][data['sentences'] >= params.max_vocab] = dico.index(UNK_WORD)
38 | unk_count = (data['sentences'] == dico.index(UNK_WORD)).sum()
39 | logger.info("Now %i unknown words covering %.2f%% of the data."
40 | % (unk_count, 100. * unk_count / (len(data['sentences']) - len(data['positions']))))
41 | if params.min_count > 0:
42 | logger.info("Selecting words with >= %i occurrences ..." % params.min_count)
43 | dico.min_count(params.min_count)
44 | data['sentences'][data['sentences'] >= len(dico)] = dico.index(UNK_WORD)
45 | unk_count = (data['sentences'] == dico.index(UNK_WORD)).sum()
46 | logger.info("Now %i unknown words covering %.2f%% of the data."
47 | % (unk_count, 100. * unk_count / (len(data['sentences']) - len(data['positions']))))
48 | if (data['sentences'].dtype == np.int32) and (len(dico) < 1 << 16):
49 | logger.info("Less than 65536 words. Moving data from int32 to uint16 ...")
50 | data['sentences'] = data['sentences'].astype(np.uint16)
51 | return data
52 |
53 |
54 | def load_binarized(path, params):
55 | """
56 | Load a binarized dataset.
57 | """
58 | assert path.endswith('.pth')
59 | if params.debug_train:
60 | path = path.replace('train', 'valid')
61 | if getattr(params, 'multi_gpu', False):
62 | split_path = '%s.%i.pth' % (path[:-4], params.local_rank)
63 | if os.path.isfile(split_path):
64 | assert params.split_data is False
65 | path = split_path
66 | assert os.path.isfile(path), path
67 | logger.info("Loading data from %s ..." % path)
68 | data = torch.load(path)
69 | data = process_binarized(data, params)
70 | return data
71 |
72 |
73 | def set_dico_parameters(params, data, dico):
74 | """
75 | Update dictionary parameters.
76 | """
77 | if 'dico' in data:
78 | assert data['dico'] == dico
79 | else:
80 | data['dico'] = dico
81 |
82 | n_words = len(dico)
83 | bos_index = dico.index(BOS_WORD)
84 | eos_index = dico.index(EOS_WORD)
85 | pad_index = dico.index(PAD_WORD)
86 | unk_index = dico.index(UNK_WORD)
87 | mask_index = dico.index(MASK_WORD)
88 | if hasattr(params, 'bos_index'):
89 | assert params.n_words == n_words
90 | assert params.bos_index == bos_index
91 | assert params.eos_index == eos_index
92 | assert params.pad_index == pad_index
93 | assert params.unk_index == unk_index
94 | assert params.mask_index == mask_index
95 | else:
96 | params.n_words = n_words
97 | params.bos_index = bos_index
98 | params.eos_index = eos_index
99 | params.pad_index = pad_index
100 | params.unk_index = unk_index
101 | params.mask_index = mask_index
102 |
103 |
104 | def load_mono_data(params, data):
105 | """
106 | Load monolingual data.
107 | """
108 | data['mono'] = {}
109 | data['mono_stream'] = {}
110 |
111 | for lang in params.mono_dataset.keys():
112 |
113 | logger.info('============ Monolingual data (%s)' % lang)
114 |
115 | assert lang in params.langs and lang not in data['mono']
116 | data['mono'][lang] = {}
117 | data['mono_stream'][lang] = {}
118 |
119 | for splt in ['train', 'valid', 'test']:
120 |
121 | # no need to load training data for evaluation
122 | if splt == 'train' and params.eval_only:
123 | continue
124 |
125 | # load data / update dictionary parameters / update data
126 | mono_data = load_binarized(params.mono_dataset[lang][splt], params)
127 | set_dico_parameters(params, data, mono_data['dico'])
128 |
129 | # create stream dataset
130 | bs = params.batch_size if splt == 'train' else 1
131 | data['mono_stream'][lang][splt] = StreamDataset(mono_data['sentences'], mono_data['positions'], bs, params)
132 |
133 | # if there are several processes on the same machine, we can split the dataset
134 | if splt == 'train' and params.split_data and 1 < params.n_gpu_per_node <= data['mono_stream'][lang][splt].n_batches:
135 | n_batches = data['mono_stream'][lang][splt].n_batches // params.n_gpu_per_node
136 | a = n_batches * params.local_rank
137 | b = n_batches * params.local_rank + n_batches
138 | data['mono_stream'][lang][splt].select_data(a, b)
139 |
140 | # for denoising auto-encoding and online back-translation, we need a non-stream (batched) dataset
141 | if lang in params.ae_steps or lang in params.bt_src_langs:
142 |
143 | # create batched dataset
144 | dataset = Dataset(mono_data['sentences'], mono_data['positions'], params)
145 |
146 | # remove empty and too long sentences
147 | if splt == 'train':
148 | dataset.remove_empty_sentences()
149 | dataset.remove_long_sentences(params.max_len)
150 |
151 | # if there are several processes on the same machine, we can split the dataset
152 | if splt == 'train' and params.n_gpu_per_node > 1 and params.split_data:
153 | n_sent = len(dataset) // params.n_gpu_per_node
154 | a = n_sent * params.local_rank
155 | b = n_sent * params.local_rank + n_sent
156 | dataset.select_data(a, b)
157 |
158 | data['mono'][lang][splt] = dataset
159 |
160 | logger.info("")
161 |
162 | logger.info("")
163 |
164 |
165 | def load_para_data(params, data):
166 | """
167 | Load parallel data.
168 | """
169 | data['para'] = {}
170 |
171 | required_para_train = set(params.clm_steps + params.mlm_steps + params.pc_steps + params.mt_steps)
172 |
173 | for src, tgt in params.para_dataset.keys():
174 |
175 | logger.info('============ Parallel data (%s-%s)' % (src, tgt))
176 |
177 | assert (src, tgt) not in data['para']
178 | data['para'][(src, tgt)] = {}
179 |
180 | for splt in ['train', 'valid', 'test']:
181 |
182 | # no need to load training data for evaluation
183 | if splt == 'train' and params.eval_only:
184 | continue
185 |
186 | # for back-translation, we can't load training data
187 | if splt == 'train' and (src, tgt) not in required_para_train and (tgt, src) not in required_para_train:
188 | continue
189 |
190 | # load binarized datasets
191 | src_path, tgt_path = params.para_dataset[(src, tgt)][splt]
192 | src_data = load_binarized(src_path, params)
193 | tgt_data = load_binarized(tgt_path, params)
194 |
195 | # update dictionary parameters
196 | set_dico_parameters(params, data, src_data['dico'])
197 | set_dico_parameters(params, data, tgt_data['dico'])
198 |
199 | # create ParallelDataset
200 | dataset = ParallelDataset(
201 | src_data['sentences'], src_data['positions'],
202 | tgt_data['sentences'], tgt_data['positions'],
203 | params
204 | )
205 |
206 | # remove empty and too long sentences
207 | if splt == 'train':
208 | dataset.remove_empty_sentences()
209 | dataset.remove_long_sentences(params.max_len)
210 |
211 | # for validation and test set, enumerate sentence per sentence
212 | if splt != 'train':
213 | dataset.tokens_per_batch = -1
214 |
215 | # if there are several processes on the same machine, we can split the dataset
216 | if splt == 'train' and params.n_gpu_per_node > 1 and params.split_data:
217 | n_sent = len(dataset) // params.n_gpu_per_node
218 | a = n_sent * params.local_rank
219 | b = n_sent * params.local_rank + n_sent
220 | dataset.select_data(a, b)
221 |
222 | data['para'][(src, tgt)][splt] = dataset
223 | logger.info("")
224 |
225 | logger.info("")
226 |
227 |
228 | def check_data_params(params):
229 | """
230 | Check datasets parameters.
231 | """
232 | # data path
233 | assert os.path.isdir(params.data_path), params.data_path
234 |
235 | # check languages
236 | params.langs = params.lgs.split('-') if params.lgs != 'debug' else ['en']
237 | assert len(params.langs) == len(set(params.langs)) >= 1
238 | # assert sorted(params.langs) == params.langs
239 | params.id2lang = {k: v for k, v in enumerate(sorted(params.langs))}
240 | params.lang2id = {k: v for v, k in params.id2lang.items()}
241 | params.n_langs = len(params.langs)
242 |
243 | # CLM steps
244 | clm_steps = [s.split('-') for s in params.clm_steps.split(',') if len(s) > 0]
245 | params.clm_steps = [(s[0], None) if len(s) == 1 else tuple(s) for s in clm_steps]
246 | assert all([(l1 in params.langs) and (l2 in params.langs or l2 is None) for l1, l2 in params.clm_steps])
247 | assert len(params.clm_steps) == len(set(params.clm_steps))
248 |
249 | # MLM / TLM steps
250 | mlm_steps = [s.split('-') for s in params.mlm_steps.split(',') if len(s) > 0]
251 | params.mlm_steps = [(s[0], None) if len(s) == 1 else tuple(s) for s in mlm_steps]
252 | assert all([(l1 in params.langs) and (l2 in params.langs or l2 is None) for l1, l2 in params.mlm_steps])
253 | assert len(params.mlm_steps) == len(set(params.mlm_steps))
254 |
255 | # parallel classification steps
256 | params.pc_steps = [tuple(s.split('-')) for s in params.pc_steps.split(',') if len(s) > 0]
257 | assert all([len(x) == 2 for x in params.pc_steps])
258 | assert all([l1 in params.langs and l2 in params.langs for l1, l2 in params.pc_steps])
259 | assert all([l1 != l2 for l1, l2 in params.pc_steps])
260 | assert len(params.pc_steps) == len(set(params.pc_steps))
261 |
262 | # machine translation steps
263 | params.mt_steps = [tuple(s.split('-')) for s in params.mt_steps.split(',') if len(s) > 0]
264 | assert all([len(x) == 2 for x in params.mt_steps])
265 | assert all([l1 in params.langs and l2 in params.langs for l1, l2 in params.mt_steps])
266 | assert all([l1 != l2 for l1, l2 in params.mt_steps])
267 | assert len(params.mt_steps) == len(set(params.mt_steps))
268 | assert len(params.mt_steps) == 0 or not params.encoder_only
269 |
270 | # denoising auto-encoder steps
271 | params.ae_steps = [s for s in params.ae_steps.split(',') if len(s) > 0]
272 | assert all([lang in params.langs for lang in params.ae_steps])
273 | assert len(params.ae_steps) == len(set(params.ae_steps))
274 | assert len(params.ae_steps) == 0 or not params.encoder_only
275 |
276 | # back-translation steps
277 | params.bt_steps = [tuple(s.split('-')) for s in params.bt_steps.split(',') if len(s) > 0]
278 | assert all([len(x) == 3 for x in params.bt_steps])
279 | assert all([l1 in params.langs and l2 in params.langs and l3 in params.langs for l1, l2, l3 in params.bt_steps])
280 | assert all([l1 == l3 and l1 != l2 for l1, l2, l3 in params.bt_steps])
281 | assert len(params.bt_steps) == len(set(params.bt_steps))
282 | assert len(params.bt_steps) == 0 or not params.encoder_only
283 | params.bt_src_langs = [l1 for l1, _, _ in params.bt_steps]
284 |
285 | # check monolingual datasets
286 | required_mono = set([l1 for l1, l2 in (params.mlm_steps + params.clm_steps) if l2 is None] + params.ae_steps + params.bt_src_langs)
287 | params.mono_dataset = {
288 | lang: {
289 | splt: os.path.join(params.data_path, '%s.%s.pth' % (splt, lang))
290 | for splt in ['train', 'valid', 'test']
291 | } for lang in params.langs if lang in required_mono
292 | }
293 | for paths in params.mono_dataset.values():
294 | for p in paths.values():
295 | if not os.path.isfile(p):
296 | logger.error(f"{p} not found")
297 | assert all([all([os.path.isfile(p) for p in paths.values()]) for paths in params.mono_dataset.values()])
298 |
299 | # check parallel datasets
300 | if not params.para_data_path:
301 | params.para_data_path = params.data_path
302 | required_para_train = set(params.clm_steps + params.mlm_steps + params.pc_steps + params.mt_steps)
303 | required_para = required_para_train | set([(l2, l3) for _, l2, l3 in params.bt_steps])
304 | params.para_dataset = {}
305 | for src in params.langs:
306 | for tgt in params.langs:
307 | if src < tgt and ((src, tgt) in required_para or (tgt, src) in required_para):
308 | params.para_dataset[(src, tgt)] = {}
309 | for splt in ['train', 'valid', 'test']:
310 | if splt != 'train':
311 | params.para_dataset[(src, tgt)][splt] = \
312 | (os.path.join(params.data_path, '%s.%s-%s.%s.pth' % (splt, src, tgt, src)),
313 | os.path.join(params.data_path, '%s.%s-%s.%s.pth' % (splt, src, tgt, tgt)))
314 | else:
315 | if (src, tgt) in required_para_train or (tgt, src) in required_para_train:
316 | params.para_dataset[(src, tgt)][splt] = \
317 | (os.path.join(params.para_data_path, '%s.%s-%s.%s.pth' % (splt, src, tgt, src)),
318 | os.path.join(params.para_data_path, '%s.%s-%s.%s.pth' % (splt, src, tgt, tgt)))
319 |
320 | for paths in params.para_dataset.values():
321 | for p1, p2 in paths.values():
322 | if not os.path.isfile(p1):
323 | logger.error(f"{p1} not found")
324 | if not os.path.isfile(p2):
325 | logger.error(f"{p2} not found")
326 | assert all([all([os.path.isfile(p1) and os.path.isfile(p2) for p1, p2 in paths.values()]) for paths in params.para_dataset.values()])
327 |
328 | # check that we can evaluate on BLEU
329 | assert params.eval_bleu is False or len(params.mt_steps + params.bt_steps) > 0
330 |
331 |
332 | def load_data(params):
333 | """
334 | Load monolingual data.
335 | The returned dictionary contains:
336 | - dico (dictionary)
337 | - vocab (FloatTensor)
338 | - train / valid / test (monolingual datasets)
339 | """
340 | data = {}
341 |
342 | # monolingual datasets
343 | load_mono_data(params, data)
344 |
345 | # parallel datasets
346 | load_para_data(params, data)
347 |
348 | # monolingual data summary
349 | logger.info('============ Data summary')
350 | for lang, v in data['mono_stream'].items():
351 | for data_set in v.keys():
352 | logger.info('{: <18} - {: >5} - {: >12}:{: >10}'.format('Monolingual data', data_set, lang, len(v[data_set])))
353 |
354 | # parallel data summary
355 | for (src, tgt), v in data['para'].items():
356 | for data_set in v.keys():
357 | logger.info('{: <18} - {: >5} - {: >12}:{: >10}'.format('Parallel data', data_set, '%s-%s' % (src, tgt), len(v[data_set])))
358 |
359 | logger.info("")
360 | return data
361 |
--------------------------------------------------------------------------------
/src/evaluation/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jind11/DAMT/3caa22822b936137e5da3e827d7a5a2078c3115e/src/evaluation/__init__.py
--------------------------------------------------------------------------------
/src/evaluation/evaluator.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2019-present, Facebook, Inc.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | #
7 |
8 | from logging import getLogger
9 | import os
10 | import re
11 | import subprocess
12 | from collections import OrderedDict
13 | import numpy as np
14 | import torch
15 |
16 | from ..utils import to_cuda, restore_segmentation, concat_batches
17 | from ..model.memory import HashingMemory
18 |
19 |
20 | BLEU_SCRIPT_PATH = os.path.join(os.path.abspath(os.path.dirname(__file__)), 'multi-bleu.perl')
21 | assert os.path.isfile(BLEU_SCRIPT_PATH)
22 |
23 |
24 | logger = getLogger()
25 |
26 |
27 | def kl_score(x):
28 | # assert np.abs(np.sum(x) - 1) < 1e-5
29 | _x = x.copy()
30 | _x[x == 0] = 1
31 | return np.log(len(x)) + (x * np.log(_x)).sum()
32 |
33 |
34 | def gini_score(x):
35 | # assert np.abs(np.sum(x) - 1) < 1e-5
36 | B = np.cumsum(np.sort(x)).mean()
37 | return 1 - 2 * B
38 |
39 |
40 | def tops(x):
41 | # assert np.abs(np.sum(x) - 1) < 1e-5
42 | y = np.cumsum(np.sort(x))
43 | top50, top90, top99 = y.shape[0] - np.searchsorted(y, [0.5, 0.1, 0.01])
44 | return top50, top90, top99
45 |
46 |
47 | def eval_memory_usage(scores, name, mem_att, mem_size):
48 | """
49 | Evaluate memory usage (HashingMemory / FFN).
50 | """
51 | # memory slot scores
52 | assert mem_size > 0
53 | mem_scores_w = np.zeros(mem_size, dtype=np.float32) # weighted scores
54 | mem_scores_u = np.zeros(mem_size, dtype=np.float32) # unweighted scores
55 |
56 | # sum each slot usage
57 | for indices, weights in mem_att:
58 | np.add.at(mem_scores_w, indices, weights)
59 | np.add.at(mem_scores_u, indices, 1)
60 |
61 | # compute the KL distance to the uniform distribution
62 | mem_scores_w = mem_scores_w / mem_scores_w.sum()
63 | mem_scores_u = mem_scores_u / mem_scores_u.sum()
64 |
65 | # store stats
66 | scores['%s_mem_used' % name] = float(100 * (mem_scores_w != 0).sum() / len(mem_scores_w))
67 |
68 | scores['%s_mem_kl_w' % name] = float(kl_score(mem_scores_w))
69 | scores['%s_mem_kl_u' % name] = float(kl_score(mem_scores_u))
70 |
71 | scores['%s_mem_gini_w' % name] = float(gini_score(mem_scores_w))
72 | scores['%s_mem_gini_u' % name] = float(gini_score(mem_scores_u))
73 |
74 | top50, top90, top99 = tops(mem_scores_w)
75 | scores['%s_mem_top50_w' % name] = float(top50)
76 | scores['%s_mem_top90_w' % name] = float(top90)
77 | scores['%s_mem_top99_w' % name] = float(top99)
78 |
79 | top50, top90, top99 = tops(mem_scores_u)
80 | scores['%s_mem_top50_u' % name] = float(top50)
81 | scores['%s_mem_top90_u' % name] = float(top90)
82 | scores['%s_mem_top99_u' % name] = float(top99)
83 |
84 |
85 | class Evaluator(object):
86 |
87 | def __init__(self, trainer, data, params):
88 | """
89 | Initialize evaluator.
90 | """
91 | self.trainer = trainer
92 | self.data = data
93 | self.dico = data['dico']
94 | self.params = params
95 | self.memory_list = trainer.memory_list
96 |
97 | # create directory to store hypotheses, and reference files for BLEU evaluation
98 | if self.params.is_master:
99 | params.hyp_path = os.path.join(params.dump_path, 'hypotheses')
100 | subprocess.Popen('mkdir -p %s' % params.hyp_path, shell=True).wait()
101 | self.create_reference_files()
102 |
103 | def get_iterator(self, data_set, lang1, lang2=None, stream=False):
104 | """
105 | Create a new iterator for a dataset.
106 | """
107 | assert data_set in ['valid', 'test']
108 | assert lang1 in self.params.langs
109 | assert lang2 is None or lang2 in self.params.langs
110 | assert stream is False or lang2 is None
111 |
112 | # hacks to reduce evaluation time when using many languages
113 | if len(self.params.langs) > 30:
114 | eval_lgs = set(["ar", "bg", "de", "el", "en", "es", "fr", "hi", "ru", "sw", "th", "tr", "ur", "vi", "zh", "ab", "ay", "bug", "ha", "ko", "ln", "min", "nds", "pap", "pt", "tg", "to", "udm", "uk", "zh_classical"])
115 | eval_lgs = set(["ar", "bg", "de", "el", "en", "es", "fr", "hi", "ru", "sw", "th", "tr", "ur", "vi", "zh"])
116 | subsample = 10 if (data_set == 'test' or lang1 not in eval_lgs) else 5
117 | n_sentences = 600 if (data_set == 'test' or lang1 not in eval_lgs) else 1500
118 | elif len(self.params.langs) > 5:
119 | subsample = 10 if data_set == 'test' else 5
120 | n_sentences = 300 if data_set == 'test' else 1500
121 | else:
122 | # n_sentences = -1 if data_set == 'valid' else 100
123 | n_sentences = -1
124 | subsample = 1
125 |
126 | if lang2 is None:
127 | if stream:
128 | iterator = self.data['mono_stream'][lang1][data_set].get_iterator(shuffle=False, subsample=subsample)
129 | else:
130 | iterator = self.data['mono'][lang1][data_set].get_iterator(
131 | shuffle=False,
132 | group_by_size=True,
133 | n_sentences=n_sentences,
134 | )
135 | else:
136 | assert stream is False
137 | _lang1, _lang2 = (lang1, lang2) if lang1 < lang2 else (lang2, lang1)
138 | iterator = self.data['para'][(_lang1, _lang2)][data_set].get_iterator(
139 | shuffle=False,
140 | group_by_size=True,
141 | n_sentences=n_sentences
142 | )
143 |
144 | for batch in iterator:
145 | yield batch if lang2 is None or lang1 < lang2 else batch[::-1]
146 |
147 | def create_reference_files(self):
148 | """
149 | Create reference files for BLEU evaluation.
150 | """
151 | params = self.params
152 | params.ref_paths = {}
153 |
154 | for (lang1, lang2), v in self.data['para'].items():
155 |
156 | assert lang1 < lang2
157 |
158 | for data_set in ['valid', 'test']:
159 |
160 | # define data paths
161 | lang1_path = os.path.join(params.hyp_path, 'ref.{0}-{1}.{2}.txt'.format(lang2, lang1, data_set))
162 | lang2_path = os.path.join(params.hyp_path, 'ref.{0}-{1}.{2}.txt'.format(lang1, lang2, data_set))
163 |
164 | # store data paths
165 | params.ref_paths[(lang2, lang1, data_set)] = lang1_path
166 | params.ref_paths[(lang1, lang2, data_set)] = lang2_path
167 |
168 | # text sentences
169 | lang1_txt = []
170 | lang2_txt = []
171 |
172 | # convert to text
173 | for (sent1, len1), (sent2, len2) in self.get_iterator(data_set, lang1, lang2):
174 | lang1_txt.extend(convert_to_text(sent1, len1, self.dico, params))
175 | lang2_txt.extend(convert_to_text(sent2, len2, self.dico, params))
176 |
177 | # replace by <> as these tokens cannot be counted in BLEU
178 | lang1_txt = [x.replace('', '<>') for x in lang1_txt]
179 | lang2_txt = [x.replace('', '<>') for x in lang2_txt]
180 |
181 | # export hypothesis
182 | with open(lang1_path, 'w', encoding='utf-8') as f:
183 | f.write('\n'.join(lang1_txt) + '\n')
184 | with open(lang2_path, 'w', encoding='utf-8') as f:
185 | f.write('\n'.join(lang2_txt) + '\n')
186 |
187 | # restore original segmentation
188 | restore_segmentation(lang1_path, bpe_type=params.bpe_type)
189 | restore_segmentation(lang2_path, bpe_type=params.bpe_type)
190 |
191 | def mask_out(self, x, lengths, rng):
192 | """
193 | Decide of random words to mask out.
194 | We specify the random generator to ensure that the test is the same at each epoch.
195 | """
196 | params = self.params
197 | slen, bs = x.size()
198 |
199 | # words to predict - be sure there is at least one word per sentence
200 | to_predict = rng.rand(slen, bs) <= params.word_pred
201 | to_predict[0] = 0
202 | for i in range(bs):
203 | to_predict[lengths[i] - 1:, i] = 0
204 | if not np.any(to_predict[:lengths[i] - 1, i]):
205 | v = rng.randint(1, lengths[i] - 1)
206 | to_predict[v, i] = 1
207 | pred_mask = torch.from_numpy(to_predict.astype(np.uint8))
208 |
209 | # generate possible targets / update x input
210 | _x_real = x[pred_mask]
211 | _x_mask = _x_real.clone().fill_(params.mask_index)
212 | x = x.masked_scatter(pred_mask, _x_mask)
213 |
214 | assert 0 <= x.min() <= x.max() < params.n_words
215 | assert x.size() == (slen, bs)
216 | assert pred_mask.size() == (slen, bs)
217 |
218 | return x, _x_real, pred_mask
219 |
220 | def run_all_evals(self, trainer):
221 | """
222 | Run all evaluations.
223 | """
224 | params = self.params
225 | scores = OrderedDict({'epoch': trainer.epoch})
226 |
227 | with torch.no_grad():
228 |
229 | for data_set in ['valid', 'test']:
230 |
231 | # causal prediction task (evaluate perplexity and accuracy)
232 | for lang1, lang2 in params.clm_steps:
233 | self.evaluate_clm(scores, data_set, lang1, lang2)
234 |
235 | # prediction task (evaluate perplexity and accuracy)
236 | for lang1, lang2 in params.mlm_steps:
237 | self.evaluate_mlm(scores, data_set, lang1, lang2)
238 |
239 | # machine translation task (evaluate perplexity and accuracy)
240 | for lang1, lang2 in set(params.mt_steps + [(l2, l3) for _, l2, l3 in params.bt_steps]):
241 | eval_bleu = params.eval_bleu and params.is_master
242 | self.evaluate_mt(scores, data_set, lang1, lang2, eval_bleu)
243 |
244 | # report average metrics per language
245 | _clm_mono = [l1 for (l1, l2) in params.clm_steps if l2 is None]
246 | if len(_clm_mono) > 0:
247 | scores['%s_clm_ppl' % data_set] = np.mean([scores['%s_%s_clm_ppl' % (data_set, lang)] for lang in _clm_mono])
248 | scores['%s_clm_acc' % data_set] = np.mean([scores['%s_%s_clm_acc' % (data_set, lang)] for lang in _clm_mono])
249 | _mlm_mono = [l1 for (l1, l2) in params.mlm_steps if l2 is None]
250 | if len(_mlm_mono) > 0:
251 | scores['%s_mlm_ppl' % data_set] = np.mean([scores['%s_%s_mlm_ppl' % (data_set, lang)] for lang in _mlm_mono])
252 | scores['%s_mlm_acc' % data_set] = np.mean([scores['%s_%s_mlm_acc' % (data_set, lang)] for lang in _mlm_mono])
253 |
254 | return scores
255 |
256 | def evaluate_clm(self, scores, data_set, lang1, lang2):
257 | """
258 | Evaluate perplexity and next word prediction accuracy.
259 | """
260 | params = self.params
261 | assert data_set in ['valid', 'test']
262 | assert lang1 in params.langs
263 | assert lang2 in params.langs or lang2 is None
264 |
265 | model = self.model if params.encoder_only else self.decoder
266 | model.eval()
267 | model = model.module if params.multi_gpu else model
268 |
269 | lang1_id = params.lang2id[lang1]
270 | lang2_id = params.lang2id[lang2] if lang2 is not None else None
271 | l1l2 = lang1 if lang2 is None else f"{lang1}-{lang2}"
272 |
273 | n_words = 0
274 | xe_loss = 0
275 | n_valid = 0
276 |
277 | # only save states / evaluate usage on the validation set
278 | eval_memory = params.use_memory and data_set == 'valid' and self.params.is_master
279 | HashingMemory.EVAL_MEMORY = eval_memory
280 | if eval_memory:
281 | all_mem_att = {k: [] for k, _ in self.memory_list}
282 |
283 | for batch in self.get_iterator(data_set, lang1, lang2, stream=(lang2 is None)):
284 |
285 | # batch
286 | if lang2 is None:
287 | x, lengths = batch
288 | positions = None
289 | langs = x.clone().fill_(lang1_id) if params.n_langs > 1 else None
290 | else:
291 | (sent1, len1), (sent2, len2) = batch
292 | x, lengths, positions, langs = concat_batches(sent1, len1, lang1_id, sent2, len2, lang2_id, params.pad_index, params.eos_index, reset_positions=True)
293 |
294 | # words to predict
295 | alen = torch.arange(lengths.max(), dtype=torch.long, device=lengths.device)
296 | pred_mask = alen[:, None] < lengths[None] - 1
297 | y = x[1:].masked_select(pred_mask[:-1])
298 | assert pred_mask.sum().item() == y.size(0)
299 |
300 | # cuda
301 | x, lengths, positions, langs, pred_mask, y = to_cuda(x, lengths, positions, langs, pred_mask, y)
302 |
303 | # forward / loss
304 | tensor = model('fwd', x=x, lengths=lengths, positions=positions, langs=langs, causal=True)
305 | word_scores, loss = model('predict', tensor=tensor, pred_mask=pred_mask, y=y, get_scores=True)
306 |
307 | # update stats
308 | n_words += y.size(0)
309 | xe_loss += loss.item() * len(y)
310 | n_valid += (word_scores.max(1)[1] == y).sum().item()
311 | if eval_memory:
312 | for k, v in self.memory_list:
313 | all_mem_att[k].append((v.last_indices, v.last_scores))
314 |
315 | # log
316 | logger.info("Found %i words in %s. %i were predicted correctly." % (n_words, data_set, n_valid))
317 |
318 | # compute perplexity and prediction accuracy
319 | ppl_name = '%s_%s_clm_ppl' % (data_set, l1l2)
320 | acc_name = '%s_%s_clm_acc' % (data_set, l1l2)
321 | scores[ppl_name] = np.exp(xe_loss / n_words)
322 | scores[acc_name] = 100. * n_valid / n_words
323 |
324 | # compute memory usage
325 | if eval_memory:
326 | for mem_name, mem_att in all_mem_att.items():
327 | eval_memory_usage(scores, '%s_%s_%s' % (data_set, l1l2, mem_name), mem_att, params.mem_size)
328 |
329 | def evaluate_mlm(self, scores, data_set, lang1, lang2):
330 | """
331 | Evaluate perplexity and next word prediction accuracy.
332 | """
333 | params = self.params
334 | assert data_set in ['valid', 'test']
335 | assert lang1 in params.langs
336 | assert lang2 in params.langs or lang2 is None
337 |
338 | model = self.model if params.encoder_only else self.encoder
339 | model.eval()
340 | model = model.module if params.multi_gpu else model
341 |
342 | rng = np.random.RandomState(0)
343 |
344 | lang1_id = params.lang2id[lang1]
345 | lang2_id = params.lang2id[lang2] if lang2 is not None else None
346 | l1l2 = lang1 if lang2 is None else f"{lang1}_{lang2}"
347 |
348 | n_words = 0
349 | xe_loss = 0
350 | n_valid = 0
351 |
352 | # only save states / evaluate usage on the validation set
353 | eval_memory = params.use_memory and data_set == 'valid' and self.params.is_master
354 | HashingMemory.EVAL_MEMORY = eval_memory
355 | if eval_memory:
356 | all_mem_att = {k: [] for k, _ in self.memory_list}
357 |
358 | for batch in self.get_iterator(data_set, lang1, lang2, stream=(lang2 is None)):
359 |
360 | # batch
361 | if lang2 is None:
362 | x, lengths = batch
363 | positions = None
364 | langs = x.clone().fill_(lang1_id) if params.n_langs > 1 else None
365 | else:
366 | (sent1, len1), (sent2, len2) = batch
367 | x, lengths, positions, langs = concat_batches(sent1, len1, lang1_id, sent2, len2, lang2_id, params.pad_index, params.eos_index, reset_positions=True)
368 |
369 | # words to predict
370 | x, y, pred_mask = self.mask_out(x, lengths, rng)
371 |
372 | # cuda
373 | x, y, pred_mask, lengths, positions, langs = to_cuda(x, y, pred_mask, lengths, positions, langs)
374 |
375 | # forward / loss
376 | tensor = model('fwd', x=x, lengths=lengths, positions=positions, langs=langs, causal=False)
377 | word_scores, loss = model('predict', tensor=tensor, pred_mask=pred_mask, y=y, get_scores=True)
378 |
379 | # update stats
380 | n_words += len(y)
381 | xe_loss += loss.item() * len(y)
382 | n_valid += (word_scores.max(1)[1] == y).sum().item()
383 | if eval_memory:
384 | for k, v in self.memory_list:
385 | all_mem_att[k].append((v.last_indices, v.last_scores))
386 |
387 | # compute perplexity and prediction accuracy
388 | ppl_name = '%s_%s_mlm_ppl' % (data_set, l1l2)
389 | acc_name = '%s_%s_mlm_acc' % (data_set, l1l2)
390 | scores[ppl_name] = np.exp(xe_loss / n_words) if n_words > 0 else 1e9
391 | scores[acc_name] = 100. * n_valid / n_words if n_words > 0 else 0.
392 |
393 | # compute memory usage
394 | if eval_memory:
395 | for mem_name, mem_att in all_mem_att.items():
396 | eval_memory_usage(scores, '%s_%s_%s' % (data_set, l1l2, mem_name), mem_att, params.mem_size)
397 |
398 |
399 | class SingleEvaluator(Evaluator):
400 |
401 | def __init__(self, trainer, data, params):
402 | """
403 | Build language model evaluator.
404 | """
405 | super().__init__(trainer, data, params)
406 | self.model = trainer.model
407 |
408 |
409 | class EncDecEvaluator(Evaluator):
410 |
411 | def __init__(self, trainer, data, params):
412 | """
413 | Build encoder / decoder evaluator.
414 | """
415 | super().__init__(trainer, data, params)
416 | self.encoder = trainer.encoder
417 | self.decoder = trainer.decoder
418 |
419 | def evaluate_mt(self, scores, data_set, lang1, lang2, eval_bleu):
420 | """
421 | Evaluate perplexity and next word prediction accuracy.
422 | """
423 | params = self.params
424 | assert data_set in ['valid', 'test']
425 | assert lang1 in params.langs
426 | assert lang2 in params.langs
427 |
428 | self.encoder.eval()
429 | self.decoder.eval()
430 | encoder = self.encoder.module if params.multi_gpu else self.encoder
431 | decoder = self.decoder.module if params.multi_gpu else self.decoder
432 |
433 | params = params
434 | lang1_id = params.lang2id[lang1]
435 | lang2_id = params.lang2id[lang2]
436 |
437 | n_words = 0
438 | xe_loss = 0
439 | n_valid = 0
440 |
441 | # only save states / evaluate usage on the validation set
442 | eval_memory = params.use_memory and data_set == 'valid' and self.params.is_master
443 | HashingMemory.EVAL_MEMORY = eval_memory
444 | if eval_memory:
445 | all_mem_att = {k: [] for k, _ in self.memory_list}
446 |
447 | # store hypothesis to compute BLEU score
448 | if eval_bleu:
449 | hypothesis = []
450 |
451 | for batch in self.get_iterator(data_set, lang1, lang2):
452 |
453 | # generate batch
454 | (x1, len1), (x2, len2) = batch
455 | langs1 = x1.clone().fill_(lang1_id)
456 | langs2 = x2.clone().fill_(lang2_id)
457 |
458 | # target words to predict
459 | alen = torch.arange(len2.max(), dtype=torch.long, device=len2.device)
460 | pred_mask = alen[:, None] < len2[None] - 1 # do not predict anything given the last target word
461 | y = x2[1:].masked_select(pred_mask[:-1])
462 | assert len(y) == (len2 - 1).sum().item()
463 |
464 | # cuda
465 | x1, len1, langs1, x2, len2, langs2, y = to_cuda(x1, len1, langs1, x2, len2, langs2, y)
466 |
467 | # encode source sentence
468 | enc1 = encoder('fwd', x=x1, lengths=len1, langs=langs1, causal=False)
469 | enc1 = enc1.transpose(0, 1)
470 | enc1 = enc1.half() if params.fp16 else enc1
471 |
472 | # decode target sentence
473 | dec2 = decoder('fwd', x=x2, lengths=len2, langs=langs2, causal=True, src_enc=enc1, src_len=len1)
474 |
475 | # loss
476 | word_scores, loss = decoder('predict', tensor=dec2, pred_mask=pred_mask, y=y, get_scores=True)
477 |
478 | # update stats
479 | n_words += y.size(0)
480 | xe_loss += loss.item() * len(y)
481 | n_valid += (word_scores.max(1)[1] == y).sum().item()
482 | if eval_memory:
483 | for k, v in self.memory_list:
484 | all_mem_att[k].append((v.last_indices, v.last_scores))
485 |
486 | # generate translation - translate / convert to text
487 | if eval_bleu:
488 | max_len = int(1.5 * len1.max().item() + 10)
489 | if params.beam_size == 1:
490 | generated, lengths = decoder.generate(enc1, len1, lang2_id, max_len=max_len)
491 | else:
492 | generated, lengths = decoder.generate_beam(
493 | enc1, len1, lang2_id, beam_size=params.beam_size,
494 | length_penalty=params.length_penalty,
495 | early_stopping=params.early_stopping,
496 | max_len=max_len
497 | )
498 | hypothesis.extend(convert_to_text(generated, lengths, self.dico, params))
499 |
500 | # compute perplexity and prediction accuracy
501 | scores['%s_%s-%s_mt_ppl' % (data_set, lang1, lang2)] = np.exp(xe_loss / n_words)
502 | scores['%s_%s-%s_mt_acc' % (data_set, lang1, lang2)] = 100. * n_valid / n_words
503 |
504 | # compute memory usage
505 | if eval_memory:
506 | for mem_name, mem_att in all_mem_att.items():
507 | eval_memory_usage(scores, '%s_%s-%s_%s' % (data_set, lang1, lang2, mem_name), mem_att, params.mem_size)
508 |
509 | # compute BLEU
510 | if eval_bleu:
511 |
512 | # hypothesis / reference paths
513 | hyp_name = 'hyp{0}.{1}-{2}.{3}.txt'.format(scores['epoch'], lang1, lang2, data_set)
514 | hyp_path = os.path.join(params.hyp_path, hyp_name)
515 | ref_path = params.ref_paths[(lang1, lang2, data_set)]
516 |
517 | # export sentences to hypothesis file / restore BPE segmentation
518 | with open(hyp_path, 'w', encoding='utf-8') as f:
519 | f.write('\n'.join(hypothesis) + '\n')
520 | restore_segmentation(hyp_path, bpe_type=params.bpe_type)
521 |
522 | # evaluate BLEU score
523 | bleu = eval_moses_bleu(ref_path, hyp_path)
524 | sacrebleu = eval_sacrebleu(ref_path, hyp_path)
525 | logger.info("BLEU %s %s : %f" % (hyp_path, ref_path, bleu))
526 | logger.info("SacreBLEU %s %s : %f" % (hyp_path, ref_path, sacrebleu))
527 | scores['%s_%s-%s_mt_bleu' % (data_set, lang1, lang2)] = bleu
528 | scores['%s_%s-%s_mt_sacrebleu' % (data_set, lang1, lang2)] = sacrebleu
529 |
530 |
531 | def convert_to_text(batch, lengths, dico, params):
532 | """
533 | Convert a batch of sentences to a list of text sentences.
534 | """
535 | batch = batch.cpu().numpy()
536 | lengths = lengths.cpu().numpy()
537 |
538 | slen, bs = batch.shape
539 | assert lengths.max() == slen and lengths.shape[0] == bs
540 | assert (batch[0] == params.eos_index).sum() == bs
541 | assert (batch == params.eos_index).sum() == 2 * bs
542 | sentences = []
543 |
544 | for j in range(bs):
545 | words = []
546 | for k in range(1, lengths[j]):
547 | if batch[k, j] == params.eos_index:
548 | break
549 | words.append(dico[batch[k, j]])
550 | sentences.append(" ".join(words))
551 | return sentences
552 |
553 |
554 | def eval_moses_bleu(ref, hyp):
555 | """
556 | Given a file of hypothesis and reference files,
557 | evaluate the BLEU score using Moses scripts.
558 | """
559 | assert os.path.isfile(hyp)
560 | assert os.path.isfile(ref) or os.path.isfile(ref + '0')
561 | assert os.path.isfile(BLEU_SCRIPT_PATH)
562 | command = BLEU_SCRIPT_PATH + ' %s < %s'
563 | p = subprocess.Popen(command % (ref, hyp), stdout=subprocess.PIPE, shell=True)
564 | result = p.communicate()[0].decode("utf-8")
565 | if result.startswith('BLEU'):
566 | return float(result[7:result.index(',')])
567 | else:
568 | logger.warning('Impossible to parse BLEU score! "%s"' % result)
569 | return -1
570 |
571 |
572 | def eval_sacrebleu(ref, hyp):
573 | ref_lines = open(ref).readlines()
574 | hyp_lines = open(hyp).readlines()
575 | scorer = SacrebleuScorer()
576 | for ref_line, hyp_line in zip(ref_lines, hyp_lines):
577 | scorer.add_string(ref_line, hyp_line)
578 | return float(re.findall("\d+\.\d+", str(scorer.result_string()))[0])
579 |
580 | class SacrebleuScorer(object):
581 | def __init__(self):
582 | import sacrebleu
583 | self.sacrebleu = sacrebleu
584 | self.reset()
585 |
586 | def reset(self, one_init=False):
587 | if one_init:
588 | raise NotImplementedError
589 | self.ref = []
590 | self.sys = []
591 |
592 | def add_string(self, ref, pred):
593 | self.ref.append(ref)
594 | self.sys.append(pred)
595 |
596 | def score(self, order=4):
597 | return self.result_string(order).score
598 |
599 | def result_string(self, order=4):
600 | if order != 4:
601 | raise NotImplementedError
602 | return self.sacrebleu.corpus_bleu(self.sys, [self.ref])
603 |
--------------------------------------------------------------------------------
/src/evaluation/glue.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2019-present, Facebook, Inc.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | #
7 |
8 | from logging import getLogger
9 | import os
10 | import copy
11 | import time
12 | import json
13 | from collections import OrderedDict
14 |
15 | import numpy as np
16 | import torch
17 | from torch import nn
18 | import torch.nn.functional as F
19 |
20 | from scipy.stats import spearmanr, pearsonr
21 | from sklearn.metrics import f1_score, matthews_corrcoef
22 |
23 | from ..optim import get_optimizer
24 | from ..utils import concat_batches, truncate, to_cuda
25 | from ..data.dataset import Dataset, ParallelDataset
26 | from ..data.loader import load_binarized, set_dico_parameters
27 |
28 |
29 | N_CLASSES = {
30 | 'MNLI-m': 3,
31 | 'MNLI-mm': 3,
32 | 'QQP': 2,
33 | 'QNLI': 2,
34 | 'SST-2': 2,
35 | 'CoLA': 2,
36 | 'MRPC': 2,
37 | 'RTE': 2,
38 | 'STS-B': 1,
39 | 'WNLI': 2,
40 | 'AX_MNLI-m': 3,
41 | }
42 |
43 |
44 | logger = getLogger()
45 |
46 |
47 | class GLUE:
48 |
49 | def __init__(self, embedder, scores, params):
50 | """
51 | Initialize GLUE trainer / evaluator.
52 | Initial `embedder` should be on CPU to save memory.
53 | """
54 | self._embedder = embedder
55 | self.params = params
56 | self.scores = scores
57 |
58 | def get_iterator(self, splt):
59 | """
60 | Build data iterator.
61 | """
62 | return self.data[splt]['x'].get_iterator(
63 | shuffle=(splt == 'train'),
64 | return_indices=True,
65 | group_by_size=self.params.group_by_size
66 | )
67 |
68 | def run(self, task):
69 | """
70 | Run GLUE training / evaluation.
71 | """
72 | params = self.params
73 |
74 | # task parameters
75 | self.task = task
76 | params.out_features = N_CLASSES[task]
77 | self.is_classif = task != 'STS-B'
78 |
79 | # load data
80 | self.data = self.load_data(task)
81 | if not self.data['dico'] == self._embedder.dico:
82 | raise Exception(("Dictionary in evaluation data (%i words) seems different than the one " +
83 | "in the pretrained model (%i words). Please verify you used the same dictionary, " +
84 | "and the same values for max_vocab and min_count.") % (len(self.data['dico']), len(self._embedder.dico)))
85 |
86 | # embedder
87 | self.embedder = copy.deepcopy(self._embedder)
88 | self.embedder.cuda()
89 |
90 | # projection layer
91 | self.proj = nn.Sequential(*[
92 | nn.Dropout(params.dropout),
93 | nn.Linear(self.embedder.out_dim, params.out_features)
94 | ]).cuda()
95 |
96 | # optimizers
97 | self.optimizer_e = get_optimizer(list(self.embedder.get_parameters(params.finetune_layers)), params.optimizer_e)
98 | self.optimizer_p = get_optimizer(self.proj.parameters(), params.optimizer_p)
99 |
100 | # train and evaluate the model
101 | for epoch in range(params.n_epochs):
102 |
103 | # update epoch
104 | self.epoch = epoch
105 |
106 | # training
107 | logger.info("GLUE - %s - Training epoch %i ..." % (task, epoch))
108 | self.train()
109 |
110 | # evaluation
111 | logger.info("GLUE - %s - Evaluating epoch %i ..." % (task, epoch))
112 | with torch.no_grad():
113 | scores = self.eval('valid')
114 | self.scores.update(scores)
115 | self.eval('test')
116 |
117 | def train(self):
118 | """
119 | Finetune for one epoch on the training set.
120 | """
121 | params = self.params
122 | self.embedder.train()
123 | self.proj.train()
124 |
125 | # training variables
126 | losses = []
127 | ns = 0 # number of sentences
128 | nw = 0 # number of words
129 | t = time.time()
130 |
131 | iterator = self.get_iterator('train')
132 | lang_id = params.lang2id['en']
133 |
134 | while True:
135 |
136 | # batch
137 | try:
138 | batch = next(iterator)
139 | except StopIteration:
140 | break
141 | if self.n_sent == 1:
142 | (x, lengths), idx = batch
143 | x, lengths = truncate(x, lengths, params.max_len, params.eos_index)
144 | else:
145 | (sent1, len1), (sent2, len2), idx = batch
146 | sent1, len1 = truncate(sent1, len1, params.max_len, params.eos_index)
147 | sent2, len2 = truncate(sent2, len2, params.max_len, params.eos_index)
148 | x, lengths, _, _ = concat_batches(sent1, len1, lang_id, sent2, len2, lang_id, params.pad_index, params.eos_index, reset_positions=False)
149 | y = self.data['train']['y'][idx]
150 | bs = len(lengths)
151 |
152 | # cuda
153 | x, y, lengths = to_cuda(x, y, lengths)
154 |
155 | # loss
156 | output = self.proj(self.embedder.get_embeddings(x, lengths, positions=None, langs=None))
157 | if self.is_classif:
158 | loss = F.cross_entropy(output, y, weight=self.weights)
159 | else:
160 | loss = F.mse_loss(output.squeeze(1), y.float())
161 |
162 | # backward / optimization
163 | self.optimizer_e.zero_grad()
164 | self.optimizer_p.zero_grad()
165 | loss.backward()
166 | self.optimizer_e.step()
167 | self.optimizer_p.step()
168 |
169 | # update statistics
170 | ns += bs
171 | nw += lengths.sum().item()
172 | losses.append(loss.item())
173 |
174 | # log
175 | if ns != 0 and ns % (10 * bs) < bs:
176 | logger.info(
177 | "GLUE - %s - Epoch %s - Train iter %7i - %.1f words/s - %s Loss: %.4f"
178 | % (self.task, self.epoch, ns, nw / (time.time() - t), 'XE' if self.is_classif else 'MSE', sum(losses) / len(losses))
179 | )
180 | nw, t = 0, time.time()
181 | losses = []
182 |
183 | # epoch size
184 | if params.epoch_size != -1 and ns >= params.epoch_size:
185 | break
186 |
187 | def eval(self, splt):
188 | """
189 | Evaluate on XNLI validation and test sets, for all languages.
190 | """
191 | params = self.params
192 | self.embedder.eval()
193 | self.proj.eval()
194 |
195 | assert splt in ['valid', 'test']
196 | has_labels = 'y' in self.data[splt]
197 |
198 | scores = OrderedDict({'epoch': self.epoch})
199 | task = self.task.lower()
200 |
201 | idxs = [] # sentence indices
202 | prob = [] # probabilities
203 | pred = [] # predicted values
204 | gold = [] # real values
205 |
206 | lang_id = params.lang2id['en']
207 |
208 | for batch in self.get_iterator(splt):
209 |
210 | # batch
211 | if self.n_sent == 1:
212 | (x, lengths), idx = batch
213 | # x, lengths = truncate(x, lengths, params.max_len, params.eos_index)
214 | else:
215 | (sent1, len1), (sent2, len2), idx = batch
216 | # sent1, len1 = truncate(sent1, len1, params.max_len, params.eos_index)
217 | # sent2, len2 = truncate(sent2, len2, params.max_len, params.eos_index)
218 | x, lengths, _, _ = concat_batches(sent1, len1, lang_id, sent2, len2, lang_id, params.pad_index, params.eos_index, reset_positions=False)
219 | y = self.data[splt]['y'][idx] if has_labels else None
220 |
221 | # cuda
222 | x, y, lengths = to_cuda(x, y, lengths)
223 |
224 | # prediction
225 | output = self.proj(self.embedder.get_embeddings(x, lengths, positions=None, langs=None))
226 | p = output.data.max(1)[1] if self.is_classif else output.squeeze(1)
227 | idxs.append(idx)
228 | prob.append(output.cpu().numpy())
229 | pred.append(p.cpu().numpy())
230 | if has_labels:
231 | gold.append(y.cpu().numpy())
232 |
233 | # indices / predictions
234 | idxs = np.concatenate(idxs)
235 | prob = np.concatenate(prob)
236 | pred = np.concatenate(pred)
237 | assert len(idxs) == len(pred), (len(idxs), len(pred))
238 | assert idxs[-1] == len(idxs) - 1, (idxs[-1], len(idxs) - 1)
239 |
240 | # score the predictions if we have labels
241 | if has_labels:
242 | gold = np.concatenate(gold)
243 | prefix = f'{splt}_{task}'
244 | if self.is_classif:
245 | scores['%s_acc' % prefix] = 100. * (pred == gold).sum() / len(pred)
246 | scores['%s_f1' % prefix] = 100. * f1_score(gold, pred, average='binary' if params.out_features == 2 else 'micro')
247 | scores['%s_mc' % prefix] = 100. * matthews_corrcoef(gold, pred)
248 | else:
249 | scores['%s_prs' % prefix] = 100. * pearsonr(pred, gold)[0]
250 | scores['%s_spr' % prefix] = 100. * spearmanr(pred, gold)[0]
251 | logger.info("__log__:%s" % json.dumps(scores))
252 |
253 | # output predictions
254 | pred_path = os.path.join(params.dump_path, f'{splt}.pred.{self.epoch}')
255 | with open(pred_path, 'w') as f:
256 | for i, p in zip(idxs, prob):
257 | f.write('%i\t%s\n' % (i, ','.join([str(x) for x in p])))
258 | logger.info(f"Wrote {len(idxs)} {splt} predictions to {pred_path}")
259 |
260 | return scores
261 |
262 | def load_data(self, task):
263 | """
264 | Load pair regression/classification bi-sentence tasks
265 | """
266 | params = self.params
267 | data = {splt: {} for splt in ['train', 'valid', 'test']}
268 | dpath = os.path.join(params.data_path, 'eval', task)
269 |
270 | self.n_sent = 1 if task in ['SST-2', 'CoLA'] else 2
271 |
272 | for splt in ['train', 'valid', 'test']:
273 |
274 | # load data and dictionary
275 | data1 = load_binarized(os.path.join(dpath, '%s.s1.pth' % splt), params)
276 | data2 = load_binarized(os.path.join(dpath, '%s.s2.pth' % splt), params) if self.n_sent == 2 else None
277 | data['dico'] = data.get('dico', data1['dico'])
278 |
279 | # set dictionary parameters
280 | set_dico_parameters(params, data, data1['dico'])
281 | if self.n_sent == 2:
282 | set_dico_parameters(params, data, data2['dico'])
283 |
284 | # create dataset
285 | if self.n_sent == 1:
286 | data[splt]['x'] = Dataset(data1['sentences'], data1['positions'], params)
287 | else:
288 | data[splt]['x'] = ParallelDataset(
289 | data1['sentences'], data1['positions'],
290 | data2['sentences'], data2['positions'],
291 | params
292 | )
293 |
294 | # load labels
295 | if splt != 'test' or task in ['MRPC']:
296 | # read labels from file
297 | with open(os.path.join(dpath, '%s.label' % splt), 'r') as f:
298 | lines = [l.rstrip() for l in f]
299 | # STS-B task
300 | if task == 'STS-B':
301 | assert all(0 <= float(x) <= 5 for x in lines)
302 | y = [float(l) for l in lines]
303 | # QQP
304 | elif task == 'QQP':
305 | UNK_LABEL = 0
306 | lab2id = {x: i for i, x in enumerate(sorted(set(lines) - set([''])))}
307 | y = [lab2id.get(x, UNK_LABEL) for x in lines]
308 | # other tasks
309 | else:
310 | lab2id = {x: i for i, x in enumerate(sorted(set(lines)))}
311 | y = [lab2id[x] for x in lines]
312 | data[splt]['y'] = torch.LongTensor(y)
313 | assert len(data[splt]['x']) == len(data[splt]['y'])
314 |
315 | # compute weights for weighted training
316 | if task != 'STS-B' and params.weighted_training:
317 | weights = torch.FloatTensor([
318 | 1.0 / (data['train']['y'] == i).sum().item()
319 | for i in range(len(lab2id))
320 | ]).cuda()
321 | self.weights = weights / weights.sum()
322 | else:
323 | self.weights = None
324 |
325 | return data
326 |
--------------------------------------------------------------------------------
/src/evaluation/multi-bleu.perl:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env perl
2 | #
3 | # This file is part of moses. Its use is licensed under the GNU Lesser General
4 | # Public License version 2.1 or, at your option, any later version.
5 |
6 | # $Id$
7 | use warnings;
8 | use strict;
9 |
10 | my $lowercase = 0;
11 | if ($ARGV[0] eq "-lc") {
12 | $lowercase = 1;
13 | shift;
14 | }
15 |
16 | my $stem = $ARGV[0];
17 | if (!defined $stem) {
18 | print STDERR "usage: multi-bleu.pl [-lc] reference < hypothesis\n";
19 | print STDERR "Reads the references from reference or reference0, reference1, ...\n";
20 | exit(1);
21 | }
22 |
23 | $stem .= ".ref" if !-e $stem && !-e $stem."0" && -e $stem.".ref0";
24 |
25 | my @REF;
26 | my $ref=0;
27 | while(-e "$stem$ref") {
28 | &add_to_ref("$stem$ref",\@REF);
29 | $ref++;
30 | }
31 | &add_to_ref($stem,\@REF) if -e $stem;
32 | die("ERROR: could not find reference file $stem") unless scalar @REF;
33 |
34 | # add additional references explicitly specified on the command line
35 | shift;
36 | foreach my $stem (@ARGV) {
37 | &add_to_ref($stem,\@REF) if -e $stem;
38 | }
39 |
40 |
41 |
42 | sub add_to_ref {
43 | my ($file,$REF) = @_;
44 | my $s=0;
45 | if ($file =~ /.gz$/) {
46 | open(REF,"gzip -dc $file|") or die "Can't read $file";
47 | } else {
48 | open(REF,$file) or die "Can't read $file";
49 | }
50 | while([) {
51 | chop;
52 | push @{$$REF[$s++]}, $_;
53 | }
54 | close(REF);
55 | }
56 |
57 | my(@CORRECT,@TOTAL,$length_translation,$length_reference);
58 | my $s=0;
59 | while() {
60 | chop;
61 | $_ = lc if $lowercase;
62 | my @WORD = split;
63 | my %REF_NGRAM = ();
64 | my $length_translation_this_sentence = scalar(@WORD);
65 | my ($closest_diff,$closest_length) = (9999,9999);
66 | foreach my $reference (@{$REF[$s]}) {
67 | # print "$s $_ <=> $reference\n";
68 | $reference = lc($reference) if $lowercase;
69 | my @WORD = split(' ',$reference);
70 | my $length = scalar(@WORD);
71 | my $diff = abs($length_translation_this_sentence-$length);
72 | if ($diff < $closest_diff) {
73 | $closest_diff = $diff;
74 | $closest_length = $length;
75 | # print STDERR "$s: closest diff ".abs($length_translation_this_sentence-$length)." = abs($length_translation_this_sentence-$length), setting len: $closest_length\n";
76 | } elsif ($diff == $closest_diff) {
77 | $closest_length = $length if $length < $closest_length;
78 | # from two references with the same closeness to me
79 | # take the *shorter* into account, not the "first" one.
80 | }
81 | for(my $n=1;$n<=4;$n++) {
82 | my %REF_NGRAM_N = ();
83 | for(my $start=0;$start<=$#WORD-($n-1);$start++) {
84 | my $ngram = "$n";
85 | for(my $w=0;$w<$n;$w++) {
86 | $ngram .= " ".$WORD[$start+$w];
87 | }
88 | $REF_NGRAM_N{$ngram}++;
89 | }
90 | foreach my $ngram (keys %REF_NGRAM_N) {
91 | if (!defined($REF_NGRAM{$ngram}) ||
92 | $REF_NGRAM{$ngram} < $REF_NGRAM_N{$ngram}) {
93 | $REF_NGRAM{$ngram} = $REF_NGRAM_N{$ngram};
94 | # print "$i: REF_NGRAM{$ngram} = $REF_NGRAM{$ngram}]
\n";
95 | }
96 | }
97 | }
98 | }
99 | $length_translation += $length_translation_this_sentence;
100 | $length_reference += $closest_length;
101 | for(my $n=1;$n<=4;$n++) {
102 | my %T_NGRAM = ();
103 | for(my $start=0;$start<=$#WORD-($n-1);$start++) {
104 | my $ngram = "$n";
105 | for(my $w=0;$w<$n;$w++) {
106 | $ngram .= " ".$WORD[$start+$w];
107 | }
108 | $T_NGRAM{$ngram}++;
109 | }
110 | foreach my $ngram (keys %T_NGRAM) {
111 | $ngram =~ /^(\d+) /;
112 | my $n = $1;
113 | # my $corr = 0;
114 | # print "$i e $ngram $T_NGRAM{$ngram}
\n";
115 | $TOTAL[$n] += $T_NGRAM{$ngram};
116 | if (defined($REF_NGRAM{$ngram})) {
117 | if ($REF_NGRAM{$ngram} >= $T_NGRAM{$ngram}) {
118 | $CORRECT[$n] += $T_NGRAM{$ngram};
119 | # $corr = $T_NGRAM{$ngram};
120 | # print "$i e correct1 $T_NGRAM{$ngram}
\n";
121 | }
122 | else {
123 | $CORRECT[$n] += $REF_NGRAM{$ngram};
124 | # $corr = $REF_NGRAM{$ngram};
125 | # print "$i e correct2 $REF_NGRAM{$ngram}
\n";
126 | }
127 | }
128 | # $REF_NGRAM{$ngram} = 0 if !defined $REF_NGRAM{$ngram};
129 | # print STDERR "$ngram: {$s, $REF_NGRAM{$ngram}, $T_NGRAM{$ngram}, $corr}\n"
130 | }
131 | }
132 | $s++;
133 | }
134 | my $brevity_penalty = 1;
135 | my $bleu = 0;
136 |
137 | my @bleu=();
138 |
139 | for(my $n=1;$n<=4;$n++) {
140 | if (defined ($TOTAL[$n])){
141 | $bleu[$n]=($TOTAL[$n])?$CORRECT[$n]/$TOTAL[$n]:0;
142 | # print STDERR "CORRECT[$n]:$CORRECT[$n] TOTAL[$n]:$TOTAL[$n]\n";
143 | }else{
144 | $bleu[$n]=0;
145 | }
146 | }
147 |
148 | if ($length_reference==0){
149 | printf "BLEU = 0, 0/0/0/0 (BP=0, ratio=0, hyp_len=0, ref_len=0)\n";
150 | exit(1);
151 | }
152 |
153 | if ($length_translation<$length_reference) {
154 | $brevity_penalty = exp(1-$length_reference/$length_translation);
155 | }
156 | $bleu = $brevity_penalty * exp((my_log( $bleu[1] ) +
157 | my_log( $bleu[2] ) +
158 | my_log( $bleu[3] ) +
159 | my_log( $bleu[4] ) ) / 4) ;
160 | printf "BLEU = %.2f, %.1f/%.1f/%.1f/%.1f (BP=%.3f, ratio=%.3f, hyp_len=%d, ref_len=%d)\n",
161 | 100*$bleu,
162 | 100*$bleu[1],
163 | 100*$bleu[2],
164 | 100*$bleu[3],
165 | 100*$bleu[4],
166 | $brevity_penalty,
167 | $length_translation / $length_reference,
168 | $length_translation,
169 | $length_reference;
170 |
171 |
172 | # print STDERR "It is in-advisable to publish scores from multi-bleu.perl. The scores depend on your tokenizer, which is unlikely to be reproducible from your paper or consistent across research groups. Instead you should detokenize then use mteval-v14.pl, which has a standard tokenization. Scores from multi-bleu.perl can still be used for internal purposes when you have a consistent tokenizer.\n";
173 |
174 | sub my_log {
175 | return -9999999999 unless $_[0];
176 | return log($_[0]);
177 | }
178 |
--------------------------------------------------------------------------------
/src/evaluation/xnli.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2019-present, Facebook, Inc.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | #
7 |
8 | from logging import getLogger
9 | import os
10 | import copy
11 | import time
12 | import json
13 | from collections import OrderedDict
14 |
15 | import torch
16 | from torch import nn
17 | import torch.nn.functional as F
18 |
19 | from ..optim import get_optimizer
20 | from ..utils import concat_batches, truncate, to_cuda
21 | from ..data.dataset import ParallelDataset
22 | from ..data.loader import load_binarized, set_dico_parameters
23 |
24 |
25 | XNLI_LANGS = ['ar', 'bg', 'de', 'el', 'en', 'es', 'fr', 'hi', 'ru', 'sw', 'th', 'tr', 'ur', 'vi', 'zh']
26 |
27 |
28 | logger = getLogger()
29 |
30 |
31 | class XNLI:
32 |
33 | def __init__(self, embedder, scores, params):
34 | """
35 | Initialize XNLI trainer / evaluator.
36 | Initial `embedder` should be on CPU to save memory.
37 | """
38 | self._embedder = embedder
39 | self.params = params
40 | self.scores = scores
41 |
42 | def get_iterator(self, splt, lang):
43 | """
44 | Get a monolingual data iterator.
45 | """
46 | assert splt in ['valid', 'test'] or splt == 'train' and lang == 'en'
47 | return self.data[lang][splt]['x'].get_iterator(
48 | shuffle=(splt == 'train'),
49 | group_by_size=self.params.group_by_size,
50 | return_indices=True
51 | )
52 |
53 | def run(self):
54 | """
55 | Run XNLI training / evaluation.
56 | """
57 | params = self.params
58 |
59 | # load data
60 | self.data = self.load_data()
61 | if not self.data['dico'] == self._embedder.dico:
62 | raise Exception(("Dictionary in evaluation data (%i words) seems different than the one " +
63 | "in the pretrained model (%i words). Please verify you used the same dictionary, " +
64 | "and the same values for max_vocab and min_count.") % (len(self.data['dico']), len(self._embedder.dico)))
65 |
66 | # embedder
67 | self.embedder = copy.deepcopy(self._embedder)
68 | self.embedder.cuda()
69 |
70 | # projection layer
71 | self.proj = nn.Sequential(*[
72 | nn.Dropout(params.dropout),
73 | nn.Linear(self.embedder.out_dim, 3)
74 | ]).cuda()
75 |
76 | # optimizers
77 | self.optimizer_e = get_optimizer(list(self.embedder.get_parameters(params.finetune_layers)), params.optimizer_e)
78 | self.optimizer_p = get_optimizer(self.proj.parameters(), params.optimizer_p)
79 |
80 | # train and evaluate the model
81 | for epoch in range(params.n_epochs):
82 |
83 | # update epoch
84 | self.epoch = epoch
85 |
86 | # training
87 | logger.info("XNLI - Training epoch %i ..." % epoch)
88 | self.train()
89 |
90 | # evaluation
91 | logger.info("XNLI - Evaluating epoch %i ..." % epoch)
92 | with torch.no_grad():
93 | scores = self.eval()
94 | self.scores.update(scores)
95 |
96 | def train(self):
97 | """
98 | Finetune for one epoch on the XNLI English training set.
99 | """
100 | params = self.params
101 | self.embedder.train()
102 | self.proj.train()
103 |
104 | # training variables
105 | losses = []
106 | ns = 0 # number of sentences
107 | nw = 0 # number of words
108 | t = time.time()
109 |
110 | iterator = self.get_iterator('train', 'en')
111 | lang_id = params.lang2id['en']
112 |
113 | while True:
114 |
115 | # batch
116 | try:
117 | batch = next(iterator)
118 | except StopIteration:
119 | break
120 | (sent1, len1), (sent2, len2), idx = batch
121 | sent1, len1 = truncate(sent1, len1, params.max_len, params.eos_index)
122 | sent2, len2 = truncate(sent2, len2, params.max_len, params.eos_index)
123 | x, lengths, positions, langs = concat_batches(
124 | sent1, len1, lang_id,
125 | sent2, len2, lang_id,
126 | params.pad_index,
127 | params.eos_index,
128 | reset_positions=False
129 | )
130 | y = self.data['en']['train']['y'][idx]
131 | bs = len(len1)
132 |
133 | # cuda
134 | x, y, lengths, positions, langs = to_cuda(x, y, lengths, positions, langs)
135 |
136 | # loss
137 | output = self.proj(self.embedder.get_embeddings(x, lengths, positions, langs))
138 | loss = F.cross_entropy(output, y)
139 |
140 | # backward / optimization
141 | self.optimizer_e.zero_grad()
142 | self.optimizer_p.zero_grad()
143 | loss.backward()
144 | self.optimizer_e.step()
145 | self.optimizer_p.step()
146 |
147 | # update statistics
148 | ns += bs
149 | nw += lengths.sum().item()
150 | losses.append(loss.item())
151 |
152 | # log
153 | if ns % (100 * bs) < bs:
154 | logger.info("XNLI - Epoch %i - Train iter %7i - %.1f words/s - Loss: %.4f" % (self.epoch, ns, nw / (time.time() - t), sum(losses) / len(losses)))
155 | nw, t = 0, time.time()
156 | losses = []
157 |
158 | # epoch size
159 | if params.epoch_size != -1 and ns >= params.epoch_size:
160 | break
161 |
162 | def eval(self):
163 | """
164 | Evaluate on XNLI validation and test sets, for all languages.
165 | """
166 | params = self.params
167 | self.embedder.eval()
168 | self.proj.eval()
169 |
170 | scores = OrderedDict({'epoch': self.epoch})
171 |
172 | for splt in ['valid', 'test']:
173 |
174 | for lang in XNLI_LANGS:
175 | if lang not in params.lang2id:
176 | continue
177 |
178 | lang_id = params.lang2id[lang]
179 | valid = 0
180 | total = 0
181 |
182 | for batch in self.get_iterator(splt, lang):
183 |
184 | # batch
185 | (sent1, len1), (sent2, len2), idx = batch
186 | x, lengths, positions, langs = concat_batches(
187 | sent1, len1, lang_id,
188 | sent2, len2, lang_id,
189 | params.pad_index,
190 | params.eos_index,
191 | reset_positions=False
192 | )
193 | y = self.data[lang][splt]['y'][idx]
194 |
195 | # cuda
196 | x, y, lengths, positions, langs = to_cuda(x, y, lengths, positions, langs)
197 |
198 | # forward
199 | output = self.proj(self.embedder.get_embeddings(x, lengths, positions, langs))
200 | predictions = output.data.max(1)[1]
201 |
202 | # update statistics
203 | valid += predictions.eq(y).sum().item()
204 | total += len(len1)
205 |
206 | # compute accuracy
207 | acc = 100.0 * valid / total
208 | scores['xnli_%s_%s_acc' % (splt, lang)] = acc
209 | logger.info("XNLI - %s - %s - Epoch %i - Acc: %.1f%%" % (splt, lang, self.epoch, acc))
210 |
211 | logger.info("__log__:%s" % json.dumps(scores))
212 | return scores
213 |
214 | def load_data(self):
215 | """
216 | Load XNLI cross-lingual classification data.
217 | """
218 | params = self.params
219 | data = {lang: {splt: {} for splt in ['train', 'valid', 'test']} for lang in XNLI_LANGS}
220 | label2id = {'contradiction': 0, 'neutral': 1, 'entailment': 2}
221 | dpath = os.path.join(params.data_path, 'eval', 'XNLI')
222 |
223 | for splt in ['train', 'valid', 'test']:
224 |
225 | for lang in XNLI_LANGS:
226 |
227 | # only English has a training set
228 | if splt == 'train' and lang != 'en':
229 | del data[lang]['train']
230 | continue
231 |
232 | # load data and dictionary
233 | data1 = load_binarized(os.path.join(dpath, '%s.s1.%s.pth' % (splt, lang)), params)
234 | data2 = load_binarized(os.path.join(dpath, '%s.s2.%s.pth' % (splt, lang)), params)
235 | data['dico'] = data.get('dico', data1['dico'])
236 |
237 | # set dictionary parameters
238 | set_dico_parameters(params, data, data1['dico'])
239 | set_dico_parameters(params, data, data2['dico'])
240 |
241 | # create dataset
242 | data[lang][splt]['x'] = ParallelDataset(
243 | data1['sentences'], data1['positions'],
244 | data2['sentences'], data2['positions'],
245 | params
246 | )
247 |
248 | # load labels
249 | with open(os.path.join(dpath, '%s.label.%s' % (splt, lang)), 'r') as f:
250 | labels = [label2id[l.rstrip()] for l in f]
251 | data[lang][splt]['y'] = torch.LongTensor(labels)
252 | assert len(data[lang][splt]['x']) == len(data[lang][splt]['y'])
253 |
254 | return data
255 |
--------------------------------------------------------------------------------
/src/logger.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2019-present, Facebook, Inc.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | #
7 |
8 | import logging
9 | import time
10 | from datetime import timedelta
11 |
12 |
13 | class LogFormatter():
14 |
15 | def __init__(self):
16 | self.start_time = time.time()
17 |
18 | def format(self, record):
19 | elapsed_seconds = round(record.created - self.start_time)
20 |
21 | prefix = "%s - %s - %s" % (
22 | record.levelname,
23 | time.strftime('%x %X'),
24 | timedelta(seconds=elapsed_seconds)
25 | )
26 | message = record.getMessage()
27 | message = message.replace('\n', '\n' + ' ' * (len(prefix) + 3))
28 | return "%s - %s" % (prefix, message) if message else ''
29 |
30 |
31 | def create_logger(filepath, rank):
32 | """
33 | Create a logger.
34 | Use a different log file for each process.
35 | """
36 | # create log formatter
37 | log_formatter = LogFormatter()
38 |
39 | # create file handler and set level to debug
40 | if filepath is not None:
41 | if rank > 0:
42 | filepath = '%s-%i' % (filepath, rank)
43 | file_handler = logging.FileHandler(filepath, "a")
44 | file_handler.setLevel(logging.DEBUG)
45 | file_handler.setFormatter(log_formatter)
46 |
47 | # create console handler and set level to info
48 | console_handler = logging.StreamHandler()
49 | console_handler.setLevel(logging.INFO)
50 | console_handler.setFormatter(log_formatter)
51 |
52 | # create logger and set level to debug
53 | logger = logging.getLogger()
54 | logger.handlers = []
55 | logger.setLevel(logging.DEBUG)
56 | logger.propagate = False
57 | if filepath is not None:
58 | logger.addHandler(file_handler)
59 | logger.addHandler(console_handler)
60 |
61 | # reset logger elapsed time
62 | def reset_time():
63 | log_formatter.start_time = time.time()
64 | logger.reset_time = reset_time
65 |
66 | return logger
67 |
--------------------------------------------------------------------------------
/src/logger.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jind11/DAMT/3caa22822b936137e5da3e827d7a5a2078c3115e/src/logger.pyc
--------------------------------------------------------------------------------
/src/model/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2019-present, Facebook, Inc.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | #
7 |
8 | from logging import getLogger
9 | import os
10 | import torch
11 |
12 | from .pretrain import load_embeddings
13 | from .transformer import DECODER_ONLY_PARAMS, TransformerModel # , TRANSFORMER_LAYER_PARAMS
14 | from .memory import HashingMemory
15 |
16 |
17 | logger = getLogger()
18 |
19 |
20 | def check_model_params(params):
21 | """
22 | Check models parameters.
23 | """
24 | # masked language modeling task parameters
25 | assert params.bptt >= 1
26 | assert 0 <= params.word_pred < 1
27 | assert 0 <= params.sample_alpha < 1
28 | s = params.word_mask_keep_rand.split(',')
29 | assert len(s) == 3
30 | s = [float(x) for x in s]
31 | assert all([0 <= x <= 1 for x in s]) and sum(s) == 1
32 | params.word_mask = s[0]
33 | params.word_keep = s[1]
34 | params.word_rand = s[2]
35 |
36 | # input sentence noise for DAE
37 | if len(params.ae_steps) == 0:
38 | assert params.word_shuffle == 0
39 | assert params.word_dropout == 0
40 | assert params.word_blank == 0
41 | else:
42 | assert params.word_shuffle == 0 or params.word_shuffle > 1
43 | assert 0 <= params.word_dropout < 1
44 | assert 0 <= params.word_blank < 1
45 |
46 | # model dimensions
47 | assert params.emb_dim % params.n_heads == 0
48 |
49 | # share input and output embeddings
50 | assert params.share_inout_emb is False or params.asm is False
51 |
52 | # adaptive softmax
53 | if params.asm:
54 | assert params.asm_div_value > 1
55 | s = params.asm_cutoffs.split(',')
56 | assert all([x.isdigit() for x in s])
57 | params.asm_cutoffs = [int(x) for x in s]
58 | assert params.max_vocab == -1 or params.asm_cutoffs[-1] < params.max_vocab
59 |
60 | # memory
61 | if params.use_memory:
62 | HashingMemory.check_params(params)
63 | s_enc = [x for x in params.mem_enc_positions.split(',') if x != '']
64 | s_dec = [x for x in params.mem_dec_positions.split(',') if x != '']
65 | assert len(s_enc) == len(set(s_enc))
66 | assert len(s_dec) == len(set(s_dec))
67 | assert all(x.isdigit() or x[-1] == '+' and x[:-1].isdigit() for x in s_enc)
68 | assert all(x.isdigit() or x[-1] == '+' and x[:-1].isdigit() for x in s_dec)
69 | params.mem_enc_positions = [(int(x[:-1]), 'after') if x[-1] == '+' else (int(x), 'in') for x in s_enc]
70 | params.mem_dec_positions = [(int(x[:-1]), 'after') if x[-1] == '+' else (int(x), 'in') for x in s_dec]
71 | assert len(params.mem_enc_positions) + len(params.mem_dec_positions) > 0
72 | assert len(params.mem_enc_positions) == 0 or 0 <= min([x[0] for x in params.mem_enc_positions]) <= max([x[0] for x in params.mem_enc_positions]) <= params.n_layers - 1
73 | assert len(params.mem_dec_positions) == 0 or 0 <= min([x[0] for x in params.mem_dec_positions]) <= max([x[0] for x in params.mem_dec_positions]) <= params.n_layers - 1
74 |
75 | # reload pretrained word embeddings
76 | if params.reload_emb != '':
77 | assert os.path.isfile(params.reload_emb)
78 |
79 | # reload a pretrained model
80 | if params.reload_model != '':
81 | if params.encoder_only:
82 | assert os.path.isfile(params.reload_model)
83 | else:
84 | s = params.reload_model.split(',')
85 | assert len(s) == 2
86 | assert all([x == '' or os.path.isfile(x) for x in s])
87 |
88 |
89 | def set_pretrain_emb(model, dico, word2id, embeddings):
90 | """
91 | Pretrain word embeddings.
92 | """
93 | n_found = 0
94 | with torch.no_grad():
95 | for i in range(len(dico)):
96 | idx = word2id.get(dico[i], None)
97 | if idx is None:
98 | continue
99 | n_found += 1
100 | model.embeddings.weight[i] = embeddings[idx].cuda()
101 | model.pred_layer.proj.weight[i] = embeddings[idx].cuda()
102 | logger.info("Pretrained %i/%i words (%.3f%%)."
103 | % (n_found, len(dico), 100. * n_found / len(dico)))
104 |
105 |
106 | def build_model(params, dico):
107 | """
108 | Build model.
109 | """
110 | if params.encoder_only:
111 | # build
112 | model = TransformerModel(params, dico, is_encoder=True, with_output=True)
113 |
114 | # reload pretrained word embeddings
115 | if params.reload_emb != '':
116 | word2id, embeddings = load_embeddings(params.reload_emb, params)
117 | set_pretrain_emb(model, dico, word2id, embeddings)
118 |
119 | # reload a pretrained model
120 | if params.reload_model != '':
121 | logger.info("Reloading model from %s ..." % params.reload_model)
122 | reloaded = torch.load(params.reload_model, map_location=lambda storage, loc: storage.cuda(params.local_rank))['model']
123 | if all([k.startswith('module.') for k in reloaded.keys()]):
124 | reloaded = {k[len('module.'):]: v for k, v in reloaded.items()}
125 |
126 | # # HACK to reload models with less layers
127 | # for i in range(12, 24):
128 | # for k in TRANSFORMER_LAYER_PARAMS:
129 | # k = k % i
130 | # if k in model.state_dict() and k not in reloaded:
131 | # logger.warning("Parameter %s not found. Ignoring ..." % k)
132 | # reloaded[k] = model.state_dict()[k]
133 |
134 | model.load_state_dict(reloaded)
135 |
136 | logger.info("Model: {}".format(model))
137 | logger.info("Number of parameters (model): %i" % sum([p.numel() for p in model.parameters() if p.requires_grad]))
138 |
139 | return model.cuda()
140 |
141 | else:
142 | # build
143 | encoder = TransformerModel(params, dico, is_encoder=True, with_output=True) # TODO: only output when necessary - len(params.clm_steps + params.mlm_steps) > 0
144 | decoder = TransformerModel(params, dico, is_encoder=False, with_output=True)
145 |
146 | # reload pretrained word embeddings
147 | if params.reload_emb != '':
148 | word2id, embeddings = load_embeddings(params.reload_emb, params)
149 | set_pretrain_emb(encoder, dico, word2id, embeddings)
150 | set_pretrain_emb(decoder, dico, word2id, embeddings)
151 |
152 | # reload a pretrained model
153 | if params.reload_model != '':
154 | enc_path, dec_path = params.reload_model.split(',')
155 | assert not (enc_path == '' and dec_path == '')
156 |
157 | # reload encoder
158 | if enc_path != '':
159 | logger.info("Reloading encoder from %s ..." % enc_path)
160 | enc_reload = torch.load(enc_path, map_location=lambda storage, loc: storage.cuda(params.local_rank))
161 | enc_reload = enc_reload['model' if 'model' in enc_reload else 'encoder']
162 | if all([k.startswith('module.') for k in enc_reload.keys()]):
163 | enc_reload = {k[len('module.'):]: v for k, v in enc_reload.items()}
164 | encoder.load_state_dict(enc_reload)
165 |
166 | # reload decoder
167 | if dec_path != '':
168 | logger.info("Reloading decoder from %s ..." % dec_path)
169 | dec_reload = torch.load(dec_path, map_location=lambda storage, loc: storage.cuda(params.local_rank))
170 | dec_reload = dec_reload['model' if 'model' in dec_reload else 'decoder']
171 | if all([k.startswith('module.') for k in dec_reload.keys()]):
172 | dec_reload = {k[len('module.'):]: v for k, v in dec_reload.items()}
173 | for i in range(params.n_layers):
174 | for name in DECODER_ONLY_PARAMS:
175 | if name % i not in dec_reload:
176 | logger.warning("Parameter %s not found." % (name % i))
177 | dec_reload[name % i] = decoder.state_dict()[name % i]
178 | decoder.load_state_dict(dec_reload)
179 |
180 | logger.debug("Encoder: {}".format(encoder))
181 | logger.debug("Decoder: {}".format(decoder))
182 | logger.info("Number of parameters (encoder): %i" % sum([p.numel() for p in encoder.parameters() if p.requires_grad]))
183 | logger.info("Number of parameters (decoder): %i" % sum([p.numel() for p in decoder.parameters() if p.requires_grad]))
184 |
185 | return encoder.cuda(), decoder.cuda()
186 |
--------------------------------------------------------------------------------
/src/model/embedder.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2019-present, Facebook, Inc.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | #
7 |
8 | from logging import getLogger
9 | import torch
10 |
11 | from .transformer import TransformerModel
12 | from ..data.dictionary import Dictionary, BOS_WORD, EOS_WORD, PAD_WORD, UNK_WORD, MASK_WORD
13 | from ..utils import AttrDict
14 |
15 |
16 | logger = getLogger()
17 |
18 |
19 | class SentenceEmbedder(object):
20 |
21 | @staticmethod
22 | def reload(path, params):
23 | """
24 | Create a sentence embedder from a pretrained model.
25 | """
26 | # reload model
27 | reloaded = torch.load(path)
28 | state_dict = reloaded['model']
29 |
30 | # handle models from multi-GPU checkpoints
31 | if 'checkpoint' in path:
32 | state_dict = {(k[7:] if k.startswith('module.') else k): v for k, v in state_dict.items()}
33 |
34 | # reload dictionary and model parameters
35 | dico = Dictionary(reloaded['dico_id2word'], reloaded['dico_word2id'], reloaded['dico_counts'])
36 | pretrain_params = AttrDict(reloaded['params'])
37 | pretrain_params.n_words = len(dico)
38 | pretrain_params.bos_index = dico.index(BOS_WORD)
39 | pretrain_params.eos_index = dico.index(EOS_WORD)
40 | pretrain_params.pad_index = dico.index(PAD_WORD)
41 | pretrain_params.unk_index = dico.index(UNK_WORD)
42 | pretrain_params.mask_index = dico.index(MASK_WORD)
43 |
44 | # build model and reload weights
45 | model = TransformerModel(pretrain_params, dico, True, True)
46 | model.load_state_dict(state_dict)
47 | model.eval()
48 |
49 | # adding missing parameters
50 | params.max_batch_size = 0
51 |
52 | return SentenceEmbedder(model, dico, pretrain_params)
53 |
54 | def __init__(self, model, dico, pretrain_params):
55 | """
56 | Wrapper on top of the different sentence embedders.
57 | Returns sequence-wise or single-vector sentence representations.
58 | """
59 | self.pretrain_params = {k: v for k, v in pretrain_params.__dict__.items()}
60 | self.model = model
61 | self.dico = dico
62 | self.n_layers = model.n_layers
63 | self.out_dim = model.dim
64 | self.n_words = model.n_words
65 |
66 | def train(self):
67 | self.model.train()
68 |
69 | def eval(self):
70 | self.model.eval()
71 |
72 | def cuda(self):
73 | self.model.cuda()
74 |
75 | def get_parameters(self, layer_range):
76 |
77 | s = layer_range.split(':')
78 | assert len(s) == 2
79 | i, j = int(s[0].replace('_', '-')), int(s[1].replace('_', '-'))
80 |
81 | # negative indexing
82 | i = self.n_layers + i + 1 if i < 0 else i
83 | j = self.n_layers + j + 1 if j < 0 else j
84 |
85 | # sanity check
86 | assert 0 <= i <= self.n_layers
87 | assert 0 <= j <= self.n_layers
88 |
89 | if i > j:
90 | return []
91 |
92 | parameters = []
93 |
94 | # embeddings
95 | if i == 0:
96 | # embeddings
97 | parameters += self.model.embeddings.parameters()
98 | logger.info("Adding embedding parameters to optimizer")
99 | # positional embeddings
100 | if self.pretrain_params['sinusoidal_embeddings'] is False:
101 | parameters += self.model.position_embeddings.parameters()
102 | logger.info("Adding positional embedding parameters to optimizer")
103 | # language embeddings
104 | if hasattr(self.model, 'lang_embeddings'):
105 | parameters += self.model.lang_embeddings.parameters()
106 | logger.info("Adding language embedding parameters to optimizer")
107 | parameters += self.model.layer_norm_emb.parameters()
108 | # layers
109 | for l in range(max(i - 1, 0), j):
110 | parameters += self.model.attentions[l].parameters()
111 | parameters += self.model.layer_norm1[l].parameters()
112 | parameters += self.model.ffns[l].parameters()
113 | parameters += self.model.layer_norm2[l].parameters()
114 | logger.info("Adding layer-%s parameters to optimizer" % (l + 1))
115 |
116 | logger.info("Optimizing on %i Transformer elements." % sum([p.nelement() for p in parameters]))
117 |
118 | return parameters
119 |
120 | def get_embeddings(self, x, lengths, positions=None, langs=None):
121 | """
122 | Inputs:
123 | `x` : LongTensor of shape (slen, bs)
124 | `lengths` : LongTensor of shape (bs,)
125 | Outputs:
126 | `sent_emb` : FloatTensor of shape (bs, out_dim)
127 | With out_dim == emb_dim
128 | """
129 | slen, bs = x.size()
130 | assert lengths.size(0) == bs and lengths.max().item() == slen
131 |
132 | # get transformer last hidden layer
133 | tensor = self.model('fwd', x=x, lengths=lengths, positions=positions, langs=langs, causal=False)
134 | assert tensor.size() == (slen, bs, self.out_dim)
135 |
136 | # single-vector sentence representation (first column of last layer)
137 | return tensor[0]
138 |
--------------------------------------------------------------------------------
/src/model/memory/__init__.py:
--------------------------------------------------------------------------------
1 | from .memory import HashingMemory
2 |
--------------------------------------------------------------------------------
/src/model/memory/query.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 | from .utils import get_slices
5 |
6 |
7 | def mlp(sizes, bias=True, batchnorm=True, groups=1):
8 | """
9 | Generate a feedforward neural network.
10 | """
11 | assert len(sizes) >= 2
12 | pairs = [(sizes[i], sizes[i + 1]) for i in range(len(sizes) - 1)]
13 | layers = []
14 |
15 | for i, (dim_in, dim_out) in enumerate(pairs):
16 | if groups == 1 or i == 0:
17 | layers.append(nn.Linear(dim_in, groups * dim_out, bias=bias))
18 | else:
19 | layers.append(GroupedLinear(groups * dim_in, groups * dim_out, bias=bias, groups=groups))
20 | if batchnorm:
21 | layers.append(nn.BatchNorm1d(groups * dim_out))
22 | if i < len(pairs) - 1:
23 | layers.append(nn.ReLU())
24 |
25 | return nn.Sequential(*layers)
26 |
27 |
28 | def convs(channel_sizes, kernel_sizes, bias=True, batchnorm=True, residual=False, groups=1):
29 | """
30 | Generate a convolutional neural network.
31 | """
32 | assert len(channel_sizes) >= 2
33 | assert len(channel_sizes) == len(kernel_sizes) + 1
34 | pairs = [(channel_sizes[i], channel_sizes[i + 1]) for i in range(len(channel_sizes) - 1)]
35 | layers = []
36 |
37 | for i, (dim_in, dim_out) in enumerate(pairs):
38 | ks = (kernel_sizes[i], kernel_sizes[i])
39 | in_group = 1 if i == 0 else groups
40 | _dim_in = dim_in * in_group
41 | _dim_out = dim_out * groups
42 | if not residual:
43 | layers.append(nn.Conv2d(_dim_in, _dim_out, ks, padding=[k // 2 for k in ks], bias=bias, groups=in_group))
44 | if batchnorm:
45 | layers.append(nn.BatchNorm2d(_dim_out))
46 | if i < len(pairs) - 1:
47 | layers.append(nn.ReLU())
48 | else:
49 | layers.append(BottleneckResidualConv2d(
50 | _dim_in, _dim_out, ks, bias=bias,
51 | batchnorm=batchnorm, groups=in_group
52 | ))
53 | if i == len(pairs) - 1:
54 | layers.append(nn.Conv2d(_dim_out, _dim_out, (1, 1), bias=bias))
55 |
56 | return nn.Sequential(*layers)
57 |
58 |
59 | class GroupedLinear(nn.Module):
60 |
61 | def __init__(self, in_features, out_features, bias=True, groups=1):
62 |
63 | super().__init__()
64 | self.in_features = in_features
65 | self.out_features = out_features
66 | self.groups = groups
67 | self.bias = bias
68 | assert groups > 1
69 |
70 | self.layer = nn.Conv1d(in_features, out_features, bias=bias, kernel_size=1, groups=groups)
71 |
72 | def forward(self, input):
73 | assert input.dim() == 2 and input.size(1) == self.in_features
74 | return self.layer(input.unsqueeze(2)).squeeze(2)
75 |
76 | def extra_repr(self):
77 | return 'in_features={}, out_features={}, groups={}, bias={}'.format(
78 | self.in_features, self.out_features, self.groups, self.bias is not None
79 | )
80 |
81 |
82 | class BottleneckResidualConv2d(nn.Module):
83 |
84 | def __init__(self, input_channels, output_channels, kernel_size, bias=True, batchnorm=True, groups=1):
85 |
86 | super().__init__()
87 | hidden_channels = min(input_channels, output_channels)
88 | assert all(k % 2 == 1 for k in kernel_size)
89 |
90 | self.conv1 = nn.Conv2d(input_channels, hidden_channels, kernel_size, padding=[k // 2 for k in kernel_size], bias=bias, groups=groups)
91 | self.conv2 = nn.Conv2d(hidden_channels, output_channels, kernel_size, padding=[k // 2 for k in kernel_size], bias=bias, groups=groups)
92 | self.act = nn.ReLU()
93 |
94 | self.batchnorm = batchnorm
95 | if self.batchnorm:
96 | self.bn1 = nn.BatchNorm2d(hidden_channels)
97 | self.bn2 = nn.BatchNorm2d(output_channels)
98 |
99 | if input_channels == output_channels:
100 | self.residual = nn.Sequential()
101 | else:
102 | self.residual = nn.Conv2d(input_channels, output_channels, (1, 1), bias=False, groups=groups)
103 |
104 | def forward(self, input):
105 | x = self.conv1(input)
106 | x = self.bn1(x) if self.batchnorm else x
107 | x = self.act(x)
108 | x = self.conv2(x)
109 | x = self.bn2(x) if self.batchnorm else x
110 | x = self.act(x + self.residual(input))
111 | return x
112 |
113 |
114 | class QueryIdentity(nn.Module):
115 |
116 | def __init__(self, input_dim, heads, shuffle_hidden):
117 | super().__init__()
118 | self.input_dim = input_dim
119 | self.heads = heads
120 | self.shuffle_query = shuffle_hidden
121 | assert shuffle_hidden is False or heads > 1
122 | assert shuffle_hidden is False or self.input_dim % (2 ** self.heads) == 0
123 | if shuffle_hidden:
124 | self.slices = {head_id: get_slices(input_dim, head_id) for head_id in range(heads)}
125 |
126 | def forward(self, input):
127 | """
128 | Generate queries from hidden states by either
129 | repeating them or creating some shuffled version.
130 | """
131 | assert input.shape[-1] == self.input_dim
132 | input = input.contiguous().view(-1, self.input_dim) if input.dim() > 2 else input
133 | bs = len(input)
134 |
135 | if self.heads == 1:
136 | query = input
137 |
138 | elif not self.shuffle_query:
139 | query = input.unsqueeze(1).repeat(1, self.heads, 1)
140 | query = query.view(bs * self.heads, self.input_dim)
141 |
142 | else:
143 | query = torch.cat([
144 | input[:, a:b]
145 | for head_id in range(self.heads)
146 | for a, b in self.slices[head_id]
147 | ], 1).view(bs * self.heads, self.input_dim)
148 |
149 | assert query.shape == (bs * self.heads, self.input_dim)
150 | return query
151 |
152 |
153 | class QueryMLP(nn.Module):
154 |
155 | def __init__(
156 | self, input_dim, heads, k_dim, product_quantization, multi_query_net,
157 | sizes, bias=True, batchnorm=True, grouped_conv=False
158 | ):
159 | super().__init__()
160 | self.input_dim = input_dim
161 | self.heads = heads
162 | self.k_dim = k_dim
163 | self.sizes = sizes
164 | self.grouped_conv = grouped_conv
165 | assert not multi_query_net or product_quantization or heads >= 2
166 | assert sizes[0] == input_dim
167 | assert sizes[-1] == (k_dim // 2) if multi_query_net else (heads * k_dim)
168 | assert self.grouped_conv is False or len(sizes) > 2
169 |
170 | # number of required MLPs
171 | self.groups = (2 * heads) if multi_query_net else 1
172 |
173 | # MLPs
174 | if self.grouped_conv:
175 | self.query_mlps = mlp(sizes, bias=bias, batchnorm=batchnorm, groups=self.groups)
176 | elif len(self.sizes) == 2:
177 | sizes_ = list(sizes)
178 | sizes_[-1] = sizes_[-1] * self.groups
179 | self.query_mlps = mlp(sizes_, bias=bias, batchnorm=batchnorm, groups=1)
180 | else:
181 | self.query_mlps = nn.ModuleList([
182 | mlp(sizes, bias=bias, batchnorm=batchnorm, groups=1)
183 | for _ in range(self.groups)
184 | ])
185 |
186 | def forward(self, input):
187 | """
188 | Compute queries using either grouped 1D convolutions or ModuleList + concat.
189 | """
190 | assert input.shape[-1] == self.input_dim
191 | input = input.contiguous().view(-1, self.input_dim) if input.dim() > 2 else input
192 | bs = len(input)
193 |
194 | if self.grouped_conv or len(self.sizes) == 2:
195 | query = self.query_mlps(input)
196 | else:
197 | outputs = [m(input) for m in self.query_mlps]
198 | query = torch.cat(outputs, 1) if len(outputs) > 1 else outputs[0]
199 |
200 | assert query.shape == (bs, self.heads * self.k_dim)
201 | return query.view(bs * self.heads, self.k_dim)
202 |
203 |
204 | class QueryConv(nn.Module):
205 |
206 | def __init__(
207 | self, input_dim, heads, k_dim, product_quantization, multi_query_net,
208 | sizes, kernel_sizes, bias=True, batchnorm=True,
209 | residual=False, grouped_conv=False
210 | ):
211 | super().__init__()
212 | self.input_dim = input_dim
213 | self.heads = heads
214 | self.k_dim = k_dim
215 | self.sizes = sizes
216 | self.grouped_conv = grouped_conv
217 | assert not multi_query_net or product_quantization or heads >= 2
218 | assert sizes[0] == input_dim
219 | assert sizes[-1] == (k_dim // 2) if multi_query_net else (heads * k_dim)
220 | assert self.grouped_conv is False or len(sizes) > 2
221 | assert len(sizes) == len(kernel_sizes) + 1 >= 2 and all(ks % 2 == 1 for ks in kernel_sizes)
222 |
223 | # number of required CNNs
224 | self.groups = (2 * heads) if multi_query_net else 1
225 |
226 | # CNNs
227 | if self.grouped_conv:
228 | self.query_convs = convs(sizes, kernel_sizes, bias=bias, batchnorm=batchnorm, residual=residual, groups=self.groups)
229 | elif len(self.sizes) == 2:
230 | sizes_ = list(sizes)
231 | sizes_[-1] = sizes_[-1] * self.groups
232 | self.query_convs = convs(sizes_, kernel_sizes, bias=bias, batchnorm=batchnorm, residual=residual, groups=1)
233 | else:
234 | self.query_convs = nn.ModuleList([
235 | convs(sizes, kernel_sizes, bias=bias, batchnorm=batchnorm, residual=residual, groups=1)
236 | for _ in range(self.groups)
237 | ])
238 |
239 | def forward(self, input):
240 |
241 | bs, nf, h, w = input.shape
242 | assert nf == self.input_dim
243 |
244 | if self.grouped_conv or len(self.sizes) == 2:
245 | query = self.query_convs(input)
246 | else:
247 | outputs = [m(input) for m in self.query_convs]
248 | query = torch.cat(outputs, 1) if len(outputs) > 1 else outputs[0]
249 |
250 | assert query.shape == (bs, self.heads * self.k_dim, h, w)
251 | query = query.transpose(1, 3).contiguous().view(bs * w * h * self.heads, self.k_dim)
252 | return query
253 |
--------------------------------------------------------------------------------
/src/model/memory/utils.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import math
3 | import numpy as np
4 | import torch
5 |
6 |
7 | # load FAISS GPU library if available (dramatically accelerates the nearest neighbor search)
8 | try:
9 | import faiss
10 | FAISS_AVAILABLE = hasattr(faiss, 'StandardGpuResources')
11 | except ImportError:
12 | FAISS_AVAILABLE = False
13 | sys.stderr.write("FAISS library was not found.\n")
14 |
15 |
16 | def get_gaussian_keys(n_keys, dim, normalized, seed):
17 | """
18 | Generate random Gaussian keys.
19 | """
20 | rng = np.random.RandomState(seed)
21 | X = rng.randn(n_keys, dim)
22 | if normalized:
23 | X /= np.linalg.norm(X, axis=1, keepdims=True)
24 | return X.astype(np.float32)
25 |
26 |
27 | def get_uniform_keys(n_keys, dim, normalized, seed):
28 | """
29 | Generate random uniform keys (same initialization as nn.Linear).
30 | """
31 | rng = np.random.RandomState(seed)
32 | bound = 1 / math.sqrt(dim)
33 | X = rng.uniform(-bound, bound, (n_keys, dim))
34 | if normalized:
35 | X /= np.linalg.norm(X, axis=1, keepdims=True)
36 | return X.astype(np.float32)
37 |
38 |
39 | def get_slices(dim, head_id):
40 | """
41 | Generate slices of hidden dimensions.
42 | Used when there are multiple heads and/or different set of keys,
43 | and that there is no query network.
44 | """
45 | if head_id == 0:
46 | return [(0, dim)]
47 | offset = dim // (2 ** (head_id + 1))
48 | starts = np.arange(0, dim, offset)
49 | slices1 = [(x, x + offset) for i, x in enumerate(starts) if i % 2 == 0]
50 | slices2 = [(x, x + offset) for i, x in enumerate(starts) if i % 2 == 1]
51 | return slices1 + slices2
52 |
53 |
54 | def cartesian_product(a, b):
55 | """
56 | Compute the batched cartesian product between two matrices.
57 | Input:
58 | a: Tensor(n, d1)
59 | b: Tensor(n, d2)
60 | Output:
61 | output: Tensor(n, d1 * d2, 2)
62 | """
63 | n1, d1 = a.shape
64 | n2, d2 = b.shape
65 | assert n1 == n2
66 | return torch.cat([
67 | a.unsqueeze(-1).repeat(1, 1, d2).unsqueeze(-1),
68 | b.repeat(1, d1).view(n2, d1, d2).unsqueeze(-1)
69 | ], 3).view(n1, d1 * d2, 2)
70 |
71 |
72 | def swig_ptr_from_FloatTensor(x):
73 | assert x.is_contiguous()
74 | assert x.dtype == torch.float32
75 | return faiss.cast_integer_to_float_ptr(x.storage().data_ptr() + x.storage_offset() * 4)
76 |
77 |
78 | def swig_ptr_from_LongTensor(x):
79 | assert x.is_contiguous()
80 | assert x.dtype == torch.int64, 'dtype=%s' % x.dtype
81 | return faiss.cast_integer_to_long_ptr(x.storage().data_ptr() + x.storage_offset() * 8)
82 |
83 |
84 | def get_knn_pytorch(a, b, k, distance='dot_product'):
85 | """
86 | Input:
87 | - matrix of size (m, d) (keys)
88 | - matrix of size (n, d) (queries)
89 | - number of nearest neighbors
90 | - distance metric
91 | Output:
92 | - `scores` matrix of size (n, k) with nearest neighors scores
93 | - `indices` matrix of size (n, k) with nearest neighors indices
94 | """
95 | m, d = a.size()
96 | n, _ = b.size()
97 | assert b.size(1) == d
98 | assert k > 0
99 | assert distance in ['dot_product', 'cosine', 'l2']
100 |
101 | with torch.no_grad():
102 |
103 | if distance == 'dot_product':
104 | scores = a.mm(b.t()) # (m, n)
105 |
106 | elif distance == 'cosine':
107 | scores = a.mm(b.t()) # (m, n)
108 | scores /= (a.norm(2, 1)[:, None] + 1e-9) # (m, n)
109 | scores /= (b.norm(2, 1)[None, :] + 1e-9) # (m, n)
110 |
111 | elif distance == 'l2':
112 | scores = a.mm(b.t()) # (m, n)
113 | scores *= 2 # (m, n)
114 | scores -= (a ** 2).sum(1)[:, None] # (m, n)
115 | scores -= (b ** 2).sum(1)[None, :] # (m, n)
116 |
117 | scores, indices = scores.topk(k=k, dim=0, largest=True) # (k, n)
118 | scores = scores.t() # (n, k)
119 | indices = indices.t() # (n, k)
120 |
121 | return scores, indices
122 |
123 |
124 | def get_knn_faiss(xb, xq, k, distance='dot_product'):
125 | """
126 | `metric` can be faiss.METRIC_INNER_PRODUCT or faiss.METRIC_L2
127 | https://github.com/facebookresearch/faiss/blob/master/gpu/test/test_pytorch_faiss.py
128 | """
129 | assert xb.device == xq.device
130 | assert distance in ['dot_product', 'l2']
131 | metric = faiss.METRIC_INNER_PRODUCT if distance == 'dot_product' else faiss.METRIC_L2
132 |
133 | xq_ptr = swig_ptr_from_FloatTensor(xq)
134 | xb_ptr = swig_ptr_from_FloatTensor(xb)
135 |
136 | nq, d1 = xq.size()
137 | nb, d2 = xb.size()
138 | assert d1 == d2
139 |
140 | D = torch.empty(nq, k, device=xb.device, dtype=torch.float32)
141 | I = torch.empty(nq, k, device=xb.device, dtype=torch.int64)
142 |
143 | D_ptr = swig_ptr_from_FloatTensor(D)
144 | I_ptr = swig_ptr_from_LongTensor(I)
145 |
146 | faiss.bruteForceKnn(
147 | FAISS_RES, metric,
148 | xb_ptr, nb,
149 | xq_ptr, nq,
150 | d1, k, D_ptr, I_ptr
151 | )
152 |
153 | return D, I
154 |
155 |
156 | if FAISS_AVAILABLE:
157 | FAISS_RES = faiss.StandardGpuResources()
158 | FAISS_RES.setDefaultNullStreamAllDevices()
159 | FAISS_RES.setTempMemory(1200 * 1024 * 1024)
160 | get_knn = get_knn_faiss
161 | else:
162 | sys.stderr.write("FAISS not available. Switching to standard nearest neighbors search implementation.\n")
163 | get_knn = get_knn_pytorch
164 |
--------------------------------------------------------------------------------
/src/model/pretrain.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2019-present, Facebook, Inc.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | #
7 |
8 | from logging import getLogger
9 | import io
10 | import numpy as np
11 | import torch
12 |
13 |
14 | logger = getLogger()
15 |
16 |
17 | def load_fasttext_model(path):
18 | """
19 | Load a binarized fastText model.
20 | """
21 | try:
22 | import fastText
23 | except ImportError:
24 | raise Exception("Unable to import fastText. Please install fastText for Python: "
25 | "https://github.com/facebookresearch/fastText")
26 | return fastText.load_model(path)
27 |
28 |
29 | def read_txt_embeddings(path, params):
30 | """
31 | Reload pretrained embeddings from a text file.
32 | """
33 | word2id = {}
34 | vectors = []
35 |
36 | # load pretrained embeddings
37 | _emb_dim_file = params.emb_dim
38 | with io.open(path, 'r', encoding='utf-8', newline='\n', errors='ignore') as f:
39 | for i, line in enumerate(f):
40 | if i == 0:
41 | split = line.split()
42 | assert len(split) == 2
43 | assert _emb_dim_file == int(split[1])
44 | continue
45 | word, vect = line.rstrip().split(' ', 1)
46 | vect = np.fromstring(vect, sep=' ')
47 | if word in word2id:
48 | logger.warning("Word \"%s\" found twice!" % word)
49 | continue
50 | if not vect.shape == (_emb_dim_file,):
51 | logger.warning("Invalid dimension (%i) for word \"%s\" in line %i."
52 | % (vect.shape[0], word, i))
53 | continue
54 | assert vect.shape == (_emb_dim_file,)
55 | word2id[word] = len(word2id)
56 | vectors.append(vect[None])
57 |
58 | assert len(word2id) == len(vectors)
59 | logger.info("Loaded %i pretrained word embeddings from %s" % (len(vectors), path))
60 |
61 | # compute new vocabulary / embeddings
62 | embeddings = np.concatenate(vectors, 0)
63 | embeddings = torch.from_numpy(embeddings).float()
64 |
65 | assert embeddings.size() == (len(word2id), params.emb_dim)
66 | return word2id, embeddings
67 |
68 |
69 | def load_bin_embeddings(path, params):
70 | """
71 | Reload pretrained embeddings from a fastText binary file.
72 | """
73 | model = load_fasttext_model(path)
74 | assert model.get_dimension() == params.emb_dim
75 | words = model.get_labels()
76 | logger.info("Loaded binary model from %s" % path)
77 |
78 | # compute new vocabulary / embeddings
79 | embeddings = np.concatenate([model.get_word_vector(w)[None] for w in words], 0)
80 | embeddings = torch.from_numpy(embeddings).float()
81 | word2id = {w: i for i, w in enumerate(words)}
82 | logger.info("Generated embeddings for %i words." % len(words))
83 |
84 | assert embeddings.size() == (len(word2id), params.emb_dim)
85 | return word2id, embeddings
86 |
87 |
88 | def load_embeddings(path, params):
89 | """
90 | Reload pretrained embeddings.
91 | """
92 | if path.endswith('.bin'):
93 | return load_bin_embeddings(path, params)
94 | else:
95 | return read_txt_embeddings(path, params)
96 |
--------------------------------------------------------------------------------
/src/optim.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2019-present, Facebook, Inc.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | #
7 |
8 | import re
9 | import math
10 | import inspect
11 |
12 | import torch
13 | from torch import optim
14 |
15 |
16 | class Adam(optim.Optimizer):
17 | """
18 | Same as https://github.com/pytorch/pytorch/blob/master/torch/optim/adam.py,
19 | without amsgrad, with step in a tensor, and states initialization in __init__.
20 | It was important to add `.item()` in `state['step'].item()`.
21 | """
22 |
23 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0):
24 | if not 0.0 <= lr:
25 | raise ValueError("Invalid learning rate: {}".format(lr))
26 | if not 0.0 <= eps:
27 | raise ValueError("Invalid epsilon value: {}".format(eps))
28 | if not 0.0 <= betas[0] < 1.0:
29 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
30 | if not 0.0 <= betas[1] < 1.0:
31 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
32 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
33 | super().__init__(params, defaults)
34 |
35 | for group in self.param_groups:
36 | for p in group['params']:
37 | state = self.state[p]
38 | state['step'] = 0 # torch.zeros(1)
39 | state['exp_avg'] = torch.zeros_like(p.data)
40 | state['exp_avg_sq'] = torch.zeros_like(p.data)
41 |
42 | def __setstate__(self, state):
43 | super().__setstate__(state)
44 |
45 | def step(self, closure=None):
46 | """
47 | Step.
48 | """
49 | loss = None
50 | if closure is not None:
51 | loss = closure()
52 |
53 | for group in self.param_groups:
54 | for p in group['params']:
55 | if p.grad is None:
56 | continue
57 | grad = p.grad.data
58 | if grad.is_sparse:
59 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')
60 |
61 | state = self.state[p]
62 |
63 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
64 | beta1, beta2 = group['betas']
65 |
66 | state['step'] += 1
67 |
68 | # if group['weight_decay'] != 0:
69 | # grad.add_(group['weight_decay'], p.data)
70 |
71 | # Decay the first and second moment running average coefficient
72 | exp_avg.mul_(beta1).add_(1 - beta1, grad)
73 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
74 | denom = exp_avg_sq.sqrt().add_(group['eps'])
75 | # denom = exp_avg_sq.sqrt().clamp_(min=group['eps'])
76 |
77 | bias_correction1 = 1 - beta1 ** state['step'] # .item()
78 | bias_correction2 = 1 - beta2 ** state['step'] # .item()
79 | step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1
80 |
81 | if group['weight_decay'] != 0:
82 | p.data.add_(-group['weight_decay'] * group['lr'], p.data)
83 |
84 | p.data.addcdiv_(-step_size, exp_avg, denom)
85 |
86 | return loss
87 |
88 |
89 | class AdamInverseSqrtWithWarmup(Adam):
90 | """
91 | Decay the LR based on the inverse square root of the update number.
92 | We also support a warmup phase where we linearly increase the learning rate
93 | from some initial learning rate (`warmup-init-lr`) until the configured
94 | learning rate (`lr`). Thereafter we decay proportional to the number of
95 | updates, with a decay factor set to align with the configured learning rate.
96 | During warmup:
97 | lrs = torch.linspace(warmup_init_lr, lr, warmup_updates)
98 | lr = lrs[update_num]
99 | After warmup:
100 | lr = decay_factor / sqrt(update_num)
101 | where
102 | decay_factor = lr * sqrt(warmup_updates)
103 | """
104 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
105 | weight_decay=0, warmup_updates=4000, warmup_init_lr=1e-7,
106 | exp_factor=0.5):
107 | super().__init__(
108 | params,
109 | lr=warmup_init_lr,
110 | betas=betas,
111 | eps=eps,
112 | weight_decay=weight_decay,
113 | )
114 |
115 | # linearly warmup for the first warmup_updates
116 | self.warmup_updates = warmup_updates
117 | self.warmup_init_lr = warmup_init_lr
118 | warmup_end_lr = lr
119 | self.lr_step = (warmup_end_lr - warmup_init_lr) / warmup_updates
120 |
121 | # then, decay prop. to the inverse square root of the update number
122 | self.exp_factor = exp_factor
123 | self.decay_factor = warmup_end_lr * warmup_updates ** self.exp_factor
124 |
125 | # total number of updates
126 | for param_group in self.param_groups:
127 | param_group['num_updates'] = 0
128 |
129 | def get_lr_for_step(self, num_updates):
130 | if num_updates < self.warmup_updates:
131 | return self.warmup_init_lr + num_updates * self.lr_step
132 | else:
133 | return self.decay_factor * (num_updates ** -self.exp_factor)
134 |
135 | def step(self, closure=None):
136 | super().step(closure)
137 | for param_group in self.param_groups:
138 | param_group['num_updates'] += 1
139 | param_group['lr'] = self.get_lr_for_step(param_group['num_updates'])
140 |
141 |
142 | class AdamCosineWithWarmup(Adam):
143 | """
144 | Assign LR based on a cyclical schedule that follows the cosine function.
145 | See https://arxiv.org/pdf/1608.03983.pdf for details.
146 | We also support a warmup phase where we linearly increase the learning rate
147 | from some initial learning rate (``--warmup-init-lr``) until the configured
148 | learning rate (``--lr``).
149 | During warmup::
150 | lrs = torch.linspace(args.warmup_init_lr, args.lr, args.warmup_updates)
151 | lr = lrs[update_num]
152 | After warmup::
153 | lr = lr_min + 0.5*(lr_max - lr_min)*(1 + cos(t_curr / t_i))
154 | where ``t_curr`` is current percentage of updates within the current period
155 | range and ``t_i`` is the current period range, which is scaled by ``t_mul``
156 | after every iteration.
157 | """
158 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
159 | weight_decay=0, warmup_updates=4000, warmup_init_lr=1e-7,
160 | min_lr=1e-9, init_period=1000000, period_mult=1, lr_shrink=0.75):
161 | super().__init__(
162 | params,
163 | lr=warmup_init_lr,
164 | betas=betas,
165 | eps=eps,
166 | weight_decay=weight_decay,
167 | )
168 |
169 | # linearly warmup for the first warmup_updates
170 | self.warmup_updates = warmup_updates
171 | self.warmup_init_lr = warmup_init_lr
172 | warmup_end_lr = lr
173 | self.lr_step = (warmup_end_lr - warmup_init_lr) / warmup_updates
174 |
175 | # then, apply cosine scheduler
176 | self.min_lr = min_lr
177 | self.max_lr = lr
178 | self.period = init_period
179 | self.period_mult = period_mult
180 | self.lr_shrink = lr_shrink
181 |
182 | # total number of updates
183 | for param_group in self.param_groups:
184 | param_group['num_updates'] = 0
185 |
186 | def get_lr_for_step(self, num_updates):
187 | if num_updates < self.warmup_updates:
188 | return self.warmup_init_lr + num_updates * self.lr_step
189 | else:
190 | t = num_updates - self.warmup_updates
191 | if self.period_mult == 1:
192 | pid = math.floor(t / self.period)
193 | t_i = self.period
194 | t_curr = t - (self.period * pid)
195 | else:
196 | pid = math.floor(math.log(1 - t / self.period * (1 - self.period_mult), self.period_mult))
197 | t_i = self.period * (self.period_mult ** pid)
198 | t_curr = t - (1 - self.period_mult ** pid) / (1 - self.period_mult) * self.period
199 | lr_shrink = self.lr_shrink ** pid
200 | min_lr = self.min_lr * lr_shrink
201 | max_lr = self.max_lr * lr_shrink
202 | return min_lr + 0.5 * (max_lr - min_lr) * (1 + math.cos(math.pi * t_curr / t_i))
203 |
204 | def step(self, closure=None):
205 | super().step(closure)
206 | for param_group in self.param_groups:
207 | param_group['num_updates'] += 1
208 | param_group['lr'] = self.get_lr_for_step(param_group['num_updates'])
209 |
210 |
211 | def get_optimizer(parameters, s):
212 | """
213 | Parse optimizer parameters.
214 | Input should be of the form:
215 | - "sgd,lr=0.01"
216 | - "adagrad,lr=0.1,lr_decay=0.05"
217 | """
218 | if "," in s:
219 | method = s[:s.find(',')]
220 | optim_params = {}
221 | for x in s[s.find(',') + 1:].split(','):
222 | split = x.split('=')
223 | assert len(split) == 2
224 | assert re.match("^[+-]?(\d+(\.\d*)?|\.\d+)$", split[1]) is not None
225 | optim_params[split[0]] = float(split[1])
226 | else:
227 | method = s
228 | optim_params = {}
229 |
230 | if method == 'adadelta':
231 | optim_fn = optim.Adadelta
232 | elif method == 'adagrad':
233 | optim_fn = optim.Adagrad
234 | elif method == 'adam':
235 | optim_fn = Adam
236 | optim_params['betas'] = (optim_params.get('beta1', 0.9), optim_params.get('beta2', 0.999))
237 | optim_params.pop('beta1', None)
238 | optim_params.pop('beta2', None)
239 | elif method == 'adam_inverse_sqrt':
240 | optim_fn = AdamInverseSqrtWithWarmup
241 | optim_params['betas'] = (optim_params.get('beta1', 0.9), optim_params.get('beta2', 0.999))
242 | optim_params.pop('beta1', None)
243 | optim_params.pop('beta2', None)
244 | elif method == 'adam_cosine':
245 | optim_fn = AdamCosineWithWarmup
246 | optim_params['betas'] = (optim_params.get('beta1', 0.9), optim_params.get('beta2', 0.999))
247 | optim_params.pop('beta1', None)
248 | optim_params.pop('beta2', None)
249 | elif method == 'adamax':
250 | optim_fn = optim.Adamax
251 | elif method == 'asgd':
252 | optim_fn = optim.ASGD
253 | elif method == 'rmsprop':
254 | optim_fn = optim.RMSprop
255 | elif method == 'rprop':
256 | optim_fn = optim.Rprop
257 | elif method == 'sgd':
258 | optim_fn = optim.SGD
259 | assert 'lr' in optim_params
260 | else:
261 | raise Exception('Unknown optimization method: "%s"' % method)
262 |
263 | # check that we give good parameters to the optimizer
264 | expected_args = inspect.getargspec(optim_fn.__init__)[0]
265 | assert expected_args[:2] == ['self', 'params']
266 | if not all(k in expected_args[2:] for k in optim_params.keys()):
267 | raise Exception('Unexpected parameters: expected "%s", got "%s"' % (
268 | str(expected_args[2:]), str(optim_params.keys())))
269 |
270 | return optim_fn(parameters, **optim_params)
271 |
--------------------------------------------------------------------------------
/src/slurm.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2019-present, Facebook, Inc.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | #
7 |
8 | from logging import getLogger
9 | import os
10 | import sys
11 | import torch
12 | import socket
13 | import signal
14 | import subprocess
15 |
16 |
17 | logger = getLogger()
18 |
19 |
20 | def sig_handler(signum, frame):
21 | logger.warning("Signal handler called with signal " + str(signum))
22 | prod_id = int(os.environ['SLURM_PROCID'])
23 | logger.warning("Host: %s - Global rank: %i" % (socket.gethostname(), prod_id))
24 | if prod_id == 0:
25 | logger.warning("Requeuing job " + os.environ['SLURM_JOB_ID'])
26 | os.system('scontrol requeue ' + os.environ['SLURM_JOB_ID'])
27 | else:
28 | logger.warning("Not the master process, no need to requeue.")
29 | sys.exit(-1)
30 |
31 |
32 | def term_handler(signum, frame):
33 | logger.warning("Signal handler called with signal " + str(signum))
34 | logger.warning("Bypassing SIGTERM.")
35 |
36 |
37 | def init_signal_handler():
38 | """
39 | Handle signals sent by SLURM for time limit / pre-emption.
40 | """
41 | signal.signal(signal.SIGUSR1, sig_handler)
42 | signal.signal(signal.SIGTERM, term_handler)
43 | logger.warning("Signal handler installed.")
44 |
45 |
46 | def init_distributed_mode(params):
47 | """
48 | Handle single and multi-GPU / multi-node / SLURM jobs.
49 | Initialize the following variables:
50 | - n_nodes
51 | - node_id
52 | - local_rank
53 | - global_rank
54 | - world_size
55 | """
56 | params.is_slurm_job = 'SLURM_JOB_ID' in os.environ and not params.debug_slurm
57 | print("SLURM job: %s" % str(params.is_slurm_job))
58 |
59 | # SLURM job
60 | if params.is_slurm_job:
61 |
62 | assert params.local_rank == -1 # on the cluster, this is handled by SLURM
63 |
64 | SLURM_VARIABLES = [
65 | 'SLURM_JOB_ID',
66 | 'SLURM_JOB_NODELIST', 'SLURM_JOB_NUM_NODES', 'SLURM_NTASKS', 'SLURM_TASKS_PER_NODE',
67 | 'SLURM_MEM_PER_NODE', 'SLURM_MEM_PER_CPU',
68 | 'SLURM_NODEID', 'SLURM_PROCID', 'SLURM_LOCALID', 'SLURM_TASK_PID'
69 | ]
70 |
71 | PREFIX = "%i - " % int(os.environ['SLURM_PROCID'])
72 | for name in SLURM_VARIABLES:
73 | value = os.environ.get(name, None)
74 | print(PREFIX + "%s: %s" % (name, str(value)))
75 |
76 | # # job ID
77 | # params.job_id = os.environ['SLURM_JOB_ID']
78 |
79 | # number of nodes / node ID
80 | params.n_nodes = int(os.environ['SLURM_JOB_NUM_NODES'])
81 | params.node_id = int(os.environ['SLURM_NODEID'])
82 |
83 | # local rank on the current node / global rank
84 | params.local_rank = int(os.environ['SLURM_LOCALID'])
85 | params.global_rank = int(os.environ['SLURM_PROCID'])
86 |
87 | # number of processes / GPUs per node
88 | params.world_size = int(os.environ['SLURM_NTASKS'])
89 | params.n_gpu_per_node = params.world_size // params.n_nodes
90 |
91 | # define master address and master port
92 | hostnames = subprocess.check_output(['scontrol', 'show', 'hostnames', os.environ['SLURM_JOB_NODELIST']])
93 | params.master_addr = hostnames.split()[0].decode('utf-8')
94 | assert 10001 <= params.master_port <= 20000 or params.world_size == 1
95 | print(PREFIX + "Master address: %s" % params.master_addr)
96 | print(PREFIX + "Master port : %i" % params.master_port)
97 |
98 | # set environment variables for 'env://'
99 | os.environ['MASTER_ADDR'] = params.master_addr
100 | os.environ['MASTER_PORT'] = str(params.master_port)
101 | os.environ['WORLD_SIZE'] = str(params.world_size)
102 | os.environ['RANK'] = str(params.global_rank)
103 |
104 | # multi-GPU job (local or multi-node) - jobs started with torch.distributed.launch
105 | elif params.local_rank != -1:
106 |
107 | assert params.master_port == -1
108 |
109 | # read environment variables
110 | params.global_rank = int(os.environ['RANK'])
111 | params.world_size = int(os.environ['WORLD_SIZE'])
112 | params.n_gpu_per_node = int(os.environ['NGPU'])
113 |
114 | # number of nodes / node ID
115 | params.n_nodes = params.world_size // params.n_gpu_per_node
116 | params.node_id = params.global_rank // params.n_gpu_per_node
117 |
118 | # local job (single GPU)
119 | else:
120 | assert params.local_rank == -1
121 | assert params.master_port == -1
122 | params.n_nodes = 1
123 | params.node_id = 0
124 | params.local_rank = 0
125 | params.global_rank = 0
126 | params.world_size = 1
127 | params.n_gpu_per_node = 1
128 |
129 | # sanity checks
130 | assert params.n_nodes >= 1
131 | assert 0 <= params.node_id < params.n_nodes
132 | assert 0 <= params.local_rank <= params.global_rank < params.world_size
133 | assert params.world_size == params.n_nodes * params.n_gpu_per_node
134 |
135 | # define whether this is the master process / if we are in distributed mode
136 | params.is_master = params.node_id == 0 and params.local_rank == 0
137 | params.multi_node = params.n_nodes > 1
138 | params.multi_gpu = params.world_size > 1
139 |
140 | # summary
141 | PREFIX = "%i - " % params.global_rank
142 | print(PREFIX + "Number of nodes: %i" % params.n_nodes)
143 | print(PREFIX + "Node ID : %i" % params.node_id)
144 | print(PREFIX + "Local rank : %i" % params.local_rank)
145 | print(PREFIX + "Global rank : %i" % params.global_rank)
146 | print(PREFIX + "World size : %i" % params.world_size)
147 | print(PREFIX + "GPUs per node : %i" % params.n_gpu_per_node)
148 | print(PREFIX + "Master : %s" % str(params.is_master))
149 | print(PREFIX + "Multi-node : %s" % str(params.multi_node))
150 | print(PREFIX + "Multi-GPU : %s" % str(params.multi_gpu))
151 | print(PREFIX + "Hostname : %s" % socket.gethostname())
152 |
153 | # set GPU device
154 | torch.cuda.set_device(params.local_rank)
155 |
156 | # initialize multi-GPU
157 | if params.multi_gpu:
158 |
159 | # http://pytorch.apachecn.org/en/0.3.0/distributed.html#environment-variable-initialization
160 | # 'env://' will read these environment variables:
161 | # MASTER_PORT - required; has to be a free port on machine with rank 0
162 | # MASTER_ADDR - required (except for rank 0); address of rank 0 node
163 | # WORLD_SIZE - required; can be set either here, or in a call to init function
164 | # RANK - required; can be set either here, or in a call to init function
165 |
166 | print("Initializing PyTorch distributed ...")
167 | torch.distributed.init_process_group(
168 | init_method='env://',
169 | backend='nccl',
170 | )
171 |
--------------------------------------------------------------------------------
/src/utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2019-present, Facebook, Inc.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | #
7 |
8 | import os
9 | import re
10 | import sys
11 | import pickle
12 | import random
13 | import getpass
14 | import argparse
15 | import subprocess
16 | import numpy as np
17 | import torch
18 |
19 | from .logger import create_logger
20 |
21 |
22 | FALSY_STRINGS = {'off', 'false', '0'}
23 | TRUTHY_STRINGS = {'on', 'true', '1'}
24 |
25 | DUMP_PATH = '/checkpoint/%s/dumped' % getpass.getuser()
26 | DYNAMIC_COEFF = ['lambda_clm', 'lambda_mlm', 'lambda_pc', 'lambda_ae', 'lambda_mt', 'lambda_bt']
27 |
28 |
29 | class AttrDict(dict):
30 | def __init__(self, *args, **kwargs):
31 | super(AttrDict, self).__init__(*args, **kwargs)
32 | self.__dict__ = self
33 |
34 |
35 | def bool_flag(s):
36 | """
37 | Parse boolean arguments from the command line.
38 | """
39 | if s.lower() in FALSY_STRINGS:
40 | return False
41 | elif s.lower() in TRUTHY_STRINGS:
42 | return True
43 | else:
44 | raise argparse.ArgumentTypeError("Invalid value for a boolean flag!")
45 |
46 |
47 | def initialize_exp(params):
48 | """
49 | Initialize the experience:
50 | - dump parameters
51 | - create a logger
52 | """
53 | # dump parameters
54 | get_dump_path(params)
55 | pickle.dump(params, open(os.path.join(params.dump_path, 'params.pkl'), 'wb'))
56 |
57 | # get running command
58 | command = ["python", sys.argv[0]]
59 | for x in sys.argv[1:]:
60 | if x.startswith('--'):
61 | assert '"' not in x and "'" not in x
62 | command.append(x)
63 | else:
64 | assert "'" not in x
65 | if re.match('^[a-zA-Z0-9_]+$', x):
66 | command.append("%s" % x)
67 | else:
68 | command.append("'%s'" % x)
69 | command = ' '.join(command)
70 | params.command = command + ' --exp_id "%s"' % params.exp_id
71 |
72 | # check experiment name
73 | assert len(params.exp_name.strip()) > 0
74 |
75 | # create a logger
76 | logger = create_logger(os.path.join(params.dump_path, 'train.log'), rank=getattr(params, 'global_rank', 0))
77 | logger.info("============ Initialized logger ============")
78 | logger.info("\n".join("%s: %s" % (k, str(v))
79 | for k, v in sorted(dict(vars(params)).items())))
80 | logger.info("The experiment will be stored in %s\n" % params.dump_path)
81 | logger.info("Running command: %s" % command)
82 | logger.info("")
83 | return logger
84 |
85 |
86 | def get_dump_path(params):
87 | """
88 | Create a directory to store the experiment.
89 | """
90 | dump_path = DUMP_PATH if params.dump_path == '' else params.dump_path
91 | assert len(params.exp_name) > 0
92 |
93 | # create the sweep path if it does not exist
94 | sweep_path = os.path.join(dump_path, params.exp_name)
95 | if not os.path.exists(sweep_path):
96 | subprocess.Popen("mkdir -p %s" % sweep_path, shell=True).wait()
97 |
98 | # create an ID for the job if it is not given in the parameters.
99 | # if we run on the cluster, the job ID is the one of Chronos.
100 | # otherwise, it is randomly generated
101 | if params.exp_id == '':
102 | chronos_job_id = os.environ.get('CHRONOS_JOB_ID')
103 | slurm_job_id = os.environ.get('SLURM_JOB_ID')
104 | assert chronos_job_id is None or slurm_job_id is None
105 | exp_id = chronos_job_id if chronos_job_id is not None else slurm_job_id
106 | if exp_id is None:
107 | chars = 'abcdefghijklmnopqrstuvwxyz0123456789'
108 | while True:
109 | exp_id = ''.join(random.choice(chars) for _ in range(10))
110 | if not os.path.isdir(os.path.join(sweep_path, exp_id)):
111 | break
112 | else:
113 | assert exp_id.isdigit()
114 | params.exp_id = exp_id
115 |
116 | # create the dump folder / update parameters
117 | params.dump_path = os.path.join(sweep_path, params.exp_id)
118 | if not os.path.isdir(params.dump_path):
119 | subprocess.Popen("mkdir -p %s" % params.dump_path, shell=True).wait()
120 |
121 |
122 | def to_cuda(*args):
123 | """
124 | Move tensors to CUDA.
125 | """
126 | return [None if x is None else x.cuda() for x in args]
127 |
128 |
129 | def restore_segmentation(path, bpe_type='fastBPE'):
130 | """
131 | Take a file segmented with BPE and restore it to its original segmentation.
132 | """
133 | assert os.path.isfile(path)
134 | if bpe_type == 'fastBPE':
135 | restore_cmd = "sed -i -r 's/(@@ )|(@@ ?$)//g' %s"
136 | elif bpe_type == 'sentencepiece':
137 | restore_cmd = u"sed -i -e 's/ //g' -e 's/^\u2581//g' -e 's/\u2581/ /g' %s"
138 | else:
139 | raise NotImplementedError
140 | subprocess.Popen(restore_cmd % path, shell=True).wait()
141 |
142 |
143 | def parse_lambda_config(params):
144 | """
145 | Parse the configuration of lambda coefficient (for scheduling).
146 | x = "3" # lambda will be a constant equal to x
147 | x = "0:1,1000:0" # lambda will start from 1 and linearly decrease to 0 during the first 1000 iterations
148 | x = "0:0,1000:0,2000:1" # lambda will be equal to 0 for the first 1000 iterations, then will linearly increase to 1 until iteration 2000
149 | """
150 | for name in DYNAMIC_COEFF:
151 | x = getattr(params, name)
152 | split = x.split(',')
153 | if len(split) == 1:
154 | setattr(params, name, float(x))
155 | setattr(params, name + '_config', None)
156 | else:
157 | split = [s.split(':') for s in split]
158 | assert all(len(s) == 2 for s in split)
159 | assert all(k.isdigit() for k, _ in split)
160 | assert all(int(split[i][0]) < int(split[i + 1][0]) for i in range(len(split) - 1))
161 | setattr(params, name, float(split[0][1]))
162 | setattr(params, name + '_config', [(int(k), float(v)) for k, v in split])
163 |
164 |
165 | def get_lambda_value(config, n_iter):
166 | """
167 | Compute a lambda value according to its schedule configuration.
168 | """
169 | ranges = [i for i in range(len(config) - 1) if config[i][0] <= n_iter < config[i + 1][0]]
170 | if len(ranges) == 0:
171 | assert n_iter >= config[-1][0]
172 | return config[-1][1]
173 | assert len(ranges) == 1
174 | i = ranges[0]
175 | x_a, y_a = config[i]
176 | x_b, y_b = config[i + 1]
177 | return y_a + (n_iter - x_a) * float(y_b - y_a) / float(x_b - x_a)
178 |
179 |
180 | def update_lambdas(params, n_iter):
181 | """
182 | Update all lambda coefficients.
183 | """
184 | for name in DYNAMIC_COEFF:
185 | config = getattr(params, name + '_config')
186 | if config is not None:
187 | setattr(params, name, get_lambda_value(config, n_iter))
188 |
189 |
190 | def set_sampling_probs(data, params):
191 | """
192 | Set the probability of sampling specific languages / language pairs during training.
193 | """
194 | coeff = params.lg_sampling_factor
195 | if coeff == -1:
196 | return
197 | assert coeff > 0
198 |
199 | # monolingual data
200 | params.mono_list = [k for k, v in data['mono_stream'].items() if 'train' in v]
201 | if len(params.mono_list) > 0:
202 | probs = np.array([1.0 * len(data['mono_stream'][lang]['train']) for lang in params.mono_list])
203 | probs /= probs.sum()
204 | probs = np.array([p ** coeff for p in probs])
205 | probs /= probs.sum()
206 | params.mono_probs = probs
207 |
208 | # parallel data
209 | params.para_list = [k for k, v in data['para'].items() if 'train' in v]
210 | if len(params.para_list) > 0:
211 | probs = np.array([1.0 * len(data['para'][(l1, l2)]['train']) for (l1, l2) in params.para_list])
212 | probs /= probs.sum()
213 | probs = np.array([p ** coeff for p in probs])
214 | probs /= probs.sum()
215 | params.para_probs = probs
216 |
217 |
218 | def concat_batches(x1, len1, lang1_id, x2, len2, lang2_id, pad_idx, eos_idx, reset_positions):
219 | """
220 | Concat batches with different languages.
221 | """
222 | assert reset_positions is False or lang1_id != lang2_id
223 | lengths = len1 + len2
224 | if not reset_positions:
225 | lengths -= 1
226 | slen, bs = lengths.max().item(), lengths.size(0)
227 |
228 | x = x1.new(slen, bs).fill_(pad_idx)
229 | x[:len1.max().item()].copy_(x1)
230 | positions = torch.arange(slen)[:, None].repeat(1, bs).to(x1.device)
231 | langs = x1.new(slen, bs).fill_(lang1_id)
232 |
233 | for i in range(bs):
234 | l1 = len1[i] if reset_positions else len1[i] - 1
235 | x[l1:l1 + len2[i], i].copy_(x2[:len2[i], i])
236 | if reset_positions:
237 | positions[l1:, i] -= len1[i]
238 | langs[l1:, i] = lang2_id
239 |
240 | assert (x == eos_idx).long().sum().item() == (4 if reset_positions else 3) * bs
241 |
242 | return x, lengths, positions, langs
243 |
244 |
245 | def truncate(x, lengths, max_len, eos_index):
246 | """
247 | Truncate long sentences.
248 | """
249 | if lengths.max().item() > max_len:
250 | x = x[:max_len].clone()
251 | lengths = lengths.clone()
252 | for i in range(len(lengths)):
253 | if lengths[i] > max_len:
254 | lengths[i] = max_len
255 | x[max_len - 1, i] = eos_index
256 | return x, lengths
257 |
258 |
259 | def shuf_order(langs, params=None, n=5):
260 | """
261 | Randomize training order.
262 | """
263 | if len(langs) == 0:
264 | return []
265 |
266 | if params is None:
267 | return [langs[i] for i in np.random.permutation(len(langs))]
268 |
269 | # sample monolingual and parallel languages separately
270 | mono = [l1 for l1, l2 in langs if l2 is None]
271 | para = [(l1, l2) for l1, l2 in langs if l2 is not None]
272 |
273 | # uniform / weighted sampling
274 | if params.lg_sampling_factor == -1:
275 | p_mono = None
276 | p_para = None
277 | else:
278 | p_mono = np.array([params.mono_probs[params.mono_list.index(k)] for k in mono])
279 | p_para = np.array([params.para_probs[params.para_list.index(tuple(sorted(k)))] for k in para])
280 | p_mono = p_mono / p_mono.sum()
281 | p_para = p_para / p_para.sum()
282 |
283 | s_mono = [mono[i] for i in np.random.choice(len(mono), size=min(n, len(mono)), p=p_mono, replace=True)] if len(mono) > 0 else []
284 | s_para = [para[i] for i in np.random.choice(len(para), size=min(n, len(para)), p=p_para, replace=True)] if len(para) > 0 else []
285 |
286 | assert len(s_mono) + len(s_para) > 0
287 | return [(lang, None) for lang in s_mono] + s_para
288 |
289 |
290 | def find_modules(module, module_name, module_instance, found):
291 | """
292 | Recursively find all instances of a specific module inside a module.
293 | """
294 | if isinstance(module, module_instance):
295 | found.append((module_name, module))
296 | else:
297 | for name, child in module.named_children():
298 | name = ('%s[%s]' if name.isdigit() else '%s.%s') % (module_name, name)
299 | find_modules(child, name, module_instance, found)
300 |
--------------------------------------------------------------------------------
/tools/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jind11/DAMT/3caa22822b936137e5da3e827d7a5a2078c3115e/tools/.DS_Store
--------------------------------------------------------------------------------
/tools/README.md:
--------------------------------------------------------------------------------
1 | # Tools
2 |
3 | In `XLM/tools/`, you will need to install the following tools:
4 |
5 | ## Tokenizers
6 |
7 | [Moses](https://github.com/moses-smt/mosesdecoder/tree/master/scripts/tokenizer) tokenizer:
8 | ```
9 | git clone https://github.com/moses-smt/mosesdecoder
10 | ```
11 |
12 | Thai [PythaiNLP](https://github.com/PyThaiNLP/pythainlp) tokenizer:
13 | ```
14 | pip install pythainlp
15 | ```
16 |
17 | Japanese [KyTea](http://www.phontron.com/kytea) tokenizer:
18 | ```
19 | wget http://www.phontron.com/kytea/download/kytea-0.4.7.tar.gz
20 | tar -xzf kytea-0.4.7.tar.gz
21 | cd kytea-0.4.7
22 | ./configure
23 | make
24 | make install
25 | kytea --help
26 | ```
27 |
28 | Chinese Stanford segmenter:
29 | ```
30 | wget https://nlp.stanford.edu/software/stanford-segmenter-2018-10-16.zip
31 | unzip stanford-segmenter-2018-10-16.zip
32 | ```
33 |
34 | ## fastBPE
35 |
36 | ```
37 | git clone https://github.com/glample/fastBPE
38 | cd fastBPE
39 | g++ -std=c++11 -pthread -O3 fastBPE/main.cc -IfastBPE -o fast
40 | ```
41 |
--------------------------------------------------------------------------------
/tools/lowercase_and_remove_accent.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2019-present, Facebook, Inc.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | #
7 |
8 | import sys
9 | import unicodedata
10 | import six
11 |
12 |
13 | def convert_to_unicode(text):
14 | """
15 | Converts `text` to Unicode (if it's not already), assuming UTF-8 input.
16 | """
17 | # six_ensure_text is copied from https://github.com/benjaminp/six
18 | def six_ensure_text(s, encoding='utf-8', errors='strict'):
19 | if isinstance(s, six.binary_type):
20 | return s.decode(encoding, errors)
21 | elif isinstance(s, six.text_type):
22 | return s
23 | else:
24 | raise TypeError("not expecting type '%s'" % type(s))
25 |
26 | return six_ensure_text(text, encoding="utf-8", errors="ignore")
27 |
28 |
29 | def run_strip_accents(text):
30 | """
31 | Strips accents from a piece of text.
32 | """
33 | text = unicodedata.normalize("NFD", text)
34 | output = []
35 | for char in text:
36 | cat = unicodedata.category(char)
37 | if cat == "Mn":
38 | continue
39 | output.append(char)
40 | return "".join(output)
41 |
42 |
43 | for line in sys.stdin:
44 | line = convert_to_unicode(line.rstrip().lower())
45 | line = run_strip_accents(line)
46 | print(u'%s' % line.lower())
47 |
--------------------------------------------------------------------------------
/tools/segment_th.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2019-present, Facebook, Inc.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | #
7 |
8 | import sys
9 | from pythainlp.tokenize import word_tokenize
10 |
11 | for line in sys.stdin.readlines():
12 | line = line.rstrip('\n')
13 | print(' '.join(word_tokenize(line)))
14 |
--------------------------------------------------------------------------------
/tools/tokenize.sh:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2019-present, Facebook, Inc.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | #
7 |
8 | # Tokenize text data in various languages
9 | # Usage: e.g. cat wiki.ar | tokenize.sh ar
10 |
11 | set -e
12 |
13 | N_THREADS=8
14 |
15 | lg=$1
16 | TOOLS_PATH=$PWD/tools
17 |
18 | # moses
19 | MOSES=$TOOLS_PATH/mosesdecoder
20 | REPLACE_UNICODE_PUNCT=$MOSES/scripts/tokenizer/replace-unicode-punctuation.perl
21 | NORM_PUNC=$MOSES/scripts/tokenizer/normalize-punctuation.perl
22 | REM_NON_PRINT_CHAR=$MOSES/scripts/tokenizer/remove-non-printing-char.perl
23 | TOKENIZER=$MOSES/scripts/tokenizer/tokenizer.perl
24 |
25 | # Chinese
26 | if [ "$lg" = "zh" ]; then
27 | $TOOLS_PATH/stanford-segmenter-*/segment.sh pku /dev/stdin UTF-8 0 | $REPLACE_UNICODE_PUNCT | $NORM_PUNC -l $lg | $REM_NON_PRINT_CHAR
28 | # Thai
29 | elif [ "$lg" = "th" ]; then
30 | cat - | $REPLACE_UNICODE_PUNCT | $NORM_PUNC -l $lg | $REM_NON_PRINT_CHAR | python $TOOLS_PATH/segment_th.py
31 | # Japanese
32 | elif [ "$lg" = "ja" ]; then
33 | cat - | $REPLACE_UNICODE_PUNCT | $NORM_PUNC -l $lg | $REM_NON_PRINT_CHAR | kytea -notags
34 | # other languages
35 | else
36 | cat - | $REPLACE_UNICODE_PUNCT | $NORM_PUNC -l $lg | $REM_NON_PRINT_CHAR | $TOKENIZER -no-escape -threads $N_THREADS -l $lg
37 | fi
38 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2019-present, Facebook, Inc.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | #
7 |
8 | import json
9 | import random
10 | import argparse
11 |
12 | from src.slurm import init_signal_handler, init_distributed_mode
13 | from src.data.loader import check_data_params, load_data
14 | from src.utils import bool_flag, initialize_exp, set_sampling_probs, shuf_order
15 | from src.model import check_model_params, build_model
16 | from src.model.memory import HashingMemory
17 | from src.trainer import SingleTrainer, EncDecTrainer
18 | from src.evaluation.evaluator import SingleEvaluator, EncDecEvaluator
19 |
20 |
21 | def get_parser():
22 | """
23 | Generate a parameters parser.
24 | """
25 | # parse parameters
26 | parser = argparse.ArgumentParser(description="Language transfer")
27 |
28 | # main parameters
29 | parser.add_argument("--dump_path", type=str, default="./dumped/",
30 | help="Experiment dump path")
31 | parser.add_argument("--exp_name", type=str, default="",
32 | help="Experiment name")
33 | parser.add_argument("--save_periodic", type=int, default=0,
34 | help="Save the model periodically (0 to disable)")
35 | parser.add_argument("--exp_id", type=str, default="",
36 | help="Experiment ID")
37 |
38 | # float16 / AMP API
39 | parser.add_argument("--fp16", type=bool_flag, default=False,
40 | help="Run model with float16")
41 | parser.add_argument("--amp", type=int, default=-1,
42 | help="Use AMP wrapper for float16 / distributed / gradient accumulation. Level of optimization. -1 to disable.")
43 |
44 | # only use an encoder (use a specific decoder for machine translation)
45 | parser.add_argument("--encoder_only", type=bool_flag, default=True,
46 | help="Only use an encoder")
47 |
48 | # model parameters
49 | parser.add_argument("--emb_dim", type=int, default=512,
50 | help="Embedding layer size")
51 | parser.add_argument("--n_layers", type=int, default=4,
52 | help="Number of Transformer layers")
53 | parser.add_argument("--n_heads", type=int, default=8,
54 | help="Number of Transformer heads")
55 | parser.add_argument("--dropout", type=float, default=0,
56 | help="Dropout")
57 | parser.add_argument("--attention_dropout", type=float, default=0,
58 | help="Dropout in the attention layer")
59 | parser.add_argument("--gelu_activation", type=bool_flag, default=False,
60 | help="Use a GELU activation instead of ReLU")
61 | parser.add_argument("--share_inout_emb", type=bool_flag, default=True,
62 | help="Share input and output embeddings")
63 | parser.add_argument("--sinusoidal_embeddings", type=bool_flag, default=False,
64 | help="Use sinusoidal embeddings")
65 | parser.add_argument("--use_lang_emb", type=bool_flag, default=True,
66 | help="Use language embedding")
67 |
68 | # memory parameters
69 | parser.add_argument("--use_memory", type=bool_flag, default=False,
70 | help="Use an external memory")
71 | if parser.parse_known_args()[0].use_memory:
72 | HashingMemory.register_args(parser)
73 | parser.add_argument("--mem_enc_positions", type=str, default="",
74 | help="Memory positions in the encoder ('4' for inside layer 4, '7,10+' for inside layer 7 and after layer 10)")
75 | parser.add_argument("--mem_dec_positions", type=str, default="",
76 | help="Memory positions in the decoder. Same syntax as `mem_enc_positions`.")
77 |
78 | # adaptive softmax
79 | parser.add_argument("--asm", type=bool_flag, default=False,
80 | help="Use adaptive softmax")
81 | if parser.parse_known_args()[0].asm:
82 | parser.add_argument("--asm_cutoffs", type=str, default="8000,20000",
83 | help="Adaptive softmax cutoffs")
84 | parser.add_argument("--asm_div_value", type=float, default=4,
85 | help="Adaptive softmax cluster sizes ratio")
86 |
87 | # causal language modeling task parameters
88 | parser.add_argument("--context_size", type=int, default=0,
89 | help="Context size (0 means that the first elements in sequences won't have any context)")
90 |
91 | # masked language modeling task parameters
92 | parser.add_argument("--word_pred", type=float, default=0.15,
93 | help="Fraction of words for which we need to make a prediction")
94 | parser.add_argument("--sample_alpha", type=float, default=0,
95 | help="Exponent for transforming word counts to probabilities (~word2vec sampling)")
96 | parser.add_argument("--word_mask_keep_rand", type=str, default="0.8,0.1,0.1",
97 | help="Fraction of words to mask out / keep / randomize, among the words to predict")
98 |
99 | # input sentence noise
100 | parser.add_argument("--word_shuffle", type=float, default=0,
101 | help="Randomly shuffle input words (0 to disable)")
102 | parser.add_argument("--word_dropout", type=float, default=0,
103 | help="Randomly dropout input words (0 to disable)")
104 | parser.add_argument("--word_blank", type=float, default=0,
105 | help="Randomly blank input words (0 to disable)")
106 |
107 | # data
108 | parser.add_argument("--data_path", type=str, default="",
109 | help="Data path")
110 | parser.add_argument("--para_data_path", type=str, default="",
111 | help="Parallel Data path")
112 | parser.add_argument("--lgs", type=str, default="",
113 | help="Languages (lg1-lg2-lg3 .. ex: en-fr-es-de)")
114 | parser.add_argument("--max_vocab", type=int, default=-1,
115 | help="Maximum vocabulary size (-1 to disable)")
116 | parser.add_argument("--min_count", type=int, default=0,
117 | help="Minimum vocabulary count")
118 | parser.add_argument("--lg_sampling_factor", type=float, default=-1,
119 | help="Language sampling factor")
120 |
121 | # batch parameters
122 | parser.add_argument("--bptt", type=int, default=256,
123 | help="Sequence length")
124 | parser.add_argument("--max_len", type=int, default=100,
125 | help="Maximum length of sentences (after BPE)")
126 | parser.add_argument("--group_by_size", type=bool_flag, default=True,
127 | help="Sort sentences by size during the training")
128 | parser.add_argument("--batch_size", type=int, default=32,
129 | help="Number of sentences per batch")
130 | parser.add_argument("--max_batch_size", type=int, default=0,
131 | help="Maximum number of sentences per batch (used in combination with tokens_per_batch, 0 to disable)")
132 | parser.add_argument("--tokens_per_batch", type=int, default=-1,
133 | help="Number of tokens per batch")
134 |
135 | # training parameters
136 | parser.add_argument("--split_data", type=bool_flag, default=False,
137 | help="Split data across workers of a same node")
138 | parser.add_argument("--optimizer", type=str, default="adam,lr=0.0001",
139 | help="Optimizer (SGD / RMSprop / Adam, etc.)")
140 | parser.add_argument("--clip_grad_norm", type=float, default=5,
141 | help="Clip gradients norm (0 to disable)")
142 | parser.add_argument("--epoch_size", type=int, default=100000,
143 | help="Epoch size / evaluation frequency (-1 for parallel data size)")
144 | parser.add_argument("--max_epoch", type=int, default=100000,
145 | help="Maximum epoch size")
146 | parser.add_argument("--stopping_criterion", type=str, default="",
147 | help="Stopping criterion, and number of non-increase before stopping the experiment")
148 | parser.add_argument("--validation_metrics", type=str, default="",
149 | help="Validation metrics")
150 | parser.add_argument("--accumulate_gradients", type=int, default=1,
151 | help="Accumulate model gradients over N iterations (N times larger batch sizes)")
152 |
153 | # training coefficients
154 | parser.add_argument("--lambda_mlm", type=str, default="1",
155 | help="Prediction coefficient (MLM)")
156 | parser.add_argument("--lambda_clm", type=str, default="1",
157 | help="Causal coefficient (LM)")
158 | parser.add_argument("--lambda_pc", type=str, default="1",
159 | help="PC coefficient")
160 | parser.add_argument("--lambda_ae", type=str, default="1",
161 | help="AE coefficient")
162 | parser.add_argument("--lambda_mt", type=str, default="1",
163 | help="MT coefficient")
164 | parser.add_argument("--lambda_bt", type=str, default="1",
165 | help="BT coefficient")
166 |
167 | # training steps
168 | parser.add_argument("--clm_steps", type=str, default="",
169 | help="Causal prediction steps (CLM)")
170 | parser.add_argument("--mlm_steps", type=str, default="",
171 | help="Masked prediction steps (MLM / TLM)")
172 | parser.add_argument("--mt_steps", type=str, default="",
173 | help="Machine translation steps")
174 | parser.add_argument("--ae_steps", type=str, default="",
175 | help="Denoising auto-encoder steps")
176 | parser.add_argument("--bt_steps", type=str, default="",
177 | help="Back-translation steps")
178 | parser.add_argument("--pc_steps", type=str, default="",
179 | help="Parallel classification steps")
180 | parser.add_argument("--delay_umt_epoch_num", type=int, default=0,
181 | help="The number of epochs to delay the umt steps")
182 |
183 | # reload pretrained embeddings / pretrained model / checkpoint
184 | parser.add_argument("--reload_emb", type=str, default="",
185 | help="Reload pretrained word embeddings")
186 | parser.add_argument("--reload_model", type=str, default="",
187 | help="Reload a pretrained model")
188 | parser.add_argument("--reload_checkpoint", type=str, default="",
189 | help="Reload a checkpoint")
190 |
191 | # beam search (for MT only)
192 | parser.add_argument("--beam_size", type=int, default=1,
193 | help="Beam size, default = 1 (greedy decoding)")
194 | parser.add_argument("--length_penalty", type=float, default=1,
195 | help="Length penalty, values < 1.0 favor shorter sentences, while values > 1.0 favor longer ones.")
196 | parser.add_argument("--early_stopping", type=bool_flag, default=False,
197 | help="Early stopping, stop as soon as we have `beam_size` hypotheses, although longer ones may have better scores.")
198 |
199 | # evaluation
200 | parser.add_argument("--eval_bleu", type=bool_flag, default=False,
201 | help="Evaluate BLEU score during MT training")
202 | parser.add_argument("--eval_only", type=bool_flag, default=False,
203 | help="Only run evaluations")
204 | parser.add_argument("--bpe_type", type=str, default='fastBPE',
205 | help="Approach to implement BPE such as: fastBPE, sentencepiece")
206 |
207 | # debug
208 | parser.add_argument("--debug_train", type=bool_flag, default=False,
209 | help="Use valid sets for train sets (faster loading)")
210 | parser.add_argument("--debug_slurm", type=bool_flag, default=False,
211 | help="Debug multi-GPU / multi-node within a SLURM job")
212 | parser.add_argument("--debug", help="Enable all debug flags",
213 | action="store_true")
214 |
215 | # multi-gpu / multi-node
216 | parser.add_argument("--local_rank", type=int, default=-1,
217 | help="Multi-GPU - Local rank")
218 | parser.add_argument("--master_port", type=int, default=-1,
219 | help="Master port (for multi-node SLURM jobs)")
220 |
221 | return parser
222 |
223 |
224 | def main(params):
225 |
226 | # initialize the multi-GPU / multi-node training
227 | init_distributed_mode(params)
228 |
229 | # initialize the experiment
230 | logger = initialize_exp(params)
231 |
232 | # initialize SLURM signal handler for time limit / pre-emption
233 | init_signal_handler()
234 |
235 | # load data
236 | data = load_data(params)
237 |
238 | # build model
239 | if params.encoder_only:
240 | model = build_model(params, data['dico'])
241 | else:
242 | encoder, decoder = build_model(params, data['dico'])
243 |
244 | # build trainer, reload potential checkpoints / build evaluator
245 | if params.encoder_only:
246 | trainer = SingleTrainer(model, data, params)
247 | evaluator = SingleEvaluator(trainer, data, params)
248 | else:
249 | trainer = EncDecTrainer(encoder, decoder, data, params)
250 | evaluator = EncDecEvaluator(trainer, data, params)
251 |
252 | # evaluation
253 | if params.eval_only:
254 | scores = evaluator.run_all_evals(trainer)
255 | for k, v in scores.items():
256 | logger.info("%s -> %.6f" % (k, v))
257 | logger.info("__log__:%s" % json.dumps(scores))
258 | exit()
259 |
260 | # set sampling probabilities for training
261 | set_sampling_probs(data, params)
262 |
263 | # language model training
264 | for epoch in range(params.max_epoch):
265 |
266 | logger.info("============ Starting epoch %i ... ============" % trainer.epoch)
267 |
268 | trainer.n_sentences = 0
269 |
270 | while trainer.n_sentences < trainer.epoch_size:
271 |
272 | # CLM steps
273 | if epoch >= params.delay_umt_epoch_num:
274 | for lang1, lang2 in shuf_order(params.clm_steps, params):
275 | trainer.clm_step(lang1, lang2, params.lambda_clm)
276 |
277 | # MLM steps (also includes TLM if lang2 is not None)
278 | if epoch >= params.delay_umt_epoch_num:
279 | for lang1, lang2 in shuf_order(params.mlm_steps, params):
280 | trainer.mlm_step(lang1, lang2, params.lambda_mlm)
281 |
282 | # parallel classification steps
283 | for lang1, lang2 in shuf_order(params.pc_steps, params):
284 | trainer.pc_step(lang1, lang2, params.lambda_pc)
285 |
286 | # denoising auto-encoder steps
287 | if epoch >= params.delay_umt_epoch_num:
288 | for lang in shuf_order(params.ae_steps):
289 | trainer.mt_step(lang, lang, params.lambda_ae)
290 |
291 | # machine translation steps
292 | for lang1, lang2 in shuf_order(params.mt_steps, params):
293 | trainer.mt_step(lang1, lang2, params.lambda_mt)
294 |
295 | # back-translation steps
296 | if epoch >= params.delay_umt_epoch_num:
297 | for lang1, lang2, lang3 in shuf_order(params.bt_steps):
298 | trainer.bt_step(lang1, lang2, lang3, params.lambda_bt)
299 |
300 | trainer.iter()
301 |
302 | logger.info("============ End of epoch %i ============" % trainer.epoch)
303 |
304 | # evaluate perplexity
305 | scores = evaluator.run_all_evals(trainer)
306 |
307 | # print / JSON log
308 | for k, v in scores.items():
309 | logger.info("%s -> %.6f" % (k, v))
310 | if params.is_master:
311 | logger.info("__log__:%s" % json.dumps(scores))
312 |
313 | # end of epoch
314 | trainer.save_best_model(scores)
315 | trainer.save_periodic()
316 | trainer.end_epoch(scores)
317 |
318 |
319 | if __name__ == '__main__':
320 |
321 | # generate parser / parse parameters
322 | parser = get_parser()
323 | params = parser.parse_args()
324 |
325 | # debug mode
326 | if params.debug:
327 | params.exp_name = 'debug'
328 | params.exp_id = 'debug_%08i' % random.randint(0, 100000000)
329 | params.debug_slurm = True
330 | params.debug_train = True
331 |
332 | # check parameters
333 | check_data_params(params)
334 | check_model_params(params)
335 |
336 | # run experiment
337 | main(params)
338 |
--------------------------------------------------------------------------------
/train_IBT.sh:
--------------------------------------------------------------------------------
1 | #
2 | # Read arguments
3 | #
4 | POSITIONAL=()
5 | while [[ $# -gt 0 ]]
6 | do
7 | key="$1"
8 | case $key in
9 | --src)
10 | SRC="$2"; shift 2;;
11 | --tgt)
12 | TGT="$2"; shift 2;;
13 | --data_name)
14 | DATA_NAME="$2"; shift 2;;
15 | --pretrained_model_dir)
16 | PRETRAINED_MODEL_DIR="$2"; shift 2;;
17 | *)
18 | POSITIONAL+=("$1")
19 | shift
20 | ;;
21 | esac
22 | done
23 | set -- "${POSITIONAL[@]}"
24 |
25 | if [ "$SRC" != 'en' ]; then
26 | OTHER_LANG=$SRC
27 | else
28 | OTHER_LANG=$TGT
29 | fi
30 | echo $OTHER_LANG
31 |
32 | if [ "$SRC" \< "$TGT" ]; then
33 | ORDERED_SRC=$SRC
34 | ORDERED_TGT=$TGT
35 | else
36 | ORDERED_SRC=$TGT
37 | ORDERED_TGT=$SRC
38 | fi
39 |
40 |
41 | epoch_size=$(cat data/$ORDERED_SRC-$ORDERED_TGT/$DATA_NAME/processed/$ORDERED_SRC-$ORDERED_TGT/train.$SRC | wc -l)
42 | max_epoch_size=300000
43 | epoch_size=$((epoch_size>max_epoch_size ? max_epoch_size : epoch_size))
44 | echo $epoch_size
45 |
46 | python -W ignore train.py \
47 | --exp_name ibt_$DATA_NAME\_$SRC\_$TGT \
48 | --dump_path ./tmp/ \
49 | --reload_model ${PRETRAINED_MODEL_DIR}/mlm_en${OTHER_LANG}_1024.pth,${PRETRAINED_MODEL_DIR}/mlm_en${OTHER_LANG}_1024.pth \
50 | --data_path data/$ORDERED_SRC-$ORDERED_TGT/$DATA_NAME/processed/$ORDERED_SRC-$ORDERED_TGT \
51 | --lgs $SRC-$TGT \
52 | --ae_steps $SRC,$TGT \
53 | --bt_steps $SRC-$TGT-$SRC,$TGT-$SRC-$TGT \
54 | --word_shuffle 3 \
55 | --word_dropout 0.1 \
56 | --word_blank 0.1 \
57 | --lambda_ae '0:1,100000:0.1,300000:0' \
58 | --encoder_only false \
59 | --emb_dim 1024 \
60 | --n_layers 6 \
61 | --n_heads 8 \
62 | --dropout 0.1 \
63 | --attention_dropout 0.1 \
64 | --gelu_activation true \
65 | --tokens_per_batch 1000 \
66 | --batch_size 32 \
67 | --bptt 256 \
68 | --optimizer adam_inverse_sqrt,beta1=0.9,beta2=0.98,lr=0.0001 \
69 | --epoch_size $epoch_size \
70 | --eval_bleu true \
71 | --stopping_criterion valid_$SRC-$TGT\_mt_bleu,3 \
72 | --validation_metrics valid_$SRC-$TGT\_mt_bleu \
73 | --max_epoch 100 \
74 | --max_len 150 \
--------------------------------------------------------------------------------
/train_IBT_plus_BACK.sh:
--------------------------------------------------------------------------------
1 | #
2 | # Read arguments
3 | #
4 | POSITIONAL=()
5 | while [[ $# -gt 0 ]]
6 | do
7 | key="$1"
8 | case $key in
9 | --src)
10 | SRC="$2"; shift 2;;
11 | --tgt)
12 | TGT="$2"; shift 2;;
13 | --src_data_name)
14 | SRC_DATA_NAME="$2"; shift 2;;
15 | --tgt_data_name)
16 | TGT_DATA_NAME="$2"; shift 2;;
17 | --pretrained_model_dir)
18 | PRETRAINED_MODEL_DIR="$2"; shift 2;;
19 | *)
20 | POSITIONAL+=("$1")
21 | shift
22 | ;;
23 | esac
24 | done
25 | set -- "${POSITIONAL[@]}"
26 |
27 | if [ "$SRC" != 'en' ]; then
28 | OTHER_LANG=$SRC
29 | else
30 | OTHER_LANG=$TGT
31 | fi
32 | echo $OTHER_LANG
33 |
34 | if [ "$SRC" \< "$TGT" ]; then
35 | ORDERED_SRC=$SRC
36 | ORDERED_TGT=$TGT
37 | else
38 | ORDERED_SRC=$TGT
39 | ORDERED_TGT=$SRC
40 | fi
41 |
42 | epoch_size=$(cat data/$ORDERED_SRC-$ORDERED_TGT/$TGT_DATA_NAME/processed/$ORDERED_SRC-$ORDERED_TGT/train.$SRC | wc -l)
43 | max_epoch_size=300000
44 | epoch_size=$((epoch_size>max_epoch_size ? max_epoch_size : epoch_size))
45 | echo $epoch_size
46 |
47 |
48 | python -W ignore train.py \
49 | --exp_name IBT_BACK_src_$SRC_DATA_NAME\_tgt_$TGT_DATA_NAME\_$SRC\_$TGT \
50 | --dump_path ./tmp/ \
51 | --reload_model ${PRETRAINED_MODEL_DIR/mlm_en${OTHER_LANG}_1024.pth,${PRETRAINED_MODEL_DIR/mlm_en${OTHER_LANG}_1024.pth \
52 | --data_path data/$ORDERED_SRC-$ORDERED_TGT/$TGT_DATA_NAME/processed/$ORDERED_SRC-$ORDERED_TGT \
53 | --para_data_path data/$ORDERED_SRC-$ORDERED_TGT/$TGT_DATA_NAME/back_translate/$SRC_DATA_NAME \
54 | --lgs $SRC-$TGT \
55 | --ae_steps $SRC,$TGT \
56 | --bt_steps $SRC-$TGT-$SRC,$TGT-$SRC-$TGT \
57 | --mt_steps $SRC-$TGT \
58 | --word_shuffle 3 \
59 | --word_dropout 0.1 \
60 | --word_blank 0.1 \
61 | --lambda_ae '0:1,100000:0.1,300000:0' \
62 | --encoder_only false \
63 | --emb_dim 1024 \
64 | --n_layers 6 \
65 | --n_heads 8 \
66 | --dropout 0.1 \
67 | --attention_dropout 0.1 \
68 | --gelu_activation true \
69 | --tokens_per_batch 1500 \
70 | --batch_size 32 \
71 | --bptt 256 \
72 | --optimizer adam_inverse_sqrt,beta1=0.9,beta2=0.98,lr=0.0001 \
73 | --epoch_size $epoch_size \
74 | --eval_bleu true \
75 | --stopping_criterion valid_$SRC-$TGT\_mt_bleu,3 \
76 | --validation_metrics valid_$SRC-$TGT\_mt_bleu \
77 | --max_epoch 50 \
78 | --max_len 150 \
--------------------------------------------------------------------------------
/train_IBT_plus_SRC.sh:
--------------------------------------------------------------------------------
1 | #
2 | # Read arguments
3 | #
4 | POSITIONAL=()
5 | while [[ $# -gt 0 ]]
6 | do
7 | key="$1"
8 | case $key in
9 | --src)
10 | SRC="$2"; shift 2;;
11 | --tgt)
12 | TGT="$2"; shift 2;;
13 | --src_data_name)
14 | SRC_DATA_NAME="$2"; shift 2;;
15 | --tgt_data_name)
16 | TGT_DATA_NAME="$2"; shift 2;;
17 | --pretrained_model_dir)
18 | PRETRAINED_MODEL_DIR="$2"; shift 2;;
19 | *)
20 | POSITIONAL+=("$1")
21 | shift
22 | ;;
23 | esac
24 | done
25 | set -- "${POSITIONAL[@]}"
26 |
27 | if [ "$SRC" != 'en' ]; then
28 | OTHER_LANG=$SRC
29 | else
30 | OTHER_LANG=$TGT
31 | fi
32 | echo $OTHER_LANG
33 |
34 | if [ "$SRC" \< "$TGT" ]; then
35 | ORDERED_SRC=$SRC
36 | ORDERED_TGT=$TGT
37 | else
38 | ORDERED_SRC=$TGT
39 | ORDERED_TGT=$SRC
40 | fi
41 |
42 | epoch_size=$(cat data/$ORDERED_SRC-$ORDERED_TGT/$SRC_DATA_NAME/processed/$ORDERED_SRC-$ORDERED_TGT/train.$ORDERED_SRC-$ORDERED_TGT.$SRC | wc -l)
43 | max_epoch_size=500000
44 | epoch_size=$((epoch_size>max_epoch_size ? max_epoch_size : epoch_size))
45 | echo $epoch_size
46 |
47 | python -W ignore train.py \
48 | --exp_name semi_sup_bidir_src_$SRC_DATA_NAME\_tgt_$TGT_DATA_NAME\_$SRC\_$TGT \
49 | --dump_path ./tmp/ \
50 | --reload_model ${PRETRAINED_MODEL_DIR}/mlm_en${OTHER_LANG}_1024.pth,${PRETRAINED_MODEL_DIR}/mlm_en${OTHER_LANG}_1024.pth \
51 | --data_path data/$ORDERED_SRC-$ORDERED_TGT/$TGT_DATA_NAME/processed/$ORDERED_SRC-$ORDERED_TGT \
52 | --para_data_path data/$ORDERED_SRC-$ORDERED_TGT/$SRC_DATA_NAME/processed/$ORDERED_SRC-$ORDERED_TGT \
53 | --lgs $SRC-$TGT \
54 | --ae_steps $SRC,$TGT \
55 | --bt_steps $SRC-$TGT-$SRC,$TGT-$SRC-$TGT \
56 | --mt_steps $SRC-$TGT \
57 | --word_shuffle 3 \
58 | --word_dropout 0.1 \
59 | --word_blank 0.1 \
60 | --lambda_ae '0:1,100000:0.1,300000:0' \
61 | --encoder_only false \
62 | --emb_dim 1024 \
63 | --n_layers 6 \
64 | --n_heads 8 \
65 | --dropout 0.1 \
66 | --attention_dropout 0.1 \
67 | --gelu_activation true \
68 | --tokens_per_batch 1200 \
69 | --batch_size 32 \
70 | --bptt 256 \
71 | --optimizer adam_inverse_sqrt,beta1=0.9,beta2=0.98,lr=0.0001 \
72 | --epoch_size $epoch_size \
73 | --eval_bleu true \
74 | --stopping_criterion valid_$SRC-$TGT\_mt_bleu,3 \
75 | --validation_metrics valid_$SRC-$TGT\_mt_bleu \
76 | --max_epoch 50 \
77 | --max_len 150 \
--------------------------------------------------------------------------------
/train_sup.sh:
--------------------------------------------------------------------------------
1 | #
2 | # Read arguments
3 | #
4 | POSITIONAL=()
5 | while [[ $# -gt 0 ]]
6 | do
7 | key="$1"
8 | case $key in
9 | --src)
10 | SRC="$2"; shift 2;;
11 | --tgt)
12 | TGT="$2"; shift 2;;
13 | --data_name)
14 | DATA_NAME="$2"; shift 2;;
15 | --pretrained_model_dir)
16 | PRETRAINED_MODEL_DIR="$2"; shift 2;;
17 | *)
18 | POSITIONAL+=("$1")
19 | shift
20 | ;;
21 | esac
22 | done
23 | set -- "${POSITIONAL[@]}"
24 |
25 | if [ "$SRC" != 'en' ]; then
26 | OTHER_LANG=$SRC
27 | else
28 | OTHER_LANG=$TGT
29 | fi
30 | echo $OTHER_LANG
31 |
32 | if [ "$SRC" \< "$TGT" ]; then
33 | ORDERED_SRC=$SRC
34 | ORDERED_TGT=$TGT
35 | else
36 | ORDERED_SRC=$TGT
37 | ORDERED_TGT=$SRC
38 | fi
39 |
40 | epoch_size=$(cat data/$ORDERED_SRC-$ORDERED_TGT/$DATA_NAME/processed/$ORDERED_SRC-$ORDERED_TGT/train.$ORDERED_SRC-$ORDERED_TGT.$SRC | wc -l)
41 | max_epoch_size=300000
42 | epoch_size=$((epoch_size>max_epoch_size ? max_epoch_size : epoch_size))
43 | echo $epoch_size
44 |
45 | python -W ignore train.py \
46 | --exp_name sup_$DATA_NAME\_$SRC\_$TGT \
47 | --dump_path ./tmp/ \
48 | --reload_model ${PRETRAINED_MODEL_DIR}/mlm_en${OTHER_LANG}_1024.pth,${PRETRAINED_MODEL_DIR}/mlm_en${OTHER_LANG}_1024.pth \
49 | --data_path data/$ORDERED_SRC-$ORDERED_TGT/$DATA_NAME/processed/$ORDERED_SRC-$ORDERED_TGT \
50 | --lgs $SRC-$TGT \
51 | --mt_steps $SRC-$TGT \
52 | --encoder_only false \
53 | --emb_dim 1024 \
54 | --n_layers 6 \
55 | --n_heads 8 \
56 | --dropout 0.1 \
57 | --attention_dropout 0.1 \
58 | --gelu_activation true \
59 | --tokens_per_batch 2500 \
60 | --batch_size 32 \
61 | --bptt 256 \
62 | --optimizer adam_inverse_sqrt,beta1=0.9,beta2=0.98,lr=0.0001 \
63 | --epoch_size $epoch_size \
64 | --eval_bleu true \
65 | --stopping_criterion valid_$SRC-$TGT\_mt_bleu,2 \
66 | --validation_metrics valid_$SRC-$TGT\_mt_bleu \
67 | --max_epoch 50 \
68 | --max_len 150 \
--------------------------------------------------------------------------------
/translate.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2019-present, Facebook, Inc.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | #
7 | # Translate sentences from the input stream.
8 | # The model will be faster is sentences are sorted by length.
9 | # Input sentences must have the same tokenization and BPE codes than the ones used in the model.
10 | #
11 | # Usage:
12 | # cat source_sentences.bpe | \
13 | # python translate.py --exp_name translate \
14 | # --src_lang en --tgt_lang fr \
15 | # --model_path trained_model.pth --output_path output
16 | #
17 |
18 | import os
19 | import io
20 | import sys
21 | import argparse
22 | import torch
23 |
24 | from src.utils import AttrDict
25 | from src.utils import bool_flag, initialize_exp
26 | from src.data.dictionary import Dictionary
27 | from src.model.transformer import TransformerModel
28 |
29 |
30 | def get_parser():
31 | """
32 | Generate a parameters parser.
33 | """
34 | # parse parameters
35 | parser = argparse.ArgumentParser(description="Translate sentences")
36 |
37 | # main parameters
38 | parser.add_argument("--dump_path", type=str, default="./dumped/", help="Experiment dump path")
39 | parser.add_argument("--exp_name", type=str, default="", help="Experiment name")
40 | parser.add_argument("--exp_id", type=str, default="", help="Experiment ID")
41 | parser.add_argument("--batch_size", type=int, default=32, help="Number of sentences per batch")
42 |
43 | # model / output paths
44 | parser.add_argument("--model_path", type=str, default="", help="Model path")
45 | parser.add_argument("--output_path_source", type=str, default="", help="Output path for source")
46 | parser.add_argument("--output_path_target", type=str, default="", help="Output path for target")
47 |
48 | # parser.add_argument("--max_vocab", type=int, default=-1, help="Maximum vocabulary size (-1 to disable)")
49 | # parser.add_argument("--min_count", type=int, default=0, help="Minimum vocabulary count")
50 |
51 | # source language / target language
52 | parser.add_argument("--src_lang", type=str, default="", help="Source language")
53 | parser.add_argument("--tgt_lang", type=str, default="", help="Target language")
54 | parser.add_argument("--src_data_path", type=str, default="", help="Input data path")
55 |
56 | return parser
57 |
58 |
59 | def main(params):
60 |
61 | # initialize the experiment
62 | logger = initialize_exp(params)
63 |
64 | # generate parser / parse parameters
65 | parser = get_parser()
66 | params = parser.parse_args()
67 | reloaded = torch.load(params.model_path)
68 | model_params = AttrDict(reloaded['params'])
69 | logger.info("Supported languages: %s" % ", ".join(model_params.lang2id.keys()))
70 |
71 | # update dictionary parameters
72 | for name in ['n_words', 'bos_index', 'eos_index', 'pad_index', 'unk_index', 'mask_index']:
73 | setattr(params, name, getattr(model_params, name))
74 |
75 | # build dictionary / build encoder / build decoder / reload weights
76 | dico = Dictionary(reloaded['dico_id2word'], reloaded['dico_word2id'], reloaded['dico_counts'])
77 | encoder = TransformerModel(model_params, dico, is_encoder=True, with_output=True).cuda().eval()
78 | decoder = TransformerModel(model_params, dico, is_encoder=False, with_output=True).cuda().eval()
79 | encoder.load_state_dict(reloaded['encoder'])
80 | decoder.load_state_dict(reloaded['decoder'])
81 | params.src_id = model_params.lang2id[params.src_lang]
82 | params.tgt_id = model_params.lang2id[params.tgt_lang]
83 |
84 | # read sentences from stdin
85 | src_sent = []
86 | for line in open(params.src_data_path, 'r').readlines():
87 | if line.strip() and len(line.split()) <= 130:
88 | src_sent.append(line)
89 | logger.info("Read %i sentences from stdin. Translating ..." % len(src_sent))
90 |
91 | f_src = io.open(params.output_path_source, 'w', encoding='utf-8')
92 | f_tgt = io.open(params.output_path_target, 'w', encoding='utf-8')
93 |
94 | for i in range(0, len(src_sent), params.batch_size):
95 |
96 | # prepare batch
97 | word_ids = [torch.LongTensor([dico.index(w) for w in s.strip().split()])
98 | for s in src_sent[i:i + params.batch_size]]
99 | lengths = torch.LongTensor([len(s) + 2 for s in word_ids])
100 | batch = torch.LongTensor(lengths.max().item(), lengths.size(0)).fill_(params.pad_index)
101 | batch[0] = params.eos_index
102 | for j, s in enumerate(word_ids):
103 | if lengths[j] > 2: # if sentence not empty
104 | batch[1:lengths[j] - 1, j].copy_(s)
105 | batch[lengths[j] - 1, j] = params.eos_index
106 | langs = batch.clone().fill_(params.src_id)
107 |
108 | # encode source batch and translate it
109 | encoded = encoder('fwd', x=batch.cuda(), lengths=lengths.cuda(), langs=langs.cuda(), causal=False)
110 | encoded = encoded.transpose(0, 1)
111 | try:
112 | decoded, dec_lengths = decoder.generate(encoded, lengths.cuda(), params.tgt_id, max_len=int(1.5 * lengths.max().item() + 10))
113 | except:
114 | print(max([len(line.split()) for line in src_sent[i:i + params.batch_size]]))
115 | else:
116 | # convert sentences to words
117 | for j in range(decoded.size(1)):
118 |
119 | # remove delimiters
120 | sent = decoded[:, j]
121 | delimiters = (sent == params.eos_index).nonzero().view(-1)
122 | assert len(delimiters) >= 1 and delimiters[0].item() == 0
123 | sent = sent[1:] if len(delimiters) == 1 else sent[1:delimiters[1]]
124 |
125 | # output translation
126 | source = src_sent[i + j].strip()
127 | target = " ".join([dico[sent[k].item()] for k in range(len(sent))])
128 | sys.stderr.write("%i / %i: %s -> %s\n" % (i + j, len(src_sent), source, target))
129 | f_src.write(source + "\n")
130 | f_tgt.write(target + "\n")
131 |
132 | f_src.close()
133 | f_tgt.close()
134 |
135 |
136 | if __name__ == '__main__':
137 |
138 | # generate parser / parse parameters
139 | parser = get_parser()
140 | params = parser.parse_args()
141 |
142 | # check parameters
143 | assert os.path.isfile(params.model_path)
144 | assert params.src_lang != '' and params.tgt_lang != '' and params.src_lang != params.tgt_lang
145 | # assert params.output_path and not os.path.isfile(params.output_path)
146 |
147 | # translate
148 | with torch.no_grad():
149 | main(params)
150 |
--------------------------------------------------------------------------------
/translate_exe.sh:
--------------------------------------------------------------------------------
1 | #
2 | # Read arguments
3 | #
4 | POSITIONAL=()
5 | while [[ $# -gt 0 ]]
6 | do
7 | key="$1"
8 | case $key in
9 | --src)
10 | SRC="$2"; shift 2;;
11 | --tgt)
12 | TGT="$2"; shift 2;;
13 | --data_name)
14 | DATA_NAME="$2"; shift 2;;
15 | --model_name)
16 | MODEL_NAME="$2"; shift 2;;
17 | --model_dir)
18 | MODEL_DIR="$2"; shift 2;;
19 | *)
20 | POSITIONAL+=("$1")
21 | shift
22 | ;;
23 | esac
24 | done
25 | set -- "${POSITIONAL[@]}"
26 |
27 | if [ "$SRC" \< "$TGT" ]; then
28 | ORDERED_SRC=$SRC
29 | ORDERED_TGT=$TGT
30 | else
31 | ORDERED_SRC=$TGT
32 | ORDERED_TGT=$SRC
33 | fi
34 |
35 | OUT_DIR=data/$ORDERED_SRC-$ORDERED_TGT/$DATA_NAME/back_translate/$MODEL_NAME
36 | mkdir -p $OUT_DIR
37 |
38 | python -W ignore translate.py \
39 | --exp_name $MODEL_NAME\_$SRC\_to_$TGT \
40 | --dump_path ./back_translate/ \
41 | --model_path $MODEL_DIR/best-valid_$SRC-$TGT\_mt_bleu.pth \
42 | --src_data_path data/$ORDERED_SRC-$ORDERED_TGT/$DATA_NAME/processed/$ORDERED_SRC-$ORDERED_TGT/train.$ORDERED_SRC-$ORDERED_TGT.$SRC \
43 | --output_path_source $OUT_DIR/train.$ORDERED_SRC-$ORDERED_TGT.$SRC \
44 | --output_path_target $OUT_DIR/train.$ORDERED_SRC-$ORDERED_TGT.$TGT \
45 | --src_lang $SRC \
46 | --tgt_lang $TGT \
47 | --batch_size 128 \
--------------------------------------------------------------------------------