├── .gitignore ├── README.md ├── heatmap ├── __init__.py ├── heatmap_utils.py └── sequence_scorer.py ├── parse_nmt.py ├── preprocess_nstack2seq_merge.py ├── runs ├── get_last_checkpoint.py ├── infer_model.sh ├── run_nstack_nmt.sh ├── train_classification.py ├── train_fairseq.sh └── treetfm_train.py ├── scripts ├── __init__.py ├── average_checkpoints.py ├── build_sym_alignment.py ├── compound_split_bleu.sh ├── convert_dictionary.lua ├── convert_imdb.py ├── convert_model.lua ├── read_binarized.py └── sacrebleu_pregen.sh └── src ├── __init__.py ├── binarization.py ├── bpe ├── __init__.py └── bpe_utils.py ├── criterions ├── __init__.py ├── classification_cross_entropy.py └── masked_lm_loss.py ├── data ├── __init__.py ├── dptree2seq_dataset.py ├── dptree2seq_sep_dataset.py ├── dptree_dictionary.py ├── dptree_index_dataset.py ├── dptree_mono_class_dataset.py ├── dptree_sep_mono_class_dataset.py ├── monolingual_classification_dataset.py ├── nstack2seq_dataset.py ├── nstack_merge_monoclass_dataset.py ├── nstack_mono_class_dataset.py ├── task_utils.py └── transforms │ ├── __init__.py │ ├── gpt2_bpe.py │ ├── moses_tokenizer.py │ ├── nltk_tokenizer.py │ ├── sentencepiece_bpe.py │ ├── space_tokenizer.py │ └── subword_nmt_bpe.py ├── dptree ├── __init__.py ├── nstack_process.py ├── tree_builder.py └── tree_process.py ├── dptree2seq_generator.py ├── dptree_tokenizer.py ├── models ├── __init__.py ├── dptree2seg_transformer.py ├── nstack_archs.py └── nstack_transformer.py ├── modules ├── __init__.py ├── default_dy_conv.py ├── default_multihead_attention.py ├── dptree_individual_multihead_attention.py ├── dptree_multihead_attention.py ├── dptree_onseq_multihead_attention.py ├── dptree_sep_multihead_attention.py ├── dptree_transformer_layer.py ├── embeddings.py ├── nstack_merge_tree_attention.py ├── nstack_transformer_layers.py ├── nstack_tree_attention.py └── nstack_tree_attention_eff.py ├── nstack2seq_generator.py ├── nstack_tokenizer.py ├── optim ├── __init__.py └── lr_scheduler │ ├── __init__.py │ └── flex_inv_sqrt_schedule.py ├── tasks ├── __init__.py ├── dptree2seq_sep_translation.py ├── dptree2seq_translation.py ├── dptree_classification.py ├── dptree_sep_classification.py ├── fairseq_classification.py ├── nstack2seq_translation.py ├── nstack_classification.py └── nstack_from_dptree_classification.py ├── trainers └── __init__.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # tree_transformer 2 | Submission to ICLR 2020: [https://openreview.net/forum?id=HJxK5pEYvr](https://openreview.net/forum?id=HJxK5pEYvr) 3 | 4 | This is an unofficial example codes for IWSLT'14 En-De 5 | 6 | # Installation 7 | 8 | Install fairseq 9 | ```bash 10 | # install the latest pytorch first 11 | pip install --upgrade fairseq==0.6.2 12 | pip install -U nltk[corenlp] 13 | 14 | git clone https://github.com/XXXXX/tree_transformer.git 15 | ``` 16 | 17 | Install CoreNLP Stanford Parser [here](https://github.com/nltk/nltk/wiki/Stanford-CoreNLP-API-in-NLTK) 18 | Suppose the parser is stored in `stanford-corenlp-full-2018-02-27` 19 | 20 | # Parsing and Preprocess translation data 21 | 22 | Follow preparation of the data [here - Fairseq](https://github.com/pytorch/fairseq/blob/master/examples/translation/prepare-iwslt14.sh). 23 | Suppose the data saved in `raw_data/iwslt14.tokenized.de-en.v2`, this contains the file train.en, train.de, valid.en, valid.de, test.en, test.de 24 | 25 | Run CoreNLP server in a separate terminal 26 | ```bash 27 | cd stanford-corenlp-full-2018-02-27/ 28 | port=9000 29 | java -Xmx12g -cp "*" edu.stanford.nlp.pipeline.StanfordCoreNLPServer -preload tokenize,ssplit,pos,lemma,ner,parse,depparse -status_port $port -port $port -timeout 15000000 30 | ``` 31 | 32 | Parse the data 33 | ```bash 34 | # ----------------- German-English ------------------------ 35 | 36 | export PARSER_PORT=9000 37 | export prefix=train 38 | 39 | export cur=`pwd` 40 | export root=${cur}/raw_data/iwslt14.tokenized.de-en.v2 41 | export before=$root/$prefix.en 42 | export after=$root/$prefix.tree-ende.en 43 | export before_t=$root/$prefix.de 44 | export after_t=$root/$prefix.tree-ende.de 45 | export bpe=${root}/code 46 | python -u tree_transformer/parse_nmt.py --ignore_error --bpe_code ${bpe} --bpe_tree --before $before --after $after --before_tgt ${before_t} --after_tgt ${after_t} 47 | 48 | # do the same for valud 49 | # files train.tree-en.en, train.tree-ende.de, valid.tree-ende.en, valid.tree-ende.de, .... 50 | ``` 51 | 52 | 53 | Preprocess data into Fairseq 54 | ```bash 55 | # IWSLT - En-De 56 | export ROOT_DIR=`pwd` 57 | export PROJDIR=tree_transformer 58 | export user_dir=${ROOT_DIR}/${PROJDIR} 59 | export RAW_DIR=${ROOT_DIR}/raw_data/iwslt14.tokenized.de-en.v2 60 | export BPE=${RAW_DIR}/code 61 | export train_r=${RAW_DIR}/train.tree-ende 62 | export valid_r=${RAW_DIR}/train.tree-ende 63 | export test_r=${RAW_DIR}/train.tree-ende 64 | export OUT=${ROOT_DIR}/data_fairseq/nstack_merge_translate_ende_iwslt_32k 65 | rm -rf $OUT 66 | python -m tree_transformer.preprocess_nstack2seq_merge \ 67 | --source-lang en --target-lang de \ 68 | --user-dir ${user_dir} \ 69 | --trainpref ${train_r} \ 70 | --validpref ${valid_r} \ 71 | --testpref ${test_r} \ 72 | --destdir $OUT \ 73 | --joined-dictionary \ 74 | --nwordssrc 32768 --nwordstgt 32768 \ 75 | --bpe_code ${BPE} \ 76 | --no_remove_root \ 77 | --workers 8 \ 78 | --eval_workers 0 \ 79 | 80 | # processed data saved in data_fairseq/nstack_merge_translate_ende_iwslt_32k 81 | ``` 82 | 83 | 84 | 85 | # Training 86 | 87 | ```bash 88 | export MAXTOKENS=1024 89 | export INFER=y 90 | export dis_port_str=--master_port=6102 91 | export problem=nstack_merge_iwslt_ende_32k 92 | export MAX_UPDATE=61000 93 | export UPDATE_FREQ=1 94 | export att_dropout=0.2 95 | export DROPOUT=0.3 && 96 | bash run_nstack_nmt.sh dwnstack_merge2seq_node_iwslt_onvalue_base_upmean_mean_mlesubenc_allcross_hier 0,1,2,3,4,5,6,7 97 | 98 | ``` 99 | 100 | 101 | 102 | 103 | 104 | 105 | -------------------------------------------------------------------------------- /heatmap/__init__.py: -------------------------------------------------------------------------------- 1 | from . import heatmap_utils 2 | from . import sequence_scorer -------------------------------------------------------------------------------- /heatmap/sequence_scorer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import torch 9 | import os 10 | import torch.nn.functional as F 11 | from fairseq import utils 12 | 13 | from fairseq.sequence_generator import SequenceGenerator, EnsembleModel 14 | from ..src.nstack2seq_generator import Nstack2SeqGenerator 15 | 16 | GET_ENCOUT = bool(int(os.environ.get('get_encout', 0))) 17 | GET_INNER_ATT = bool(int(os.environ.get('get_inner_att', 0))) 18 | INNER_ATT = int(os.environ.get('inner_att', -1)) 19 | 20 | 21 | class HeatmapSequenceScorer(object): 22 | """Scores the target for a given source sentence.""" 23 | 24 | def __init__(self, models, pad): 25 | self.models = models 26 | # self.pad = tgt_dict.pad() 27 | self.pad = pad 28 | 29 | def cuda(self): 30 | for model in self.models: 31 | model.cuda() 32 | return self 33 | 34 | def score_batched_itr(self, data_itr, cuda=False, timer=None): 35 | """Iterate over a batched dataset and yield scored translations.""" 36 | for sample in data_itr: 37 | s = utils.move_to_cuda(sample) if cuda else sample 38 | if timer is not None: 39 | timer.start() 40 | pos_scores, attn = self.score(s) 41 | for i, id in enumerate(s['id'].data): 42 | # remove padding from ref 43 | src = utils.strip_pad(s['net_input']['src_tokens'].data[i, :], self.pad) 44 | ref = utils.strip_pad(s['target'].data[i, :], self.pad) if s['target'] is not None else None 45 | tgt_len = ref.numel() 46 | pos_scores_i = pos_scores[i][:tgt_len] 47 | score_i = pos_scores_i.sum() / tgt_len 48 | if attn is not None: 49 | attn_i = attn[i] 50 | _, alignment = attn_i.max(dim=0) 51 | else: 52 | attn_i = alignment = None 53 | hypos = [{ 54 | 'tokens': ref, 55 | 'score': score_i, 56 | 'attention': attn_i, 57 | 'alignment': alignment, 58 | 'positional_scores': pos_scores_i, 59 | }] 60 | if timer is not None: 61 | timer.stop(s['ntokens']) 62 | # return results in the same format as SequenceGenerator 63 | yield id, src, ref, hypos 64 | 65 | def score(self, sample): 66 | """Score a batch of translations.""" 67 | net_input = sample['net_input'] 68 | 69 | # compute scores for each model in the ensemble 70 | avg_probs = None 71 | avg_attn = None 72 | for model in self.models: 73 | with torch.no_grad(): 74 | model.eval() 75 | # decoder_out = model.forward(**net_input) 76 | prev_output_tokens = net_input['prev_output_tokens'] 77 | del net_input['prev_output_tokens'] 78 | encoder_out = model.encoder(**net_input) 79 | decoder_out = model.decoder(prev_output_tokens, encoder_out) 80 | # return decoder_out 81 | if GET_ENCOUT: 82 | attn = F.softmax(100 * encoder_out['encoder_out'].transpose(0, 1), 1).mean(-1) 83 | bsz, tk = attn.size() 84 | tq = prev_output_tokens.size(1) 85 | attn = attn.unsqueeze_(1).expand(bsz, tq, tk) 86 | cross_attn = decoder_out[1]['attn'] 87 | assert list(attn.size()) == list(cross_attn.size()), f'{attn.size()} != {cross_attn.size()}, {prev_output_tokens.size()}' 88 | # attn: [b, tk, C] 89 | else: 90 | attn = decoder_out[1] 91 | 92 | probs = model.get_normalized_probs(decoder_out, log_probs=len(self.models) == 1, sample=sample).data 93 | if avg_probs is None: 94 | avg_probs = probs 95 | else: 96 | avg_probs.add_(probs) 97 | 98 | if attn is not None: 99 | # {'attn': attn, 'inner_states': inner_states} 100 | if not torch.is_tensor(attn): 101 | if GET_INNER_ATT: 102 | attn = attn['inner_atts'][INNER_ATT] 103 | else: 104 | attn = attn['attn'] 105 | 106 | assert torch.is_tensor(attn), f'attn: {attn}' 107 | attn = attn.data 108 | if avg_attn is None: 109 | avg_attn = attn 110 | else: 111 | avg_attn.add_(attn) 112 | if len(self.models) > 1: 113 | avg_probs.div_(len(self.models)) 114 | avg_probs.log_() 115 | if avg_attn is not None: 116 | avg_attn.div_(len(self.models)) 117 | avg_probs = avg_probs.gather( 118 | dim=2, 119 | index=sample['target'].data.unsqueeze(-1), 120 | ) 121 | return avg_probs.squeeze(2), avg_attn 122 | 123 | 124 | class HeatmapSequenceAttentionEntropyScorer(HeatmapSequenceScorer): 125 | def score(self, sample): 126 | """Score a batch of translations.""" 127 | net_input = sample['net_input'] 128 | 129 | # compute scores for each model in the ensemble 130 | avg_probs = None 131 | avg_attn = None 132 | assert len(self.models) == 1, f'{len(self.models)} not 1' 133 | for model in self.models: 134 | with torch.no_grad(): 135 | model.eval() 136 | # decoder_out = model.forward(**net_input) 137 | prev_output_tokens = net_input['prev_output_tokens'] 138 | del net_input['prev_output_tokens'] 139 | encoder_out = model.encoder(**net_input) 140 | decoder_out = model.decoder(prev_output_tokens, encoder_out) 141 | # return decoder_out 142 | if GET_ENCOUT: 143 | # attn = F.softmax(100 * encoder_out['encoder_out'].transpose(0, 1), 1).mean(-1) 144 | # bsz, tk = attn.size() 145 | # tq = prev_output_tokens.size(1) 146 | # attn = attn.unsqueeze_(1).expand(bsz, tq, tk) 147 | # cross_attn = decoder_out[1]['attn'] 148 | # assert list(attn.size()) == list(cross_attn.size()), f'{attn.size()} != {cross_attn.size()}, {prev_output_tokens.size()}' 149 | # # attn: [b, tk, C] 150 | raise NotImplementedError 151 | else: 152 | attn = decoder_out[1] 153 | 154 | probs = model.get_normalized_probs(decoder_out, log_probs=len(self.models) == 1, sample=sample).data 155 | if avg_probs is None: 156 | avg_probs = probs 157 | else: 158 | avg_probs.add_(probs) 159 | 160 | assert 'inner_atts' in attn 161 | inner_atts = attn['inner_atts'] 162 | avg_attn = inner_atts[-1] 163 | # [b, tq, tk] 164 | inner_att_entropies = [-(x * x.log()).sum(dim=-1) for x in inner_atts] 165 | # [b, tq] 166 | 167 | inner_atts = torch.cat([x.unsqueeze_(-1) for x in inner_atts], dim=-1) 168 | inner_att_entropies = torch.cat([x.unsqueeze_(-1) for x in inner_att_entropies], dim=-1) 169 | # [b, tq, tk, 6] 170 | # [b, tq, 6] 171 | 172 | # if attn is not None: 173 | # # {'attn': attn, 'inner_states': inner_states} 174 | # if not torch.is_tensor(attn): 175 | # if GET_INNER_ATT: 176 | # attn = attn['inner_atts'][INNER_ATT] 177 | # else: 178 | # attn = attn['attn'] 179 | # 180 | # assert torch.is_tensor(attn), f'attn: {attn}' 181 | # attn = attn.data 182 | # if avg_attn is None: 183 | # avg_attn = attn 184 | # else: 185 | # avg_attn.add_(attn) 186 | 187 | if len(self.models) > 1: 188 | avg_probs.div_(len(self.models)) 189 | avg_probs.log_() 190 | if avg_attn is not None: 191 | avg_attn.div_(len(self.models)) 192 | 193 | avg_probs = avg_probs.gather( 194 | dim=2, 195 | index=sample['target'].data.unsqueeze(-1), 196 | ) 197 | return avg_probs.squeeze(2), avg_attn, inner_atts, inner_att_entropies 198 | 199 | def score_batched_itr(self, data_itr, cuda=False, timer=None): 200 | """Iterate over a batched dataset and yield scored translations.""" 201 | for sample in data_itr: 202 | s = utils.move_to_cuda(sample) if cuda else sample 203 | if timer is not None: 204 | timer.start() 205 | pos_scores, attn, inner_atts, inner_att_entropies = self.score(s) 206 | for i, id in enumerate(s['id'].data): 207 | # remove padding from ref 208 | src = utils.strip_pad(s['net_input']['src_tokens'].data[i, :], self.pad) 209 | ref = utils.strip_pad(s['target'].data[i, :], self.pad) if s['target'] is not None else None 210 | tgt_len = ref.numel() 211 | pos_scores_i = pos_scores[i][:tgt_len] 212 | score_i = pos_scores_i.sum() / tgt_len 213 | if attn is not None: 214 | attn_i = attn[i] 215 | _, alignment = attn_i.max(dim=0) 216 | else: 217 | attn_i = alignment = None 218 | 219 | inner_att = inner_atts[i] 220 | inner_att_entropy = inner_att_entropies[i] 221 | 222 | hypos = [{ 223 | 'tokens': ref, 224 | 'score': score_i, 225 | 'attention': attn_i, 226 | 'inner_att': inner_att, 227 | 'inner_att_entropy': inner_att_entropy, 228 | 'alignment': alignment, 229 | 'positional_scores': pos_scores_i, 230 | }] 231 | if timer is not None: 232 | timer.stop(s['ntokens']) 233 | # return results in the same format as SequenceGenerator 234 | yield id, src, ref, hypos 235 | 236 | 237 | class Nstack2SeqHeatmapScorer(object): 238 | 239 | def __init__(self, generator, image_dir, **kwargs) -> None: 240 | super().__init__() 241 | self.generator = generator 242 | self.image_dir = image_dir 243 | 244 | def generate( 245 | self, 246 | models, 247 | sample, 248 | prefix_tokens=None, 249 | bos_token=None, 250 | **kwargs 251 | ): 252 | hypos = self.generator.generate(models, sample, prefix_tokens=prefix_tokens) 253 | target = sample['target'] 254 | 255 | flipped_src_tokens = sample['net_input']['src_tokens'] 256 | src_tokens = torch.flip(flipped_src_tokens, [2]) 257 | attention = hypos['attention'] 258 | 259 | # src_tokens = hypos['tokens'] 260 | """ 261 | 'tokens': tokens_clone[i], 262 | 'score': score, 263 | 'attention': hypo_attn, # src_len x tgt_len 264 | 'alignment': alignment, 265 | 'positional_scores': pos_scores[i], 266 | """ 267 | 268 | assert src_tokens.size(0) == 1, f'bsz should be 1 ->{src_tokens.size()}' 269 | 270 | 271 | 272 | -------------------------------------------------------------------------------- /runs/get_last_checkpoint.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import collections 3 | import torch 4 | import os 5 | import re 6 | 7 | 8 | def last_n_checkpoint_index(paths, n, update_based, upper_bound=None): 9 | # assert len(paths) == 1 10 | # path = paths[0] 11 | # assert len(paths) == 1 12 | path = paths 13 | 14 | if update_based: 15 | pt_regexp = re.compile(r'checkpoint_\d+_(\d+)\.pt') 16 | else: 17 | pt_regexp = re.compile(r'checkpoint(\d+)\.pt') 18 | files = os.listdir(path) 19 | # print(files) 20 | # print(pt_regexp) 21 | entries = [] 22 | for f in files: 23 | m = pt_regexp.fullmatch(f) 24 | if m is not None: 25 | sort_key = int(m.group(1)) 26 | if upper_bound is None or sort_key <= upper_bound: 27 | entries.append((sort_key, m.group(0))) 28 | if len(entries) < n: 29 | # print(paths) 30 | raise Exception('Found {} checkpoint files but need at least {}', len(entries), n) 31 | last_checkpoints = [os.path.join(path, x[1]) for x in sorted(entries, reverse=True)[:n]] 32 | last_checkpoint_index = [x[0] for x in sorted(entries, reverse=True)[:n]][0] 33 | return last_checkpoint_index 34 | 35 | 36 | if __name__ == '__main__': 37 | parser = argparse.ArgumentParser() 38 | 39 | parser.add_argument('--dir', required=True, help='Input checkpoint file paths.') 40 | args = parser.parse_args() 41 | 42 | assert os.path.exists(args.dir) 43 | last_checkpoint_index = last_n_checkpoint_index( 44 | args.dir, 1, False, upper_bound=None, 45 | ) 46 | print(last_checkpoint_index) 47 | 48 | -------------------------------------------------------------------------------- /runs/infer_model.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | 4 | 5 | set -e 6 | # specify machines 7 | 8 | [ -z "$CUDA_VISIBLE_DEVICES" ] && { echo "Must set export CUDA_VISIBLE_DEVICES="; exit 1; } || echo "CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES}" 9 | IFS=',' read -r -a GPUS <<< "$CUDA_VISIBLE_DEVICES" 10 | export NUM_GPU=${#GPUS[@]} 11 | 12 | 13 | export MACHINE="${MACHINE:-ntu}" 14 | 15 | echo "MACHINE -> ${MACHINE}" 16 | 17 | export ROOT_DIR=`pwd` 18 | export PROJDIR=tree_transformer 19 | export ROOT_DIR="${ROOT_DIR/\/tree_transformer\/runs/}" 20 | 21 | 22 | export user_dir=${ROOT_DIR}/${PROJDIR} 23 | 24 | 25 | 26 | if [ -d ${TRAIN_DIR} ]; then 27 | # if train exists 28 | echo "directory train exists!: ${TRAIN_DIR}" 29 | else 30 | echo "directory train not exists!: ${TRAIN_DIR}" 31 | exit 1 32 | fi 33 | 34 | 35 | #export EPOCHS="${EPOCHS:-300}" 36 | export PROBLEM="${PROBLEM:-nstack_merge_translate_ende_iwslt_32k}" 37 | 38 | export RAW_DATA_DIR=${ROOT_DIR}/raw_data_fairseq/${PROBLEM} 39 | export DATA_DIR=${ROOT_DIR}/data_fairseq/${PROBLEM} 40 | export TRAIN_DIR_PREFIX=${ROOT_DIR}/train_tree_transformer/${PROBLEM} 41 | 42 | export EXP="${EXP:-transformer_wmt_ende_8gpu1}" 43 | export ID="${ID:-1}" 44 | export INFER_ID="${INFER_ID:-1}" 45 | 46 | export extra_params="${extra_params:-}" 47 | 48 | export TGT_LANG="${TGT_LANG:-de}" 49 | export SRC_LANG="${SRC_LANG:-en}" 50 | export TESTSET="${TESTSET:-newstest2014}" 51 | 52 | 53 | export INFERMODE="${INFERMODE:-avg}" 54 | export INFER_DIR=${TRAIN_DIR}/infer 55 | mkdir -p ${INFER_DIR} 56 | 57 | # generate parameters 58 | 59 | export TASK="${TASK:-translation}" 60 | export BEAM="${BEAM:-5}" 61 | #export INFER_BSZ="${INFER_BSZ:-128}" 62 | #export INFER_BSZ="${INFER_BSZ:-4096}" 63 | export INFER_BSZ="${INFER_BSZ:-2048}" 64 | export LENPEN="${LENPEN:-0.6}" 65 | #export LEFT_PAD_SRC="${LEFT_PAD_SRC:-True}" 66 | export LEFT_PAD_SRC="${LEFT_PAD_SRC:-False}" 67 | export RMBPE="${RMBPE:-y}" 68 | export GETBLEU="${GETBLEU:-y}" 69 | export NEWCODE="${NEWCODE:-y}" 70 | export INFER_TASK="${INFER_TASK:-mt}" 71 | export HEAT="${HEAT:-n}" 72 | export rm_srceos="${rm_srceos:-0}" 73 | export rm_lastpunct="${rm_lastpunct:-0}" 74 | export get_entropies="${get_entropies:-0}" 75 | export GENSET="${GENSET:-test}" 76 | 77 | export GEN_DIR=${INFER_DIR}/${GENSET}.tok.rmBpe${RMBPE}.genout.${TGT_LANG}.b${BEAM}.lenpen${LENPEN}.leftpad${LEFT_PAD_SRC}.${INFERMODE} 78 | 79 | 80 | [ ${rm_srceos} -eq 1 ] && export rm_srceos_s="--remove-eos-from-source " || export rm_srceos_s= 81 | [ ${rm_lastpunct} -eq 1 ] && export rm_lastpunct_s="--remove-last-punct-source " || export rm_lastpunct_s= 82 | [ ${get_entropies} -eq 1 ] && export get_entropies_s="--layer-att-entropy " || export get_entropies_s= 83 | [ ${RMBPE} == "y" ] && export rm_bpe_s="--remove-bpe " || export rm_bpe_s= 84 | 85 | #/projects/nmt/train_tree_transformer/wmt16_en_de_new_bpe/me_vaswani_wmt_en_de_big-transformer_big_128-b5120-gpu8-upfre16-1fp16-id24 86 | echo "========== INFERENCE =================" 87 | echo "TASK = ${TASK}" 88 | echo "infermode = ${INFERMODE}" 89 | echo "BEAM = ${BEAM}" 90 | echo "INFER_BSZ = ${INFER_BSZ}" 91 | echo "LENPEN = ${LENPEN}" 92 | echo "LEFT_PAD_SRC = ${LEFT_PAD_SRC}" 93 | echo "RMBPE = ${RMBPE}" 94 | echo "GETBLEU = ${GETBLEU}" 95 | echo "NEWCODE = ${NEWCODE}" 96 | echo "rm_srceos = ${rm_srceos} - string=${rm_srceos_s}" 97 | echo "rm_lastpunct = ${rm_lastpunct} - string=${rm_lastpunct_s}" 98 | echo "========== INFERENCE =================" 99 | 100 | # selecting infermode 101 | # --------------------------------------------------------------------------------------------------- 102 | if [ ${INFERMODE} == "best" ]; then 103 | 104 | export CHECKPOINT=${TRAIN_DIR}/checkpoint_best.pt 105 | mkdir -p ${GEN_DIR} 106 | 107 | export GEN_OUT=${GEN_DIR}/infer 108 | export HYPO=${GEN_OUT}.hypo 109 | export REF=${GEN_OUT}.ref 110 | export BLEU_OUT=${GEN_OUT}.bleu 111 | 112 | echo "GEN_OUT = ${GEN_OUT}" 113 | 114 | # --------------------------------------------------------------------------------------------------------- 115 | elif [ ${INFERMODE} == "avg" ]; then 116 | 117 | export AVG_NUM="${AVG_NUM:-5}" 118 | # export UPPERBOUND="${UPPERBOUND:-22}" 119 | export UPPERBOUND="${UPPERBOUND:-100000000}" 120 | # export AVG_CHECKPOINT_OUT="${AVG_CHECKPOINT_OUT:-$TRAIN_DIR/averaged_model.${AVG_NUM}.u${UPPERBOUND}.pt}" 121 | export LAST_EPOCH=`python get_last_checkpoint.py --dir=${TRAIN_DIR}` 122 | 123 | export GEN_DIR=${GEN_DIR}.avg${AVG_NUM}.e${LAST_EPOCH}.u${UPPERBOUND} 124 | mkdir -p ${GEN_DIR} 125 | 126 | export AVG_CHECKPOINT_OUT="${AVG_CHECKPOINT_OUT:-$GEN_DIR/averaged_model.id${INFER_ID}.avg${AVG_NUM}.e${LAST_EPOCH}.u${UPPERBOUND}.pt}" 127 | export GEN_OUT=${GEN_DIR}/infer 128 | export GEN_OUT=${GEN_OUT}.avg${AVG_NUM}.b${BEAM}.lp${LENPEN} 129 | export HYPO=${GEN_OUT}.hypo 130 | export REF=${GEN_OUT}.ref 131 | export BLEU_OUT=${GEN_OUT}.bleu 132 | 133 | echo "GEN_DIR = ${GEN_DIR}" 134 | echo "GEN_OUT = ${GEN_OUT}" 135 | 136 | 137 | echo "AVG_NUM = ${AVG_NUM}" 138 | echo "LAST_EPOCH = ${LAST_EPOCH}" 139 | echo "AVG_CHECKPOINT_OUT = ${AVG_CHECKPOINT_OUT}" 140 | echo "---- Score by averaging last checkpoints ${AVG_NUM} -> ${AVG_CHECKPOINT_OUT}" 141 | echo "Generating average checkpoints..." 142 | # exit 1 143 | 144 | if [ -f ${AVG_CHECKPOINT_OUT} ]; then 145 | echo "File ${AVG_CHECKPOINT_OUT} exists...." 146 | else 147 | python ../scripts/average_checkpoints.py \ 148 | --user-dir ${user_dir} \ 149 | --inputs ${TRAIN_DIR} \ 150 | --num-epoch-checkpoints ${AVG_NUM} \ 151 | --checkpoint-upper-bound ${UPPERBOUND} \ 152 | --output ${AVG_CHECKPOINT_OUT} 153 | echo "Finish generating averaged, start generating samples" 154 | fi 155 | 156 | export CHECKPOINT=${AVG_CHECKPOINT_OUT} 157 | 158 | else 159 | echo "INFERMODE invalid: ${INFERMODE}" 160 | exit 1 161 | fi 162 | 163 | 164 | echo "Start generating" 165 | 166 | 167 | 168 | export command="$(which fairseq-generate) ${DATA_DIR} \ 169 | --task ${TASK} \ 170 | --user-dir ${user_dir} \ 171 | --path ${CHECKPOINT} \ 172 | --left-pad-source ${LEFT_PAD_SRC} \ 173 | --max-tokens ${INFER_BSZ} \ 174 | --beam ${BEAM} \ 175 | --gen-subset ${GENSET} \ 176 | --lenpen ${LENPEN} \ 177 | ${extra_params} \ 178 | ${rm_bpe_s} ${rm_srceos_s} ${rm_lastpunct_s} | dd of=${GEN_OUT}" 179 | # ${rm_bpe_s} ${rm_srceos_s} ${rm_lastpunct_s} | tee ${GEN_OUT}" 180 | 181 | echo "Command: ${command}" 182 | echo "----------------------------------------" 183 | eval ${command} 184 | 185 | 186 | echo "---- Score by score.py for mode=${INFERMODE}, avg=${AVG_NUM} -----" 187 | echo "decode bleu from model ${AVG_CHECKPOINT_OUT}" 188 | echo "decode bleu from file ${GEN_OUT}" 189 | echo ".............................." 190 | 191 | export SRC=${GEN_OUT}.src 192 | export HYPO=${GEN_OUT}.hypo 193 | export REF=${GEN_OUT}.ref 194 | export REF_TW=${GEN_OUT}.ref.tweak 195 | export BLEU_OUT=${GEN_OUT}.bleu 196 | 197 | grep ^S ${GEN_OUT} | cut -f2- > ${SRC} 198 | grep ^T ${GEN_OUT} | cut -f2- > ${REF} 199 | grep ^H ${GEN_OUT} | cut -f3- > ${HYPO} 200 | 201 | 202 | 203 | $(which fairseq-score) --sys ${HYPO} --ref ${REF} > ${BLEU_OUT} 204 | cat ${BLEU_OUT} 205 | echo "" 206 | 207 | 208 | -------------------------------------------------------------------------------- /runs/run_nstack_nmt.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | 4 | set -e 5 | 6 | export arch=$1 7 | export gpus=$2 8 | 9 | 10 | export CUDA_VISIBLE_DEVICES=$gpus 11 | export id="${id:-50}" 12 | export fp16="${fp16:-0}" 13 | 14 | 15 | #export problem="${problem:-nstack_iwslt_ende_v2}" 16 | export problem="${problem:-nstack_merge_iwslt_ende_32k}" 17 | 18 | export RM_EXIST_DIR="${RM_EXIST_DIR:-y}" 19 | 20 | echo "CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES}, problem=${problem}" 21 | 22 | 23 | 24 | export dis_port_str="${dis_port_str:-}" 25 | 26 | export KEEP_LAS_CHECKPOINT="${KEEP_LAS_CHECKPOINT:-20}" 27 | 28 | 29 | #if [ ${problem} == "nstack_merge_iwslt_ende_32k" ]; then 30 | export PROBLEM=nstack_merge_translate_ende_iwslt_32k 31 | export TASK=nstack_merge2seq 32 | 33 | export HPARAMS=transformer_base 34 | export LEFT_PAD_SRC=False 35 | 36 | # export MAXTOKENS=2048 37 | export MAXTOKENS="${MAXTOKENS:-1024}" 38 | export UPDATE_FREQ="${UPDATE_FREQ:-1}" 39 | export RM_EXIST_DIR="${RM_EXIST_DIR:-n}" 40 | 41 | export INFER="${INFER:-y}" 42 | export max_pos="${max_pos:-1024}" 43 | export max_tgt_pos="${max_tgt_pos:-1024}" 44 | export more_params="${more_params:-}" 45 | # export more_params=--on_filter_nsent 46 | 47 | # export extra_params="--append-eos-to-target --max-source-positions ${max_pos} --max-target-positions ${max_tgt_pos}" 48 | export extra_params="--append-eos-to-target ${more_params} --max-source-positions ${max_pos} --max-target-positions ${max_tgt_pos}" 49 | 50 | export ID=${id}msp${max_pos}default 51 | export DDP_BACKEND="${DDP_BACKEND:-no_c10d}" 52 | 53 | export out_log="${out_log:-y}" 54 | export log_dir=`pwd`/../../${problem}-logs/ 55 | #export log_dir=`pwd`/../../gpu4-logs/ 56 | mkdir -p ${log_dir} 57 | export log_file=${log_dir}/${PROBLEM}-${arch}-${ID}.log 58 | # export MAX_UPDATE="${MAX_UPDATE:-35500}" 59 | export MAX_UPDATE="${MAX_UPDATE:-45000}" 60 | export nobar=1 61 | 62 | export LR="${LR:-0.0005}" 63 | export DROPOUT="${DROPOUT:-0.3}" 64 | export WDECAY="${WDECAY:-0.0001}" 65 | export AVG_NUM=10 66 | export LENPEN=1 67 | 68 | export ARCH=${arch} 69 | 70 | 71 | 72 | bash train_fairseq.sh -------------------------------------------------------------------------------- /runs/train_fairseq.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -e 4 | 5 | pip install fairseq==0.6.2 6 | pip install tensorboardX 7 | 8 | # todo: specify gpus 9 | [ -z "$CUDA_VISIBLE_DEVICES" ] && { echo "Must set export CUDA_VISIBLE_DEVICES="; exit 1; } || echo "CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES}" 10 | IFS=',' read -r -a GPUS <<< "$CUDA_VISIBLE_DEVICES" 11 | export NUM_GPU=${#GPUS[@]} 12 | 13 | export ROOT_DIR=`pwd` 14 | export PROJDIR=tree_transformer 15 | export ROOT_DIR="${ROOT_DIR/\/tree_transformer\/runs/}" 16 | 17 | export user_dir=${ROOT_DIR}/${PROJDIR} 18 | 19 | export PROBLEM="${PROBLEM:-translate_ende_wmt_bpe32k}" 20 | 21 | export RAW_DATA_DIR=${ROOT_DIR}/raw_data_fairseq/${PROBLEM} 22 | export DATA_DIR=${ROOT_DIR}/data_fairseq/${PROBLEM} 23 | export TRAIN_DIR_PREFIX=${ROOT_DIR}/train_tree_transformer/${PROBLEM} 24 | 25 | 26 | export ID="${ID:-1}" 27 | export HPARAMS="${HPARAMS:-transformer_base}" 28 | 29 | [ -z "$ARCH" ] && { echo "Must set export ARCH="; exit 1; } || echo "ARCH = ${ARCH}" 30 | 31 | 32 | if [ ${HPARAMS} == "transformer_base" ]; then 33 | export TASK="${TASK:-translation}" 34 | 35 | export OPTIM=adam 36 | export ADAMBETAS='(0.9, 0.98)' 37 | export CLIPNORM=0.0 38 | export LRSCHEDULE=inverse_sqrt 39 | export WARMUP_INIT=1e-07 40 | # export 41 | # wamrup 4000 for 8 gpus, 16000 for 1 gpus 42 | export WARMUP="${WARMUP:-4000}" 43 | export LR="${LR:-0.0007}" 44 | export MIN_LR=1e-09 45 | export DROPOUT="${DROPOUT:-0.1}" 46 | export WDECAY="${WDECAY:-0.0}" 47 | export LB_SMOOTH=0.1 48 | export MAXTOKENS="${MAXTOKENS:-4096}" # 8gpus 49 | export UPDATE_FREQ="${UPDATE_FREQ:-8}" 50 | # export LEFT_PAD_SRC="${LEFT_PAD_SRC:-False}" 51 | 52 | elif [ ${HPARAMS} == "transformer_base_stt2" ]; then 53 | 54 | # export MODEL=transformer 55 | # export HPARAMS=lm_gbw 56 | # export ARCH=${MODEL}_${HPARAMS} 57 | export ARCH="${ARCH:-fi_transformer_encoder_class_tiny}" 58 | export TASK="${TASK:-seq_classification}" 59 | export CRITERION="${CRITERION:-classification_cross_entropy}" 60 | 61 | export OPTIM="${OPTIM:-adam}" 62 | export ADAMBETAS='(0.9, 0.98)' 63 | export CLIPNORM=0.0 64 | export LRSCHEDULE=inverse_sqrt 65 | export WARMUP_INIT=1e-07 66 | # wamrup 4000 for 8 gpus, 16000 for 1 gpus 67 | export WARMUP="${WARMUP:-4000}" 68 | export LR="${LR:-0.0007}" 69 | export MIN_LR=1e-09 70 | export DROPOUT="${DROPOUT:-0.1}" 71 | export WDECAY=0.0 72 | export LB_SMOOTH=0.1 73 | export MAXTOKENS="${MAXTOKENS:-4096}" # 8gpus 74 | export UPDATE_FREQ="${UPDATE_FREQ:-1}" 75 | export MAX_UPDATE="${MAX_UPDATE:-2000}" 76 | export LEFT_PAD_SRC="${LEFT_PAD_SRC:-True}" 77 | # export LEFT_PAD_SRC=True 78 | export log_interval="${log_interval:-1000}" 79 | export max_sent_valid="--max-sentences-valid 1" 80 | export NCLASSES="${NCLASSES:-2}" 81 | export train_command="${train_command:-fairseq-train}" 82 | else 83 | 84 | echo "undefined HPARAMS: ${HPARAMS}" 85 | exit 1 86 | fi 87 | 88 | 89 | export LR_PERIOD_UPDATES="${LR_PERIOD_UPDATES:-20000}" 90 | 91 | export MAX_UPDATE="${MAX_UPDATE:-103000}" 92 | export KEEP_LAS_CHECKPOINT="${KEEP_LAS_CHECKPOINT:-10}" 93 | 94 | export DDP_BACKEND="${DDP_BACKEND:-c10d}" 95 | export LRSRINK="${LRSRINK:-0.1}" 96 | export MAX_LR="${MAX_LR:-0.001}" 97 | export WORKERS="${WORKERS:-0}" 98 | export INFER="${INFER:-y}" 99 | export DISTRIBUTED="${DISTRIBUTED:-y}" 100 | export CRITERION="${CRITERION:-label_smoothed_cross_entropy}" 101 | 102 | export VALID_SET="${VALID_SET:-valid}" 103 | 104 | export CRAWL_TEST="${CRAWL_TEST:-n}" 105 | export extra_params="${extra_params:-}" 106 | 107 | export RM_EXIST_DIR="${RM_EXIST_DIR:-n}" 108 | 109 | export src="${src:-en}" 110 | 111 | export optim_text="${optim_text:---optimizer ${OPTIM} --clip-norm ${CLIPNORM} }" 112 | export scheduler_text="${scheduler_text:---lr-scheduler ${LRSCHEDULE} --warmup-init-lr ${WARMUP_INIT} --warmup-updates ${WARMUP} }" 113 | 114 | export fp16="${fp16:-0}" 115 | export rm_srceos="${rm_srceos:-0}" 116 | export rm_lastpunct="${rm_lastpunct:-0}" 117 | export nobar="${nobar:-1}" 118 | export shareemb="${shareemb:-1}" 119 | export shareemb_dec="${shareemb_dec:-0}" 120 | export usetfboard="${usetfboard:-0}" 121 | 122 | export dis_port_str="${dis_port_str:-}" 123 | export nrank_str="${nrank_str:-}" 124 | 125 | export max_sent_valid="${max_sent_valid:-}" 126 | 127 | export att_dropout="${att_dropout:-0}" 128 | export weight_dropout="${weight_dropout:-0}" 129 | #--max-sentences-valid 1 130 | 131 | # todo: specify distributed and fp16 132 | #[ ${fp16} -eq 0 ] && export fp16s="#" || export fp16s= 133 | [ ${fp16} -eq 1 ] && export fp16s="--fp16 " || export fp16s= 134 | [ ${rm_srceos} -eq 1 ] && export rm_srceos_s="--remove-eos-from-source " || export rm_srceos_s= 135 | [ ${rm_lastpunct} -eq 1 ] && export rm_lastpunct_s="--remove-last-punct-source " || export rm_lastpunct_s= 136 | [ ${nobar} -eq 1 ] && export nobarstr="--no-progress-bar" || export nobarstr= 137 | [ ${NUM_GPU} -gt 1 ] && export distro= || export distro="#" 138 | 139 | [ ${shareemb} -eq 1 ] && export shareemb_str="--share-all-embeddings " || export shareemb_str= 140 | [ ${shareemb_dec} -eq 1 ] && export shareemb_dec_str="--share-decoder-input-output-embed " || export shareemb_dec_str= 141 | [ ${usetfboard} -eq 1 ] && export tfboardstr="--tensorboard-logdir ${TFBOARD_DIR} " || export tfboardstr= 142 | 143 | [ ${att_dropout} -eq 0 ] && export att_dropout_str= || export att_dropout_str="--attention-dropout ${att_dropout} " 144 | [ ${weight_dropout} -eq 0 ] && export weight_dropout_str= || export weight_dropout_str="--weight-dropout ${weight_dropout} " 145 | 146 | 147 | #--attention-dropout 0.1 \ 148 | # --weight-dropout 0.1 \ 149 | 150 | 151 | export LEFT_PAD_SRC="${LEFT_PAD_SRC:-False}" 152 | export log_interval="${log_interval:-100}" 153 | 154 | export TRAIN_DIR=${TRAIN_DIR_PREFIX}/${ARCH}-${HPARAMS}-b${MAXTOKENS}-gpu${NUM_GPU}-upfre${UPDATE_FREQ}-${fp16}fp16-id${ID} 155 | export TFBOARD_DIR=${TRAIN_DIR}/tfboard 156 | 157 | #rm -rf ${TRAIN_DIR} 158 | echo "=====================================================" 159 | echo "START TRAINING: ${TRAIN_DIR}" 160 | echo "PROJDIR: ${PROJDIR}" 161 | echo "user_dir: ${user_dir}" 162 | echo "ARCH: ${ARCH}" 163 | echo "HPARAMS: ${HPARAMS}" 164 | echo "DISTRO: ${distro}" 165 | echo "INFER: ${INFER}" 166 | echo "CRITERION: ${CRITERION}" 167 | echo "fp16: ${fp16}" 168 | echo "rm_srceos: ${rm_srceos}" 169 | echo "rm_lastpunct_s: ${rm_lastpunct_s}" 170 | echo "TFBOARD_DIR: ${TFBOARD_DIR}" 171 | echo "=====================================================" 172 | 173 | if [ ${RM_EXIST_DIR} == "y" ]; then 174 | echo "Removing existing folder ${TRAIN_DIR}...." 175 | rm -rf ${TRAIN_DIR} 176 | fi 177 | 178 | mkdir -p ${TRAIN_DIR} 179 | 180 | export out_log="${out_log:-n}" 181 | export LOGFILE="${LOGFILE:-$TRAIN_DIR/train.log}" 182 | export tee_begin="" 183 | export tee_end="" 184 | 185 | if [ $out_log == "y" ]; then 186 | echo "Printing logs to log file ${LOGFILE}" 187 | export tee_begin=" -u " 188 | export tee_end=" | tee ${LOGFILE}" 189 | touch ${LOGFILE} 190 | fi 191 | 192 | #export last_params 193 | 194 | export train_command="${train_command:-fairseq-train}" 195 | 196 | 197 | if [ ${DISTRIBUTED} == "y" ]; then 198 | if [ ${NUM_GPU} -gt 1 ]; then 199 | export init_command="python ${tee_begin} -m torch.distributed.launch ${dis_port_str} --nproc_per_node ${NUM_GPU} $(which fairseq-train) ${DATA_DIR} --ddp-backend=${DDP_BACKEND} ${nobarstr} ${nrank_str}" 200 | else 201 | export init_command="$(which fairseq-train) ${tee_begin} ${DATA_DIR} --ddp-backend=${DDP_BACKEND} ${nobarstr} " 202 | fi 203 | else 204 | export init_command="${train_command} ${DATA_DIR} --ddp-backend=${DDP_BACKEND} ${nobarstr} " 205 | fi 206 | 207 | 208 | 209 | echo "init_command = ${init_command}" 210 | 211 | 212 | echo "Run model ${ARCH}, ${HPARAMS}" 213 | 214 | 215 | if [ ${HPARAMS} == "transformer_base_stt2" ]; then 216 | export full_command="${init_command} \ 217 | --user-dir ${user_dir} \ 218 | --arch ${ARCH} \ 219 | --task ${TASK} \ 220 | --valid-subset ${VALID_SET} \ 221 | --source-lang ${src} \ 222 | --log-interval ${log_interval} \ 223 | --num-workers ${WORKERS} \ 224 | --share-all-embeddings \ 225 | ${optim_text} \ 226 | ${scheduler_text} \ 227 | --lr ${LR} \ 228 | --min-lr ${MIN_LR} \ 229 | --dropout ${DROPOUT} \ 230 | --weight-decay ${WDECAY} \ 231 | --update-freq ${UPDATE_FREQ} \ 232 | --criterion ${CRITERION} \ 233 | --label-smoothing ${LB_SMOOTH} \ 234 | --max-tokens ${MAXTOKENS} \ 235 | --left-pad-source ${LEFT_PAD_SRC} \ 236 | --max-update ${MAX_UPDATE} \ 237 | --save-dir ${TRAIN_DIR} \ 238 | --keep-last-epochs ${KEEP_LAS_CHECKPOINT} \ 239 | --nclasses ${NCLASSES} \ 240 | ${max_sent_valid} \ 241 | ${extra_params} \ 242 | ${fp16s} ${rm_srceos_s} ${rm_lastpunct_s} | tee ${LOGFILE}" 243 | 244 | else 245 | echo "run else commands" 246 | export full_command="${init_command} \ 247 | --user-dir ${user_dir} \ 248 | --arch ${ARCH} \ 249 | --task ${TASK} \ 250 | --source-lang ${src} \ 251 | --log-interval ${log_interval} \ 252 | --num-workers ${WORKERS} \ 253 | --optimizer ${OPTIM} \ 254 | --clip-norm ${CLIPNORM} \ 255 | --lr-scheduler ${LRSCHEDULE} \ 256 | --warmup-init-lr ${WARMUP_INIT} \ 257 | --warmup-updates ${WARMUP} \ 258 | --lr ${LR} \ 259 | --min-lr ${MIN_LR} \ 260 | --dropout ${DROPOUT} \ 261 | --weight-decay ${WDECAY} \ 262 | --update-freq ${UPDATE_FREQ} \ 263 | --criterion ${CRITERION} \ 264 | --label-smoothing ${LB_SMOOTH} \ 265 | --adam-betas '(0.9, 0.98)' \ 266 | --max-tokens ${MAXTOKENS} \ 267 | --left-pad-source ${LEFT_PAD_SRC} \ 268 | --max-update ${MAX_UPDATE} \ 269 | --save-dir ${TRAIN_DIR} \ 270 | --keep-last-epochs ${KEEP_LAS_CHECKPOINT} \ 271 | ${att_dropout_str} \ 272 | ${weight_dropout_str} \ 273 | ${tfboardstr} \ 274 | ${shareemb_str} \ 275 | ${shareemb_dec_str} \ 276 | ${max_sent_valid} \ 277 | ${extra_params} \ 278 | ${fp16s} ${rm_srceos_s} ${rm_lastpunct_s} ${tee_end}" 279 | fi 280 | 281 | echo "full command: " 282 | echo $full_command 283 | 284 | eval $full_command 285 | 286 | echo "==================" 287 | echo "==================" 288 | echo "finish training at ${TRAIN_DIR}" 289 | 290 | if [ ${INFER} == "y" ]; then 291 | echo "Start inference ...." 292 | bash infer_model.sh 293 | fi 294 | 295 | 296 | 297 | 298 | 299 | 300 | -------------------------------------------------------------------------------- /scripts/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nxphi47/tree_transformer/8ac39e40441b14011b440dece6374bb4231632cc/scripts/__init__.py -------------------------------------------------------------------------------- /scripts/average_checkpoints.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import argparse 4 | import collections 5 | import torch 6 | import os 7 | import re 8 | from fairseq.utils import import_user_module 9 | 10 | 11 | def default_avg_params(params_dict): 12 | averaged_params = collections.OrderedDict() 13 | 14 | # v should be a list of torch Tensor. 15 | for k, v in params_dict.items(): 16 | summed_v = None 17 | for x in v: 18 | summed_v = summed_v + x if summed_v is not None else x 19 | averaged_params[k] = summed_v / len(v) 20 | 21 | return averaged_params 22 | 23 | 24 | def ema_avg_params(params_dict, ema_decay): 25 | averaged_params = collections.OrderedDict() 26 | lens = [len(v) for k, v in params_dict.items()] 27 | assert all(x == lens[0] for x in lens), f'lens params: {lens}' 28 | num_checkpoints = lens[0] 29 | # y = x 30 | 31 | for k, v in params_dict.items(): 32 | # order: newest to oldest 33 | # reverse the order 34 | # y_t = x_t * decay + y_{t-1} * (1 - decay) 35 | total_v = None 36 | for x in reversed(v): 37 | if total_v is None: 38 | total_v = x 39 | else: 40 | total_v = x * ema_decay + total_v * (1.0 - ema_decay) 41 | 42 | averaged_params[k] = total_v 43 | return averaged_params 44 | 45 | 46 | def average_checkpoints(inputs, ema_decay=1.0): 47 | """Loads checkpoints from inputs and returns a model with averaged weights. 48 | 49 | Args: 50 | inputs: An iterable of string paths of checkpoints to load from. 51 | 52 | Returns: 53 | A dict of string keys mapping to various values. The 'model' key 54 | from the returned dict should correspond to an OrderedDict mapping 55 | string parameter names to torch Tensors. 56 | """ 57 | params_dict = collections.OrderedDict() 58 | params_keys = None 59 | new_state = None 60 | for i, f in enumerate(inputs): 61 | state = torch.load( 62 | f, 63 | map_location=( 64 | lambda s, _: torch.serialization.default_restore_location(s, 'cpu') 65 | ), 66 | ) 67 | # Copies over the settings from the first checkpoint 68 | if new_state is None: 69 | new_state = state 70 | 71 | model_params = state['model'] 72 | 73 | model_params_keys = list(model_params.keys()) 74 | if params_keys is None: 75 | params_keys = model_params_keys 76 | elif params_keys != model_params_keys: 77 | raise KeyError( 78 | 'For checkpoint {}, expected list of params: {}, ' 79 | 'but found: {}'.format(f, params_keys, model_params_keys) 80 | ) 81 | 82 | for k in params_keys: 83 | if k not in params_dict: 84 | params_dict[k] = [] 85 | p = model_params[k] 86 | if isinstance(p, torch.HalfTensor): 87 | p = p.float() 88 | params_dict[k].append(p) 89 | 90 | if ema_decay < 1.0: 91 | print(f'Exponential moving averaging, decay={ema_decay}') 92 | averaged_params = ema_avg_params(params_dict, ema_decay) 93 | else: 94 | print(f'Default averaging') 95 | averaged_params = default_avg_params(params_dict) 96 | new_state['model'] = averaged_params 97 | return new_state 98 | 99 | 100 | def last_n_checkpoints(paths, n, update_based, upper_bound=None): 101 | assert len(paths) == 1 102 | path = paths[0] 103 | if update_based: 104 | pt_regexp = re.compile(r'checkpoint_\d+_(\d+)\.pt') 105 | else: 106 | pt_regexp = re.compile(r'checkpoint(\d+)\.pt') 107 | files = os.listdir(path) 108 | 109 | entries = [] 110 | for f in files: 111 | m = pt_regexp.fullmatch(f) 112 | if m is not None: 113 | sort_key = int(m.group(1)) 114 | if upper_bound is None or sort_key <= upper_bound: 115 | entries.append((sort_key, m.group(0))) 116 | if len(entries) < n: 117 | raise Exception('Found {} checkpoint files but need at least {}', len(entries), n) 118 | return [os.path.join(path, x[1]) for x in sorted(entries, reverse=True)[:n]] 119 | 120 | 121 | def main(): 122 | parser = argparse.ArgumentParser( 123 | description='Tool to average the params of input checkpoints to ' 124 | 'produce a new checkpoint', 125 | ) 126 | # fmt: off 127 | parser.add_argument('--inputs', required=True, nargs='+', 128 | help='Input checkpoint file paths.') 129 | parser.add_argument('--output', required=True, metavar='FILE', 130 | help='Write the new checkpoint containing the averaged weights to this path.') 131 | num_group = parser.add_mutually_exclusive_group() 132 | num_group.add_argument('--num-epoch-checkpoints', type=int, 133 | help='if set, will try to find checkpoints with names checkpoint_xx.pt in the path specified by input, ' 134 | 'and average last this many of them.') 135 | num_group.add_argument('--num-update-checkpoints', type=int, 136 | help='if set, will try to find checkpoints with names checkpoint_ee_xx.pt in the path specified by input, ' 137 | 'and average last this many of them.') 138 | parser.add_argument('--checkpoint-upper-bound', type=int, 139 | help='when using --num-epoch-checkpoints, this will set an upper bound on which checkpoint to use, ' 140 | 'e.g., with --num-epoch-checkpoints=10 --checkpoint-upper-bound=50, checkpoints 41-50 would be averaged.') 141 | 142 | # parser.add_argument('--ema', type=float, default=1.0, help='exponential moving average decay') 143 | # parser.add_argument('--no-progress-bar', action='store_true', help='disable progress bar') 144 | parser.add_argument('--ema', default='False', type=str, metavar='BOOL', help='ema') 145 | parser.add_argument('--ema_decay', type=float, default=1.0, help='exponential moving average decay') 146 | parser.add_argument('--user-dir', default=None) 147 | 148 | # fmt: on 149 | args = parser.parse_args() 150 | 151 | import_user_module(args) 152 | print(args) 153 | 154 | num = None 155 | is_update_based = False 156 | if args.num_update_checkpoints is not None: 157 | num = args.num_update_checkpoints 158 | is_update_based = True 159 | elif args.num_epoch_checkpoints is not None: 160 | num = args.num_epoch_checkpoints 161 | 162 | assert args.checkpoint_upper_bound is None or args.num_epoch_checkpoints is not None, \ 163 | '--checkpoint-upper-bound requires --num-epoch-checkpoints' 164 | assert args.num_epoch_checkpoints is None or args.num_update_checkpoints is None, \ 165 | 'Cannot combine --num-epoch-checkpoints and --num-update-checkpoints' 166 | 167 | if num is not None: 168 | args.inputs = last_n_checkpoints( 169 | args.inputs, num, is_update_based, upper_bound=args.checkpoint_upper_bound, 170 | ) 171 | # print('averaging checkpoints: ', args.inputs) 172 | print('averaging checkpoints: ') 173 | for checkpoint in args.inputs: 174 | print(checkpoint) 175 | print('-' * 40) 176 | 177 | # ema = args.ema 178 | # assert isinstance(args.ema, bool) 179 | print(f'Start averaing with ema={args.ema}, ema_decay={args.ema_decay}') 180 | new_state = average_checkpoints(args.inputs, args.ema_decay) 181 | torch.save(new_state, args.output) 182 | print('Finished writing averaged checkpoint to {}.'.format(args.output)) 183 | 184 | 185 | if __name__ == '__main__': 186 | main() 187 | -------------------------------------------------------------------------------- /scripts/build_sym_alignment.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | # 8 | 9 | """ 10 | Use this script in order to build symmetric alignments for your translation 11 | dataset. 12 | This script depends on fast_align and mosesdecoder tools. You will need to 13 | build those before running the script. 14 | fast_align: 15 | github: http://github.com/clab/fast_align 16 | instructions: follow the instructions in README.md 17 | mosesdecoder: 18 | github: http://github.com/moses-smt/mosesdecoder 19 | instructions: http://www.statmt.org/moses/?n=Development.GetStarted 20 | The script produces the following files under --output_dir: 21 | text.joined - concatenation of lines from the source_file and the 22 | target_file. 23 | align.forward - forward pass of fast_align. 24 | align.backward - backward pass of fast_align. 25 | aligned.sym_heuristic - symmetrized alignment. 26 | """ 27 | 28 | import argparse 29 | import os 30 | from itertools import zip_longest 31 | 32 | 33 | def main(): 34 | parser = argparse.ArgumentParser(description='symmetric alignment builer') 35 | # fmt: off 36 | parser.add_argument('--fast_align_dir', 37 | help='path to fast_align build directory') 38 | parser.add_argument('--mosesdecoder_dir', 39 | help='path to mosesdecoder root directory') 40 | parser.add_argument('--sym_heuristic', 41 | help='heuristic to use for symmetrization', 42 | default='grow-diag-final-and') 43 | parser.add_argument('--source_file', 44 | help='path to a file with sentences ' 45 | 'in the source language') 46 | parser.add_argument('--target_file', 47 | help='path to a file with sentences ' 48 | 'in the target language') 49 | parser.add_argument('--output_dir', 50 | help='output directory') 51 | # fmt: on 52 | args = parser.parse_args() 53 | 54 | fast_align_bin = os.path.join(args.fast_align_dir, 'fast_align') 55 | symal_bin = os.path.join(args.mosesdecoder_dir, 'bin', 'symal') 56 | sym_fast_align_bin = os.path.join( 57 | args.mosesdecoder_dir, 'scripts', 'ems', 58 | 'support', 'symmetrize-fast-align.perl') 59 | 60 | # create joined file 61 | joined_file = os.path.join(args.output_dir, 'text.joined') 62 | with open(args.source_file, 'r', encoding='utf-8') as src, open(args.target_file, 'r', encoding='utf-8') as tgt: 63 | with open(joined_file, 'w', encoding='utf-8') as joined: 64 | for s, t in zip_longest(src, tgt): 65 | print('{} ||| {}'.format(s.strip(), t.strip()), file=joined) 66 | 67 | bwd_align_file = os.path.join(args.output_dir, 'align.backward') 68 | 69 | # run forward alignment 70 | fwd_align_file = os.path.join(args.output_dir, 'align.forward') 71 | fwd_fast_align_cmd = '{FASTALIGN} -i {JOINED} -d -o -v > {FWD}'.format( 72 | FASTALIGN=fast_align_bin, 73 | JOINED=joined_file, 74 | FWD=fwd_align_file) 75 | assert os.system(fwd_fast_align_cmd) == 0 76 | 77 | # run backward alignment 78 | bwd_align_file = os.path.join(args.output_dir, 'align.backward') 79 | bwd_fast_align_cmd = '{FASTALIGN} -i {JOINED} -d -o -v -r > {BWD}'.format( 80 | FASTALIGN=fast_align_bin, 81 | JOINED=joined_file, 82 | BWD=bwd_align_file) 83 | assert os.system(bwd_fast_align_cmd) == 0 84 | 85 | # run symmetrization 86 | sym_out_file = os.path.join(args.output_dir, 'aligned') 87 | sym_cmd = '{SYMFASTALIGN} {FWD} {BWD} {SRC} {TGT} {OUT} {HEURISTIC} {SYMAL}'.format( 88 | SYMFASTALIGN=sym_fast_align_bin, 89 | FWD=fwd_align_file, 90 | BWD=bwd_align_file, 91 | SRC=args.source_file, 92 | TGT=args.target_file, 93 | OUT=sym_out_file, 94 | HEURISTIC=args.sym_heuristic, 95 | SYMAL=symal_bin 96 | ) 97 | assert os.system(sym_cmd) == 0 98 | 99 | 100 | if __name__ == '__main__': 101 | main() 102 | -------------------------------------------------------------------------------- /scripts/compound_split_bleu.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | if [ $# -ne 1 ]; then 4 | echo "usage: $0 GENERATE_PY_OUTPUT" 5 | exit 1 6 | fi 7 | 8 | GEN=$1 9 | 10 | SYS=$GEN.sys 11 | REF=$GEN.ref 12 | 13 | if [ $(tail -n 1 $GEN | grep BLEU | wc -l) -ne 1 ]; then 14 | echo "not done generating" 15 | exit 16 | fi 17 | 18 | grep ^H $GEN | cut -f3- | perl -ple 's{(\S)-(\S)}{$1 ##AT##-##AT## $2}g' > $SYS 19 | grep ^T $GEN | cut -f2- | perl -ple 's{(\S)-(\S)}{$1 ##AT##-##AT## $2}g' > $REF 20 | 21 | perl -ple 's{(\S)-(\S)}{$1 ##AT##-##AT## $2}g' > $REF 22 | python score.py --sys $SYS --ref $REF 23 | -------------------------------------------------------------------------------- /scripts/convert_dictionary.lua: -------------------------------------------------------------------------------- 1 | -- Copyright (c) 2017-present, Facebook, Inc. 2 | -- All rights reserved. 3 | -- 4 | -- This source code is licensed under the license found in the LICENSE file in 5 | -- the root directory of this source tree. An additional grant of patent rights 6 | -- can be found in the PATENTS file in the same directory. 7 | -- 8 | -- Usage: convert_dictionary.lua 9 | require 'fairseq' 10 | require 'torch' 11 | require 'paths' 12 | 13 | if #arg < 1 then 14 | print('usage: convert_dictionary.lua ') 15 | os.exit(1) 16 | end 17 | if not paths.filep(arg[1]) then 18 | print('error: file does not exit: ' .. arg[1]) 19 | os.exit(1) 20 | end 21 | 22 | dict = torch.load(arg[1]) 23 | dst = paths.basename(arg[1]):gsub('.th7', '.txt') 24 | assert(dst:match('.txt$')) 25 | 26 | f = io.open(dst, 'w') 27 | for idx, symbol in ipairs(dict.index_to_symbol) do 28 | if idx > dict.cutoff then 29 | break 30 | end 31 | f:write(symbol) 32 | f:write(' ') 33 | f:write(dict.index_to_freq[idx]) 34 | f:write('\n') 35 | end 36 | f:close() 37 | -------------------------------------------------------------------------------- /scripts/convert_model.lua: -------------------------------------------------------------------------------- 1 | -- Copyright (c) 2017-present, Facebook, Inc. 2 | -- All rights reserved. 3 | -- 4 | -- This source code is licensed under the license found in the LICENSE file in 5 | -- the root directory of this source tree. An additional grant of patent rights 6 | -- can be found in the PATENTS file in the same directory. 7 | -- 8 | -- Usage: convert_model.lua 9 | require 'torch' 10 | local fairseq = require 'fairseq' 11 | 12 | model = torch.load(arg[1]) 13 | 14 | function find_weight_norm(container, module) 15 | for _, wn in ipairs(container:listModules()) do 16 | if torch.type(wn) == 'nn.WeightNorm' and wn.modules[1] == module then 17 | return wn 18 | end 19 | end 20 | end 21 | 22 | function push_state(dict, key, module) 23 | if torch.type(module) == 'nn.Linear' then 24 | local wn = find_weight_norm(model.module, module) 25 | assert(wn) 26 | dict[key .. '.weight_v'] = wn.v:float() 27 | dict[key .. '.weight_g'] = wn.g:float() 28 | elseif torch.type(module) == 'nn.TemporalConvolutionTBC' then 29 | local wn = find_weight_norm(model.module, module) 30 | assert(wn) 31 | local v = wn.v:float():view(wn.viewOut):transpose(2, 3) 32 | dict[key .. '.weight_v'] = v 33 | dict[key .. '.weight_g'] = wn.g:float():view(module.weight:size(3), 1, 1) 34 | else 35 | dict[key .. '.weight'] = module.weight:float() 36 | end 37 | if module.bias then 38 | dict[key .. '.bias'] = module.bias:float() 39 | end 40 | end 41 | 42 | encoder_dict = {} 43 | decoder_dict = {} 44 | combined_dict = {} 45 | 46 | function encoder_state(encoder) 47 | luts = encoder:findModules('nn.LookupTable') 48 | push_state(encoder_dict, 'embed_tokens', luts[1]) 49 | push_state(encoder_dict, 'embed_positions', luts[2]) 50 | 51 | fcs = encoder:findModules('nn.Linear') 52 | assert(#fcs >= 2) 53 | local nInputPlane = fcs[1].weight:size(1) 54 | push_state(encoder_dict, 'fc1', table.remove(fcs, 1)) 55 | push_state(encoder_dict, 'fc2', table.remove(fcs, #fcs)) 56 | 57 | for i, module in ipairs(encoder:findModules('nn.TemporalConvolutionTBC')) do 58 | push_state(encoder_dict, 'convolutions.' .. tostring(i - 1), module) 59 | if nInputPlane ~= module.weight:size(3) / 2 then 60 | push_state(encoder_dict, 'projections.' .. tostring(i - 1), table.remove(fcs, 1)) 61 | end 62 | nInputPlane = module.weight:size(3) / 2 63 | end 64 | assert(#fcs == 0) 65 | end 66 | 67 | function decoder_state(decoder) 68 | luts = decoder:findModules('nn.LookupTable') 69 | push_state(decoder_dict, 'embed_tokens', luts[1]) 70 | push_state(decoder_dict, 'embed_positions', luts[2]) 71 | 72 | fcs = decoder:findModules('nn.Linear') 73 | local nInputPlane = fcs[1].weight:size(1) 74 | push_state(decoder_dict, 'fc1', table.remove(fcs, 1)) 75 | push_state(decoder_dict, 'fc2', fcs[#fcs - 1]) 76 | push_state(decoder_dict, 'fc3', fcs[#fcs]) 77 | 78 | table.remove(fcs, #fcs) 79 | table.remove(fcs, #fcs) 80 | 81 | for i, module in ipairs(decoder:findModules('nn.TemporalConvolutionTBC')) do 82 | if nInputPlane ~= module.weight:size(3) / 2 then 83 | push_state(decoder_dict, 'projections.' .. tostring(i - 1), table.remove(fcs, 1)) 84 | end 85 | nInputPlane = module.weight:size(3) / 2 86 | 87 | local prefix = 'attention.' .. tostring(i - 1) 88 | push_state(decoder_dict, prefix .. '.in_projection', table.remove(fcs, 1)) 89 | push_state(decoder_dict, prefix .. '.out_projection', table.remove(fcs, 1)) 90 | push_state(decoder_dict, 'convolutions.' .. tostring(i - 1), module) 91 | end 92 | assert(#fcs == 0) 93 | end 94 | 95 | 96 | _encoder = model.module.modules[2] 97 | _decoder = model.module.modules[3] 98 | 99 | encoder_state(_encoder) 100 | decoder_state(_decoder) 101 | 102 | for k, v in pairs(encoder_dict) do 103 | combined_dict['encoder.' .. k] = v 104 | end 105 | for k, v in pairs(decoder_dict) do 106 | combined_dict['decoder.' .. k] = v 107 | end 108 | 109 | 110 | torch.save('state_dict.t7', combined_dict) 111 | -------------------------------------------------------------------------------- /scripts/read_binarized.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) 2017-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the LICENSE file in 6 | # the root directory of this source tree. An additional grant of patent rights 7 | # can be found in the PATENTS file in the same directory. 8 | # 9 | 10 | import argparse 11 | 12 | from fairseq.data import dictionary 13 | from fairseq.data import IndexedDataset 14 | 15 | 16 | def get_parser(): 17 | parser = argparse.ArgumentParser( 18 | description='writes text from binarized file to stdout') 19 | # fmt: off 20 | parser.add_argument('--dict', metavar='FP', required=True, help='dictionary containing known words') 21 | parser.add_argument('--input', metavar='FP', required=True, help='binarized file to read') 22 | # fmt: on 23 | 24 | return parser 25 | 26 | 27 | def main(args): 28 | dict = dictionary.Dictionary.load(args.dict) 29 | ds = IndexedDataset(args.input, fix_lua_indexing=True) 30 | for tensor_line in ds: 31 | print(dict.string(tensor_line)) 32 | 33 | 34 | if __name__ == '__main__': 35 | parser = get_parser() 36 | args = parser.parse_args() 37 | main(args) 38 | -------------------------------------------------------------------------------- /scripts/sacrebleu_pregen.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | if [ $# -ne 4 ]; then 4 | echo "usage: $0 TESTSET SRCLANG TGTLANG GEN" 5 | exit 1 6 | fi 7 | 8 | TESTSET=$1 9 | SRCLANG=$2 10 | TGTLANG=$3 11 | 12 | GEN=$4 13 | 14 | echo 'Cloning Moses github repository (for tokenization scripts)...' 15 | git clone https://github.com/moses-smt/mosesdecoder.git 16 | 17 | SCRIPTS=mosesdecoder/scripts 18 | DETOKENIZER=$SCRIPTS/tokenizer/tokenizer.perl 19 | 20 | grep ^H $GEN \ 21 | | sed 's/^H\-//' \ 22 | | sort -n -k 1 \ 23 | | cut -f 3 \ 24 | | perl $DETOKENIZER -l $TGTLANG \ 25 | | sed "s/ - /-/g" \ 26 | > $GEN.sorted.detok 27 | 28 | sacrebleu --test-set $TESTSET --language-pair "${SRCLANG}-${TGTLANG}" < $GEN.sorted.detok 29 | -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | # from .multiprocessing_pdb import pdb 9 | # 10 | # __all__ = ['pdb'] 11 | # __version__ = '0.6.0' 12 | 13 | 14 | # import src.models 15 | # import src.modules 16 | from . import utils 17 | from . import optim 18 | from . import bpe 19 | from . import criterions 20 | from . import models 21 | from . import modules 22 | from . import dptree 23 | from . import trainers 24 | 25 | from . import data 26 | from . import tasks 27 | 28 | from . import dptree_tokenizer 29 | from . import nstack_tokenizer 30 | 31 | 32 | from . import binarization 33 | 34 | 35 | 36 | 37 | # import fairseq.optim 38 | # import fairseq.optim.lr_scheduler 39 | # import fairseq.tasks 40 | # import fairseq.dptree 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | -------------------------------------------------------------------------------- /src/bpe/__init__.py: -------------------------------------------------------------------------------- 1 | from . import bpe_utils -------------------------------------------------------------------------------- /src/criterions/__init__.py: -------------------------------------------------------------------------------- 1 | from . import classification_cross_entropy 2 | 3 | 4 | 5 | -------------------------------------------------------------------------------- /src/criterions/masked_lm_loss.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | 9 | import math 10 | import torch 11 | import torch.nn.functional as F 12 | 13 | from fairseq import utils 14 | from fairseq.criterions import FairseqCriterion, register_criterion 15 | 16 | 17 | def compute_cross_entropy_loss(logits, targets, ignore_index=-100): 18 | """ 19 | Function to compute the cross entropy loss. The default value of 20 | ignore_index is the same as the default value for F.cross_entropy in 21 | pytorch. 22 | """ 23 | assert logits.size(0) == targets.size(-1), \ 24 | "Logits and Targets tensor shapes don't match up" 25 | 26 | loss = F.nll_loss( 27 | F.log_softmax(logits, -1, dtype=torch.float32), 28 | targets, 29 | reduction="sum", 30 | ignore_index=ignore_index, 31 | ) 32 | return loss 33 | 34 | 35 | @register_criterion('masked_lm_loss') 36 | class MaskedLmLoss(FairseqCriterion): 37 | """ 38 | Implementation for the loss used in masked language model (MLM) training. 39 | This optionally also computes the next sentence prediction (NSP) loss and 40 | adds it to the overall loss based on the specified args. There are three 41 | cases to consider: 42 | 1) Generic MLM training without NSP loss. In this case sentence_targets 43 | and sentence_logits are both None. 44 | 2) BERT training without NSP loss. In this case sentence_targets is 45 | not None but sentence_logits is None and we should not be computing 46 | a sentence level loss. 47 | 3) BERT training with NSP loss. In this case both sentence_targets and 48 | sentence_logits are not None and we should be computing a sentence 49 | level loss. The weight of the sentence level loss is specified as 50 | an argument. 51 | """ 52 | 53 | def __init__(self, args, task): 54 | super().__init__(args, task) 55 | 56 | @staticmethod 57 | def add_args(parser): 58 | """Args for MaskedLM Loss""" 59 | # Default for masked_lm_only is False so as to not break BERT training 60 | parser.add_argument('--masked-lm-only', default=False, 61 | action='store_true', help='compute MLM loss only') 62 | parser.add_argument('--nsp-loss-weight', default=1.0, type=float, 63 | help='weight for next sentence prediction' 64 | ' loss (default 1)') 65 | 66 | def forward(self, model, sample, reduce=True): 67 | """Compute the loss for the given sample. 68 | Returns a tuple with three elements: 69 | 1) the loss 70 | 2) the sample size, which is used as the denominator for the gradient 71 | 3) logging outputs to display while training 72 | """ 73 | lm_logits, output_metadata = model(**sample["net_input"]) 74 | 75 | # reshape lm_logits from (N,T,C) to (N*T,C) 76 | lm_logits = lm_logits.view(-1, lm_logits.size(-1)) 77 | lm_targets = sample['lm_target'].view(-1) 78 | lm_loss = compute_cross_entropy_loss( 79 | lm_logits, lm_targets, self.padding_idx) 80 | 81 | # compute the number of tokens for which loss is computed. This is used 82 | # to normalize the loss 83 | ntokens = utils.strip_pad(lm_targets, self.padding_idx).numel() 84 | loss = lm_loss / ntokens 85 | nsentences = sample['nsentences'] 86 | # nsentences = 0 87 | 88 | # Compute sentence loss if masked_lm_only is False 89 | sentence_loss = None 90 | if not self.args.masked_lm_only: 91 | sentence_logits = output_metadata['sentence_logits'] 92 | sentence_targets = sample['sentence_target'].view(-1) 93 | # This needs to be recomputed due to some differences between 94 | # TokenBlock and BlockPair dataset. This can be resolved with a 95 | # refactor of BERTModel which we will do in the future. 96 | # TODO: Remove this after refactor of BERTModel 97 | nsentences = sentence_targets.size(0) 98 | 99 | # Check for logits being none which can happen when remove_heads 100 | # is set to true in the BERT model. Ideally we should set 101 | # masked_lm_only to true in this case, but that requires some 102 | # refactor in the BERT model. 103 | if sentence_logits is not None: 104 | sentence_loss = compute_cross_entropy_loss( 105 | sentence_logits, sentence_targets) 106 | 107 | loss += self.args.nsp_loss_weight * (sentence_loss / nsentences) 108 | 109 | # NOTE: as we are summing up per token mlm loss and per sentence nsp loss 110 | # we don't need to use sample_size as denominator for the gradient 111 | # here sample_size is just used for logging 112 | sample_size = 1 113 | logging_output = { 114 | 'loss': utils.item(loss.data) if reduce else loss.data, 115 | 'lm_loss': utils.item(lm_loss.data) if reduce else lm_loss.data, 116 | # sentence loss is not always computed 117 | 'sentence_loss': ( 118 | ( 119 | utils.item(sentence_loss.data) if reduce 120 | else sentence_loss.data 121 | ) if sentence_loss is not None else 0.0 122 | ), 123 | 'ntokens': ntokens, 124 | 'nsentences': nsentences, 125 | 'sample_size': sample_size, 126 | } 127 | return loss, sample_size, logging_output 128 | 129 | @staticmethod 130 | def aggregate_logging_outputs(logging_outputs): 131 | """Aggregate logging outputs from data parallel training.""" 132 | lm_loss_sum = sum(log.get('lm_loss', 0) for log in logging_outputs) 133 | sentence_loss_sum = sum( 134 | log.get('sentence_loss', 0) for log in logging_outputs) 135 | ntokens = sum(log.get('ntokens', 0) for log in logging_outputs) 136 | nsentences = sum(log.get('nsentences', 0) for log in logging_outputs) 137 | sample_size = sum(log.get('sample_size', 0) for log in logging_outputs) 138 | agg_loss = sum(log.get('loss', 0) for log in logging_outputs) 139 | 140 | agg_output = { 141 | 'loss': agg_loss / sample_size / math.log(2), 142 | 'lm_loss': lm_loss_sum / ntokens / math.log(2), 143 | 'sentence_loss': sentence_loss_sum / nsentences / math.log(2), 144 | 'nll_loss': lm_loss_sum / ntokens / math.log(2), 145 | 'ntokens': ntokens, 146 | 'nsentences': nsentences, 147 | 'sample_size': sample_size, 148 | } 149 | return agg_output 150 | -------------------------------------------------------------------------------- /src/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .dptree_index_dataset import * 2 | from .dptree2seq_dataset import * 3 | from .dptree2seq_sep_dataset import * 4 | 5 | 6 | from .dptree_mono_class_dataset import * 7 | from .dptree_sep_mono_class_dataset import * 8 | from .task_utils import * 9 | 10 | 11 | from fairseq.data.dictionary import TruncatedDictionary 12 | from .dptree_dictionary import * 13 | 14 | from .nstack_mono_class_dataset import * 15 | from .nstack2seq_dataset import * 16 | 17 | from .monolingual_classification_dataset import * 18 | from .nstack_merge_monoclass_dataset import * 19 | 20 | __all__ = [ 21 | 'DPTreeWrapperDictionary', 22 | 'DPTreeIndexedCachedDataset', 23 | 'TruncatedDictionary', 24 | 25 | 'DPTree2SeqPairDataset', 26 | 'DPTREE_KEYS', 27 | 'DPTreeMonoClassificationDataset', 28 | 'DPTreeSeparateMonoClassificationDataset', 29 | 'DPTreeSeparateLIClassificationDataset', 30 | 'DPTreeSeparateIndexedDatasetBuilder', 31 | 32 | 'DPTree2SeqSeparatePairDataset', 33 | 'MonolingualClassificationDataset', 34 | 35 | 'NodeStackFromDPTreeSepMonoClassificationDataset', 36 | 'NodeStackFromDPTreeSepNodeTargetMonoClassificationDataset', 37 | 'NodeStackTreeMonoClassificationDataset', 38 | 'Nstack2SeqPairDataset', 39 | ] 40 | -------------------------------------------------------------------------------- /src/data/dptree2seq_dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | from fairseq import utils 5 | 6 | from fairseq.data import data_utils, FairseqDataset 7 | 8 | DPTREE_KEYS = ['nodes', 'labels', 'indices', 'length'] 9 | 10 | 11 | def collate_tokens(values, pad_idx, eos_idx, left_pad, move_eos_to_beginning=False): 12 | """Convert a list of 1d tensors into a padded 2d tensor.""" 13 | size = max(v.size(0) for v in values) 14 | res = values[0].new(len(values), size).fill_(pad_idx) 15 | 16 | def copy_tensor(src, dst): 17 | assert dst.numel() == src.numel() 18 | if move_eos_to_beginning: 19 | assert src[-1] == eos_idx 20 | dst[0] = eos_idx 21 | dst[1:] = src[:-1] 22 | else: 23 | dst.copy_(src) 24 | 25 | for i, v in enumerate(values): 26 | copy_tensor(v, res[i][size - len(v):] if left_pad else res[i][:len(v)]) 27 | return res 28 | 29 | 30 | def dptree2seq_collate_indices(values, pad_idx, eos_idx, left_pad, move_eos_to_beginning): 31 | """convert list of 2d tensors into padded 3d tensors""" 32 | assert not left_pad 33 | size = max(v.size(0) for v in values) 34 | res = values[0].new(len(values), size, 2).fill_(pad_idx) 35 | 36 | def copy_tensor(src, dst): 37 | assert dst.numel() == src.numel() 38 | if move_eos_to_beginning: 39 | assert src[-1] == eos_idx 40 | dst[0] = eos_idx 41 | dst[1:] = src[:-1] 42 | else: 43 | dst.copy_(src) 44 | 45 | for i, v in enumerate(values): 46 | copy_tensor(v, res[i][size - len(v):] if left_pad else res[i][:len(v)]) 47 | 48 | return res 49 | 50 | 51 | def dptree2seq_collate( 52 | samples, pad_idx, eos_idx, left_pad_source=False, left_pad_target=False, input_feeding=True, 53 | ): 54 | if len(samples) == 0: 55 | return {} 56 | 57 | # print(samples) 58 | # raise NotImplementedError 59 | 60 | def merge(key, left_pad, move_eos_to_beginning=False): 61 | return data_utils.collate_tokens( 62 | [s[key] for s in samples], 63 | pad_idx, eos_idx, left_pad, move_eos_to_beginning, 64 | ) 65 | 66 | def merge_source(left_pad, move_eos_to_beginning=False): 67 | # src = [s['source'] for s in samples] 68 | assert samples[0]['source'] is not None 69 | src = {k: [dic['source'][k] for dic in samples] for k in samples[0]['source']} 70 | 71 | nodes = src['nodes'] 72 | labels = src['labels'] 73 | indices = src['indices'] 74 | length = src['length'] 75 | 76 | nodes = data_utils.collate_tokens(nodes, pad_idx, eos_idx, left_pad, move_eos_to_beginning) 77 | labels = data_utils.collate_tokens(labels, pad_idx, eos_idx, left_pad, move_eos_to_beginning) 78 | indices = dptree2seq_collate_indices(indices, 0, 0, left_pad, move_eos_to_beginning) 79 | length = torch.cat([x.unsqueeze_(0) for x in length], 0) 80 | 81 | src_o = { 82 | 'nodes': nodes, 83 | 'labels': labels, 84 | 'indices': indices, 85 | 'length': length 86 | } 87 | return src_o 88 | 89 | id = torch.LongTensor([s['id'] for s in samples]) 90 | src = merge_source(left_pad_source) 91 | src_lengths = torch.LongTensor([s['source']['nodes'].numel() for s in samples]) 92 | src_lengths, sort_order = src_lengths.sort(descending=True) 93 | id = id.index_select(0, sort_order) 94 | 95 | # reoreder 96 | src = {k: v.index_select(0, sort_order) for k, v in src.items()} 97 | 98 | prev_output_tokens = None 99 | target = None 100 | if samples[0].get('target', None) is not None: 101 | target = merge('target', left_pad=left_pad_target) 102 | target = target.index_select(0, sort_order) 103 | ntokens = sum(len(s['target']) for s in samples) 104 | 105 | if input_feeding: 106 | # we create a shifted version of targets for feeding the 107 | # previous output token(s) into the next decoder step 108 | prev_output_tokens = merge( 109 | 'target', 110 | left_pad=left_pad_target, 111 | move_eos_to_beginning=True, 112 | ) 113 | prev_output_tokens = prev_output_tokens.index_select(0, sort_order) 114 | else: 115 | ntokens = sum(len(s['source']['nodes']) for s in samples) 116 | 117 | batch = { 118 | 'id': id, 119 | 'nsentences': len(samples), 120 | 'ntokens': ntokens, 121 | 'net_input': { 122 | 'src_tokens': src['nodes'], 123 | 'src_labels': src['labels'], 124 | 'src_indices': src['indices'], 125 | 'src_sent_lengths': src['length'], 126 | 'src_lengths': src_lengths, 127 | }, 128 | 'target': target, 129 | } 130 | # sizes = {k: v.size() for k, v in batch['net_input'].items()} 131 | # print(f'batch-net-inputs: {sizes}') 132 | if prev_output_tokens is not None: 133 | batch['net_input']['prev_output_tokens'] = prev_output_tokens 134 | return batch 135 | 136 | 137 | class DPTree2SeqPairDataset(FairseqDataset): 138 | def __init__( 139 | self, srcs, src_sizes, src_dict, 140 | tgt=None, tgt_sizes=None, tgt_dict=None, 141 | left_pad_source=True, left_pad_target=False, 142 | max_source_positions=1024, max_target_positions=1024, 143 | shuffle=True, input_feeding=True, remove_eos_from_source=False, 144 | append_eos_to_target=False, 145 | ): 146 | if tgt_dict is not None: 147 | assert src_dict.pad() == tgt_dict.pad() 148 | assert src_dict.eos() == tgt_dict.eos() 149 | assert src_dict.unk() == tgt_dict.unk() 150 | self.srcs = srcs 151 | self.src = srcs['nodes'] 152 | self.tgt = tgt 153 | self.src_sizes = np.array(src_sizes) 154 | self.tgt_sizes = np.array(tgt_sizes) if tgt_sizes is not None else None 155 | self.src_dict = src_dict 156 | self.tgt_dict = tgt_dict 157 | self.left_pad_source = left_pad_source 158 | self.left_pad_target = left_pad_target 159 | self.max_source_positions = max_source_positions 160 | self.max_target_positions = max_target_positions 161 | self.shuffle = shuffle 162 | self.input_feeding = input_feeding 163 | self.remove_eos_from_source = remove_eos_from_source 164 | self.append_eos_to_target = append_eos_to_target 165 | same = self.tgt_dict.eos() == self.src_dict.eos() 166 | print(f'| ATTENTION ! EOS same: {same}') 167 | 168 | def __getitem__(self, index): 169 | tgt_item = self.tgt[index] if self.tgt is not None else None 170 | # src_item = self.src[index] 171 | src_item = {k: v[index] for k, v in self.srcs.items()} 172 | 173 | # Append EOS to end of tgt sentence if it does not have an EOS and remove 174 | # EOS from end of src sentence if it exists. This is useful when we use 175 | # use existing datasets for opposite directions i.e., when we want to 176 | # use tgt_dataset as src_dataset and vice versa 177 | if self.append_eos_to_target: 178 | # if self.tgt_dict: 179 | # # same = self.tgt_dict.eos() == self.src_dict.eos() 180 | # # print(f'EOS same: {same}') 181 | # eos = self.tgt_dict.eos() 182 | # else: 183 | # eos = self.src_dict.eos() 184 | 185 | eos = self.tgt_dict.eos() if self.tgt_dict else self.src_dict.eos() 186 | if self.tgt and self.tgt[index][-1] != eos: 187 | tgt_item = torch.cat([self.tgt[index], torch.LongTensor([eos])]) 188 | 189 | if self.remove_eos_from_source: 190 | 191 | # eos = self.src_dict.eos() 192 | # if self.src[index][-1] == eos: 193 | # src_item = self.src[index][:-1] 194 | raise NotImplementedError(f'remove_eos_from_source not supported, the tree should remove the eos already!') 195 | 196 | return { 197 | 'id': index, 198 | 'source': src_item, 199 | 'target': tgt_item, 200 | } 201 | 202 | def __len__(self): 203 | return len(self.src) 204 | 205 | def collater(self, samples): 206 | """Merge a list of samples to form a mini-batch. 207 | 208 | Args: 209 | samples (List[dict]): samples to collate 210 | 211 | Returns: 212 | dict: a mini-batch with the following keys: 213 | 214 | - `id` (LongTensor): example IDs in the original input order 215 | - `ntokens` (int): total number of tokens in the batch 216 | - `net_input` (dict): the input to the Model, containing keys: 217 | 218 | - `src_tokens` (LongTensor): a padded 2D Tensor of tokens in 219 | the source sentence of shape `(bsz, src_len)`. Padding will 220 | appear on the left if *left_pad_source* is ``True``. 221 | - `src_lengths` (LongTensor): 1D Tensor of the unpadded 222 | lengths of each source sentence of shape `(bsz)` 223 | - `prev_output_tokens` (LongTensor): a padded 2D Tensor of 224 | tokens in the target sentence, shifted right by one position 225 | for input feeding/teacher forcing, of shape `(bsz, 226 | tgt_len)`. This key will not be present if *input_feeding* 227 | is ``False``. Padding will appear on the left if 228 | *left_pad_target* is ``True``. 229 | 230 | - `target` (LongTensor): a padded 2D Tensor of tokens in the 231 | target sentence of shape `(bsz, tgt_len)`. Padding will appear 232 | on the left if *left_pad_target* is ``True``. 233 | """ 234 | return dptree2seq_collate( 235 | samples, pad_idx=self.src_dict.pad(), eos_idx=self.src_dict.eos(), 236 | left_pad_source=self.left_pad_source, left_pad_target=self.left_pad_target, 237 | input_feeding=self.input_feeding, 238 | ) 239 | 240 | def get_dummy_batch(self, num_tokens, max_positions, src_len=128, tgt_len=128): 241 | """Return a dummy batch with a given number of tokens.""" 242 | src_len, tgt_len = utils.resolve_max_positions( 243 | (src_len, tgt_len), 244 | max_positions, 245 | (self.max_source_positions, self.max_target_positions), 246 | ) 247 | bsz = max(num_tokens // max(src_len, tgt_len), 1) 248 | return self.collater([ 249 | { 250 | 'id': i, 251 | 'source': self._get_dummy_source_example(src_len), 252 | 'target': self.tgt_dict.dummy_sentence(tgt_len) if self.tgt_dict is not None else None, 253 | } 254 | for i in range(bsz) 255 | ]) 256 | 257 | def _get_dummy_source_example(self, src_len): 258 | # 'nodes', 'labels', 'indices', 'length'] 259 | nodes = self.src_dict.dummy_sentence(src_len) 260 | labels = self.src_dict.dummy_sentence(src_len) 261 | node_len = nodes.size()[0] 262 | seq_len = int((node_len + 1) // 2) # w/o pad 263 | 264 | # t = torch.Tensor(length).uniform_(self.nspecial + 1, len(self)).long() 265 | length = torch.tensor([seq_len]).long() 266 | # test dummy zeros for indices 267 | fl_indices = torch.arange(node_len).long().unsqueeze(1) 268 | row_indices = fl_indices // seq_len 269 | col_indices = fl_indices - row_indices * seq_len 270 | indices = torch.cat([row_indices, col_indices], 1) 271 | 272 | example = { 273 | 'nodes': nodes, 274 | 'labels': labels, 275 | 'indices': indices, 276 | 'length': length 277 | } 278 | return example 279 | 280 | def num_tokens(self, index): 281 | """Return the number of tokens in a sample. This value is used to 282 | enforce ``--max-tokens`` during batching.""" 283 | return max(self.src_sizes[index], self.tgt_sizes[index] if self.tgt_sizes is not None else 0) 284 | 285 | def size(self, index): 286 | """Return an example's size as a float or tuple. This value is used when 287 | filtering a dataset with ``--max-positions``.""" 288 | return (self.src_sizes[index], self.tgt_sizes[index] if self.tgt_sizes is not None else 0) 289 | 290 | def ordered_indices(self): 291 | """Return an ordered list of indices. Batches will be constructed based 292 | on this order.""" 293 | if self.shuffle: 294 | indices = np.random.permutation(len(self)) 295 | else: 296 | indices = np.arange(len(self)) 297 | if self.tgt_sizes is not None: 298 | indices = indices[np.argsort(self.tgt_sizes[indices], kind='mergesort')] 299 | return indices[np.argsort(self.src_sizes[indices], kind='mergesort')] 300 | 301 | def prefetch(self, indices): 302 | # self.src.prefetch(indices) 303 | print(f'| {self.__class__.__name__}:prefetch:starting...') 304 | for k, v in self.srcs.items(): 305 | v.prefetch(indices) 306 | print(f'| {self.__class__.__name__}:prefetch:{k}') 307 | print(f'| {self.__class__.__name__}:prefetch:tgt') 308 | self.tgt.prefetch(indices) 309 | print(f'| {self.__class__.__name__}:prefetch:finished...') 310 | 311 | @property 312 | def supports_prefetch(self): 313 | return ( 314 | hasattr(self.src, 'supports_prefetch') 315 | and self.src.supports_prefetch 316 | and hasattr(self.tgt, 'supports_prefetch') 317 | and self.tgt.supports_prefetch 318 | ) 319 | 320 | 321 | -------------------------------------------------------------------------------- /src/data/dptree2seq_sep_dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | from fairseq import utils 5 | 6 | from fairseq.data import data_utils, FairseqDataset 7 | from .dptree2seq_dataset import * 8 | 9 | 10 | def collate_token_list(values, pad_idx, eos_idx, left_pad, move_eos_to_beginning=False): 11 | try: 12 | nsent = max(v.size(0) for v in values) 13 | size = max(v.size(1) for v in values) 14 | res = values[0].new(len(values), nsent, size).fill_(pad_idx) 15 | except RuntimeError as e: 16 | print(f'values: {[v.size() for v in values]}') 17 | raise e 18 | 19 | def copy_tensor(src, dst): 20 | assert dst.numel() == src.numel(), f'{res.size()}, {src.size()}, {dst.size()}' 21 | if move_eos_to_beginning: 22 | # assert src[-1] == eos_idx 23 | # dst[0] = eos_idx 24 | # dst[1:] = src[:-1] 25 | raise NotImplementedError 26 | else: 27 | dst.copy_(src) 28 | 29 | for i, v in enumerate(values): 30 | # dest = res[i][size - len(v):] if left_pad else res[i][:len(v)] 31 | nsent, length = v.size() 32 | if left_pad: 33 | dest = res[i, :nsent, size - len(v[0]):] 34 | else: 35 | dest = res[i, :nsent, :len(v[0])] 36 | copy_tensor(v, dest) 37 | return res 38 | 39 | 40 | def dptree_collate_sep_indices(values, pad_idx, eos_idx, left_pad, move_eos_to_beginning): 41 | """convert list of 2d tensors into padded 3d tensors""" 42 | assert not left_pad 43 | nsent = max(v.size(0) for v in values) 44 | size = max(v.size(1) for v in values) 45 | res = values[0].new(len(values), nsent, size, 2).fill_(pad_idx) 46 | 47 | def copy_tensor(src, dst): 48 | assert dst.numel() == src.numel(), f'{res.size()}, {src.size()}, {dst.size()}' 49 | if move_eos_to_beginning: 50 | # assert src[-1] == eos_idx 51 | # dst[0] = eos_idx 52 | # dst[1:] = src[:-1] 53 | raise NotImplementedError 54 | else: 55 | dst.copy_(src) 56 | 57 | for i, v in enumerate(values): 58 | nsent, length, _ = v.size() 59 | copy_tensor(v, res[i, :nsent, size - len(v[0]):] if left_pad else res[i, :nsent, :len(v[0])]) 60 | 61 | return res 62 | 63 | 64 | def dptree2seq_sep_collate( 65 | samples, pad_idx, eos_idx, left_pad_source=False, left_pad_target=False, input_feeding=True, 66 | ): 67 | if len(samples) == 0: 68 | return {} 69 | 70 | # print(samples) 71 | # raise NotImplementedError 72 | 73 | def merge_target(key, left_pad, move_eos_to_beginning=False): 74 | try: 75 | output = data_utils.collate_tokens( 76 | [s[key] for s in samples], 77 | pad_idx, eos_idx, left_pad, move_eos_to_beginning, 78 | ) 79 | except AssertionError as ae: 80 | print([s[key][-1] for s in samples]) 81 | print([len(s[key]) for s in samples]) 82 | raise ae 83 | return output 84 | 85 | def merge_source_backup(left_pad, move_eos_to_beginning=False): 86 | # src = [s['source'] for s in samples] 87 | assert samples[0]['source'] is not None 88 | src = {k: [dic['source'][k] for dic in samples] for k in samples[0]['source']} 89 | 90 | nodes = src['nodes'] 91 | labels = src['labels'] 92 | indices = src['indices'] 93 | length = src['length'] 94 | 95 | nodes = data_utils.collate_tokens(nodes, pad_idx, eos_idx, left_pad, move_eos_to_beginning) 96 | labels = data_utils.collate_tokens(labels, pad_idx, eos_idx, left_pad, move_eos_to_beginning) 97 | indices = dptree2seq_collate_indices(indices, 0, 0, left_pad, move_eos_to_beginning) 98 | length = torch.cat([x.unsqueeze_(0) for x in length], 0) 99 | 100 | src_o = { 101 | 'nodes': nodes, 102 | 'labels': labels, 103 | 'indices': indices, 104 | 'length': length 105 | } 106 | return src_o 107 | 108 | def merge_source(left_pad, move_eos_to_beginning=False): 109 | assert samples[0]['source'] is not None 110 | src = {k: [dic['source'][k] for dic in samples] for k in samples[0]['source']} 111 | 112 | nodes = src['nodes'] 113 | labels = src['labels'] 114 | indices = src['indices'] 115 | length = src['length'] 116 | 117 | nodes = collate_token_list(nodes, pad_idx, eos_idx, left_pad, move_eos_to_beginning) 118 | labels = collate_token_list(labels, pad_idx, eos_idx, left_pad, move_eos_to_beginning) 119 | indices = dptree_collate_sep_indices(indices, 0, 0, left_pad, move_eos_to_beginning) 120 | # length = torch.cat([x.unsqueeze_(0) for x in length], 0) 121 | length = data_utils.collate_tokens(length, 0, 0, False) 122 | 123 | src_o = { 124 | 'nodes': nodes, 125 | 'labels': labels, 126 | 'indices': indices, 127 | 'length': length 128 | } 129 | return src_o 130 | 131 | id = torch.LongTensor([s['id'] for s in samples]) 132 | src = merge_source(left_pad_source) 133 | src_lengths = torch.LongTensor([s['source']['nodes'].numel() for s in samples]) 134 | src_lengths, sort_order = src_lengths.sort(descending=True) 135 | id = id.index_select(0, sort_order) 136 | 137 | # reoreder 138 | src = {k: v.index_select(0, sort_order) for k, v in src.items()} 139 | 140 | prev_output_tokens = None 141 | target = None 142 | if samples[0].get('target', None) is not None: 143 | target = merge_target('target', left_pad=left_pad_target) 144 | target = target.index_select(0, sort_order) 145 | ntokens = sum(len(s['target']) for s in samples) 146 | 147 | if input_feeding: 148 | # we create a shifted version of targets for feeding the 149 | # previous output token(s) into the next decoder step 150 | prev_output_tokens = merge_target( 151 | 'target', 152 | left_pad=left_pad_target, 153 | move_eos_to_beginning=True, 154 | ) 155 | prev_output_tokens = prev_output_tokens.index_select(0, sort_order) 156 | else: 157 | ntokens = sum(len(s['source']['nodes']) for s in samples) 158 | 159 | batch = { 160 | 'id': id, 161 | 'nsentences': len(samples), 162 | 'ntokens': ntokens, 163 | 'net_input': { 164 | 'src_tokens': src['nodes'], 165 | 'src_labels': src['labels'], 166 | 'src_indices': src['indices'], 167 | 'src_sent_lengths': src['length'], 168 | 'src_lengths': src_lengths, 169 | }, 170 | 'target': target, 171 | } 172 | # sizes = {k: v.size() for k, v in batch['net_input'].items()} 173 | # print(f'batch-net-inputs: {sizes}') 174 | if prev_output_tokens is not None: 175 | batch['net_input']['prev_output_tokens'] = prev_output_tokens 176 | return batch 177 | 178 | 179 | class DPTree2SeqSeparatePairDataset(DPTree2SeqPairDataset): 180 | def collater(self, samples): 181 | """Merge a list of samples to form a mini-batch. 182 | 183 | Args: 184 | samples (List[dict]): samples to collate 185 | 186 | Returns: 187 | dict: a mini-batch with the following keys: 188 | 189 | - `id` (LongTensor): example IDs in the original input order 190 | - `ntokens` (int): total number of tokens in the batch 191 | - `net_input` (dict): the input to the Model, containing keys: 192 | 193 | - `src_tokens` (LongTensor): a padded 2D Tensor of tokens in 194 | the source sentence of shape `(bsz, src_len)`. Padding will 195 | appear on the left if *left_pad_source* is ``True``. 196 | - `src_lengths` (LongTensor): 1D Tensor of the unpadded 197 | lengths of each source sentence of shape `(bsz)` 198 | - `prev_output_tokens` (LongTensor): a padded 2D Tensor of 199 | tokens in the target sentence, shifted right by one position 200 | for input feeding/teacher forcing, of shape `(bsz, 201 | tgt_len)`. This key will not be present if *input_feeding* 202 | is ``False``. Padding will appear on the left if 203 | *left_pad_target* is ``True``. 204 | 205 | - `target` (LongTensor): a padded 2D Tensor of tokens in the 206 | target sentence of shape `(bsz, tgt_len)`. Padding will appear 207 | on the left if *left_pad_target* is ``True``. 208 | """ 209 | return dptree2seq_sep_collate( 210 | samples, pad_idx=self.src_dict.pad(), eos_idx=self.src_dict.eos(), 211 | left_pad_source=self.left_pad_source, left_pad_target=self.left_pad_target, 212 | input_feeding=self.input_feeding, 213 | ) 214 | 215 | def _get_dummy_source_example(self, src_len): 216 | # 'nodes', 'labels', 'indices', 'length'] 217 | nodes = self.src_dict.dummy_sentence(src_len) 218 | labels = self.src_dict.dummy_sentence(src_len) 219 | node_len = nodes.size()[0] 220 | seq_len = int((node_len + 1) // 2) # w/o pad 221 | 222 | # t = torch.Tensor(length).uniform_(self.nspecial + 1, len(self)).long() 223 | length = torch.tensor([seq_len]).long() 224 | # test dummy zeros for indices 225 | fl_indices = torch.arange(node_len).long().unsqueeze(1) 226 | row_indices = fl_indices // seq_len 227 | col_indices = fl_indices - row_indices * seq_len 228 | indices = torch.cat([row_indices, col_indices], 1) 229 | 230 | example = { 231 | 'nodes': nodes.unsqueeze_(0), 232 | 'labels': labels.unsqueeze_(0), 233 | 'indices': indices.unsqueeze_(0), 234 | 'length': length 235 | } 236 | return example 237 | -------------------------------------------------------------------------------- /src/data/dptree_dictionary.py: -------------------------------------------------------------------------------- 1 | 2 | from collections import Counter 3 | from multiprocessing import Pool 4 | import os 5 | 6 | import torch 7 | 8 | from fairseq.tokenizer import tokenize_line 9 | from fairseq.binarizer import safe_readline 10 | from fairseq.data import data_utils, Dictionary 11 | 12 | 13 | class DPTreeWrapperDictionary(Dictionary): 14 | 15 | def __init__(self, pad='', eos='', unk='', no_strip_node_label=False): 16 | super().__init__(pad, eos, unk) 17 | self.no_strip_node_label = no_strip_node_label 18 | 19 | @classmethod 20 | def load(cls, f, ignore_utf_errors=False, no_strip_node_label=False): 21 | """Loads the dictionary from a text file with the format: 22 | 23 | ``` 24 | 25 | 26 | ... 27 | ``` 28 | """ 29 | if isinstance(f, str): 30 | try: 31 | if not ignore_utf_errors: 32 | with open(f, 'r', encoding='utf-8') as fd: 33 | return cls.load(fd) 34 | else: 35 | with open(f, 'r', encoding='utf-8', errors='ignore') as fd: 36 | return cls.load(fd) 37 | except FileNotFoundError as fnfe: 38 | raise fnfe 39 | except UnicodeError: 40 | raise Exception("Incorrect encoding detected in {}, please " 41 | "rebuild the dataset".format(f)) 42 | 43 | d = cls(no_strip_node_label=no_strip_node_label) 44 | lines = f.readlines() 45 | indices_start_line = d._load_meta(lines) 46 | for line in lines[indices_start_line:]: 47 | idx = line.rfind(' ') 48 | if idx == -1: 49 | raise ValueError("Incorrect dictionary format, expected ' '") 50 | word = line[:idx] 51 | count = int(line[idx + 1:]) 52 | d.indices[word] = len(d.symbols) 53 | d.symbols.append(word) 54 | d.count.append(count) 55 | return d 56 | 57 | def string(self, tensor, bpe_symbol=None, escape_unk=False): 58 | """Helper for converting a tensor of token indices to a string. 59 | 60 | Can optionally remove BPE symbols or escape words. 61 | """ 62 | if torch.is_tensor(tensor) and tensor.dim() == 2: 63 | return '\n'.join(self.string(t) for t in tensor) 64 | 65 | def token_string(i): 66 | if i == self.unk(): 67 | return self.unk_string(escape_unk) 68 | else: 69 | return self[i] 70 | if self.no_strip_node_label: 71 | sent = ' '.join(token_string(i) for i in tensor if i != self.eos()) 72 | else: 73 | sent = ' '.join(token_string(i) for i in tensor if i != self.eos() and "_node_label" not in token_string(i)) 74 | return data_utils.process_bpe_symbol(sent, bpe_symbol) 75 | 76 | 77 | 78 | 79 | -------------------------------------------------------------------------------- /src/data/dptree_index_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import struct 3 | 4 | import numpy as np 5 | import torch 6 | from fairseq.data.indexed_dataset import IndexedDatasetBuilder 7 | 8 | # from fairseq.tokenizer import Tokenizer 9 | from fairseq.data import * 10 | 11 | def read_longs(f, n): 12 | a = np.empty(n, dtype=np.int64) 13 | f.readinto(a) 14 | return a 15 | 16 | 17 | def write_longs(f, a): 18 | f.write(np.array(a, dtype=np.int64)) 19 | 20 | 21 | dtypes = { 22 | 1: np.uint8, 23 | 2: np.int8, 24 | 3: np.int16, 25 | 4: np.int32, 26 | 5: np.int64, 27 | 6: np.float, 28 | 7: np.double, 29 | } 30 | 31 | 32 | def code(dtype): 33 | for k in dtypes.keys(): 34 | if dtypes[k] == dtype: 35 | return k 36 | 37 | 38 | def index_file_path(prefix_path): 39 | return prefix_path + '.idx' 40 | 41 | 42 | def data_file_path(prefix_path): 43 | return prefix_path + '.bin' 44 | 45 | 46 | class DPTreeIndexedCachedDataset(IndexedDataset): 47 | 48 | def __init__(self, path, fix_lua_indexing=False): 49 | super().__init__(path, fix_lua_indexing=fix_lua_indexing) 50 | self.cache = None 51 | self.cache_index = {} 52 | 53 | @property 54 | def supports_prefetch(self): 55 | return True 56 | 57 | def prefetch(self, indices): 58 | if all(i in self.cache_index for i in indices): 59 | return 60 | if not self.data_file: 61 | self.read_data(self.path) 62 | indices = sorted(set(indices)) 63 | total_size = 0 64 | for i in indices: 65 | total_size += self.data_offsets[i + 1] - self.data_offsets[i] 66 | self.cache = np.empty(total_size, dtype=self.dtype) 67 | ptx = 0 68 | self.cache_index.clear() 69 | for i in indices: 70 | self.cache_index[i] = ptx 71 | size = self.data_offsets[i + 1] - self.data_offsets[i] 72 | a = self.cache[ptx : ptx + size] 73 | self.data_file.seek(self.data_offsets[i] * self.element_size) 74 | self.data_file.readinto(a) 75 | ptx += size 76 | 77 | def __getitem__(self, i): 78 | self.check_index(i) 79 | tensor_size = self.sizes[self.dim_offsets[i]:self.dim_offsets[i + 1]] 80 | flat_t_size = [np.prod(tensor_size)] 81 | # a = np.empty(tensor_size, dtype=self.dtype) 82 | a = np.empty(flat_t_size, dtype=self.dtype) 83 | ptx = self.cache_index[i] 84 | np.copyto(a, self.cache[ptx : ptx + a.size]) 85 | a = np.reshape(a, tensor_size) 86 | item = torch.from_numpy(a).long() 87 | if self.fix_lua_indexing: 88 | item -= 1 # subtract 1 for 0-based indexing 89 | return item 90 | 91 | 92 | class FloatIndexedCachedDataset(IndexedDataset): 93 | 94 | def __init__(self, path, fix_lua_indexing=False): 95 | super().__init__(path, fix_lua_indexing=fix_lua_indexing) 96 | self.cache = None 97 | self.cache_index = {} 98 | 99 | @property 100 | def supports_prefetch(self): 101 | return True 102 | 103 | def prefetch(self, indices): 104 | if all(i in self.cache_index for i in indices): 105 | return 106 | if not self.data_file: 107 | self.read_data(self.path) 108 | indices = sorted(set(indices)) 109 | total_size = 0 110 | for i in indices: 111 | total_size += self.data_offsets[i + 1] - self.data_offsets[i] 112 | self.cache = np.empty(total_size, dtype=self.dtype) 113 | ptx = 0 114 | self.cache_index.clear() 115 | for i in indices: 116 | self.cache_index[i] = ptx 117 | size = self.data_offsets[i + 1] - self.data_offsets[i] 118 | a = self.cache[ptx : ptx + size] 119 | self.data_file.seek(self.data_offsets[i] * self.element_size) 120 | self.data_file.readinto(a) 121 | ptx += size 122 | 123 | def __getitem__(self, i): 124 | self.check_index(i) 125 | tensor_size = self.sizes[self.dim_offsets[i]:self.dim_offsets[i + 1]] 126 | a = np.empty(tensor_size, dtype=self.dtype) 127 | ptx = self.cache_index[i] 128 | np.copyto(a, self.cache[ptx: ptx + a.size]) 129 | item = torch.from_numpy(a).float() 130 | if self.fix_lua_indexing: 131 | item -= 1 # subtract 1 for 0-based indexing 132 | return item 133 | 134 | 135 | # class DPTreeSeparateIndexedCachedDataset(DPTreeIndexedCachedDataset): 136 | 137 | 138 | class DPTreeSeparateIndexedDatasetBuilder(IndexedDatasetBuilder): 139 | def add_item(self, tensor): 140 | # +1 for Lua compatibility 141 | bytes = self.out_file.write(np.array(tensor.numpy() + 1, dtype=self.dtype)) 142 | self.data_offsets.append(self.data_offsets[-1] + bytes / self.element_size) 143 | 144 | # flat = tensor.view(-1) 145 | # for s in tensor.view(-1).size(): 146 | # for s in tensor.size(): 147 | # self.sizes.append(s) 148 | # x += list(xx.size() 149 | self.sizes += list(tensor.size()) 150 | self.dim_offsets.append(self.dim_offsets[-1] + len(tensor.size())) 151 | 152 | def merge_file_(self, another_file): 153 | index = IndexedDataset(another_file) 154 | assert index.dtype == self.dtype 155 | 156 | begin = self.data_offsets[-1] 157 | for offset in index.data_offsets[1:]: 158 | self.data_offsets.append(begin + offset) 159 | 160 | self.sizes.extend(index.sizes) 161 | 162 | begin = self.dim_offsets[-1] 163 | for dim_offset in index.dim_offsets[1:]: 164 | self.dim_offsets.append(begin + dim_offset) 165 | 166 | with open(data_file_path(another_file), 'rb') as f: 167 | while True: 168 | data = f.read(1024) 169 | if data: 170 | self.out_file.write(data) 171 | else: 172 | break 173 | 174 | def finalize(self, index_file): 175 | self.out_file.close() 176 | index = open(index_file, 'wb') 177 | index.write(b'TNTIDX\x00\x00') 178 | index.write(struct.pack('.-.(...).idx""" 10 | """Infer language from filename: ..(...).idx""" 11 | # fixme: eg: train.en.(...).idx 12 | src, dst = None, None 13 | src = None 14 | for filename in os.listdir(path): 15 | parts = filename.split('.') 16 | # if len(parts) >= 3 and len(parts[1].split('-')) == 2: 17 | # return parts[1].split('-') 18 | if len(parts) >= 3: 19 | return parts[1] 20 | return src 21 | 22 | 23 | # def filter_by_size(indices, size_fn, max_positions, raise_exception=False): 24 | # """ 25 | # Filter indices based on their size. 26 | # 27 | # Args: 28 | # indices (List[int]): ordered list of dataset indices 29 | # size_fn (callable): function that returns the size of a given index 30 | # max_positions (tuple): filter elements larger than this size. 31 | # Comparisons are done component-wise. 32 | # raise_exception (bool, optional): if ``True``, raise an exception if 33 | # any elements are filtered (default: False). 34 | # """ 35 | # def check_size(idx): 36 | # if isinstance(max_positions, float) or isinstance(max_positions, int): 37 | # return size_fn(idx) <= max_positions 38 | # elif isinstance(max_positions, dict): 39 | # idx_size = size_fn(idx) 40 | # assert isinstance(idx_size, dict) 41 | # intersect_keys = set(max_positions.keys()) & set(idx_size.keys()) 42 | # return all( 43 | # all(a is None or b is None or a <= b 44 | # for a, b in zip(idx_size[key], max_positions[key])) 45 | # for key in intersect_keys 46 | # ) 47 | # else: 48 | # return all(a is None or b is None or a <= b 49 | # for a, b in zip(size_fn(idx), max_positions)) 50 | # 51 | # ignored = [] 52 | # itr = collect_filtered(check_size, indices, ignored) 53 | # 54 | # for idx in itr: 55 | # if len(ignored) > 0 and raise_exception: 56 | # raise Exception(( 57 | # 'Size of sample #{} is invalid (={}) since max_positions={}, ' 58 | # 'skip this example with --skip-invalid-size-inputs-valid-test' 59 | # ).format(ignored[0], size_fn(ignored[0]), max_positions)) 60 | # yield idx 61 | # 62 | # if len(ignored) > 0: 63 | # print(( 64 | # '| WARNING: {} samples have invalid sizes and will be skipped, ' 65 | # 'max_positions={}, first few sample ids={}' 66 | # ).format(len(ignored), max_positions, ignored[:10])) 67 | 68 | 69 | def filter_by_class_size(indices, size_fn, max_positions, class_fn, class_index, raise_exception=False): 70 | """ 71 | Filter indices based on their size. 72 | 73 | Args: 74 | indices (List[int]): ordered list of dataset indices 75 | size_fn (callable): function that returns the size of a given index 76 | max_positions (tuple): filter elements larger than this size. 77 | Comparisons are done component-wise. 78 | raise_exception (bool, optional): if ``True``, raise an exception if 79 | any elements are filtered (default: False). 80 | """ 81 | 82 | assert isinstance(class_index, int) 83 | 84 | def check_class(idx): 85 | return class_fn(idx) != class_index 86 | 87 | def check_size(idx): 88 | if isinstance(max_positions, float) or isinstance(max_positions, int): 89 | return size_fn(idx) <= max_positions 90 | elif isinstance(max_positions, dict): 91 | idx_size = size_fn(idx) 92 | assert isinstance(idx_size, dict) 93 | intersect_keys = set(max_positions.keys()) & set(idx_size.keys()) 94 | return all( 95 | all(a is None or b is None or a <= b 96 | for a, b in zip(idx_size[key], max_positions[key])) 97 | for key in intersect_keys 98 | ) 99 | else: 100 | return all(a is None or b is None or a <= b 101 | for a, b in zip(size_fn(idx), max_positions)) 102 | 103 | def check_both(idx): 104 | return check_class(idx) and check_size(idx) 105 | 106 | ignored = [] 107 | itr = data_utils.collect_filtered(check_both, indices, ignored) 108 | 109 | for idx in itr: 110 | if len(ignored) > 0 and raise_exception: 111 | raise Exception(( 112 | 'Size of sample #{} is invalid (={}) since max_positions={}, ' 113 | 'skip this example with --skip-invalid-size-inputs-valid-test' 114 | ).format(ignored[0], size_fn(ignored[0]), class_index)) 115 | yield idx 116 | 117 | if len(ignored) > 0: 118 | print(( 119 | '| WARNING: {} samples have invalid sizes and will be skipped, ' 120 | 'max_positions={}, first few sample ids={}' 121 | ).format(len(ignored), class_index, ignored[:10])) 122 | 123 | 124 | def filter_by_class(indices, class_fn, class_index, raise_exception=False): 125 | """ 126 | Filter indices based on their size. 127 | 128 | Args: 129 | indices (List[int]): ordered list of dataset indices 130 | size_fn (callable): function that returns the size of a given index 131 | max_positions (tuple): filter elements larger than this size. 132 | Comparisons are done component-wise. 133 | raise_exception (bool, optional): if ``True``, raise an exception if 134 | any elements are filtered (default: False). 135 | """ 136 | 137 | assert isinstance(class_index, int) 138 | 139 | def check_class(idx): 140 | return class_fn(idx) != class_index 141 | 142 | # def check_size(idx): 143 | # if isinstance(max_positions, float) or isinstance(max_positions, int): 144 | # return size_fn(idx) <= max_positions 145 | # elif isinstance(max_positions, dict): 146 | # idx_size = size_fn(idx) 147 | # assert isinstance(idx_size, dict) 148 | # intersect_keys = set(max_positions.keys()) & set(idx_size.keys()) 149 | # return all( 150 | # all(a is None or b is None or a <= b 151 | # for a, b in zip(idx_size[key], max_positions[key])) 152 | # for key in intersect_keys 153 | # ) 154 | # else: 155 | # return all(a is None or b is None or a <= b 156 | # for a, b in zip(size_fn(idx), max_positions)) 157 | 158 | # def check_both(idx): 159 | # return check_class(idx) and check_size(idx) 160 | 161 | ignored = [] 162 | itr = data_utils.collect_filtered(check_class, indices, ignored) 163 | 164 | for idx in itr: 165 | if len(ignored) > 0 and raise_exception: 166 | raise Exception(( 167 | 'Size of sample #{} is invalid (={}) since max_positions={}, ' 168 | 'skip this example with --skip-invalid-size-inputs-valid-test' 169 | ).format(ignored[0], class_fn(ignored[0]), class_index)) 170 | yield idx 171 | 172 | # if len(ignored) > 0: 173 | # print(( 174 | # '| WARNING: {} samples have invalid sizes and will be skipped, ' 175 | # 'max_positions={}, first few sample ids={}' 176 | # ).format(len(ignored), class_index, ignored[:10])) 177 | 178 | 179 | def filter_by_nsent(indices, size_fn, max_positions, raise_exception=False): 180 | """ 181 | Filter indices based on their size. 182 | 183 | Args: 184 | indices (List[int]): ordered list of dataset indices 185 | size_fn (callable): function that returns the size of a given index 186 | max_positions (tuple): filter elements larger than this size. 187 | Comparisons are done component-wise. 188 | raise_exception (bool, optional): if ``True``, raise an exception if 189 | any elements are filtered (default: False). 190 | """ 191 | def check_size(idx): 192 | if isinstance(max_positions, float) or isinstance(max_positions, int): 193 | return size_fn(idx) <= max_positions 194 | elif isinstance(max_positions, dict): 195 | idx_size = size_fn(idx) 196 | assert isinstance(idx_size, dict) 197 | intersect_keys = set(max_positions.keys()) & set(idx_size.keys()) 198 | return all( 199 | all(a is None or b is None or a <= b 200 | for a, b in zip(idx_size[key], max_positions[key])) 201 | for key in intersect_keys 202 | ) 203 | else: 204 | return all(a is None or b is None or a <= b 205 | for a, b in zip(size_fn(idx), max_positions)) 206 | 207 | ignored = [] 208 | itr = data_utils.collect_filtered(check_size, indices, ignored) 209 | 210 | for idx in itr: 211 | if len(ignored) > 0 and raise_exception: 212 | raise Exception(( 213 | 'Size of sample #{} is invalid (={}) since max_positions={}, ' 214 | 'skip this example with --skip-invalid-size-inputs-valid-test' 215 | ).format(ignored[0], size_fn(ignored[0]), max_positions)) 216 | yield idx 217 | 218 | if len(ignored) > 0: 219 | print(( 220 | '| WARNING: {} samples have invalid sizes and will be skipped, ' 221 | 'max_positions={}, first few sample ids={}' 222 | ).format(len(ignored), max_positions, ignored[:10])) 223 | 224 | 225 | 226 | -------------------------------------------------------------------------------- /src/data/transforms/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | 9 | import importlib 10 | import os 11 | 12 | from fairseq import registry 13 | 14 | 15 | build_tokenizer, register_tokenizer, TOKENIZER_REGISTRY = registry.setup_registry( 16 | '--tokenizer', 17 | default='space', 18 | ) 19 | 20 | 21 | build_bpe, register_bpe, BPE_REGISTRY = registry.setup_registry( 22 | '--bpe', 23 | default=None, 24 | ) 25 | 26 | 27 | # automatically import any Python files in the transforms/ directory 28 | for file in os.listdir(os.path.dirname(__file__)): 29 | if file.endswith('.py') and not file.startswith('_'): 30 | module = file[:file.find('.py')] 31 | importlib.import_module('fairseq.data.transforms.' + module) 32 | -------------------------------------------------------------------------------- /src/data/transforms/gpt2_bpe.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | from fairseq import file_utils 9 | from fairseq.data.transforms import register_bpe 10 | 11 | 12 | @register_bpe('gpt2') 13 | class GPT2BPE(object): 14 | 15 | @staticmethod 16 | def add_args(parser): 17 | # fmt: off 18 | parser.add_argument('--gpt2-encoder-json', type=str, 19 | default='https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/encoder.json', 20 | help='path to encoder.json') 21 | parser.add_argument('--gpt2-vocab-bpe', type=str, 22 | default='https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/vocab.bpe', 23 | help='path to vocab.bpe') 24 | # fmt: on 25 | 26 | def __init__(self, args): 27 | encoder_json = file_utils.cached_path(args.gpt2_encoder_json) 28 | vocab_bpe = file_utils.cached_path(args.gpt2_vocab_bpe) 29 | self.bpe = get_encoder(encoder_json, vocab_bpe) 30 | 31 | def encode(self, x: str) -> str: 32 | return ' '.join(map(str, self.bpe.encode(x))) 33 | 34 | def decode(self, x: str) -> str: 35 | return self.bpe.decode(map(int, x.split())) 36 | 37 | 38 | """Byte pair encoding utilities from GPT-2""" 39 | 40 | from functools import lru_cache 41 | import json 42 | import os 43 | 44 | 45 | @lru_cache() 46 | def bytes_to_unicode(): 47 | """ 48 | Returns list of utf-8 byte and a corresponding list of unicode strings. 49 | The reversible bpe codes work on unicode strings. 50 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. 51 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. 52 | This is a signficant percentage of your normal, say, 32K bpe vocab. 53 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings. 54 | And avoids mapping to whitespace/control characters the bpe code barfs on. 55 | """ 56 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) 57 | cs = bs[:] 58 | n = 0 59 | for b in range(2**8): 60 | if b not in bs: 61 | bs.append(b) 62 | cs.append(2**8+n) 63 | n += 1 64 | cs = [chr(n) for n in cs] 65 | return dict(zip(bs, cs)) 66 | 67 | def get_pairs(word): 68 | """Return set of symbol pairs in a word. 69 | Word is represented as tuple of symbols (symbols being variable-length strings). 70 | """ 71 | pairs = set() 72 | prev_char = word[0] 73 | for char in word[1:]: 74 | pairs.add((prev_char, char)) 75 | prev_char = char 76 | return pairs 77 | 78 | class Encoder: 79 | 80 | def __init__(self, encoder, bpe_merges, errors='replace'): 81 | self.encoder = encoder 82 | self.decoder = {v:k for k,v in self.encoder.items()} 83 | self.errors = errors # how to handle errors in decoding 84 | self.byte_encoder = bytes_to_unicode() 85 | self.byte_decoder = {v:k for k, v in self.byte_encoder.items()} 86 | self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges)))) 87 | self.cache = {} 88 | 89 | try: 90 | import regex as re 91 | self.re = re 92 | except ImportError: 93 | raise ImportError('Please install regex with: pip install regex') 94 | 95 | # Should haved added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions 96 | self.pat = self.re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""") 97 | 98 | def bpe(self, token): 99 | if token in self.cache: 100 | return self.cache[token] 101 | word = tuple(token) 102 | pairs = get_pairs(word) 103 | 104 | if not pairs: 105 | return token 106 | 107 | while True: 108 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) 109 | if bigram not in self.bpe_ranks: 110 | break 111 | first, second = bigram 112 | new_word = [] 113 | i = 0 114 | while i < len(word): 115 | try: 116 | j = word.index(first, i) 117 | new_word.extend(word[i:j]) 118 | i = j 119 | except: 120 | new_word.extend(word[i:]) 121 | break 122 | 123 | if word[i] == first and i < len(word)-1 and word[i+1] == second: 124 | new_word.append(first+second) 125 | i += 2 126 | else: 127 | new_word.append(word[i]) 128 | i += 1 129 | new_word = tuple(new_word) 130 | word = new_word 131 | if len(word) == 1: 132 | break 133 | else: 134 | pairs = get_pairs(word) 135 | word = ' '.join(word) 136 | self.cache[token] = word 137 | return word 138 | 139 | def encode(self, text): 140 | bpe_tokens = [] 141 | for token in self.re.findall(self.pat, text): 142 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) 143 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) 144 | return bpe_tokens 145 | 146 | def decode(self, tokens): 147 | text = ''.join([self.decoder[token] for token in tokens]) 148 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors=self.errors) 149 | return text 150 | 151 | def get_encoder(encoder_json_path, vocab_bpe_path): 152 | with open(encoder_json_path, 'r') as f: 153 | encoder = json.load(f) 154 | with open(vocab_bpe_path, 'r', encoding="utf-8") as f: 155 | bpe_data = f.read() 156 | bpe_merges = [tuple(merge_str.split()) for merge_str in bpe_data.split('\n')[1:-1]] 157 | return Encoder( 158 | encoder=encoder, 159 | bpe_merges=bpe_merges, 160 | ) 161 | -------------------------------------------------------------------------------- /src/data/transforms/moses_tokenizer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | from fairseq.data.transforms import register_tokenizer 9 | 10 | 11 | @register_tokenizer('moses') 12 | class MosesTokenizer(object): 13 | 14 | @staticmethod 15 | def add_args(parser): 16 | # fmt: off 17 | parser.add_argument('-s', '--source-lang', default='en', metavar='SRC', 18 | help='source language') 19 | parser.add_argument('-t', '--target-lang', default='en', metavar='TARGET', 20 | help='target language') 21 | parser.add_argument('--aggressive-dash-splits', action='store_true', default=False, 22 | help='triggers dash split rules') 23 | parser.add_argument('--no-escape', action='store_true', default=False, 24 | help='don\'t perform HTML escaping on apostrophy, quotes, etc.') 25 | # fmt: on 26 | 27 | def __init__(self, args): 28 | self.args = args 29 | try: 30 | from sacremoses import MosesTokenizer, MosesDetokenizer 31 | self.tok = MosesTokenizer(args.source_lang) 32 | self.detok = MosesDetokenizer(args.target_lang) 33 | except ImportError: 34 | raise ImportError('Please install Moses tokenizer with: pip install sacremoses') 35 | 36 | def encode(self, x: str) -> str: 37 | return self.tok.tokenize( 38 | x, 39 | aggressive_dash_splits=self.args.aggressive_dash_splits, 40 | return_str=True, 41 | escape=(not self.args.no_escape), 42 | ) 43 | 44 | def decode(self, x: str) -> str: 45 | return self.detok.detokenize(x.split()) 46 | -------------------------------------------------------------------------------- /src/data/transforms/nltk_tokenizer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | from fairseq.data.transforms import register_tokenizer 9 | 10 | 11 | @register_tokenizer('nltk') 12 | class NLTKTokenizer(object): 13 | 14 | def __init__(self, source_lang=None, target_lang=None): 15 | try: 16 | from nltk.tokenize import word_tokenize 17 | self.word_tokenize = word_tokenize 18 | except ImportError: 19 | raise ImportError('Please install nltk with: pip install nltk') 20 | 21 | def encode(self, x: str) -> str: 22 | return ' '.join(self.word_tokenize(x)) 23 | 24 | def decode(self, x: str) -> str: 25 | return x 26 | -------------------------------------------------------------------------------- /src/data/transforms/sentencepiece_bpe.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | from fairseq import file_utils 9 | from fairseq.data.transforms import register_bpe 10 | 11 | 12 | @register_bpe('sentencepiece') 13 | class SentencepieceBPE(object): 14 | 15 | @staticmethod 16 | def add_args(parser): 17 | # fmt: off 18 | parser.add_argument('--sentencepiece-vocab', type=str, 19 | help='path to sentencepiece vocab') 20 | # fmt: on 21 | 22 | def __init__(self, args): 23 | vocab = file_utils.cached_path(args.sentencepiece_vocab) 24 | try: 25 | import sentencepiece as spm 26 | self.sp = spm.SentencePieceProcessor() 27 | self.sp.Load(vocab) 28 | except ImportError: 29 | raise ImportError('Please install sentencepiece with: pip install sentencepiece') 30 | 31 | def encode(self, x: str) -> str: 32 | return ' '.join(self.sp.EncodeAsPieces(x)) 33 | 34 | def decode(self, x: str) -> str: 35 | return x.replace(' ', '').replace('\u2581', ' ').strip() 36 | -------------------------------------------------------------------------------- /src/data/transforms/space_tokenizer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import re 9 | 10 | from fairseq.data.transforms import register_tokenizer 11 | 12 | 13 | @register_tokenizer('space') 14 | class SpaceTokenizer(object): 15 | 16 | def __init__(self, source_lang=None, target_lang=None): 17 | self.space_tok = re.compile(r"\s+") 18 | 19 | def encode(self, x: str) -> str: 20 | return self.space_tok.sub(" ", x).strip().split() 21 | 22 | def decode(self, x: str) -> str: 23 | return x 24 | -------------------------------------------------------------------------------- /src/data/transforms/subword_nmt_bpe.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | from fairseq import file_utils 9 | from fairseq.data.transforms import register_bpe 10 | 11 | 12 | @register_bpe('subword_nmt') 13 | class SubwordNMTBPE(object): 14 | 15 | @staticmethod 16 | def add_args(parser): 17 | # fmt: off 18 | parser.add_argument('--bpe-codes', type=str, 19 | help='path to subword NMT BPE') 20 | parser.add_argument('--bpe-separator', default='@@', 21 | help='BPE separator') 22 | # fmt: on 23 | 24 | def __init__(self, args): 25 | codes = file_utils.cached_path(args.bpe_codes) 26 | try: 27 | from subword_nmt import apply_bpe 28 | bpe_parser = apply_bpe.create_parser() 29 | bpe_args = bpe_parser.parse_args([ 30 | '--codes', codes, 31 | '--separator', args.bpe_separator, 32 | ]) 33 | self.bpe = apply_bpe.BPE( 34 | bpe_args.codes, 35 | bpe_args.merges, 36 | bpe_args.separator, 37 | None, 38 | bpe_args.glossaries, 39 | ) 40 | self.bpe_symbol = bpe_args.separator + ' ' 41 | except ImportError: 42 | raise ImportError('Please install subword_nmt with: pip install subword-nmt') 43 | 44 | def encode(self, x: str) -> str: 45 | return self.bpe.process_line(x) 46 | 47 | def decode(self, x: str) -> str: 48 | return (x + ' ').replace(self.bpe_symbol, '').rstrip() 49 | -------------------------------------------------------------------------------- /src/dptree/__init__.py: -------------------------------------------------------------------------------- 1 | from .tree_process import * 2 | from .tree_builder import * 3 | from .nstack_process import * 4 | 5 | -------------------------------------------------------------------------------- /src/models/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from . import dptree2seg_transformer 3 | from . import nstack_transformer 4 | from . import nstack_archs 5 | -------------------------------------------------------------------------------- /src/modules/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from .dptree_multihead_attention import * 3 | from .dptree_transformer_layer import * 4 | 5 | from .dptree_sep_multihead_attention import * 6 | from .dptree_individual_multihead_attention import * 7 | from .dptree_onseq_multihead_attention import * 8 | 9 | 10 | from .default_multihead_attention import * 11 | from .default_dy_conv import * 12 | 13 | from .nstack_tree_attention import * 14 | from .nstack_merge_tree_attention import * 15 | from .nstack_tree_attention_eff import * 16 | from .nstack_transformer_layers import * 17 | 18 | 19 | __all__ = [ 20 | 21 | 'DPTreeMultiheadAttention', 22 | 'DPTreeOnlyKeyAttention', 23 | 'DPTreeSeparateOnlyKeyWeightSplitAttention', 24 | 'DPTreeSeparateOnlyKeyMatSumAttention', 25 | 'DPTreeSeparateOnlyKeyWeightSplitMatSumAttention', 26 | 'DPTreeSeparateOnlyKeyRightUpAttention', 27 | 28 | 'DPTreeIndividualOnlyKeyAttention', 29 | 'DPTreeIndividualRNNOnlyKeyAttention', 30 | 'DPTreeIndividualRootAverageTransformerEncoder', 31 | 32 | 'DPTreeOnSeqAttention', 33 | 34 | 'DPTreeTransformerEncoderLayer', 35 | 'DPTree2SeqTransformerDecoderLayer', 36 | 'DPTreeTransformerEncoder', 37 | 'DPTree2SeqTransformerDecoder', 38 | 39 | 'DefaultMultiheadAttention', 40 | 41 | 'LearnedPositionalEmbedding', 42 | 'PositionalEmbedding', 43 | 44 | 'NodeStackOnKeyAttention', 45 | 'NodeStackOnValueAttention', 46 | ] 47 | -------------------------------------------------------------------------------- /src/modules/default_dy_conv.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | 12 | from fairseq import utils 13 | from fairseq.modules import unfold1d 14 | 15 | 16 | def Linear(in_features, out_features, bias=True): 17 | m = nn.Linear(in_features, out_features, bias) 18 | nn.init.xavier_uniform_(m.weight) 19 | if bias: 20 | nn.init.constant_(m.bias, 0.) 21 | return m 22 | 23 | 24 | # def unfold1d_(x, kernel_size, padding_l, pad_value=0): 25 | # '''unfold T x B x C to T x B x C x K''' 26 | # if kernel_size > 1: 27 | # T, B, C = x.size() 28 | # x = F.pad(x, (0, 0, 0, 0, padding_l, kernel_size - 1 - padding_l), value=pad_value) 29 | # x = x.as_strided((T, B, C, kernel_size), (B*C, C, 1, B*C)) 30 | # else: 31 | # x = x.unsqueeze(3) 32 | # return x 33 | # 34 | 35 | 36 | class DefaultDynamicConv1dTBC(nn.Module): 37 | '''Dynamic lightweight convolution taking T x B x C inputs 38 | Args: 39 | input_size: # of channels of the input 40 | kernel_size: convolution channels 41 | padding_l: padding to the left when using "same" padding 42 | num_heads: number of heads used. The weight is of shape (num_heads, 1, kernel_size) 43 | weight_dropout: the drop rate of the DropConnect to drop the weight 44 | weight_softmax: normalize the weight with softmax before the convolution 45 | renorm_padding: re-normalize the filters to ignore the padded part (only the non-padding parts sum up to 1) 46 | bias: use bias 47 | conv_bias: bias of the convolution 48 | query_size: specified when feeding a different input as the query 49 | in_proj: project the input and generate the filter together 50 | 51 | Shape: 52 | Input: TxBxC, i.e. (timesteps, batch_size, input_size) 53 | Output: TxBxC, i.e. (timesteps, batch_size, input_size) 54 | 55 | Attributes: 56 | weight: the learnable weights of the module of shape 57 | `(num_heads, 1, kernel_size)` 58 | bias: the learnable bias of the module of shape `(input_size)` 59 | ''' 60 | def __init__(self, args, input_size, kernel_size=1, padding_l=None, num_heads=1, 61 | weight_dropout=0., weight_softmax=False, 62 | renorm_padding=False, bias=False, conv_bias=False, 63 | query_size=None, in_proj=False): 64 | super().__init__() 65 | 66 | self.args = args 67 | self.softmax_x = getattr(args, 'softmax_x', False) 68 | 69 | self.input_size = input_size 70 | self.query_size = input_size if query_size is None else query_size 71 | self.kernel_size = kernel_size 72 | self.padding_l = padding_l 73 | self.num_heads = num_heads 74 | self.weight_dropout = weight_dropout 75 | self.weight_softmax = weight_softmax 76 | self.renorm_padding = renorm_padding 77 | 78 | if in_proj: 79 | self.weight_linear = Linear(self.input_size, self.input_size + num_heads * kernel_size * 1) 80 | else: 81 | self.weight_linear = Linear(self.query_size, num_heads * kernel_size * 1, bias=bias) 82 | if conv_bias: 83 | self.conv_bias = nn.Parameter(torch.Tensor(input_size)) 84 | else: 85 | self.conv_bias = None 86 | self.reset_parameters() 87 | 88 | @property 89 | def in_proj(self): 90 | return self.weight_linear.out_features == self.input_size + self.num_heads * self.kernel_size 91 | 92 | def reset_parameters(self): 93 | self.weight_linear.reset_parameters() 94 | if self.conv_bias is not None: 95 | nn.init.constant_(self.conv_bias, 0.) 96 | 97 | def forward(self, x, incremental_state=None, query=None, unfold=None): 98 | '''Assuming the input, x, of the shape T x B x C and producing an output in the shape T x B x C 99 | args: 100 | x: Input of shape T x B x C, i.e. (timesteps, batch_size, input_size) 101 | incremental_state: A dict to keep the state 102 | unfold: unfold the input or not. If not, we use the matrix trick instead 103 | query: use the specified query to predict the conv filters 104 | ''' 105 | unfold = x.size(0) > 512 if unfold is None else unfold # use unfold mode as default for long sequence to save memory 106 | unfold = unfold or (incremental_state is not None) 107 | assert query is None or not self.in_proj 108 | 109 | if query is None: 110 | query = x 111 | 112 | if unfold: 113 | output = self._forward_unfolded(x, incremental_state, query) 114 | else: 115 | output = self._forward_expanded(x, incremental_state, query) 116 | 117 | if self.conv_bias is not None: 118 | output = output + self.conv_bias.view(1, 1, -1) 119 | return output 120 | 121 | def _forward_unfolded(self, x, incremental_state, query): 122 | '''The conventional implementation of convolutions. 123 | Unfolding the input by having a window shifting to the right.''' 124 | T, B, C = x.size() 125 | K, H = self.kernel_size, self.num_heads 126 | R = C // H 127 | assert R * H == C == self.input_size 128 | 129 | if self.in_proj: 130 | proj = self.weight_linear(x) 131 | x = proj.narrow(2, 0, self.input_size).contiguous() 132 | weight = proj.narrow(2, self.input_size, H*K).contiguous().view(T*B*H, -1) 133 | else: 134 | weight = self.weight_linear(query).view(T*B*H, -1) 135 | 136 | # renorm_padding is only implemented in _forward_expanded 137 | assert not self.renorm_padding or incremental_state is not None 138 | 139 | if incremental_state is not None: 140 | input_buffer = self._get_input_buffer(incremental_state) 141 | if input_buffer is None: 142 | input_buffer = x.new() 143 | x_unfold = torch.cat([input_buffer, x.unsqueeze(3)], dim=3) 144 | if self.kernel_size > 1: 145 | self._set_input_buffer(incremental_state, x_unfold[:, :, :, -self.kernel_size+1:]) 146 | x_unfold = x_unfold.view(T*B*H, R, -1) 147 | else: 148 | padding_l = self.padding_l 149 | if K > T and padding_l == K-1: 150 | weight = weight.narrow(1, K-T, T) 151 | K, padding_l = T, T-1 152 | # unfold the input: T x B x C --> T' x B x C x K 153 | x_unfold = unfold1d(x, K, padding_l, 0) 154 | x_unfold = x_unfold.view(T*B*H, R, K) 155 | 156 | if self.weight_softmax and not self.renorm_padding: 157 | weight = F.softmax(weight, dim=1) 158 | weight = weight.narrow(1, 0, K) 159 | 160 | if incremental_state is not None: 161 | weight = weight[:, -x_unfold.size(2):] 162 | K = weight.size(1) 163 | 164 | if self.weight_softmax and self.renorm_padding: 165 | weight = F.softmax(weight, dim=1) 166 | 167 | weight = F.dropout(weight, self.weight_dropout, training=self.training, inplace=False) 168 | 169 | output = torch.bmm(x_unfold, weight.unsqueeze(2)) # T*B*H x R x 1 170 | output = output.view(T, B, C) 171 | return output 172 | 173 | def _forward_expanded(self, x, incremental_stat, query): 174 | '''Turn the convolution filters into band matrices and do matrix multiplication. 175 | This is faster when the sequence is short, but less memory efficient. 176 | This is not used in the decoder during inference. 177 | ''' 178 | T, B, C = x.size() 179 | K, H = self.kernel_size, self.num_heads 180 | R = C // H 181 | assert R * H == C == self.input_size 182 | if self.in_proj: 183 | proj = self.weight_linear(x) 184 | x = proj.narrow(2, 0, self.input_size).contiguous() 185 | weight = proj.narrow(2, self.input_size, H*K).contiguous().view(T*B*H, -1) 186 | else: 187 | weight = self.weight_linear(query).view(T*B*H, -1) 188 | 189 | if not self.renorm_padding: 190 | if self.weight_softmax: 191 | weight = F.softmax(weight, dim=1) 192 | weight = F.dropout(weight, self.weight_dropout, training=self.training, inplace=False) 193 | weight = weight.narrow(1, 0, K).contiguous() 194 | weight = weight.view(T, B*H, K).transpose(0, 1) 195 | 196 | x = x.view(T, B*H, R).transpose(0, 1) 197 | if self.weight_softmax and self.renorm_padding: 198 | # turn the convolution filters into band matrices 199 | weight_expanded = weight.new(B*H, T, T+K-1).fill_(float('-inf')) 200 | weight_expanded.as_strided((B*H, T, K), (T*(T+K-1), T+K, 1)).copy_(weight) 201 | weight_expanded = weight_expanded.narrow(2, self.padding_l, T) 202 | # normalize the weight over valid positions like self-attention 203 | weight_expanded = F.softmax(weight_expanded, dim=2) 204 | weight_expanded = F.dropout(weight_expanded, self.weight_dropout, training=self.training, inplace=False) 205 | else: 206 | P = self.padding_l 207 | # For efficieny, we cut the kernel size and reduce the padding when the kernel is larger than the length 208 | if K > T and P == K-1: 209 | weight = weight.narrow(2, K-T, T) 210 | K, P = T, T-1 211 | # turn the convolution filters into band matrices 212 | weight_expanded = weight.new_zeros(B*H, T, T+K-1, requires_grad=False) 213 | weight_expanded.as_strided((B*H, T, K), (T*(T+K-1), T+K, 1)).copy_(weight) 214 | weight_expanded = weight_expanded.narrow(2, P, T) # B*H x T x T 215 | 216 | output = torch.bmm(weight_expanded, x) 217 | output = output.transpose(0, 1).contiguous().view(T, B, C) 218 | return output 219 | 220 | def reorder_incremental_state(self, incremental_state, new_order): 221 | input_buffer = self._get_input_buffer(incremental_state) 222 | if input_buffer is not None: 223 | input_buffer = input_buffer.index_select(1, new_order) 224 | self._set_input_buffer(incremental_state, input_buffer) 225 | 226 | def _get_input_buffer(self, incremental_state): 227 | return utils.get_incremental_state(self, incremental_state, 'input_buffer') 228 | 229 | def _set_input_buffer(self, incremental_state, new_buffer): 230 | return utils.set_incremental_state(self, incremental_state, 'input_buffer', new_buffer) 231 | 232 | def extra_repr(self): 233 | s = '{}, kernel_size={}, padding_l={}, num_heads={}, weight_softmax={}, conv_bias={}, renorm_padding={}, in_proj={}'.format( 234 | self.input_size, self.kernel_size, self.padding_l, 235 | self.num_heads, self.weight_softmax, self.conv_bias is not None, self.renorm_padding, 236 | self.in_proj, 237 | ) 238 | 239 | if self.query_size != self.input_size: 240 | s += ', query_size={}'.format(self.query_size) 241 | if self.weight_dropout > 0.: 242 | s += ', weight_dropout={}'.format(self.weight_dropout) 243 | return s 244 | -------------------------------------------------------------------------------- /src/modules/default_multihead_attention.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import torch 9 | from torch import nn 10 | from torch.nn import Parameter 11 | import torch.nn.functional as F 12 | 13 | from fairseq import utils 14 | 15 | 16 | class DefaultMultiheadAttention(nn.Module): 17 | """Multi-headed attention. 18 | 19 | See "Attention Is All You Need" for more details. 20 | """ 21 | 22 | def __init__(self, args, embed_dim, num_heads, dropout=0., bias=True, add_bias_kv=False, add_zero_attn=False): 23 | super().__init__() 24 | self.args = args 25 | self.embed_dim = embed_dim 26 | self.num_heads = num_heads 27 | self.dropout = dropout 28 | self.head_dim = embed_dim // num_heads 29 | assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads" 30 | self.scaling = self.head_dim ** -0.5 31 | 32 | self.in_proj_weight = Parameter(torch.Tensor(3 * embed_dim, embed_dim)) 33 | if bias: 34 | self.in_proj_bias = Parameter(torch.Tensor(3 * embed_dim)) 35 | else: 36 | self.register_parameter('in_proj_bias', None) 37 | self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) 38 | 39 | if add_bias_kv: 40 | self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim)) 41 | self.bias_v = Parameter(torch.Tensor(1, 1, embed_dim)) 42 | else: 43 | self.bias_k = self.bias_v = None 44 | 45 | self.add_zero_attn = add_zero_attn 46 | 47 | self.reset_parameters() 48 | 49 | self.onnx_trace = False 50 | 51 | def prepare_for_onnx_export_(self): 52 | self.onnx_trace = True 53 | 54 | def reset_parameters(self): 55 | nn.init.xavier_uniform_(self.in_proj_weight) 56 | nn.init.xavier_uniform_(self.out_proj.weight) 57 | if self.in_proj_bias is not None: 58 | nn.init.constant_(self.in_proj_bias, 0.) 59 | nn.init.constant_(self.out_proj.bias, 0.) 60 | if self.bias_k is not None: 61 | nn.init.xavier_normal_(self.bias_k) 62 | if self.bias_v is not None: 63 | nn.init.xavier_normal_(self.bias_v) 64 | 65 | def forward(self, query, key, value, key_padding_mask=None, incremental_state=None, 66 | need_weights=True, static_kv=False, attn_mask=None, **kwargs): 67 | """Input shape: Time x Batch x Channel 68 | 69 | Self-attention can be implemented by passing in the same arguments for 70 | query, key and value. Timesteps can be masked by supplying a T x T mask in the 71 | `attn_mask` argument. Padding elements can be excluded from 72 | the key by passing a binary ByteTensor (`key_padding_mask`) with shape: 73 | batch x src_len, where padding elements are indicated by 1s. 74 | """ 75 | 76 | qkv_same = query.data_ptr() == key.data_ptr() == value.data_ptr() 77 | kv_same = key.data_ptr() == value.data_ptr() 78 | 79 | tgt_len, bsz, embed_dim = query.size() 80 | assert embed_dim == self.embed_dim 81 | assert list(query.size()) == [tgt_len, bsz, embed_dim] 82 | assert key.size() == value.size() 83 | 84 | if incremental_state is not None: 85 | saved_state = self._get_input_buffer(incremental_state) 86 | if 'prev_key' in saved_state: 87 | # previous time steps are cached - no need to recompute 88 | # key and value if they are static 89 | if static_kv: 90 | assert kv_same and not qkv_same 91 | key = value = None 92 | else: 93 | saved_state = None 94 | 95 | if qkv_same: 96 | # self-attention 97 | q, k, v = self.in_proj_qkv(query) 98 | elif kv_same: 99 | # encoder-decoder attention 100 | q = self.in_proj_q(query) 101 | if key is None: 102 | assert value is None 103 | k = v = None 104 | else: 105 | k, v = self.in_proj_kv(key) 106 | else: 107 | q = self.in_proj_q(query) 108 | k = self.in_proj_k(key) 109 | v = self.in_proj_v(value) 110 | q *= self.scaling 111 | 112 | if self.bias_k is not None: 113 | assert self.bias_v is not None 114 | k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)]) 115 | v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)]) 116 | if attn_mask is not None: 117 | attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1) 118 | if key_padding_mask is not None: 119 | key_padding_mask = torch.cat( 120 | [key_padding_mask, key_padding_mask.new_zeros(key_padding_mask.size(0), 1)], dim=1) 121 | 122 | q = q.contiguous().view(tgt_len, bsz * self.num_heads, self.head_dim).transpose(0, 1) 123 | if k is not None: 124 | k = k.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1) 125 | if v is not None: 126 | v = v.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1) 127 | 128 | if saved_state is not None: 129 | # saved states are stored with shape (bsz, num_heads, seq_len, head_dim) 130 | if 'prev_key' in saved_state: 131 | prev_key = saved_state['prev_key'].view(bsz * self.num_heads, -1, self.head_dim) 132 | if static_kv: 133 | k = prev_key 134 | else: 135 | k = torch.cat((prev_key, k), dim=1) 136 | if 'prev_value' in saved_state: 137 | prev_value = saved_state['prev_value'].view(bsz * self.num_heads, -1, self.head_dim) 138 | if static_kv: 139 | v = prev_value 140 | else: 141 | v = torch.cat((prev_value, v), dim=1) 142 | saved_state['prev_key'] = k.view(bsz, self.num_heads, -1, self.head_dim) 143 | saved_state['prev_value'] = v.view(bsz, self.num_heads, -1, self.head_dim) 144 | 145 | self._set_input_buffer(incremental_state, saved_state) 146 | 147 | src_len = k.size(1) 148 | 149 | if key_padding_mask is not None: 150 | assert key_padding_mask.size(0) == bsz 151 | assert key_padding_mask.size(1) == src_len 152 | 153 | if self.add_zero_attn: 154 | src_len += 1 155 | k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1) 156 | v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1) 157 | if attn_mask is not None: 158 | attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1) 159 | if key_padding_mask is not None: 160 | key_padding_mask = torch.cat( 161 | [key_padding_mask, torch.zeros(key_padding_mask.size(0), 1).type_as(key_padding_mask)], dim=1) 162 | 163 | attn_weights = torch.bmm(q, k.transpose(1, 2)) 164 | assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len] 165 | 166 | if attn_mask is not None: 167 | attn_mask = attn_mask.unsqueeze(0) 168 | if self.onnx_trace: 169 | attn_mask = attn_mask.repeat(attn_weights.size(0), 1, 1) 170 | attn_weights += attn_mask 171 | 172 | if key_padding_mask is not None: 173 | # don't attend to padding symbols 174 | attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) 175 | if self.onnx_trace: 176 | attn_weights = torch.where( 177 | key_padding_mask.unsqueeze(1).unsqueeze(2), 178 | torch.Tensor([float("-Inf")]), 179 | attn_weights.float() 180 | ).type_as(attn_weights) 181 | else: 182 | attn_weights = attn_weights.float().masked_fill( 183 | key_padding_mask.unsqueeze(1).unsqueeze(2), 184 | float('-inf'), 185 | ).type_as(attn_weights) # FP16 support: cast to float and back 186 | attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) 187 | 188 | attn_weights = utils.softmax( 189 | attn_weights, dim=-1, onnx_trace=self.onnx_trace, 190 | ).type_as(attn_weights) 191 | attn_weights = F.dropout(attn_weights, p=self.dropout, training=self.training) 192 | 193 | attn = torch.bmm(attn_weights, v) 194 | assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim] 195 | if (self.onnx_trace and attn.size(1) == 1): 196 | # when ONNX tracing a single decoder step (sequence length == 1) 197 | # the transpose is a no-op copy before view, thus unnecessary 198 | attn = attn.contiguous().view(tgt_len, bsz, embed_dim) 199 | else: 200 | attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) 201 | attn = self.out_proj(attn) 202 | 203 | if need_weights: 204 | # average attention weights over heads 205 | attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) 206 | attn_weights = attn_weights.sum(dim=1) / self.num_heads 207 | else: 208 | attn_weights = None 209 | 210 | return attn, attn_weights 211 | 212 | def in_proj_qkv(self, query): 213 | return self._in_proj(query).chunk(3, dim=-1) 214 | 215 | def in_proj_kv(self, key): 216 | return self._in_proj(key, start=self.embed_dim).chunk(2, dim=-1) 217 | 218 | def in_proj_q(self, query): 219 | return self._in_proj(query, end=self.embed_dim) 220 | 221 | def in_proj_k(self, key): 222 | return self._in_proj(key, start=self.embed_dim, end=2 * self.embed_dim) 223 | 224 | def in_proj_v(self, value): 225 | return self._in_proj(value, start=2 * self.embed_dim) 226 | 227 | def _in_proj(self, input, start=0, end=None): 228 | weight = self.in_proj_weight 229 | bias = self.in_proj_bias 230 | weight = weight[start:end, :] 231 | if bias is not None: 232 | bias = bias[start:end] 233 | return F.linear(input, weight, bias) 234 | 235 | def reorder_incremental_state(self, incremental_state, new_order): 236 | """Reorder buffered internal state (for incremental generation).""" 237 | input_buffer = self._get_input_buffer(incremental_state) 238 | if input_buffer is not None: 239 | for k in input_buffer.keys(): 240 | input_buffer[k] = input_buffer[k].index_select(0, new_order) 241 | self._set_input_buffer(incremental_state, input_buffer) 242 | 243 | def _get_input_buffer(self, incremental_state): 244 | return utils.get_incremental_state( 245 | self, 246 | incremental_state, 247 | 'attn_state', 248 | ) or {} 249 | 250 | def _set_input_buffer(self, incremental_state, buffer): 251 | utils.set_incremental_state( 252 | self, 253 | incremental_state, 254 | 'attn_state', 255 | buffer, 256 | ) 257 | 258 | 259 | 260 | -------------------------------------------------------------------------------- /src/modules/dptree_individual_multihead_attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import Parameter 4 | import torch.nn.functional as F 5 | 6 | from fairseq import utils 7 | 8 | from fairseq.modules.multihead_attention import * 9 | 10 | DEBUG = False 11 | from .dptree_multihead_attention import * 12 | from .dptree_sep_multihead_attention import * 13 | 14 | 15 | class DPTreeIndividualOnlyKeyAttention(DPTreeSeparateOnlyKeyAttention): 16 | 17 | @classmethod 18 | def indices2gt_indices(cls, indices, seq_len, query_len, heads, nsent, head_dim=None): 19 | """ 20 | 21 | :param indices: [b, m, tk, 2] 22 | :param seq_len: 23 | :param query_len: 24 | :param heads: 25 | :param head_dim: 26 | :return: [B * h, m, tq, Tk] 27 | """ 28 | gt_idx = (indices[:, :, :, 0] * seq_len + indices[:, :, :, 1]).unsqueeze_(1).unsqueeze_(3) 29 | # size = gt_idx.size() 30 | bsz, _, nsent, _, tk = gt_idx.size() 31 | gt_idx = gt_idx.expand(-1, heads, -1, query_len, -1).contiguous().view(bsz * heads, nsent, query_len, tk) 32 | return gt_idx 33 | 34 | @classmethod 35 | def indices2flat_indices(cls, indices, seq_len, head_dim, heads, nsent, query_len=None): 36 | assert query_len is not None 37 | fl_idx = cls.indices2gt_indices(indices, seq_len, query_len, heads, nsent=nsent) 38 | return fl_idx 39 | 40 | def dptree_dot_product(self, q, k, fl_idx, gt_idx, seq_len): 41 | """ 42 | 43 | :param q: [bq * h, m, tq, d] 44 | :param k: [bk * h, m, tk, d] 45 | :param fl_idx: [bk * h, m, tq, tk] 46 | :param gt_idx: [bk * h, m, tq, tk] 47 | :param seq_len: 48 | :return:dp_scores [bk * h, m, tq, tk] 49 | """ 50 | bqh, m, tq, d = q.size() 51 | # q = q.unsqueeze(1) 52 | # q: [bq * h, m, tq, d] 53 | linear_scores = torch.matmul(q, k.transpose(2, 3)) 54 | # linear_scores: [bq * h, m, tq, tk] 55 | 56 | matrix_scores = self.__class__.scores_affinity2dptable(linear_scores, fl_idx, seq_len) 57 | # matrix: [b, * h, m, tq, t, t] 58 | """ 59 | a b c d 60 | 0 f g h 61 | 0 0 i j 62 | 0 0 0 k 63 | 64 | a 0 0 0 65 | b f 0 0 66 | c g i 0 67 | d h j k 68 | """ 69 | acc_fw_matrix = torch.cumsum(matrix_scores, dim=4) 70 | acc_bw_matrix = torch.cumsum(matrix_scores, dim=3).transpose(3, 4) 71 | dp_matrix = torch.matmul(acc_fw_matrix, acc_bw_matrix) 72 | bszh, nsent, tq_dp_mat, t1, t2 = dp_matrix.size() 73 | assert t1 == t2, f"{t1} != {t2}" 74 | assert tq == tq_dp_mat, f'{dp_matrix.size()} ??? {q.size()} ??? {linear_scores.size()} ??? {fl_idx.size()} ??? {gt_idx.size()}' 75 | 76 | dp_linear_mat = dp_matrix.view(bszh, nsent, tq, t1 * t2) 77 | dp_scores = torch.gather(dp_linear_mat, dim=3, index=gt_idx) 78 | 79 | return dp_scores 80 | 81 | def compute_dptree_att(self, q, k, v, fl_idx, gt_idx, attn_mask, key_padding_mask, src_len, tgt_len, 82 | bsz, need_weights): 83 | """ 84 | 85 | :param q: [B * h, m, tq, d] 86 | :param k: [B * h, m, tk, d] 87 | :param v: [B * h, m, tk, d] 88 | :param fl_idx: [B * h, m, tq, Tk] nor [B * h, d, Tk] 89 | :param gt_idx: [B * h, m, tq, tk] 90 | :param attn_mask: 91 | :param key_padding_mask: [B, m, tk] 92 | :param src_len: 93 | :param tgt_len: 94 | :param bsz: 95 | :param need_weights: 96 | :return: 97 | """ 98 | k_size = k.size() 99 | node_len = k_size[2] 100 | nsent = k_size[1] 101 | seq_len = int((node_len + 1) // 2) + 1 102 | 103 | attn_weights = self.dptree_dot_product(q, k, fl_idx, gt_idx, seq_len) 104 | assert not torch.isnan(attn_weights).any() 105 | 106 | # assert list(attn_weights.size()) == [bsz * self.num_heads, nsent, tgt_len, src_len], 107 | if list(attn_weights.size()) != [bsz * self.num_heads, nsent, tgt_len, src_len]: 108 | raise ValueError(f'{attn_weights.size()} != {[bsz * self.num_heads, nsent, tgt_len, src_len]}, q={q.size()}, fl_idx={fl_idx.size()}') 109 | 110 | if attn_mask is not None: 111 | raise NotImplementedError('attn_mask for decoder not yet') 112 | 113 | # if key_padding_mask is not None: 114 | # attn_weights = attn_weights.view(bsz, self.num_heads, nsent, tgt_len, src_len) 115 | # 116 | # exp_key_padding_mask = key_padding_mask.unsqueeze(1).unsqueeze(3) 117 | # # exp_key_padding_mask: [bsz, 1, nsent, 1, src_len] 118 | # assert not self.onnx_trace 119 | # 120 | # # src_lens = (1.0 - exp_key_padding_mask.type_as(attn_weights)).sum(dim=-1, keepdim=True) 121 | # # src_lens = (1.0 - exp_key_padding_mask.type_as(attn_weights)).sum(dim=-1, keepdim=True).clamp_(min=1.0, max=1e9) 122 | # src_lens_denom = (~exp_key_padding_mask).type_as(attn_weights).sum(dim=-1, keepdim=True).clamp_(min=1.0, max=10000) 123 | # # assert not (src_lens.int() == 0).any(), f'{key_padding_mask}' 124 | # src_lens_denom = src_lens_denom.sqrt_() 125 | # attn_weights /= src_lens_denom 126 | # 127 | # attn_weights = attn_weights.float().masked_fill( 128 | # exp_key_padding_mask, float('-inf')).type_as(attn_weights) 129 | # 130 | # attn_weights = attn_weights.view(bsz * self.num_heads, nsent, tgt_len, src_len) 131 | # 132 | # else: 133 | # src_lens_denom = torch.tensor(node_len, dtype=attn_weights.dtype, device=attn_weights.device) 134 | # src_lens_denom = src_lens_denom.sqrt_() 135 | # attn_weights /= src_lens_denom 136 | if key_padding_mask is not None: 137 | attn_weights = attn_weights.view(bsz, self.num_heads, nsent, tgt_len, src_len) 138 | 139 | exp_pad_mask = key_padding_mask.unsqueeze(1).unsqueeze(3) 140 | assert not self.onnx_trace 141 | 142 | src_lens_denom = (~exp_pad_mask).type_as(attn_weights).sum(dim=-1, keepdim=True).clamp_(min=1.0, max=10000) 143 | src_lens_denom = self.norm_src_len(src_lens_denom) 144 | attn_weights /= src_lens_denom 145 | 146 | attn_weights = attn_weights.float().masked_fill( 147 | exp_pad_mask, float('-inf')).type_as(attn_weights) 148 | 149 | attn_weights = attn_weights.view(bsz * self.num_heads, nsent, tgt_len, src_len) 150 | 151 | else: 152 | src_lens_denom = torch.tensor(node_len, dtype=attn_weights.dtype, device=attn_weights.device) 153 | src_lens_denom = self.norm_src_len(src_lens_denom) 154 | attn_weights /= src_lens_denom 155 | # src_lens = None 156 | # assert not torch.isnan(attn_weights).any() 157 | assert not torch.isnan(attn_weights).any(), f'src_lens: {src_lens_denom is None}: {src_lens_denom == 0}' 158 | 159 | attn_weights = F.softmax(attn_weights.float(), dim=-1).type_as(attn_weights) 160 | # Since some docs have empty tree in batch, softmax(-inf all) -> NaN -> replace with zeros 161 | # attn_weights[attn_weights != attn_weights] = 0 162 | attn_weights = torch.where(torch.isnan(attn_weights), torch.zeros_like(attn_weights), attn_weights) 163 | 164 | attn_weights = F.dropout(attn_weights, p=self.dropout, training=self.training) 165 | 166 | # attn_weights: [b * h, m, tgt_len, src_len] 167 | # values: [b * h, m, tk, d] 168 | 169 | # attn_weights = attn_weights.permute(0, 2, 3, 1).contiguous().view(bsz * self.num_heads, tgt_len, src_len * nsent) 170 | # attn_weights: [b * h, tgt_len, src_len * m] 171 | assert not torch.isnan(attn_weights).any() 172 | 173 | attn = torch.matmul(attn_weights, v) 174 | assert not torch.isnan(attn).any() 175 | # attn: [b * h, m, tq, d] 176 | 177 | assert list(attn.size()) == [bsz * self.num_heads, nsent, tgt_len, self.head_dim] 178 | assert not self.onnx_trace 179 | # attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, self.embed_dim) 180 | attn = attn.permute(2, 1, 0, 3).contiguous().view(tgt_len, nsent, bsz, self.embed_dim) 181 | # attn: [tq, m, b, h * dim] 182 | 183 | if need_weights: 184 | # average attention weights over heads 185 | attn_weights = attn_weights.view(bsz, self.num_heads, nsent, tgt_len, src_len) 186 | attn_weights = attn_weights.sum(dim=1) / self.num_heads 187 | else: 188 | attn_weights = None 189 | 190 | return attn, attn_weights 191 | 192 | def forward( 193 | self, query, key, value, flat_indices, gt_indices, key_padding_mask=None, 194 | incremental_state=None, 195 | need_weights=True, static_kv=False, attn_mask=None, force_self_att=False): 196 | """Input shape: Time x Batch x Channel 197 | 198 | :param query: [Tq, B, m, C] 199 | :param key: [Tk, B, m, C] 200 | :param value: [Tk, B, m, C] 201 | :param flat_indices: [B * h, m, Tq, Tk] 202 | :param gt_indices: [B * h, m, Tq, Tk] 203 | :param key_padding_mask: [B, m, Tk] 204 | """ 205 | 206 | assert flat_indices.size() == gt_indices.size(), f'{flat_indices.size()} != {gt_indices.size()}' 207 | tq, query_bsz, qnsent, dim = query.size() 208 | tk, key_bsz, nsent, dim_k = key.size() 209 | assert query_bsz == key_bsz 210 | assert qnsent == nsent 211 | 212 | assert attn_mask is None, f'not None attn_mask (decoder self-attention) not ready' 213 | 214 | f_query = query.view(tq, query_bsz * qnsent, dim) 215 | f_key = key.view(tk, key_bsz * nsent, dim_k) 216 | f_value = value.view(tk, key_bsz * nsent, dim_k) 217 | 218 | # f_key_pad_mask = key_padding_mask = key_padding_mask.view(key_bsz * nsent, tk) 219 | assert not torch.isnan(query).any() 220 | ( 221 | f_query, f_key, f_value, key_padding_mask, saved_state, src_len, tgt_len, query_bsz_) = self.prepare_dptree_qkv( 222 | f_query, f_key, f_value, key_padding_mask, incremental_state, need_weights, static_kv, 223 | force_self_att=force_self_att 224 | ) 225 | # q: [bq * m * h, tq, d] 226 | # fk: [bk * m * h, tk, d] 227 | # fv: [bk * m * h, tk, d] 228 | # fpad: [bk, m, tk] 229 | assert not torch.isnan(f_query).any() 230 | assert not torch.isnan(f_key).any() 231 | assert not torch.isnan(f_value).any() 232 | if key_padding_mask is not None: 233 | assert not torch.isnan(key_padding_mask).any() 234 | 235 | # f_key = f_key.view(key_bsz, nsent, self.num_heads, tk, self.head_dim).contiguous().permute(0, 2, 1, 3, 4) 236 | # f_key = f_key.view(key_bsz * self.num_heads, nsent, tk, self.head_dim) 237 | # 238 | # f_value = f_value.view(key_bsz, nsent, self.num_heads, tk, self.head_dim).contiguous().permute(0, 2, 3, 1, 4) 239 | # f_value = f_value.view(key_bsz * self.num_heads, tk * nsent, self.head_dim) 240 | # 241 | 242 | f_query = f_query.view(query_bsz, qnsent, self.num_heads, tq, self.head_dim).permute(0, 2, 1, 3, 4).contiguous() 243 | # f_query: [b, h, m, tq, d] 244 | f_query = f_query.view(query_bsz * self.num_heads, qnsent, tq, self.head_dim) 245 | 246 | f_key = f_key.view(key_bsz, nsent, self.num_heads, tk, self.head_dim).permute(0, 2, 1, 3, 4).contiguous() 247 | # f_key: [b, h, m, tk, d] 248 | f_key = f_key.view(key_bsz * self.num_heads, nsent, tk, self.head_dim) 249 | 250 | f_value = f_value.view(key_bsz, nsent, self.num_heads, tk, self.head_dim).permute(0, 2, 1, 3, 4).contiguous() 251 | # f_value: [b, h, m, tk, d] 252 | f_value = f_value.view(key_bsz * self.num_heads, nsent, tk, self.head_dim) 253 | 254 | # q: [bq * h, tq, d] 255 | # fk: [bk * h, m, tk, d] 256 | # fv: [bk * h, m, tk, d] 257 | # fpad: [bk, m, tk] 258 | 259 | (attn, attn_weights) = self.compute_dptree_att( 260 | f_query, f_key, f_value, flat_indices, gt_indices, attn_mask, key_padding_mask, 261 | src_len, tgt_len, query_bsz, need_weights) 262 | 263 | assert not torch.isnan(attn).any() 264 | # attn: [tq, m, b, h * dim] 265 | return attn, attn_weights 266 | 267 | 268 | class DPTreeIndividualRNNOnlyKeyAttention(DPTreeIndividualOnlyKeyAttention): 269 | 270 | def forward(self, query, key, value, flat_indices, gt_indices, key_padding_mask=None, incremental_state=None, 271 | need_weights=True, static_kv=False, attn_mask=None, force_self_att=False): 272 | """Input shape: Time x Batch x Channel 273 | *** Only Apply to Encoder-Self-Attention 274 | :param query: [Tq, B, m, C] 275 | :param key: [Tk, B, m, C] 276 | :param value: [Tk, B, m, C] 277 | :param flat_indices: [B * h, m, Tq, Tk] 278 | :param gt_indices: [B * h, m, Tq, Tk] 279 | :param key_padding_mask: [B, m, Tk] 280 | """ 281 | 282 | assert flat_indices.size() == gt_indices.size(), f'{flat_indices.size()} != {gt_indices.size()}' 283 | tq, query_bsz, qnsent, dim = query.size() 284 | tk, key_bsz, nsent, dim_k = key.size() 285 | assert query_bsz == key_bsz 286 | assert qnsent == nsent 287 | 288 | assert attn_mask is None, f'not None attn_mask (decoder self-attention) not ready' 289 | assert incremental_state is None, f'not None incremental_state' 290 | 291 | queries = query.chunk(qnsent, 2) 292 | keys = key.chunk(nsent, 2) 293 | values = value.chunk(nsent, 2) 294 | fl_indices_list = flat_indices.chunk(nsent, 1) 295 | gt_indices_list = gt_indices.chunk(nsent, 1) 296 | pad_masks = key_padding_mask.chunk(nsent, 1) 297 | 298 | assert attn_mask is None 299 | 300 | reduce_lengths = [] 301 | original_lengths = [] 302 | attentions = [] 303 | for i, (q, k, v, fi, gt, mask) in enumerate(zip( 304 | queries, keys, values, fl_indices_list, gt_indices_list, pad_masks 305 | )): 306 | # reduce padding.... 307 | # mask: [b, 1, tk] 308 | ori_length = mask.size(-1) 309 | original_lengths.append(ori_length) 310 | length_mask = (~mask).int().sum(-1) + 1 311 | max_length = length_mask.max() 312 | reduce_lengths.append(max_length) 313 | 314 | q = q[:max_length] 315 | k = k[:max_length] 316 | v = v[:max_length] 317 | # FIXME: check this one 318 | fi = fi[:, :, :max_length, :max_length] 319 | gt = gt[:, :, :max_length, :max_length] 320 | mask = mask[:, :, :max_length] 321 | 322 | attn, attn_weights_ = super().forward( 323 | q, k, v, fi, gt, mask, incremental_state, 324 | need_weights, static_kv, attn_mask, force_self_att) 325 | 326 | # pad_attention 327 | # attn: [tq, m, b, h * dim] 328 | attn = F.pad(attn, [0, 0, 0, 0, 0, 0, 0, tq - max_length]) 329 | attentions.append(attn) 330 | # attn: [tq, m, b, h * dim] 331 | 332 | out_attn = torch.cat(attentions, 1) 333 | return out_attn, None 334 | 335 | 336 | 337 | 338 | -------------------------------------------------------------------------------- /src/optim/__init__.py: -------------------------------------------------------------------------------- 1 | from . import lr_scheduler -------------------------------------------------------------------------------- /src/optim/lr_scheduler/__init__.py: -------------------------------------------------------------------------------- 1 | from . import flex_inv_sqrt_schedule 2 | 3 | 4 | -------------------------------------------------------------------------------- /src/optim/lr_scheduler/flex_inv_sqrt_schedule.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | from fairseq.optim.lr_scheduler import FairseqLRScheduler, register_lr_scheduler 9 | 10 | 11 | @register_lr_scheduler('capped_inverse_sqrt') 12 | class CappedInverseSquareRootSchedule(FairseqLRScheduler): 13 | """Decay the LR based on the inverse square root of the update number. 14 | 15 | We also support a warmup phase where we linearly increase the learning rate 16 | from some initial learning rate (``--warmup-init-lr``) until the configured 17 | learning rate (``--lr``). Thereafter we decay proportional to the number of 18 | updates, with a decay factor set to align with the configured learning rate. 19 | 20 | During warmup:: 21 | 22 | lrs = torch.linspace(args.warmup_init_lr, args.lr, args.warmup_updates) 23 | lr = lrs[update_num] 24 | 25 | After warmup:: 26 | 27 | decay_factor = args.lr * sqrt(inv_decay) 28 | lr = decay_factor / sqrt(update_num) 29 | """ 30 | 31 | def __init__(self, args, optimizer): 32 | super().__init__(args, optimizer) 33 | if len(args.lr) > 1: 34 | raise ValueError( 35 | 'Cannot use a fixed learning rate schedule with inverse_sqrt.' 36 | ' Consider --lr-scheduler=fixed instead.' 37 | ) 38 | warmup_end_lr = args.lr[0] 39 | if args.warmup_init_lr < 0: 40 | args.warmup_init_lr = warmup_end_lr 41 | 42 | self.inv_decay = args.inv_decay if args.inv_decay > 0 else args.warmup_updates 43 | # linearly warmup for the first args.warmup_updates 44 | self.lr_step = (warmup_end_lr - args.warmup_init_lr) / args.warmup_updates 45 | 46 | self.max_lr = warmup_end_lr 47 | 48 | # then, decay prop. to the inverse square root of the update number 49 | # self.decay_factor = warmup_end_lr * args.warmup_updates**0.5 50 | self.decay_factor = warmup_end_lr * self.inv_decay ** 0.5 51 | 52 | # initial learning rate 53 | self.lr = args.warmup_init_lr 54 | self.optimizer.set_lr(self.lr) 55 | 56 | print(f'inv_decay={self.inv_decay},warmup={args.warmup_updates},decay_factor={self.decay_factor}') 57 | 58 | @staticmethod 59 | def add_args(parser): 60 | """Add arguments to the parser for this LR scheduler.""" 61 | # fmt: off 62 | parser.add_argument('--warmup-updates', default=4000, type=int, metavar='N', 63 | help='warmup the learning rate linearly for the first N updates') 64 | parser.add_argument('--inv-decay', default=-1, type=int, metavar='N', 65 | help='inv-decay, set to warmup if not >0') 66 | parser.add_argument('--warmup-init-lr', default=-1, type=float, metavar='LR', 67 | help='initial learning rate during warmup phase; default is args.lr') 68 | # fmt: on 69 | 70 | def step(self, epoch, val_loss=None): 71 | """Update the learning rate at the end of the given epoch.""" 72 | super().step(epoch, val_loss) 73 | # we don't change the learning rate at epoch boundaries 74 | return self.optimizer.get_lr() 75 | 76 | def step_update(self, num_updates): 77 | """Update the learning rate after each update.""" 78 | if num_updates < self.args.warmup_updates: 79 | self.lr = self.args.warmup_init_lr + num_updates*self.lr_step 80 | else: 81 | self.lr = min(self.decay_factor * num_updates**-0.5, self.max_lr) 82 | self.optimizer.set_lr(self.lr) 83 | return self.lr 84 | -------------------------------------------------------------------------------- /src/tasks/__init__.py: -------------------------------------------------------------------------------- 1 | from .dptree2seq_translation import * 2 | from .fairseq_classification import * 3 | from .dptree_classification import * 4 | from .dptree_sep_classification import * 5 | from .dptree2seq_sep_translation import * 6 | from .nstack_from_dptree_classification import * 7 | from .nstack_classification import * 8 | from .nstack2seq_translation import * 9 | 10 | 11 | -------------------------------------------------------------------------------- /src/tasks/dptree2seq_sep_translation.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | import numpy as np 3 | import os 4 | import torch 5 | from fairseq import options, utils 6 | from fairseq.data import ( 7 | data_utils, Dictionary, LanguagePairDataset, ConcatDataset, 8 | IndexedRawTextDataset, IndexedCachedDataset, IndexedDataset 9 | ) 10 | from fairseq.sequence_generator import SequenceGenerator, EnsembleModel 11 | import math 12 | from ..data import ( 13 | DPTree2SeqPairDataset, DPTreeIndexedCachedDataset, DPTree2SeqSeparatePairDataset, DPTreeWrapperDictionary 14 | ) 15 | 16 | from .dptree2seq_translation import * 17 | from fairseq.tasks import FairseqTask, register_task 18 | 19 | 20 | @register_task('dptree2seq_sep') 21 | class DPTree2SeqSeparateTranslationTask(DPTree2SeqTranslationTask): 22 | 23 | 24 | @classmethod 25 | def setup_task(cls, args, **kwargs): 26 | """Setup the task (e.g., load dictionaries). 27 | 28 | Args: 29 | args (argparse.Namespace): parsed command-line arguments 30 | """ 31 | args.left_pad_source = options.eval_bool(args.left_pad_source) 32 | args.left_pad_target = options.eval_bool(args.left_pad_target) 33 | 34 | assert not args.left_pad_source, f'args.left_pad_source must be False' 35 | 36 | # find language pair automatically 37 | if args.source_lang is None or args.target_lang is None: 38 | args.source_lang, args.target_lang = data_utils.infer_language_pair(args.data[0]) 39 | if args.source_lang is None or args.target_lang is None: 40 | raise Exception('Could not infer language pair, please provide it explicitly') 41 | 42 | # load dictionaries 43 | args.no_strip_node_label = getattr(args, 'no_strip_node_label', False) 44 | src_dict = DPTreeWrapperDictionary.load( 45 | os.path.join(args.data[0], 'dict.{}.txt'.format(args.source_lang)), 46 | no_strip_node_label=args.no_strip_node_label) 47 | tgt_dict = Dictionary.load(os.path.join(args.data[0], 'dict.{}.txt'.format(args.target_lang))) 48 | assert src_dict.pad() == tgt_dict.pad() 49 | assert src_dict.eos() == tgt_dict.eos() 50 | assert src_dict.unk() == tgt_dict.unk() 51 | print('| [{}] DPtree-dictionary: {} types'.format(args.source_lang, len(src_dict))) 52 | print('| [{}] dictionary: {} types'.format(args.target_lang, len(tgt_dict))) 53 | 54 | return cls(args, src_dict, tgt_dict) 55 | 56 | def build_generator(self, args): 57 | if args.score_reference: 58 | from fairseq.sequence_scorer import SequenceScorer 59 | return SequenceScorer(self.target_dictionary) 60 | raise NotImplementedError 61 | else: 62 | from ..dptree2seq_generator import DPtree2SeqSeparateGenerator 63 | assert self.target_dictionary.eos() == self.source_dictionary.eos(), f'{self.target_dictionary.eos()} - {self.source_dictionary.eos()}' 64 | return DPtree2SeqSeparateGenerator( 65 | self.target_dictionary, 66 | beam_size=args.beam, 67 | max_len_a=args.max_len_a, 68 | max_len_b=args.max_len_b, 69 | min_len=args.min_len, 70 | stop_early=(not args.no_early_stop), 71 | normalize_scores=(not args.unnormalized), 72 | len_penalty=args.lenpen, 73 | unk_penalty=args.unkpen, 74 | sampling=args.sampling, 75 | sampling_topk=args.sampling_topk, 76 | sampling_temperature=args.sampling_temperature, 77 | diverse_beam_groups=args.diverse_beam_groups, 78 | diverse_beam_strength=args.diverse_beam_strength, 79 | match_source_len=args.match_source_len, 80 | no_repeat_ngram_size=args.no_repeat_ngram_size, 81 | ) 82 | 83 | def load_dataset(self, split, combine=False, **kwargs): 84 | def split_exists(split, src, tgt, lang, data_path): 85 | filename = os.path.join(data_path, '{}.{}-{}.{}'.format(split, src, tgt, lang)) 86 | if self.args.raw_text and IndexedRawTextDataset.exists(filename): 87 | return True 88 | elif not self.args.raw_text and IndexedDataset.exists(filename): 89 | return True 90 | return False 91 | 92 | def indexed_dataset(path, dictionary): 93 | if self.args.raw_text: 94 | raise NotImplementedError 95 | elif IndexedDataset.exists(path): 96 | return DPTreeIndexedCachedDataset(path, fix_lua_indexing=True) 97 | return None 98 | 99 | src_datasets = [] 100 | src_datasets_dict = {k: [] for k in DPTREE_KEYS} 101 | tgt_datasets = [] 102 | 103 | data_paths = self.args.data 104 | print(f'| split = {split}') 105 | print(f'| self.args.data = {self.args.data}') 106 | 107 | for dk, data_path in enumerate(data_paths): 108 | for k in itertools.count(): 109 | split_k = split + (str(k) if k > 0 else '') 110 | 111 | # infer langcode 112 | src, tgt = self.args.source_lang, self.args.target_lang 113 | if split_exists(split_k, src, tgt, tgt, data_path): 114 | prefix = os.path.join(data_path, '{}.{}-{}.'.format(split_k, src, tgt)) 115 | elif split_exists(split_k, tgt, src, tgt, data_path): 116 | prefix = os.path.join(data_path, '{}.{}-{}.'.format(split_k, tgt, src)) 117 | else: 118 | if k > 0 or dk > 0: 119 | break 120 | else: 121 | raise FileNotFoundError('Dataset not found: {} ({})'.format(split, data_path)) 122 | 123 | for modality in src_datasets_dict.keys(): 124 | src_datasets_dict[modality].append(indexed_dataset(f'{prefix}{src}.{modality}', self.src_dict)) 125 | 126 | # src_datasets.append(indexed_dataset(prefix + src, self.src_dict)) 127 | tgt_datasets.append(indexed_dataset(prefix + tgt, self.tgt_dict)) 128 | 129 | print('| {} {} {} examples'.format(data_path, split_k, len(tgt_datasets[-1]))) 130 | 131 | if not combine: 132 | break 133 | 134 | assert len(src_datasets_dict[DPTREE_KEYS[0]]) == len(tgt_datasets) 135 | 136 | if len(tgt_datasets) == 1: 137 | # src_dataset, tgt_dataset = src_datasets[0], tgt_datasets[0] 138 | 139 | src_dataset_dict = {k: v[0] for k, v in src_datasets_dict.items()} 140 | tgt_dataset = tgt_datasets[0] 141 | else: 142 | sample_ratios = [1] * len(tgt_datasets) 143 | sample_ratios[0] = self.args.upsample_primary 144 | # src_dataset = ConcatDataset(src_datasets, sample_ratios) 145 | 146 | src_dataset_dict = {k: ConcatDataset(v, sample_ratios) for k, v in src_datasets_dict.items()} 147 | tgt_dataset = ConcatDataset(tgt_datasets, sample_ratios) 148 | 149 | # src_sizes = src_dataset_dict['nodes'].sizes 150 | src_sizes = src_dataset_dict['nodes'].sizes.reshape(-1, 2).sum(-1) 151 | 152 | self.datasets[split] = DPTree2SeqSeparatePairDataset( 153 | src_dataset_dict, src_sizes, self.src_dict, 154 | tgt_dataset, tgt_dataset.sizes, self.tgt_dict, 155 | left_pad_source=self.args.left_pad_source, 156 | left_pad_target=self.args.left_pad_target, 157 | max_source_positions=self.args.max_source_positions, 158 | max_target_positions=self.args.max_target_positions, 159 | remove_eos_from_source=self.args.remove_eos_from_source, 160 | append_eos_to_target=self.args.append_eos_to_target, 161 | input_feeding=self.args.input_feeding, 162 | ) 163 | 164 | 165 | 166 | 167 | -------------------------------------------------------------------------------- /src/tasks/dptree_classification.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | import numpy as np 3 | import os 4 | import torch 5 | from fairseq import options, utils 6 | from fairseq.data import ( 7 | data_utils, Dictionary, LanguagePairDataset, ConcatDataset, 8 | IndexedRawTextDataset, IndexedCachedDataset, IndexedDataset, 9 | TruncatedDictionary, 10 | ) 11 | 12 | from ..data import task_utils, monolingual_classification_dataset, MonolingualClassificationDataset, \ 13 | DPTreeMonoClassificationDataset, DPTreeIndexedCachedDataset 14 | 15 | from fairseq import tokenizer 16 | from fairseq.sequence_generator import SequenceGenerator, EnsembleModel 17 | import math 18 | 19 | from fairseq.tasks import FairseqTask, register_task 20 | 21 | from .dptree2seq_translation import DPTREE_KEYS 22 | from .fairseq_classification import SequenceClassification, MonolingualClassificationDataset 23 | 24 | 25 | 26 | def try_load_dictionary(args): 27 | try: 28 | dict_path = os.path.join(args.data, f'dict.txt') 29 | print(f'| dict_path = {dict_path}') 30 | dictionary = Dictionary.load(dict_path) 31 | except FileNotFoundError as e: 32 | dict_path = os.path.join(args.data, f'dict.{args.source_lang}.txt') 33 | print(f'| dict_path = {dict_path}') 34 | dictionary = Dictionary.load(dict_path) 35 | return dictionary 36 | 37 | 38 | @register_task('dptree_classification') 39 | class DPTreeClassification(FairseqTask): 40 | """ 41 | Model following language_modeling task. 42 | with target as label 43 | """ 44 | 45 | @staticmethod 46 | def add_args(parser): 47 | """Add task-specific arguments to the parser.""" 48 | # fmt: off 49 | parser.add_argument('data', help='path to data directory') 50 | # parser.add_argument('--sample-break-mode', 51 | # choices=['none', 'complete', 'eos'], 52 | # help='If omitted or "none", fills each sample with tokens-per-sample ' 53 | # 'tokens. If set to "complete", splits samples only at the end ' 54 | # 'of sentence, but may include multiple sentences per sample. ' 55 | # 'If set to "eos", includes only one sentence per sample.') 56 | # parser.add_argument('--tokens-per-sample', default=1024, type=int, 57 | # help='max number of tokens per sample for LM dataset') 58 | parser.add_argument('-s', '--source-lang', default=None, metavar='SRC', 59 | help='source language') 60 | parser.add_argument('--lazy-load', action='store_true', 61 | help='load the dataset lazily') 62 | parser.add_argument('--raw-text', default=False, action='store_true', 63 | help='load raw text dataset') 64 | parser.add_argument('--output-dictionary-size', default=-1, type=int, 65 | help='limit the size of output dictionary') 66 | parser.add_argument('--self-target', action='store_true', 67 | help='include self target') 68 | parser.add_argument('--future-target', action='store_true', 69 | help='include future target') 70 | parser.add_argument('--past-target', action='store_true', 71 | help='include past target') 72 | parser.add_argument('--left-pad-source', default='False', type=str, metavar='BOOL', 73 | help='pad the source on the left') 74 | # parser.add_argument('--max-source-positions', default=1024, type=int, metavar='N', 75 | # help='max number of tokens in the source sequence') 76 | parser.add_argument('--max-source-positions', default=1000000, type=int, metavar='N', 77 | help='max number of tokens in the source sequence') 78 | 79 | def __init__(self, args, dictionary, output_dictionary): 80 | super().__init__(args) 81 | self.dictionary = dictionary 82 | self.output_dictionary = output_dictionary 83 | 84 | @classmethod 85 | def setup_task(cls, args, **kwargs): 86 | """Setup the task (e.g., load dictionaries). 87 | 88 | Args: 89 | args (argparse.Namespace): parsed command-line arguments 90 | """ 91 | print(f'| args.data = {args.data}') 92 | args.left_pad_source = options.eval_bool(args.left_pad_source) 93 | # assert args.left_pad_source, f'Need left_pad_source True as use EOS as classifcation token' 94 | assert not args.left_pad_source, f'args.left_pad_source must be False as it the root for classification' 95 | 96 | assert args.source_lang is not None 97 | if args.source_lang is None: 98 | args.source_lang = task_utils.infer_language_mono(args.data) 99 | 100 | dict_path = os.path.join(args.data, 'dict.txt') 101 | if not os.path.exists(dict_path): 102 | dict_path = os.path.join(args.data, f'dict.{args.source_lang}.txt') 103 | 104 | dictionary = None 105 | output_dictionary = None 106 | if args.data: 107 | dictionary = Dictionary.load(dict_path) 108 | print('| dictionary: {} types'.format(len(dictionary))) 109 | output_dictionary = dictionary 110 | if args.output_dictionary_size >= 0: 111 | output_dictionary = TruncatedDictionary(dictionary, args.output_dictionary_size) 112 | 113 | if args.source_lang is None: 114 | args.source_lang = task_utils.infer_language_mono(args.data) 115 | 116 | # dict_path = os.path.join(args.data, 'dict.txt') 117 | # src_dict = Dictionary.load(dict_path) 118 | # print('| [{}] dictionary: {} types'.format(args.source_lang, len(src_dict))) 119 | return cls(args, dictionary, output_dictionary) 120 | 121 | def load_dataset(self, split, combine=False, **kwargs): 122 | """Load a given dataset split. 123 | 124 | Args: 125 | split (str): name of the split (e.g., train, valid, test) 126 | """ 127 | 128 | def split_exists(split, data_type, data_path): 129 | filename = os.path.join(data_path, f'{split}.{data_type}') 130 | assert not self.args.raw_text 131 | # if self.args.raw_text and IndexedRawTextDataset.exists(filename): 132 | # return True 133 | # elif not self.args.raw_text and IndexedDataset.exists(filename): 134 | # return True 135 | # return False 136 | exists = [IndexedDataset.exists(os.path.join(data_path, f'{split}.{data_type}.{k}')) for k in DPTREE_KEYS] 137 | if all(exists): 138 | return True 139 | else: 140 | print(f'Following modality not exists: {exists}') 141 | return False 142 | 143 | # def indexed_dataset(path, dictionary): 144 | def indexed_dataset(path): 145 | assert IndexedCachedDataset.exists(path), f'IndexedCachedDataset.exists({path})' 146 | return IndexedCachedDataset(path, fix_lua_indexing=True) 147 | 148 | def dptree_indexed_dataset(path): 149 | assert DPTreeIndexedCachedDataset.exists(path), f'DPTreeIndexedCachedDataset.exists({path})' 150 | return DPTreeIndexedCachedDataset(path, fix_lua_indexing=True) 151 | 152 | src_datasets = [] 153 | tgt_datasets = [] 154 | src_datasets_dict = {k: [] for k in DPTREE_KEYS} 155 | 156 | # data_paths = self.args.data 157 | data_path = self.args.data 158 | print(f'| split = {split}') 159 | print(f'| self.args.data = {self.args.data}') 160 | # singular data path 161 | lang = self.args.source_lang 162 | 163 | src, tgt = 'input', 'target' 164 | 165 | for k in itertools.count(): 166 | split_k = split + (str(k) if k > 0 else '') 167 | if split_exists(split_k, src, data_path): 168 | prefix = os.path.join(data_path, f'{split}.') 169 | else: 170 | if k > 0: 171 | break 172 | else: 173 | raise FileNotFoundError('Dataset not found: {} ({})'.format(split, data_path)) 174 | # src_datasets.append(indexed_dataset(prefix + src)) 175 | for modality in src_datasets_dict.keys(): 176 | src_datasets_dict[modality].append(dptree_indexed_dataset(f'{prefix}{src}.{modality}')) 177 | 178 | tgt_datasets.append(indexed_dataset(prefix + tgt)) 179 | 180 | # print('| {} {} {} examples'.format(data_path, split, len(src_datasets[-1]))) 181 | print('| {} {} {} examples'.format(data_path, split_k, len(tgt_datasets[-1]))) 182 | if not combine: 183 | break 184 | 185 | assert len(src_datasets_dict[DPTREE_KEYS[0]]) == len(tgt_datasets) 186 | 187 | if len(tgt_datasets) == 1: 188 | # src_dataset, tgt_dataset = src_datasets[0], tgt_datasets[0] 189 | src_dataset_dict = {k: v[0] for k, v in src_datasets_dict.items()} 190 | tgt_dataset = tgt_datasets[0] 191 | else: 192 | sample_ratios = [1] * len(src_datasets) 193 | sample_ratios[0] = self.args.upsample_primary 194 | # src_dataset = ConcatDataset(src_datasets, sample_ratios) 195 | src_dataset_dict = {k: ConcatDataset(v, sample_ratios) for k, v in src_datasets_dict.items()} 196 | tgt_dataset = ConcatDataset(tgt_datasets, sample_ratios) 197 | 198 | src_sizes = src_dataset_dict['nodes'].sizes 199 | self.datasets[split] = DPTreeMonoClassificationDataset( 200 | # srcs, src_sizes, src_dict 201 | src_dataset_dict, src_sizes, self.source_dictionary, 202 | tgt_dataset, 203 | left_pad_source=self.args.left_pad_source, 204 | # left_pad_target=self.args.left_pad_target, 205 | max_source_positions=self.args.max_source_positions, 206 | # max_target_positions=self.args.max_target_positions, 207 | ) 208 | 209 | def max_positions(self): 210 | """Return the max sentence length allowed by the task.""" 211 | # return (self.args.max_source_positions, self.args.max_target_positions) 212 | # return (self.args.max_source_positions, self.args.max_source_positions) 213 | return self.args.max_source_positions 214 | 215 | @property 216 | def source_dictionary(self): 217 | """Return the source :class:`~fairseq.data.Dictionary`.""" 218 | return self.dictionary 219 | 220 | @property 221 | def target_dictionary(self): 222 | """Return the target :class:`~fairseq.data.Dictionary`.""" 223 | return self.dictionary 224 | 225 | @classmethod 226 | def build_dictionary(cls, filenames, workers=1, threshold=-1, nwords=-1, padding_factor=8): 227 | d = Dictionary() 228 | for filename in filenames: 229 | Dictionary.add_file_to_dictionary(filename, d, tokenizer.tokenize_line, workers) 230 | d.finalize(threshold=threshold, nwords=nwords, padding_factor=padding_factor) 231 | return d 232 | 233 | -------------------------------------------------------------------------------- /src/tasks/dptree_sep_classification.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | import numpy as np 3 | import os 4 | import torch 5 | from fairseq import options, utils 6 | from fairseq.data import ( 7 | data_utils, Dictionary, LanguagePairDataset, ConcatDataset, 8 | IndexedRawTextDataset, IndexedCachedDataset, IndexedDataset, 9 | TruncatedDictionary, 10 | ) 11 | 12 | from . import data_utils 13 | 14 | from ..data import task_utils, monolingual_classification_dataset, MonolingualClassificationDataset, \ 15 | DPTreeMonoClassificationDataset, \ 16 | DPTreeIndexedCachedDataset, \ 17 | DPTreeSeparateMonoClassificationDataset, \ 18 | DPTreeSeparateNodeMonoClassificationDataset, \ 19 | DPTreeSeparateLIClassificationDataset 20 | 21 | from fairseq import tokenizer 22 | from fairseq.sequence_generator import SequenceGenerator, EnsembleModel 23 | import math 24 | 25 | from fairseq.tasks import FairseqTask, register_task 26 | 27 | from .dptree2seq_translation import DPTREE_KEYS 28 | from .fairseq_classification import SequenceClassification, MonolingualClassificationDataset 29 | 30 | from .dptree_classification import DPTreeClassification 31 | 32 | 33 | @register_task('dptree_sep_classification') 34 | class DPTreeSeparateClassification(DPTreeClassification): 35 | def load_dataset(self, split, combine=False, **kwargs): 36 | """Load a given dataset split. 37 | 38 | Args: 39 | split (str): name of the split (e.g., train, valid, test) 40 | """ 41 | 42 | def split_exists(split, data_type, data_path): 43 | filename = os.path.join(data_path, f'{split}.{data_type}') 44 | assert not self.args.raw_text 45 | exists = [IndexedDataset.exists(os.path.join(data_path, f'{split}.{data_type}.{k}')) for k in DPTREE_KEYS] 46 | if all(exists): 47 | return True 48 | else: 49 | print(f'Following modality not exists: {exists}') 50 | return False 51 | 52 | # def indexed_dataset(path, dictionary): 53 | def indexed_dataset(path): 54 | assert IndexedCachedDataset.exists(path), f'IndexedCachedDataset.exists({path})' 55 | return IndexedCachedDataset(path, fix_lua_indexing=True) 56 | 57 | def dptree_indexed_dataset(path): 58 | assert DPTreeIndexedCachedDataset.exists(path), f'DPTreeIndexedCachedDataset.exists({path})' 59 | return DPTreeIndexedCachedDataset(path, fix_lua_indexing=True) 60 | 61 | src_datasets = [] 62 | tgt_datasets = [] 63 | src_datasets_dict = {k: [] for k in DPTREE_KEYS} 64 | 65 | # data_paths = self.args.data 66 | data_path = self.args.data 67 | print(f'| split = {split}') 68 | print(f'| self.args.data = {self.args.data}') 69 | # singular data path 70 | lang = self.args.source_lang 71 | 72 | src, tgt = 'input', 'target' 73 | 74 | for k in itertools.count(): 75 | split_k = split + (str(k) if k > 0 else '') 76 | if split_exists(split_k, src, data_path): 77 | prefix = os.path.join(data_path, f'{split}.') 78 | else: 79 | if k > 0: 80 | break 81 | else: 82 | raise FileNotFoundError('Dataset not found: {} ({})'.format(split, data_path)) 83 | # src_datasets.append(indexed_dataset(prefix + src)) 84 | for modality in src_datasets_dict.keys(): 85 | src_datasets_dict[modality].append(dptree_indexed_dataset(f'{prefix}{src}.{modality}')) 86 | 87 | tgt_datasets.append(indexed_dataset(prefix + tgt)) 88 | 89 | print('| {} {} {} examples'.format(data_path, split_k, len(tgt_datasets[-1]))) 90 | if not combine: 91 | break 92 | 93 | assert len(src_datasets_dict[DPTREE_KEYS[0]]) == len(tgt_datasets) 94 | 95 | if len(tgt_datasets) == 1: 96 | # src_dataset, tgt_dataset = src_datasets[0], tgt_datasets[0] 97 | src_dataset_dict = {k: v[0] for k, v in src_datasets_dict.items()} 98 | tgt_dataset = tgt_datasets[0] 99 | else: 100 | sample_ratios = [1] * len(src_datasets) 101 | sample_ratios[0] = self.args.upsample_primary 102 | # src_dataset = ConcatDataset(src_datasets, sample_ratios) 103 | src_dataset_dict = {k: ConcatDataset(v, sample_ratios) for k, v in src_datasets_dict.items()} 104 | tgt_dataset = ConcatDataset(tgt_datasets, sample_ratios) 105 | 106 | src_sizes = src_dataset_dict['nodes'].sizes.reshape(-1, 2).sum(-1) 107 | # print(f'src_sizes::: {src_sizes}') 108 | self.datasets[split] = DPTreeSeparateMonoClassificationDataset( 109 | # srcs, src_sizes, src_dict 110 | src_dataset_dict, src_sizes, self.source_dictionary, 111 | tgt_dataset, 112 | left_pad_source=self.args.left_pad_source, 113 | # left_pad_target=self.args.left_pad_target, 114 | max_source_positions=self.args.max_source_positions, 115 | # max_target_positions=self.args.max_target_positions, 116 | ) 117 | 118 | 119 | @register_task('dptree_sep_node_classification') 120 | class DPTreeSeparateNodeClassification(DPTreeSeparateClassification): 121 | 122 | def load_dataset(self, split, combine=False, **kwargs): 123 | """Load a given dataset split. 124 | 125 | Args: 126 | split (str): name of the split (e.g., train, valid, test) 127 | """ 128 | 129 | def split_exists(split, data_type, data_path): 130 | filename = os.path.join(data_path, f'{split}.{data_type}') 131 | assert not self.args.raw_text 132 | exists = [IndexedDataset.exists(os.path.join(data_path, f'{split}.{data_type}.{k}')) for k in DPTREE_KEYS] 133 | if all(exists): 134 | return True 135 | else: 136 | print(f'Following modality not exists: {exists}') 137 | return False 138 | 139 | # def indexed_dataset(path, dictionary): 140 | def indexed_dataset(path): 141 | assert IndexedCachedDataset.exists(path), f'IndexedCachedDataset.exists({path})' 142 | return IndexedCachedDataset(path, fix_lua_indexing=True) 143 | 144 | def dptree_indexed_dataset(path): 145 | assert DPTreeIndexedCachedDataset.exists(path), f'DPTreeIndexedCachedDataset.exists({path})' 146 | return DPTreeIndexedCachedDataset(path, fix_lua_indexing=True) 147 | 148 | src_datasets = [] 149 | # tgt_datasets = [] 150 | src_datasets_dict = {k: [] for k in DPTREE_KEYS} 151 | 152 | # data_paths = self.args.data 153 | data_path = self.args.data 154 | # print(f'| split = {split}') 155 | # print(f'| self.args.data = {self.args.data}') 156 | # singular data path 157 | lang = self.args.source_lang 158 | 159 | src, tgt = 'txt', 'target' 160 | 161 | for k in itertools.count(): 162 | split_k = split + (str(k) if k > 0 else '') 163 | if split_exists(split_k, src, data_path): 164 | prefix = os.path.join(data_path, f'{split}.') 165 | else: 166 | if k > 0: 167 | break 168 | else: 169 | raise FileNotFoundError('Dataset not found: {} ({})'.format(split, data_path)) 170 | 171 | for modality in src_datasets_dict.keys(): 172 | src_datasets_dict[modality].append(dptree_indexed_dataset(f'{prefix}{src}.{modality}')) 173 | 174 | # tgt_datasets.append(indexed_dataset(prefix + tgt)) 175 | 176 | # print('| {} {} {} examples'.format(data_path, split, len(src_datasets[-1]))) 177 | print('| {} {} {} examples'.format(data_path, split_k, len(src_datasets_dict[DPTREE_KEYS[0]][-1]))) 178 | if not combine: 179 | break 180 | 181 | assert len(src_datasets_dict[DPTREE_KEYS[0]]) == len(src_datasets_dict[DPTREE_KEYS[1]]) 182 | 183 | if len(src_datasets_dict[DPTREE_KEYS[0]]) == 1: 184 | # src_dataset, tgt_dataset = src_datasets[0], tgt_datasets[0] 185 | src_dataset_dict = {k: v[0] for k, v in src_datasets_dict.items()} 186 | # tgt_dataset = tgt_datasets[0] 187 | else: 188 | # sample_ratios = [1] * len(src_datasets) 189 | # sample_ratios[0] = self.args.upsample_primary 190 | # # src_dataset = ConcatDataset(src_datasets, sample_ratios) 191 | # src_dataset_dict = {k: ConcatDataset(v, sample_ratios) for k, v in src_datasets_dict.items()} 192 | # tgt_dataset = ConcatDataset(tgt_datasets, sample_ratios) 193 | raise NotImplementedError 194 | 195 | src_sizes = src_dataset_dict['nodes'].sizes.reshape(-1, 2).sum(-1) 196 | self.datasets[split] = DPTreeSeparateNodeMonoClassificationDataset( 197 | # srcs, src_sizes, src_dict 198 | src_dataset_dict, src_sizes, self.source_dictionary, 199 | None, 200 | left_pad_source=self.args.left_pad_source, 201 | # left_pad_target=self.args.left_pad_target, 202 | max_source_positions=self.args.max_source_positions, 203 | # max_target_positions=self.args.max_target_positions, 204 | ) 205 | 206 | 207 | @register_task('dptree_sep_li_classification') 208 | class DPTreeSeparateLIClassification(DPTreeSeparateClassification): 209 | 210 | def load_dataset(self, split, combine=False, **kwargs): 211 | """Load a given dataset split. 212 | 213 | Args: 214 | split (str): name of the split (e.g., train, valid, test) 215 | """ 216 | 217 | def split_exists(split, data_type, data_path): 218 | filename = os.path.join(data_path, f'{split}.{data_type}') 219 | assert not self.args.raw_text 220 | exists = [IndexedDataset.exists(os.path.join(data_path, f'{split}.{data_type}.{k}')) for k in DPTREE_KEYS] 221 | if all(exists): 222 | return True 223 | else: 224 | print(f'Following modality not exists: {exists}') 225 | return False 226 | 227 | # def indexed_dataset(path, dictionary): 228 | def indexed_dataset(path): 229 | assert IndexedCachedDataset.exists(path), f'IndexedCachedDataset.exists({path})' 230 | return IndexedCachedDataset(path, fix_lua_indexing=True) 231 | 232 | def dptree_indexed_dataset(path): 233 | assert DPTreeIndexedCachedDataset.exists(path), f'DPTreeIndexedCachedDataset.exists({path})' 234 | return DPTreeIndexedCachedDataset(path, fix_lua_indexing=True) 235 | 236 | src_datasets = [] 237 | tgt_datasets = [] 238 | src_datasets_dict_1 = {k: [] for k in DPTREE_KEYS} 239 | src_datasets_dict_2 = {k: [] for k in DPTREE_KEYS} 240 | 241 | # data_paths = self.args.data 242 | data_path = self.args.data 243 | print(f'| split = {split}') 244 | print(f'| self.args.data = {self.args.data}') 245 | # singular data path 246 | lang = self.args.source_lang 247 | 248 | src1, src2, tgt = 'input1', 'input2', 'target' 249 | 250 | for k in itertools.count(): 251 | split_k = split + (str(k) if k > 0 else '') 252 | if split_exists(split_k, src1, data_path): 253 | prefix = os.path.join(data_path, f'{split}.') 254 | else: 255 | if k > 0: 256 | break 257 | else: 258 | raise FileNotFoundError('Dataset not found: {} ({})'.format(split, data_path)) 259 | 260 | for modality in src_datasets_dict_1.keys(): 261 | src_datasets_dict_1[modality].append(dptree_indexed_dataset(f'{prefix}{src1}.{modality}')) 262 | for modality in src_datasets_dict_2.keys(): 263 | src_datasets_dict_2[modality].append(dptree_indexed_dataset(f'{prefix}{src2}.{modality}')) 264 | 265 | tgt_datasets.append(indexed_dataset(prefix + tgt)) 266 | 267 | print('| {} {} {} examples'.format(data_path, split_k, len(tgt_datasets[-1]))) 268 | if not combine: 269 | break 270 | 271 | assert len(src_datasets_dict_1[DPTREE_KEYS[0]]) == len(tgt_datasets) 272 | assert len(src_datasets_dict_2[DPTREE_KEYS[0]]) == len(tgt_datasets) 273 | 274 | if len(tgt_datasets) == 1: 275 | # src_dataset, tgt_dataset = src_datasets[0], tgt_datasets[0] 276 | src_dataset_dict_1 = {k: v[0] for k, v in src_datasets_dict_1.items()} 277 | src_dataset_dict_2 = {k: v[0] for k, v in src_datasets_dict_2.items()} 278 | tgt_dataset = tgt_datasets[0] 279 | else: 280 | # sample_ratios = [1] * len(src_datasets) 281 | # sample_ratios[0] = self.args.upsample_primary 282 | # # src_dataset = ConcatDataset(src_datasets, sample_ratios) 283 | # src_dataset_dict = {k: ConcatDataset(v, sample_ratios) for k, v in src_datasets_dict_1.items()} 284 | # tgt_dataset = ConcatDataset(tgt_datasets, sample_ratios) 285 | raise NotImplementedError(f'No concatenation') 286 | 287 | src1_sizes = src_dataset_dict_1['nodes'].sizes.reshape(-1, 2).sum(-1) 288 | src2_sizes = src_dataset_dict_2['nodes'].sizes.reshape(-1, 2).sum(-1) 289 | # print(f'src_sizes::: {src_sizes}') 290 | self.datasets[split] = DPTreeSeparateLIClassificationDataset( 291 | # srcs, src_sizes, src_dict 292 | # src_dataset_dict, src_sizes, self.source_dictionary, 293 | src_dataset_dict_1, src1_sizes, src_dataset_dict_2, src2_sizes, self.source_dictionary, 294 | tgt_dataset, 295 | left_pad_source=self.args.left_pad_source, 296 | max_source_positions=self.args.max_source_positions, 297 | ) 298 | -------------------------------------------------------------------------------- /src/tasks/nstack_from_dptree_classification.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | import numpy as np 3 | import os 4 | import torch 5 | from fairseq import options, utils 6 | from fairseq.data import ( 7 | data_utils, Dictionary, LanguagePairDataset, ConcatDataset, 8 | IndexedRawTextDataset, IndexedCachedDataset, IndexedDataset, 9 | TruncatedDictionary, 10 | ) 11 | 12 | from . import data_utils 13 | 14 | from ..data import task_utils, monolingual_classification_dataset, MonolingualClassificationDataset, \ 15 | DPTreeMonoClassificationDataset, \ 16 | DPTreeIndexedCachedDataset, \ 17 | DPTreeSeparateMonoClassificationDataset, \ 18 | DPTreeSeparateNodeMonoClassificationDataset, \ 19 | DPTreeSeparateLIClassificationDataset, \ 20 | NodeStackFromDPTreeSepNodeTargetMonoClassificationDataset, \ 21 | NodeStackFromDPTreeSepMonoClassificationDataset 22 | 23 | from fairseq.data import data_utils, FairseqDataset, iterators, Dictionary 24 | 25 | from fairseq import tokenizer 26 | from fairseq.sequence_generator import SequenceGenerator, EnsembleModel 27 | import math 28 | 29 | from fairseq.tasks import FairseqTask, register_task 30 | 31 | from .dptree2seq_translation import DPTREE_KEYS 32 | from .fairseq_classification import SequenceClassification, MonolingualClassificationDataset 33 | 34 | from .dptree_classification import DPTreeClassification 35 | from .dptree_sep_classification import * 36 | 37 | 38 | @register_task('nstack_f_dptree_classification') 39 | class NStackFromDPTreeSeparateClassification(DPTreeSeparateClassification): 40 | def load_dataset(self, split, combine=False, **kwargs): 41 | """Load a given dataset split. 42 | 43 | Args: 44 | split (str): name of the split (e.g., train, valid, test) 45 | """ 46 | 47 | def split_exists(split, data_type, data_path): 48 | filename = os.path.join(data_path, f'{split}.{data_type}') 49 | assert not self.args.raw_text 50 | exists = [IndexedDataset.exists(os.path.join(data_path, f'{split}.{data_type}.{k}')) for k in DPTREE_KEYS] 51 | if all(exists): 52 | return True 53 | else: 54 | print(f'Following modality not exists: {exists}') 55 | return False 56 | 57 | # def indexed_dataset(path, dictionary): 58 | def indexed_dataset(path): 59 | assert IndexedCachedDataset.exists(path), f'IndexedCachedDataset.exists({path})' 60 | return IndexedCachedDataset(path, fix_lua_indexing=True) 61 | 62 | def dptree_indexed_dataset(path): 63 | assert DPTreeIndexedCachedDataset.exists(path), f'DPTreeIndexedCachedDataset.exists({path})' 64 | return DPTreeIndexedCachedDataset(path, fix_lua_indexing=True) 65 | 66 | src_datasets = [] 67 | tgt_datasets = [] 68 | src_datasets_dict = {k: [] for k in DPTREE_KEYS} 69 | 70 | # data_paths = self.args.data 71 | data_path = self.args.data 72 | print(f'| split = {split}') 73 | print(f'| self.args.data = {self.args.data}') 74 | # singular data path 75 | lang = self.args.source_lang 76 | 77 | src, tgt = 'input', 'target' 78 | 79 | for k in itertools.count(): 80 | split_k = split + (str(k) if k > 0 else '') 81 | if split_exists(split_k, src, data_path): 82 | prefix = os.path.join(data_path, f'{split}.') 83 | else: 84 | if k > 0: 85 | break 86 | else: 87 | raise FileNotFoundError('Dataset not found: {} ({})'.format(split, data_path)) 88 | # src_datasets.append(indexed_dataset(prefix + src)) 89 | for modality in src_datasets_dict.keys(): 90 | src_datasets_dict[modality].append(dptree_indexed_dataset(f'{prefix}{src}.{modality}')) 91 | 92 | tgt_datasets.append(indexed_dataset(prefix + tgt)) 93 | 94 | print('| {} {} {} examples'.format(data_path, split_k, len(tgt_datasets[-1]))) 95 | if not combine: 96 | break 97 | 98 | assert len(src_datasets_dict[DPTREE_KEYS[0]]) == len(tgt_datasets) 99 | 100 | if len(tgt_datasets) == 1: 101 | # src_dataset, tgt_dataset = src_datasets[0], tgt_datasets[0] 102 | src_dataset_dict = {k: v[0] for k, v in src_datasets_dict.items()} 103 | tgt_dataset = tgt_datasets[0] 104 | else: 105 | sample_ratios = [1] * len(src_datasets) 106 | sample_ratios[0] = self.args.upsample_primary 107 | # src_dataset = ConcatDataset(src_datasets, sample_ratios) 108 | src_dataset_dict = {k: ConcatDataset(v, sample_ratios) for k, v in src_datasets_dict.items()} 109 | tgt_dataset = ConcatDataset(tgt_datasets, sample_ratios) 110 | 111 | src_sizes = src_dataset_dict['nodes'].sizes.reshape(-1, 2).sum(-1) 112 | # print(f'src_sizes::: {src_sizes}') 113 | self.datasets[split] = NodeStackFromDPTreeSepMonoClassificationDataset( 114 | # srcs, src_sizes, src_dict 115 | src_dataset_dict, src_sizes, self.source_dictionary, 116 | tgt_dataset, 117 | left_pad_source=self.args.left_pad_source, 118 | # left_pad_target=self.args.left_pad_target, 119 | max_source_positions=self.args.max_source_positions, 120 | # max_target_positions=self.args.max_target_positions, 121 | ) 122 | 123 | 124 | @register_task('nstack_f_dptree_node_classification') 125 | class NStackFromDPTreeSeparateNodeClassification(DPTreeSeparateNodeClassification): 126 | 127 | @staticmethod 128 | def add_args(parser): 129 | DPTreeSeparateNodeClassification.add_args(parser) 130 | parser.add_argument('--only_binary', action='store_true', help='only_binary') 131 | parser.add_argument('--filter_class_index', default=2, type=int) 132 | 133 | def get_batch_iterator( 134 | self, dataset, max_tokens=None, max_sentences=None, max_positions=None, 135 | ignore_invalid_inputs=False, required_batch_size_multiple=1, 136 | seed=1, num_shards=1, shard_id=0, num_workers=0, filter_class_index=None 137 | ): 138 | """ 139 | Get an iterator that yields batches of data from the given dataset. 140 | 141 | Args: 142 | dataset (~fairseq.data.FairseqDataset): dataset to batch 143 | max_tokens (int, optional): max number of tokens in each batch 144 | (default: None). 145 | max_sentences (int, optional): max number of sentences in each 146 | batch (default: None). 147 | max_positions (optional): max sentence length supported by the 148 | model (default: None). 149 | ignore_invalid_inputs (bool, optional): don't raise Exception for 150 | sentences that are too long (default: False). 151 | required_batch_size_multiple (int, optional): require batch size to 152 | be a multiple of N (default: 1). 153 | seed (int, optional): seed for random number generator for 154 | reproducibility (default: 1). 155 | num_shards (int, optional): shard the data iterator into N 156 | shards (default: 1). 157 | shard_id (int, optional): which shard of the data iterator to 158 | return (default: 0). 159 | num_workers (int, optional): how many subprocesses to use for data 160 | loading. 0 means the data will be loaded in the main process 161 | (default: 0). 162 | 163 | Returns: 164 | ~fairseq.iterators.EpochBatchIterator: a batched iterator over the 165 | given dataset split 166 | """ 167 | assert isinstance(dataset, FairseqDataset) 168 | 169 | # get indices ordered by example size 170 | with data_utils.numpy_seed(seed): 171 | indices = dataset.ordered_indices() 172 | 173 | # filter examples that are too large 174 | 175 | indices = data_utils.filter_by_size( 176 | indices, dataset.size, max_positions, raise_exception=(not ignore_invalid_inputs), 177 | ) 178 | 179 | if filter_class_index is not None or self.args.only_binary: 180 | filter_class_index = filter_class_index if filter_class_index is not None else self.args.filter_class_index 181 | # print(f'Filtering data of class {filter_class_index}') 182 | class_fn = dataset.sample_class 183 | assert class_fn is not None 184 | # indices = task_utils.filter_by_class_size( 185 | # indices, dataset.size, max_positions, class_fn, filter_class_index, 186 | # raise_exception=(not ignore_invalid_inputs), 187 | # ) 188 | indices = task_utils.filter_by_class( 189 | indices, class_fn, filter_class_index, 190 | raise_exception=False, 191 | ) 192 | 193 | # else: 194 | # indices = data_utils.filter_by_size( 195 | # indices, dataset.size, max_positions, raise_exception=(not ignore_invalid_inputs), 196 | # ) 197 | 198 | # if filter_class_index is not None or self.args.only_binary: 199 | # # assert isinstance(filter_class_index, int) 200 | # filter_class_index = filter_class_index if filter_class_index is not None else self.args.filter_class_index 201 | # print(f'Filtering data of class {filter_class_index}') 202 | # class_fn = dataset.sample_class 203 | # assert class_fn is not None 204 | # indices = task_utils.filter_by_class( 205 | # indices, class_fn, filter_class_index, raise_exception=(not ignore_invalid_inputs), 206 | # ) 207 | 208 | # create mini-batches with given size constraints 209 | batch_sampler = data_utils.batch_by_size( 210 | indices, dataset.num_tokens, max_tokens=max_tokens, max_sentences=max_sentences, 211 | required_batch_size_multiple=required_batch_size_multiple, 212 | ) 213 | 214 | # return a reusable, sharded iterator 215 | return iterators.EpochBatchIterator( 216 | dataset=dataset, 217 | collate_fn=dataset.collater, 218 | batch_sampler=batch_sampler, 219 | seed=seed, 220 | num_shards=num_shards, 221 | shard_id=shard_id, 222 | num_workers=num_workers, 223 | ) 224 | 225 | def load_dataset(self, split, combine=False, **kwargs): 226 | """Load a given dataset split. 227 | 228 | Args: 229 | split (str): name of the split (e.g., train, valid, test) 230 | """ 231 | # assert split in ['train', 'valid', 'valid1', 'valid-bin', 'valid1-bin'], f'invalid: {split}' 232 | get_binary = "-bin" in split 233 | split_name = split 234 | split = split.replace('-bin', '') 235 | 236 | def split_exists(split, data_type, data_path): 237 | filename = os.path.join(data_path, f'{split}.{data_type}') 238 | assert not self.args.raw_text 239 | exists = [IndexedDataset.exists(os.path.join(data_path, f'{split}.{data_type}.{k}')) for k in DPTREE_KEYS] 240 | if all(exists): 241 | return True 242 | else: 243 | print(f'Following modality not exists: {exists}') 244 | return False 245 | 246 | # def indexed_dataset(path, dictionary): 247 | def indexed_dataset(path): 248 | assert IndexedCachedDataset.exists(path), f'IndexedCachedDataset.exists({path})' 249 | return IndexedCachedDataset(path, fix_lua_indexing=True) 250 | 251 | def dptree_indexed_dataset(path): 252 | assert DPTreeIndexedCachedDataset.exists(path), f'DPTreeIndexedCachedDataset.exists({path})' 253 | return DPTreeIndexedCachedDataset(path, fix_lua_indexing=True) 254 | 255 | src_datasets = [] 256 | # tgt_datasets = [] 257 | src_datasets_dict = {k: [] for k in DPTREE_KEYS} 258 | 259 | data_path = self.args.data 260 | lang = self.args.source_lang 261 | 262 | src, tgt = 'txt', 'target' 263 | 264 | for k in itertools.count(): 265 | split_k = split + (str(k) if k > 0 else '') 266 | if split_exists(split_k, src, data_path): 267 | prefix = os.path.join(data_path, f'{split}.') 268 | else: 269 | if k > 0: 270 | break 271 | else: 272 | raise FileNotFoundError('Dataset not found: {} ({})'.format(split, data_path)) 273 | 274 | for modality in src_datasets_dict.keys(): 275 | src_datasets_dict[modality].append(dptree_indexed_dataset(f'{prefix}{src}.{modality}')) 276 | 277 | print('| {} {} {} examples'.format(data_path, split_k, len(src_datasets_dict[DPTREE_KEYS[0]][-1]))) 278 | if not combine: 279 | break 280 | 281 | assert len(src_datasets_dict[DPTREE_KEYS[0]]) == len(src_datasets_dict[DPTREE_KEYS[1]]) 282 | 283 | if len(src_datasets_dict[DPTREE_KEYS[0]]) == 1: 284 | # src_dataset, tgt_dataset = src_datasets[0], tgt_datasets[0] 285 | src_dataset_dict = {k: v[0] for k, v in src_datasets_dict.items()} 286 | # tgt_dataset = tgt_datasets[0] 287 | else: 288 | # sample_ratios = [1] * len(src_datasets) 289 | # sample_ratios[0] = self.args.upsample_primary 290 | # # src_dataset = ConcatDataset(src_datasets, sample_ratios) 291 | # src_dataset_dict = {k: ConcatDataset(v, sample_ratios) for k, v in src_datasets_dict.items()} 292 | # tgt_dataset = ConcatDataset(tgt_datasets, sample_ratios) 293 | raise NotImplementedError 294 | 295 | src_sizes = src_dataset_dict['nodes'].sizes.reshape(-1, 2).sum(-1) 296 | self.datasets[split_name] = NodeStackFromDPTreeSepNodeTargetMonoClassificationDataset( 297 | # srcs, src_sizes, src_dict 298 | src_dataset_dict, src_sizes, self.source_dictionary, 299 | None, 300 | left_pad_source=self.args.left_pad_source, 301 | # left_pad_target=self.args.left_pad_target, 302 | max_source_positions=self.args.max_source_positions, 303 | # max_target_positions=self.args.max_target_positions, 304 | ) 305 | 306 | 307 | 308 | 309 | 310 | def tree_str_post_process(tree_string): 311 | tree_string = tree_string.replace('-LRB- (', '-LRB- -LRB-').replace('-RRB- )', '-RRB- -RRB-') 312 | return tree_string 313 | 314 | from nltk import Tree 315 | def tree_from_string(tree_string): 316 | try: 317 | s = tree_string 318 | s = tree_str_post_process(s) 319 | tree = Tree.fromstring(s) 320 | except Exception as e: 321 | # print(f'Tree.fromstring(tree_string) failed, try to omit the post_process') 322 | try: 323 | tree = Tree.fromstring(tree_string) 324 | except Exception as e: 325 | print(f'ERROR: unable to parse the tree') 326 | print(tree_string) 327 | raise e 328 | return tree 329 | 330 | 331 | def convert_flat(inf, outf, tarf): 332 | print(f'{inf} -> {outf}') 333 | with open(inf, 'r') as f: 334 | tree_lines = f.read().strip().split('\n') 335 | print(f'len = {len(tree_lines)}') 336 | flats = [] 337 | targets = [] 338 | for i, l in enumerate(tree_lines): 339 | try: 340 | tree = tree_from_string(l) 341 | leave_s = ' '.join(list(tree.leaves())) 342 | except Exception as e: 343 | print(f'Problem at index {i}: {l}') 344 | flats.append(leave_s) 345 | targets.append(tree.label()) 346 | with open(outf, 'w') as f: 347 | f.write('\n'.join(flats)) 348 | with open(tarf, 'w') as f: 349 | f.write('\n'.join(targets)) 350 | 351 | 352 | 353 | 354 | -------------------------------------------------------------------------------- /src/trainers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nxphi47/tree_transformer/8ac39e40441b14011b440dece6374bb4231632cc/src/trainers/__init__.py -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict, OrderedDict 2 | import importlib.util 3 | import logging 4 | import os 5 | import re 6 | import sys 7 | import traceback 8 | 9 | import torch 10 | import torch.nn.functional as F 11 | from torch.serialization import default_restore_location 12 | 13 | from fairseq import utils 14 | 15 | from torch import nn 16 | 17 | 18 | # def parse_embedding(embed_path): 19 | # """Parse embedding text file into a dictionary of word and embedding tensors. 20 | # 21 | # The first line can have vocabulary size and dimension. The following lines 22 | # should contain word and embedding separated by spaces. 23 | # 24 | # Example: 25 | # 2 5 26 | # the -0.0230 -0.0264 0.0287 0.0171 0.1403 27 | # at -0.0395 -0.1286 0.0275 0.0254 -0.0932 28 | # """ 29 | # embed_dict = {} 30 | # with open(embed_path) as f_embed: 31 | # next(f_embed) # skip header 32 | # for line in f_embed: 33 | # pieces = line.rstrip().split(" ") 34 | # embed_dict[pieces[0]] = torch.Tensor([float(weight) for weight in pieces[1:]]) 35 | # return embed_dict 36 | 37 | 38 | def load_embedding_wmode(embed_dict, vocab, embedding, mode): 39 | for idx in range(len(vocab)): 40 | token = vocab[idx] 41 | if token in embed_dict: 42 | embedding.weight.data[idx] = embed_dict[token] 43 | # nn.init.constant_(embedding.weight[idx], ) 44 | return embedding 45 | 46 | 47 | 48 | 49 | --------------------------------------------------------------------------------