├── nn ├── __init__.py └── data_parallel.py ├── biunilm ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-37.pyc │ ├── loader_utils.cpython-37.pyc │ └── seq2seq_loader.cpython-37.pyc ├── loader_utils.py ├── decode_seq2seq.py ├── run_ppl.py └── seq2seq_loader.py ├── img └── model.png ├── pytorch_pretrained_bert ├── __pycache__ │ ├── __init__.cpython-37.pyc │ ├── file_utils.cpython-37.pyc │ └── tokenization.cpython-37.pyc ├── __init__.py ├── __main__.py ├── loss.py ├── optimization_fp16.py ├── file_utils.py ├── tokenization.py └── optimization.py ├── .idea ├── vcs.xml ├── inspectionProfiles │ └── profiles_settings.xml ├── misc.xml ├── modules.xml ├── deployment.xml ├── MultiT-C-Dialog-git.iml ├── webServers.xml └── workspace.xml ├── utils.py ├── run_eval.sh ├── run_ppl.sh ├── run_train.sh ├── run_pretrain.sh ├── run_2step_pre.sh ├── run_2step_ft.sh ├── run_sequential_train.py ├── setup.py ├── get_tfidf.py ├── pre_tokenize.py ├── README.md ├── eval.py └── qg ├── eval.py └── eval_on_unilm_tokenized_ref.py /nn/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /biunilm/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /img/model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zengyan-97/MultiT-C-Dialog/HEAD/img/model.png -------------------------------------------------------------------------------- /biunilm/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zengyan-97/MultiT-C-Dialog/HEAD/biunilm/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /biunilm/__pycache__/loader_utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zengyan-97/MultiT-C-Dialog/HEAD/biunilm/__pycache__/loader_utils.cpython-37.pyc -------------------------------------------------------------------------------- /biunilm/__pycache__/seq2seq_loader.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zengyan-97/MultiT-C-Dialog/HEAD/biunilm/__pycache__/seq2seq_loader.cpython-37.pyc -------------------------------------------------------------------------------- /pytorch_pretrained_bert/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zengyan-97/MultiT-C-Dialog/HEAD/pytorch_pretrained_bert/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /pytorch_pretrained_bert/__pycache__/file_utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zengyan-97/MultiT-C-Dialog/HEAD/pytorch_pretrained_bert/__pycache__/file_utils.cpython-37.pyc -------------------------------------------------------------------------------- /pytorch_pretrained_bert/__pycache__/tokenization.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zengyan-97/MultiT-C-Dialog/HEAD/pytorch_pretrained_bert/__pycache__/tokenization.cpython-37.pyc -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 6 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 6 | 7 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/deployment.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | -------------------------------------------------------------------------------- /.idea/MultiT-C-Dialog-git.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 12 | -------------------------------------------------------------------------------- /pytorch_pretrained_bert/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.4.0" 2 | from .tokenization import BertTokenizer, BasicTokenizer, WordpieceTokenizer 3 | from .modeling import (BertConfig, BertModel, BertForPreTraining, BertForMaskedLM, BertForNextSentencePrediction, BertForSequenceClassification, 4 | BertForMultipleChoice, BertForTokenClassification, BertForQuestionAnswering, BertForPreTrainingLossMask, BertPreTrainingPairRel, BertPreTrainingPairTransform) 5 | from .optimization import BertAdam, BertAdamFineTune 6 | # from .optimization_fp16 import FP16_Optimizer_State 7 | from .file_utils import PYTORCH_PRETRAINED_BERT_CACHE 8 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | 4 | def tmp(data_type, dial_mask_rate): 5 | from random import random as rand 6 | if data_type != 'dial': 7 | print(1) 8 | elif data_type == 'dial' and dial_mask_rate > 0 and rand() < dial_mask_rate: 9 | print(1) 10 | else: 11 | print(0) 12 | 13 | 14 | 15 | 16 | if __name__ == '__main__': 17 | rdir = "./data/pretrain_bert_V2/" 18 | wdir = "./data/pretrain_bert_V2/debug/" 19 | 20 | for filen in os.listdir(rdir): 21 | if filen.endswith('train') or filen.endswith('valid'): 22 | print(filen) 23 | os.system('head -100 {:} > {:}'.format(rdir+filen, wdir+filen)) 24 | -------------------------------------------------------------------------------- /.idea/webServers.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 14 | 15 | -------------------------------------------------------------------------------- /run_eval.sh: -------------------------------------------------------------------------------- 1 | # run decoding 2 | DATA_DIR=./data/newtmp 3 | export PYTORCH_PRETRAINED_BERT_CACHE=./cache_tmp/bert-base-uncased-pretrained-cache 4 | export CUDA_VISIBLE_DEVICES=0 5 | 6 | 7 | # run decoding 8 | python biunilm/decode_seq2seq.py \ 9 | --n_clayer 2 \ 10 | --model_recover_path ./saved/tmp/bert_save/model.e2_s1.2.bin \ 11 | --output_file ${DATA_DIR}/tmp.preds.txt \ 12 | --batch_size 64 --beam_size 4 --max_tgt_length 36 --min_len 10 --length_penalty 0 \ 13 | --data_dir ${DATA_DIR} \ 14 | --input_file dial.test --split test \ 15 | --bert_model bert-base-uncased --do_lower_case --s2s_special_token \ 16 | --mode s2s \ 17 | --tokenized_input \ 18 | --max_seq_length 80 \ 19 | --forbid_duplicate_ngrams --ngram_size 2 20 | 21 | -------------------------------------------------------------------------------- /run_ppl.sh: -------------------------------------------------------------------------------- 1 | DATA_DIR=./data/newtmp 2 | OUTPUT_DIR=./saved/tmp 3 | export PYTORCH_PRETRAINED_BERT_CACHE=./saved/bert-base-uncased-pretrained-cache 4 | export CUDA_VISIBLE_DEVICES=0 5 | 6 | 7 | python biunilm/run_ppl.py \ 8 | --n_clayer 2 \ 9 | --seed 42 \ 10 | --data_dir ${DATA_DIR} --tokenized_input \ 11 | --c_tfidf_map c_tfidf_map.pkl \ 12 | --s2s_special_token --max_pred 20 \ 13 | --skipgram_prb 0.2 --skipgram_size 3 \ 14 | --output_dir ${OUTPUT_DIR}/bert_save \ 15 | --bert_model bert-base-uncased --do_lower_case \ 16 | --log_dir ${OUTPUT_DIR}/bert_log \ 17 | --max_seq_length 80 --max_position_embeddings 80 \ 18 | --train_batch_size 80 --eval_batch_size 512 --gradient_accumulation_steps 1 \ 19 | --learning_rate 3e-5 --warmup_proportion 0.1 --label_smoothing 0 20 | 21 | -------------------------------------------------------------------------------- /run_train.sh: -------------------------------------------------------------------------------- 1 | DATA_DIR=./data/newtmp 2 | OUTPUT_DIR=./saved/tmp 3 | export PYTORCH_PRETRAINED_BERT_CACHE=./cache_tmp/bert-base-uncased-pretrained-cache 4 | export CUDA_VISIBLE_DEVICES=0 5 | 6 | 7 | python biunilm/run_train.py \ 8 | --n_clayer 2 --gate attn --FGfree --early_stop \ 9 | --seed 42 \ 10 | --num_train_epochs 10 --valid_steps 4096 \ 11 | --data_dir ${DATA_DIR} --tokenized_input --mask_source_words \ 12 | --c_tfidf_map c_tfidf_map.pkl \ 13 | --s2s_special_token --mask_prob 0.25 --max_pred 20 \ 14 | --skipgram_prb 0.2 --skipgram_size 3 \ 15 | --output_dir ${OUTPUT_DIR}/bert_save \ 16 | --bert_model bert-base-uncased --do_lower_case \ 17 | --log_dir ${OUTPUT_DIR}/bert_log \ 18 | --max_seq_length 80 --max_position_embeddings 80 \ 19 | --train_batch_size 74 --eval_batch_size 74 --gradient_accumulation_steps 1 \ 20 | --learning_rate 3e-5 --warmup_proportion 0.1 --label_smoothing 0 21 | 22 | -------------------------------------------------------------------------------- /run_pretrain.sh: -------------------------------------------------------------------------------- 1 | DATA_DIR=./data/pretrain_dial_data 2 | OUTPUT_DIR=./saved/pretrained 3 | export PYTORCH_PRETRAINED_BERT_CACHE=./cache_tmp/bert-base-uncased-pretrained-cache 4 | export CUDA_VISIBLE_DEVICES=0 5 | 6 | 7 | python biunilm/run_train.py \ 8 | --n_clayer 2 --gate attn --FGfree --early_stop --n_text 0 \ 9 | --seed 42 \ 10 | --num_train_epochs 10 --valid_steps 4096 \ 11 | --data_dir ${DATA_DIR} --tokenized_input --mask_source_words \ 12 | --c_tfidf_map c_tfidf_map.pkl \ 13 | --s2s_special_token --mask_prob 0.25 --max_pred 20 \ 14 | --skipgram_prb 0.2 --skipgram_size 3 \ 15 | --output_dir ${OUTPUT_DIR}/bert_save \ 16 | --bert_model bert-base-uncased --do_lower_case \ 17 | --log_dir ${OUTPUT_DIR}/bert_log \ 18 | --max_seq_length 80 --max_position_embeddings 80 \ 19 | --train_batch_size 74 --eval_batch_size 74 --gradient_accumulation_steps 1 \ 20 | --learning_rate 3e-5 --warmup_proportion 0.1 --label_smoothing 0 21 | 22 | -------------------------------------------------------------------------------- /run_2step_pre.sh: -------------------------------------------------------------------------------- 1 | DATA_DIR=./data/small_2k_bert 2 | OUTPUT_DIR=./saved/usr_2step_fine_1M 3 | export PYTORCH_PRETRAINED_BERT_CACHE=./cache_tmp/bert-base-uncased-pretrained-cache 4 | export CUDA_VISIBLE_DEVICES=0 5 | 6 | 7 | # --n_text 3000000 8 | 9 | python biunilm/run_train.py \ 10 | --n_text -1 --n_dial 0 --early_stop \ 11 | --n_clayer 2 \ 12 | --seed 42 \ 13 | --do_preprocess \ 14 | --do_train --do_eval --num_train_epochs 10 --valid_steps 4096 \ 15 | --data_dir ${DATA_DIR} --tokenized_input --mask_source_words \ 16 | --c_tfidf_map c_tfidf_map.pkl \ 17 | --s2s_special_token --mask_prob 0.25 --max_pred 20 \ 18 | --skipgram_prb 0.2 --skipgram_size 3 \ 19 | --output_dir ${OUTPUT_DIR}/bert_save \ 20 | --bert_model bert-base-uncased --do_lower_case \ 21 | --log_dir ${OUTPUT_DIR}/bert_log \ 22 | --max_seq_length 80 --max_position_embeddings 80 \ 23 | --train_batch_size 80 --eval_batch_size 80 --gradient_accumulation_steps 1 \ 24 | --learning_rate 3e-5 --warmup_proportion 0.1 --label_smoothing 0 25 | 26 | -------------------------------------------------------------------------------- /pytorch_pretrained_bert/__main__.py: -------------------------------------------------------------------------------- 1 | # coding: utf8 2 | def main(): 3 | import sys 4 | try: 5 | from .convert_tf_checkpoint_to_pytorch import convert_tf_checkpoint_to_pytorch 6 | except ModuleNotFoundError: 7 | print("pytorch_pretrained_bert can only be used from the commandline to convert TensorFlow models in PyTorch, " 8 | "In that case, it requires TensorFlow to be installed. Please see " 9 | "https://www.tensorflow.org/install/ for installation instructions.") 10 | raise 11 | 12 | if len(sys.argv) != 5: 13 | # pylint: disable=line-too-long 14 | print("Should be used as `pytorch_pretrained_bert convert_tf_checkpoint_to_pytorch TF_CHECKPOINT TF_CONFIG PYTORCH_DUMP_OUTPUT`") 15 | else: 16 | PYTORCH_DUMP_OUTPUT = sys.argv.pop() 17 | TF_CONFIG = sys.argv.pop() 18 | TF_CHECKPOINT = sys.argv.pop() 19 | convert_tf_checkpoint_to_pytorch(TF_CHECKPOINT, TF_CONFIG, PYTORCH_DUMP_OUTPUT) 20 | 21 | if __name__ == '__main__': 22 | main() 23 | -------------------------------------------------------------------------------- /run_2step_ft.sh: -------------------------------------------------------------------------------- 1 | DATA_DIR=./data/small_2k_bert 2 | OUTPUT_DIR=./saved/tmp 3 | export PYTORCH_PRETRAINED_BERT_CACHE=./cache_tmp/bert-base-uncased-pretrained-cache 4 | export CUDA_VISIBLE_DEVICES=0 5 | # export CUDA_VISIBLE_DEVICES=0,1 6 | 7 | 8 | python biunilm/run_train.py \ 9 | --fine_tune --model_recover_path ./saved/usr_2step_fine_1M/bert_save/model.e4_s12500.50000.bin \ 10 | --n_clayer 2 --n_text 0 --n_dial -1 \ 11 | --seed 42 \ 12 | --do_preprocess \ 13 | --do_train --do_eval --num_train_epochs 10 --valid_steps 4096 \ 14 | --data_dir ${DATA_DIR} --tokenized_input --mask_source_words \ 15 | --c_tfidf_map c_tfidf_map.pkl \ 16 | --s2s_special_token --mask_prob 0.25 --max_pred 20 \ 17 | --skipgram_prb 0.2 --skipgram_size 3 \ 18 | --output_dir ${OUTPUT_DIR}/bert_save \ 19 | --bert_model bert-base-uncased --do_lower_case \ 20 | --log_dir ${OUTPUT_DIR}/bert_log \ 21 | --max_seq_length 80 --max_position_embeddings 80 \ 22 | --train_batch_size 80 --eval_batch_size 80 --gradient_accumulation_steps 1 \ 23 | --learning_rate 3e-5 --warmup_proportion 0.1 --label_smoothing 0 24 | 25 | -------------------------------------------------------------------------------- /run_sequential_train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import subprocess 4 | 5 | 6 | def fine_tune_scheme(op): 7 | tmp = { 8 | '1': [ 9 | "sh -x ./run_ppl.sh usr_2step_fine_1M", 10 | "sh -x ./run_ppl.sh usr_1M_1M", 11 | "sh -x ./run_ppl.sh usr_2step_fine_500K", 12 | "sh -x ./run_ppl.sh usr_1M_500K", 13 | "sh -x ./run_ppl.sh usr_2step_fine_250K", 14 | "sh -x ./run_ppl.sh usr_1M_250K", 15 | 16 | "sh -x ./run_ppl.sh usr_2step_fine_200K", 17 | "sh -x ./run_ppl.sh usr_1M_200K", 18 | "sh -x ./run_ppl.sh usr_2step_fine_150K", 19 | "sh -x ./run_ppl.sh usr_1M_150K", 20 | 21 | "sh -x ./run_ppl.sh usr_2step_fine_100K", 22 | "sh -x ./run_ppl.sh usr_1M_100K", 23 | "sh -x ./run_ppl.sh usr_2step_fine_50K", 24 | "sh -x ./run_ppl.sh usr_1M_50K", 25 | 26 | ], 27 | 28 | '2':[ 29 | "sh -x ./run_2step_ft_250K_100K.sh 200000 200K", 30 | "sh -x ./run_2step_ft_250K_100K.sh 150000 150K", 31 | ], 32 | 33 | '3':[ 34 | "sh -x ./run_eval.sh", 35 | "sh -x ./run_eval_tp.sh", 36 | ], 37 | 38 | '4':[ 39 | "sh -x ./run_ppl.sh", 40 | "sh -x ./run_eval.sh" 41 | ] 42 | 43 | } 44 | 45 | for cmd in tmp[op]: 46 | print('-'*20) 47 | print(cmd) 48 | print('-'*20) 49 | p = subprocess.Popen(cmd, shell=True) 50 | p.wait() 51 | print('\n'*2) 52 | 53 | 54 | if __name__ == '__main__': 55 | fine_tune_scheme(op=sys.argv[1]) 56 | 57 | 58 | -------------------------------------------------------------------------------- /pytorch_pretrained_bert/loss.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | 7 | import torch 8 | import torch.nn.functional as F 9 | from torch.nn.modules.loss import _Loss 10 | 11 | 12 | class LabelSmoothingLoss(_Loss): 13 | """ 14 | With label smoothing, 15 | KL-divergence between q_{smoothed ground truth prob.}(w) 16 | and p_{prob. computed by model}(w) is minimized. 17 | """ 18 | 19 | def __init__(self, label_smoothing=0, tgt_vocab_size=0, ignore_index=0, size_average=None, reduce=None, reduction='mean'): 20 | assert 0.0 < label_smoothing <= 1.0 21 | self.ignore_index = ignore_index 22 | super(LabelSmoothingLoss, self).__init__( 23 | size_average=size_average, reduce=reduce, reduction=reduction) 24 | 25 | assert label_smoothing > 0 26 | assert tgt_vocab_size > 0 27 | 28 | smoothing_value = label_smoothing / (tgt_vocab_size - 2) 29 | one_hot = torch.full((tgt_vocab_size,), smoothing_value) 30 | one_hot[self.ignore_index] = 0 31 | self.register_buffer('one_hot', one_hot.unsqueeze(0)) 32 | self.confidence = 1.0 - label_smoothing 33 | self.tgt_vocab_size = tgt_vocab_size 34 | 35 | def forward(self, output, target): 36 | """ 37 | output (FloatTensor): batch_size * num_pos * n_classes 38 | target (LongTensor): batch_size * num_pos 39 | """ 40 | assert self.tgt_vocab_size == output.size(2) 41 | batch_size, num_pos = target.size(0), target.size(1) 42 | output = output.view(-1, self.tgt_vocab_size) 43 | target = target.view(-1) 44 | model_prob = self.one_hot.repeat(target.size(0), 1) 45 | model_prob.scatter_(1, target.unsqueeze(1), self.confidence) 46 | model_prob.masked_fill_((target == self.ignore_index).unsqueeze(1), 0) 47 | 48 | return F.kl_div(output, model_prob, reduction='none').view(batch_size, num_pos, -1).sum(2) 49 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | """ 2 | Simple check list from AllenNLP repo: https://github.com/allenai/allennlp/blob/master/setup.py 3 | 4 | To create the package for pypi. 5 | 6 | 1. Change the version in __init__.py and setup.py. 7 | 8 | 2. Commit these changes with the message: "Release: VERSION" 9 | 10 | 3. Add a tag in git to mark the release: "git tag VERSION -m'Adds tag VERSION for pypi' " 11 | Push the tag to git: git push --tags origin master 12 | 13 | 4. Build both the sources and the wheel. Do not change anything in setup.py between 14 | creating the wheel and the source distribution (obviously). 15 | 16 | For the wheel, run: "python setup.py bdist_wheel" in the top level allennlp directory. 17 | (this will build a wheel for the python version you use to build it - make sure you use python 3.x). 18 | 19 | For the sources, run: "python setup.py sdist" 20 | You should now have a /dist directory with both .whl and .tar.gz source versions of allennlp. 21 | 22 | 5. Check that everything looks correct by uploading the package to the pypi test server: 23 | 24 | twine upload dist/* -r pypitest 25 | (pypi suggest using twine as other methods upload files via plaintext.) 26 | 27 | Check that you can install it in a virtualenv by running: 28 | pip install -i https://testpypi.python.org/pypi allennlp 29 | 30 | 6. Upload the final version to actual pypi: 31 | twine upload dist/* -r pypi 32 | 33 | 7. Copy the release notes from RELEASE.md to the tag in github once everything is looking hunky-dory. 34 | 35 | """ 36 | from setuptools import find_packages, setup 37 | 38 | setup( 39 | name="pytorch_pretrained_bert", 40 | version="0.4.0", 41 | author="Thomas Wolf, Victor Sanh, Tim Rault, Google AI Language Team Authors", 42 | author_email="thomas@huggingface.co", 43 | description="PyTorch version of Google AI BERT model with script to load Google pre-trained models", 44 | long_description="pytorch", 45 | long_description_content_type="text/markdown", 46 | keywords='BERT NLP deep learning google', 47 | license='Apache', 48 | url="https://github.com/huggingface/pytorch-pretrained-BERT", 49 | packages=find_packages(exclude=["*.tests", "*.tests.*", 50 | "tests.*", "tests"]), 51 | install_requires=['numpy', 52 | 'boto3', 53 | 'requests', 54 | 'tqdm'], 55 | entry_points={ 56 | 'console_scripts': [ 57 | "pytorch_pretrained_bert=pytorch_pretrained_bert.__main__:main" 58 | ] 59 | }, 60 | python_requires='>=3.5.0', 61 | tests_require=['pytest'], 62 | classifiers=[ 63 | 'Intended Audience :: Science/Research', 64 | 'License :: OSI Approved :: Apache Software License', 65 | 'Programming Language :: Python :: 3', 66 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 67 | ], 68 | ) 69 | -------------------------------------------------------------------------------- /.idea/workspace.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 13 | 14 | 19 | 20 | 21 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 1619316875000 58 | 63 | 64 | 65 | 66 | 68 | 69 | 78 | 79 | -------------------------------------------------------------------------------- /get_tfidf.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import pickle 4 | from collections import Counter 5 | from numpy import sum, log10, mean 6 | 7 | 8 | def do_pickle(obj, path): 9 | if not path.endswith('.pkl'): 10 | print("Recommend to end with '.pkl'.") 11 | 12 | assert os.path.exists(os.path.dirname(path)) 13 | 14 | with open(path, 'wb') as f: 15 | pickle.dump(obj, f) 16 | print('{:} have pickled.'.format(path)) 17 | 18 | 19 | def get_tfidf(rpath, 20 | wdir, 21 | show_sample=False): 22 | 23 | if not os.path.exists(wdir): 24 | os.mkdir(wdir) 25 | 26 | wpath = os.path.join(wdir, "c_tfidf_map.pkl") 27 | 28 | print('#'*20) 29 | print("Notice: {:} should be tokenized by Bert First.".format(os.path.basename(rpath))) 30 | print('#'*20) 31 | 32 | c_vocab_map = {} 33 | c_counter = {} 34 | with open(rpath, 'rt') as f: 35 | for index, line in enumerate(f): 36 | if (index+1) % 200000 == 0: 37 | print("{:}\t{:}...".format(os.path.basename(rpath), index)) 38 | sys.stdout.flush() 39 | 40 | _, label, text, _ = line.strip().split('\t') 41 | if label not in c_vocab_map.keys(): 42 | c_vocab_map[label] = {} 43 | 44 | for w in text.strip().split(' '): 45 | try: 46 | c_vocab_map[label][w] += 1 47 | except KeyError: 48 | c_vocab_map[label][w] = 1 49 | 50 | try: 51 | c_counter[label] += 1 52 | except KeyError: 53 | c_counter[label] = 1 54 | 55 | # 一些情况 56 | print('#'*20) 57 | c_counter = list(c_counter.values()) 58 | print('{:} conditions; min: {:}, max: {:}, avg: {:}'.format(len(c_counter), min(c_counter), max(c_counter), mean(c_counter))) 59 | print('#'*20) 60 | sys.stdout.flush() 61 | 62 | print("# Get tf") 63 | sys.stdout.flush() 64 | c_sum_map = {label: sum(list(c_vocab_map[label].values())) for label in c_vocab_map.keys()} 65 | c_tfvocab_map = {} 66 | for label, vocab in c_vocab_map.items(): 67 | c_tfvocab_map[label] = {w: n/c_sum_map[label] for w, n in vocab.items()} 68 | 69 | print("# Get idf") 70 | sys.stdout.flush() 71 | word_counter = Counter() 72 | for _, vocab in c_vocab_map.items(): 73 | word_counter += Counter(vocab.keys()) 74 | 75 | n_labels = len(c_vocab_map) 76 | word_idf_map = {w: log10(n_labels/n_occur) for w, n_occur in word_counter.items()} 77 | 78 | print("# Get tf-idf") 79 | sys.stdout.flush() 80 | c_tfidfvocab_map = {} 81 | for label, tfvocab_map in c_tfvocab_map.items(): 82 | c_tfidfvocab_map[label] = {w: tf * word_idf_map[w] for w, tf in tfvocab_map.items()} 83 | 84 | do_pickle(c_tfidfvocab_map, wpath) 85 | 86 | # Write samples 87 | if show_sample: 88 | def write_sample(label): 89 | with open('./{:}.tmp.txt'.format(label), 'wt') as f: 90 | res = sorted(c_tfidfvocab_map[label].items(), key=lambda p:p[1], reverse=True) 91 | for w, tfidf in res: 92 | f.write("{:}\t{:.8f}\n".format(w, tfidf)) 93 | 94 | print("tfidf of {:} in ./{:}.tmp.txt".format(label, label)) 95 | 96 | write_sample('nba') 97 | write_sample('movies') 98 | 99 | return c_tfidfvocab_map 100 | 101 | 102 | if __name__ == '__main__': 103 | data_dir = sys.argv[1] 104 | read_filen = sys.argv[2] 105 | rpath = os.path.join(data_dir, read_filen) 106 | get_tfidf(rpath, data_dir, show_sample=True) 107 | -------------------------------------------------------------------------------- /pre_tokenize.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import time 4 | import math 5 | import random 6 | 7 | from pytorch_pretrained_bert.tokenization import BertTokenizer 8 | 9 | 10 | def simple_prep(sent, bert_tokenizer): 11 | return ' '.join(bert_tokenizer.tokenize(sent.strip())) 12 | 13 | 14 | def list_to_txt(samples, wpath): 15 | with open(wpath, 'wt') as f: 16 | for s in samples: 17 | f.write('\t'.join(s)+'\n') 18 | 19 | 20 | def text_to_bert(rpath, wpath, bert_model="bert-base-uncased", do_shuffle=False): 21 | wdir = os.path.dirname(wpath) 22 | if not os.path.exists(wdir): 23 | os.mkdir(wdir) 24 | 25 | bert_model = bert_model.strip() 26 | if bert_model.endswith('uncased'): 27 | do_lower_case = True 28 | else: 29 | do_lower_case = False 30 | bert_tokenizer = BertTokenizer.from_pretrained( 31 | bert_model, do_lower_case=do_lower_case) 32 | 33 | samples = [] 34 | start = time.time() 35 | with open(rpath, 'rt') as f: 36 | for index, line in enumerate(f): 37 | if (index+1) % 100000 == 0: 38 | print('{:}\t{:}\t{:.1f}min'.format(os.path.basename(rpath), 39 | index, (time.time() - start) / 60)) 40 | sys.stdout.flush() 41 | 42 | label, text = line.strip().split('\t') 43 | samples.append(('', label, 44 | simple_prep(text, bert_tokenizer), 'mono')) 45 | 46 | sys.stdout.flush() 47 | 48 | if do_shuffle: 49 | random.shuffle(samples) 50 | 51 | list_to_txt(samples, wpath) 52 | 53 | 54 | def dial_to_bert(rpath, wpath, bert_model="bert-base-uncased", do_shuffle=False): 55 | wdir = os.path.dirname(wpath) 56 | if not os.path.exists(wdir): 57 | os.mkdir(wdir) 58 | 59 | bert_model = bert_model.strip() 60 | if bert_model.endswith('uncased'): 61 | do_lower_case = True 62 | else: 63 | do_lower_case = False 64 | tokenizer = BertTokenizer.from_pretrained( 65 | bert_model, do_lower_case=do_lower_case) 66 | 67 | samples = [] 68 | start = time.time() 69 | with open(rpath, 'rt') as f: 70 | for index, line in enumerate(f): 71 | if (index+1) % 100000 == 0: 72 | print('{:}\t{:}\t{:.1f}min'.format(os.path.basename(rpath), 73 | index, (time.time() - start) / 60)) 74 | sys.stdout.flush() 75 | 76 | src, label, tgt = line.strip().split('\t') 77 | samples.append((simple_prep(src, tokenizer), label, 78 | simple_prep(tgt, tokenizer), 'dial')) 79 | 80 | sys.stdout.flush() 81 | 82 | if do_shuffle: 83 | random.shuffle(samples) 84 | 85 | list_to_txt(samples, wpath) 86 | 87 | 88 | def run_to_bert_file(): 89 | rpath = sys.argv[1] 90 | wpath = sys.argv[2] 91 | 92 | datatype, _ = os.path.basename(rpath).split('.') 93 | assert datatype in ['dial', 'text'] 94 | 95 | if datatype == 'dial': 96 | dial_to_bert(rpath, wpath) 97 | else: 98 | text_to_bert(rpath, wpath) 99 | 100 | 101 | def run_to_bert_dir(): 102 | from multiprocessing import Pool 103 | 104 | rdir = sys.argv[1] 105 | wdir = sys.argv[2] 106 | p = Pool(4) 107 | for data_type in ['dial', 'text']: 108 | for label in ['train', 'valid', 'test']: 109 | target = "{:}.{:}".format(data_type, label) 110 | assert os.path.exists(os.path.join(rdir, target)) 111 | print(os.path.join(rdir, target)) 112 | if data_type == 'dial': 113 | p.apply_async(dial_to_bert, args=(os.path.join(rdir, target), 114 | os.path.join(wdir, target))) 115 | elif data_type == 'text': 116 | p.apply_async(text_to_bert, args=(os.path.join(rdir, target), 117 | os.path.join(wdir, target))) 118 | 119 | else: 120 | raise ValueError 121 | 122 | print('\n') 123 | 124 | p.close() 125 | p.join() 126 | 127 | 128 | if __name__ == '__main__': 129 | run_to_bert_dir() 130 | -------------------------------------------------------------------------------- /pytorch_pretrained_bert/optimization_fp16.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | """PyTorch optimization for BERT model.""" 3 | 4 | from apex.optimizers import FP16_Optimizer 5 | 6 | 7 | class FP16_Optimizer_State(FP16_Optimizer): 8 | def __init__(self, 9 | init_optimizer, 10 | static_loss_scale=1.0, 11 | dynamic_loss_scale=False, 12 | dynamic_loss_args=None, 13 | verbose=True): 14 | super(FP16_Optimizer_State, self).__init__(init_optimizer, 15 | static_loss_scale, dynamic_loss_scale, dynamic_loss_args, verbose) 16 | 17 | def state_dict(self): 18 | """ 19 | Returns a dict containing the current state of this :class:`FP16_Optimizer` instance. 20 | This dict contains attributes of :class:`FP16_Optimizer`, as well as the state_dict 21 | of the contained Pytorch optimizer. 22 | Example:: 23 | checkpoint = {} 24 | checkpoint['model'] = model.state_dict() 25 | checkpoint['optimizer'] = optimizer.state_dict() 26 | torch.save(checkpoint, "saved.pth") 27 | """ 28 | state_dict = {} 29 | state_dict['dynamic_loss_scale'] = self.dynamic_loss_scale 30 | state_dict['cur_scale'] = self.cur_scale 31 | state_dict['cur_iter'] = self.cur_iter 32 | if state_dict['dynamic_loss_scale']: 33 | state_dict['last_overflow_iter'] = self.last_overflow_iter 34 | state_dict['scale_factor'] = self.scale_factor 35 | state_dict['scale_window'] = self.scale_window 36 | state_dict['optimizer_state_dict'] = self.optimizer.state_dict() 37 | state_dict['fp32_groups_flat'] = self.fp32_groups_flat 38 | return state_dict 39 | 40 | def load_state_dict(self, state_dict): 41 | """ 42 | Loads a state_dict created by an earlier call to state_dict(). 43 | If ``fp16_optimizer_instance`` was constructed from some ``init_optimizer``, 44 | whose parameters in turn came from ``model``, it is expected that the user 45 | will call ``model.load_state_dict()`` before 46 | ``fp16_optimizer_instance.load_state_dict()`` is called. 47 | Example:: 48 | model = torch.nn.Linear(D_in, D_out).cuda().half() 49 | optimizer = torch.optim.SGD(model.parameters(), lr=1e-3) 50 | optimizer = FP16_Optimizer(optimizer, static_loss_scale = 128.0) 51 | ... 52 | checkpoint = torch.load("saved.pth") 53 | model.load_state_dict(checkpoint['model']) 54 | optimizer.load_state_dict(checkpoint['optimizer']) 55 | """ 56 | # I think it should actually be ok to reload the optimizer before the model. 57 | self.dynamic_loss_scale = state_dict['dynamic_loss_scale'] 58 | self.cur_scale = state_dict['cur_scale'] 59 | self.cur_iter = state_dict['cur_iter'] 60 | if state_dict['dynamic_loss_scale']: 61 | self.last_overflow_iter = state_dict['last_overflow_iter'] 62 | self.scale_factor = state_dict['scale_factor'] 63 | self.scale_window = state_dict['scale_window'] 64 | self.optimizer.load_state_dict(state_dict['optimizer_state_dict']) 65 | # At this point, the optimizer's references to the model's fp32 parameters are up to date. 66 | # The optimizer's hyperparameters and internal buffers are also up to date. 67 | # However, the fp32 master copies of the model's fp16 params stored by the optimizer are still 68 | # out of date. There are two options. 69 | # 1: Refresh the master params from the model's fp16 params. 70 | # This requires less storage but incurs precision loss. 71 | # 2: Save and restore the fp32 master copies separately. 72 | # We choose option 2. 73 | # 74 | # Pytorch Optimizer.load_state_dict casts saved buffers (e.g. momentum) to the type and device 75 | # of their associated parameters, because it's possible those buffers might not exist yet in 76 | # the current optimizer instance. In our case, as long as the current FP16_Optimizer has been 77 | # constructed in the same way as the one whose state_dict we are loading, the same master params 78 | # are guaranteed to exist, so we can just copy_() from the saved master params. 79 | for current, saved in zip(self.fp32_groups_flat, state_dict['fp32_groups_flat']): 80 | current.data.copy_(saved.data) 81 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MultiT-C-Dialog 2 | 3 | 4 | This code is the official pytorch implementation of [A Simple and Efficient Multi-Task Learning Approach for Conditioned Dialogue Generation](https://arxiv.org/abs/2010.11140).
5 | 6 |

7 | 8 | 9 | In our experiments, we fine-tuned BERT for conditioned dialogue generation. \ 10 | Recently, researchers have proposed some large pre-trained dialogue models by utilizing Reddit/Twitter data. \ 11 | These models all utilize auto-regressive training objective. \ 12 | It is easy to apply this objective to conditioned language/dialogue generation task. \ 13 | However, the conditioned language encoding task in our approach applies bi-directional attention, and mask language modeling objective is thus needed. 14 | 15 | 16 | ## Requirements 17 | ``` 18 | python3.7 19 | torch==1.1.0 20 | ``` 21 | Then, run this: 22 | ``` 23 | pip install . 24 | ``` 25 | 26 | 27 | Notice that when you modify the code in ./biunilm/ or ./pytorch_pretrained_bert/, \ 28 | you need to re-run this command: 29 | ``` 30 | pip install . 31 | ``` 32 | Or, directly update the corresponding code in: 33 | ``` 34 | xxx/anaconda3/envs/xxx/lib/python3.7/site-packages/biunilm 35 | xxx/anaconda3/envs/xxx/lib/python3.7/site-packages/pytorch_pretrained_bert 36 | ``` 37 | 38 | ## Download Data 39 | Download [Persona Reddit](https://files.pushshift.io/reddit/) and [Topic-related Dialogue](https://github.com/nouhadziri/THRED). 40 | We leave the data cleaning / filtering process to users. 41 | Process the data into labeled dialogue corpus: 42 | ``` 43 | dial.train 44 | dial.valid 45 | dial.test 46 | ### each file consists of lines in the form of: 47 | # dialogue-context \t condition-label \t response 48 | ### for multi-turn dialogue, concatenate the turns in context using [SEP] 49 | ``` 50 | and labeled text corpus: 51 | ``` 52 | text.train 53 | text.valid 54 | text.test 55 | # each file consists of lines in the form of: 56 | # condition-label \t text 57 | ``` 58 | 59 | ## Preprocessing 60 | Please, tokenize the dataset in advance: 61 | ``` 62 | python ./pre_tokenize.py $rdir $wdir 63 | ``` 64 | Then, calculate TF-IDF scores in advance: 65 | ``` 66 | python ./get_tfidf.py $datadir $rfilen 67 | 68 | # $rpath can be the combination of text.train and dial.train (after tokenization) 69 | ``` 70 | 71 | 72 | 73 | ## Model Training 74 | 75 | Further pre-train on a dialogue corpus (optional): 76 | ``` 77 | sh -x ./pretrain.sh 78 | # use as the condition label when preprocessing the dataset 79 | ``` 80 | 81 | 82 | Use our approach to fine-tune on a labeled dialogue corpus and a labeled text corpus: 83 | ``` 84 | sh -x ./train.sh 85 | ``` 86 | where DATA_DIR should contain the two corpora. Some options are: 87 | ``` 88 | --n_text: set the number of text samples 89 | --n_dial: set the number of dialogue samples 90 | --FGfree: eliminating finetune-generation discrepancy 91 | --model_recover_path: load pre-trained model 92 | ``` 93 | 94 | 95 | Or, apply sequential fine-tuning: 96 | ``` 97 | sh -x ./run_2step_pre.sh 98 | sh -x ./run_2step_ft.sh 99 | ``` 100 | 101 | 102 | Tips: If labeled text corpus is limited, use our approach to avoid catastrophic forgetting (training on small text corpus will largely erase the pre-training result). \ 103 | If labeled text corpus is sufficient, use sequential fine-tuning. In this case, the final training goal is optimizing dialogue generation, and it will be better. 104 | 105 | 106 | 107 | ## Model Evaluation 108 | 109 | Calculate perplexity on the dialogue data: 110 | ``` 111 | sh -x ./run_ppl.sh 112 | ``` 113 | This command will automatically load the latest checkpoint in ${OUTPUT_DIR}. 114 | 115 | 116 | Generate responses: 117 | ``` 118 | sh -x ./run_eval.sh 119 | ``` 120 | 121 | 122 | We provide a evaluation scrip: 123 | ``` 124 | python eval.py $rdir $model 125 | ``` 126 | 127 | 128 | ## Acknowledgments 129 | Our code is based on [UniLM](https://github.com/microsoft/unilm/tree/master/unilm-v1). Thanks! 130 | 131 | 132 | ## Citation 133 | 134 | ```bibtex 135 | @misc{zeng2021simple, 136 | title={A Simple and Efficient Multi-Task Learning Approach for Conditioned Dialogue Generation}, 137 | author={Yan Zeng and Jian-Yun Nie}, 138 | year={2021}, 139 | eprint={2010.11140}, 140 | archivePrefix={arXiv}, 141 | primaryClass={cs.CL} 142 | } 143 | ``` 144 | 145 | If activating --FGfree option, please cite: 146 | ```bibtex 147 | @misc{zeng2020opendomain, 148 | title={Open-Domain Dialogue Generation Based on Pre-trained Language Models}, 149 | author={Yan Zeng and Jian-Yun Nie}, 150 | year={2020}, 151 | eprint={2010.12780}, 152 | archivePrefix={arXiv}, 153 | primaryClass={cs.CL} 154 | } 155 | ``` 156 | 157 | 158 | 159 | ## Contact 160 | For help using this code, please submit a GitHub issue. 161 | For serious problems, please contact Yan Zeng ([yan.zeng@umontreal.ca](mailto:yan.zeng@umontreal.ca)). 162 | 163 | 164 | -------------------------------------------------------------------------------- /nn/data_parallel.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import DataParallel 3 | from torch.cuda._utils import _get_device_index 4 | from torch.nn.parallel._functions import Scatter 5 | from itertools import chain 6 | 7 | 8 | def scatter_imbalance(inputs, target_gpus, dim=0): 9 | r""" 10 | Slices tensors into approximately equal chunks and 11 | distributes them across given GPUs. Duplicates 12 | references to objects that are not tensors. 13 | """ 14 | def scatter_map(obj): 15 | if isinstance(obj, torch.Tensor): 16 | if (len(target_gpus) == 4) and (obj.size(dim) == 22): 17 | return Scatter.apply(target_gpus, (4, 6, 6, 6), dim, obj) 18 | if (len(target_gpus) == 4) and (obj.size(dim) == 60): 19 | return Scatter.apply(target_gpus, (12, 16, 16, 16), dim, obj) 20 | elif (len(target_gpus) == 4) and (obj.size(dim) == 144): 21 | return Scatter.apply(target_gpus, (24, 40, 40, 40), dim, obj) 22 | elif (len(target_gpus) == 8) and (obj.size(dim) == 46): 23 | return Scatter.apply(target_gpus, (4, 6, 6, 6, 6, 6, 6, 6), dim, obj) 24 | elif (len(target_gpus) == 8) and (obj.size(dim) == 62): 25 | return Scatter.apply(target_gpus, (6, 8, 8, 8, 8, 8, 8, 8), dim, obj) 26 | elif (len(target_gpus) == 8) and (obj.size(dim) == 94): 27 | return Scatter.apply(target_gpus, (10, 12, 12, 12, 12, 12, 12, 12), dim, obj) 28 | elif (len(target_gpus) == 8) and (obj.size(dim) == 110): 29 | return Scatter.apply(target_gpus, (12, 14, 14, 14, 14, 14, 14, 14), dim, obj) 30 | elif (len(target_gpus) == 8) and (obj.size(dim) == 118): 31 | return Scatter.apply(target_gpus, (13, 15, 15, 15, 15, 15, 15, 15), dim, obj) 32 | elif (len(target_gpus) == 8) and (obj.size(dim) == 126): 33 | return Scatter.apply(target_gpus, (14, 16, 16, 16, 16, 16, 16, 16), dim, obj) 34 | elif (len(target_gpus) == 8) and (obj.size(dim) == 134): 35 | return Scatter.apply(target_gpus, (15, 17, 17, 17, 17, 17, 17, 17), dim, obj) 36 | elif (len(target_gpus) == 8) and (obj.size(dim) == 142): 37 | return Scatter.apply(target_gpus, (16, 18, 18, 18, 18, 18, 18, 18), dim, obj) 38 | elif (len(target_gpus) == 16) and (obj.size(dim) == 222): 39 | return Scatter.apply(target_gpus, (12, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14), dim, obj) 40 | return Scatter.apply(target_gpus, None, dim, obj) 41 | if isinstance(obj, tuple) and len(obj) > 0: 42 | return list(zip(*map(scatter_map, obj))) 43 | if isinstance(obj, list) and len(obj) > 0: 44 | return list(map(list, zip(*map(scatter_map, obj)))) 45 | if isinstance(obj, dict) and len(obj) > 0: 46 | return list(map(type(obj), zip(*map(scatter_map, obj.items())))) 47 | return [obj for targets in target_gpus] 48 | 49 | # After scatter_map is called, a scatter_map cell will exist. This cell 50 | # has a reference to the actual function scatter_map, which has references 51 | # to a closure that has a reference to the scatter_map cell (because the 52 | # fn is recursive). To avoid this reference cycle, we set the function to 53 | # None, clearing the cell 54 | try: 55 | return scatter_map(inputs) 56 | finally: 57 | scatter_map = None 58 | 59 | 60 | def scatter_kwargs_imbalance(inputs, kwargs, target_gpus, dim=0): 61 | r"""Scatter with support for kwargs dictionary""" 62 | inputs = scatter_imbalance(inputs, target_gpus, dim) if inputs else [] 63 | kwargs = scatter_imbalance(kwargs, target_gpus, dim) if kwargs else [] 64 | if len(inputs) < len(kwargs): 65 | inputs.extend([() for _ in range(len(kwargs) - len(inputs))]) 66 | elif len(kwargs) < len(inputs): 67 | kwargs.extend([{} for _ in range(len(inputs) - len(kwargs))]) 68 | inputs = tuple(inputs) 69 | kwargs = tuple(kwargs) 70 | return inputs, kwargs 71 | 72 | 73 | class DataParallelImbalance(DataParallel): 74 | def __init__(self, module, device_ids=None, output_device=None, dim=0): 75 | super(DataParallelImbalance, self).__init__( 76 | module, device_ids, output_device, dim) 77 | 78 | if not torch.cuda.is_available(): 79 | self.module = module 80 | self.device_ids = [] 81 | return 82 | 83 | if device_ids is None: 84 | device_ids = list(range(torch.cuda.device_count())) 85 | if output_device is None: 86 | output_device = device_ids[0] 87 | 88 | if not all(t.is_cuda and t.device.index == device_ids[0] 89 | for t in chain(module.parameters(), module.buffers())): 90 | raise RuntimeError("module must have its parameters and buffers " 91 | "on device %d (device_ids[0])" % device_ids[0]) 92 | 93 | self.dim = dim 94 | self.module = module 95 | self.device_ids = list( 96 | map(lambda x: _get_device_index(x, True), device_ids)) 97 | self.output_device = _get_device_index(output_device, True) 98 | 99 | if len(self.device_ids) == 1: 100 | self.module.cuda(device_ids[0]) 101 | 102 | def forward(self, *inputs, **kwargs): 103 | if not self.device_ids: 104 | return self.module(*inputs, **kwargs) 105 | inputs, kwargs = self.scatter_imbalance( 106 | inputs, kwargs, self.device_ids) 107 | if len(self.device_ids) == 1: 108 | return self.module(*inputs[0], **kwargs[0]) 109 | replicas = self.replicate(self.module, self.device_ids[:len(inputs)]) 110 | outputs = self.parallel_apply(replicas, inputs, kwargs) 111 | return self.gather(outputs, self.output_device) 112 | 113 | def scatter_imbalance(self, inputs, kwargs, device_ids): 114 | return scatter_kwargs_imbalance(inputs, kwargs, device_ids, dim=self.dim) 115 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | import re 5 | from itertools import chain 6 | from numpy import mean 7 | 8 | from nltk.tokenize import TweetTokenizer 9 | 10 | 11 | NAN = '' 12 | 13 | 14 | def read_ref(rpath, do_split=False): 15 | ref_list = [] 16 | with open(rpath, 'rt') as f: 17 | for line in f: 18 | _, _, _, ref = line.strip().split('\t') 19 | ref = ref.strip() 20 | assert len(ref) > 0 21 | if do_split: 22 | ref_list.append(ref.split(' ')) 23 | else: 24 | ref_list.append(ref) 25 | 26 | return ref_list 27 | 28 | 29 | def read_tokenized_ref(rpath, do_split=False): 30 | ref_list = [] 31 | with open(rpath, 'rt') as f: 32 | for line in f: 33 | _, _, ref, _ = line.strip().split('\t') 34 | ref = ref.strip().replace(' ##', '') 35 | assert len(ref) > 0 36 | if do_split: 37 | ref_list.append(ref.split(' ')) 38 | else: 39 | ref_list.append(ref) 40 | 41 | return ref_list 42 | 43 | 44 | def read_tokenized_src(rpath, do_split=False): 45 | assert do_split is False 46 | ref_list = [] 47 | with open(rpath, 'rt') as f: 48 | for line in f: 49 | src, _, _, _ = line.strip().split('\t') 50 | src = src.strip().replace(' ##', '').strip() 51 | 52 | assert 'SEP' not in src 53 | assert len(src) > 0 54 | ref_list.append(src) 55 | 56 | return ref_list 57 | 58 | 59 | def read_bert_gen(rpath, do_split=False): 60 | gen_list = [] 61 | with open(rpath, 'rt') as f: 62 | for line in f: 63 | line = line.strip() 64 | 65 | if len(line) > 0: 66 | gen_list.append(line.strip()) 67 | else: 68 | gen_list.append(NAN) 69 | 70 | return gen_list 71 | 72 | 73 | def list_to_txt(samples, wpath, to_str=False): 74 | assert isinstance(samples, list) 75 | wdir = os.path.dirname(wpath) 76 | if not os.path.exists(wdir): 77 | os.mkdir(wdir) 78 | 79 | with open(wpath, 'wt') as f: 80 | for s in samples: 81 | if to_str: 82 | s = str(s) 83 | f.write(s.strip() + '\n') 84 | 85 | 86 | def pad_sequence(sequence, n, pad_left=False, pad_right=False, 87 | left_pad_symbol=None, right_pad_symbol=None): 88 | sequence = iter(sequence) 89 | if pad_left: 90 | sequence = chain((left_pad_symbol,) * (n - 1), sequence) 91 | if pad_right: 92 | sequence = chain(sequence, (right_pad_symbol,) * (n - 1)) 93 | return sequence 94 | 95 | 96 | def ngrams(sequence, n, pad_left=False, pad_right=False, 97 | left_pad_symbol=None, right_pad_symbol=None): 98 | 99 | sequence = pad_sequence(sequence, n, pad_left, pad_right, 100 | left_pad_symbol, right_pad_symbol) 101 | 102 | history = [] 103 | while n > 1: 104 | history.append(next(sequence)) 105 | n -= 1 106 | for item in sequence: 107 | history.append(item) 108 | yield tuple(history) 109 | del history[0] 110 | 111 | 112 | def distinct_n_sentence_level(sentence, n): 113 | """ 114 | Compute distinct-N for a single sentence. 115 | :param sentence: a list of words. 116 | :param n: int, ngram. 117 | :return: float, the metric value. 118 | """ 119 | assert isinstance(sentence, list) 120 | if len(sentence) == 0: 121 | return 0.0 # Prevent a zero division 122 | distinct_ngrams = set(ngrams(sentence, n)) 123 | return len(distinct_ngrams) / len(sentence) 124 | 125 | 126 | def get_distinct(gen_list_, n, batch_size=32, ret_raw=False): 127 | assert isinstance(gen_list_, list) 128 | 129 | gen_list = [] 130 | for i in range(0, len(gen_list_), batch_size): 131 | gen_list.append(' '.join(gen_list_[i:i+batch_size])) 132 | 133 | dist_list = [] 134 | for gen in gen_list: 135 | if isinstance(gen, str): 136 | gen = gen.strip().split(' ') 137 | 138 | dist_list.append(distinct_n_sentence_level(gen, n)) 139 | 140 | assert len(dist_list) == len(gen_list) 141 | 142 | if ret_raw: 143 | return dist_list 144 | else: 145 | return mean(dist_list) 146 | 147 | 148 | def to_uni(sentence, tokenizer=None): 149 | def _replace(sentence, bef_token, aft_token): 150 | sentence = sentence.replace(" {:} ".format(bef_token), " {:} ".format(aft_token)) 151 | sentence = sentence.replace("{:} ".format(bef_token), "{:} ".format(aft_token)) 152 | sentence = sentence.replace(" {:}".format(bef_token), " {:}".format(aft_token)) 153 | return sentence 154 | 155 | if tokenizer is not None: 156 | sentence = ' '.join(tokenizer.tokenize(sentence)).strip() 157 | else: 158 | sentence = sentence.strip() 159 | 160 | sentence = sentence.replace("n ' t", "n't") 161 | sentence = sentence.replace("' m", "'m") 162 | sentence = sentence.replace("' s", "'s") 163 | sentence = sentence.replace("' re", "'re") 164 | sentence = sentence.replace("' d", "'d") 165 | sentence = sentence.replace("' ve", "'ve") 166 | sentence = sentence.replace("' ll", "'ll") 167 | 168 | # e.g. what's who's 169 | sentence = re.sub("([a-z])n't", r"\1 n't", sentence) 170 | sentence = _replace(sentence, "i'm", "i 'm") 171 | sentence = re.sub("([a-z])'s", r"\1 's", sentence) 172 | sentence = re.sub("([a-z])'re", r"\1 're", sentence) 173 | sentence = re.sub("([a-z])'d", r"\1 'd", sentence) 174 | sentence = re.sub("([a-z])'ve", r"\1 've", sentence) 175 | sentence = re.sub("([a-z])'ll", r"\1 'll", sentence) 176 | 177 | sentence = sentence.replace('. . .', '...') 178 | 179 | return sentence.strip() 180 | 181 | 182 | def do_eval(rdir, model_name, do_to_uni=False): 183 | 184 | if 'gpt' in model_name: 185 | print("### TweetTokenizer") 186 | sys.stdout.flush() 187 | tokenizer = TweetTokenizer() 188 | else: 189 | tokenizer = None 190 | 191 | ref_list = read_tokenized_ref(rdir+'dial.test') 192 | src_list = read_tokenized_src(rdir+'dial.test') 193 | 194 | gen_list = read_bert_gen(rdir + '{:}.preds.txt'.format(model_name)) 195 | print('Read Bert') 196 | 197 | if do_to_uni: 198 | ref_list = [to_uni(s, tokenizer) for s in ref_list] 199 | src_list = [to_uni(s, tokenizer) for s in src_list] 200 | gen_list = [to_uni(s, tokenizer) for s in gen_list] 201 | 202 | list_to_txt(ref_list, './tmp/ref.txt') 203 | list_to_txt(src_list, './tmp/src.txt') 204 | list_to_txt(gen_list, './tmp/{:}.txt'.format(model_name)) 205 | 206 | avg_len = [len(s.strip().split(' ')) for s in gen_list] 207 | print("Average Len: {:}".format(mean(avg_len))) 208 | print('\n') 209 | 210 | print('Eval {:} Distinct...'.format(model_name)) 211 | sys.stdout.flush() 212 | gen_res = [' '.join(gen_list)] 213 | dist1 = get_distinct(gen_res, 1) 214 | dist2 = get_distinct(gen_res, 2) 215 | dist3 = get_distinct(gen_res, 3) 216 | dist4 = get_distinct(gen_res, 4) 217 | print("Dist1: {:.3f}, Dist2: {:.3f}, Dist3: {:.3f}, Dist4: {:.3f}".format(dist1, dist2, dist3, dist4)) 218 | 219 | print('Eval {:}...'.format(model_name)) 220 | sys.stdout.flush() 221 | os.system("nlg-eval --hypothesis=tmp/{:}.txt --references=tmp/ref.txt".format(model_name)) 222 | 223 | 224 | if __name__ == '__main__': 225 | rdir = sys.argv[1].strip() 226 | model_name = sys.argv[1].strip() # '{:}.preds.txt'.format(model_name) 227 | do_eval(rdir, model_name, do_to_uni=True) 228 | -------------------------------------------------------------------------------- /qg/eval.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | from __future__ import print_function 3 | __author__ = 'xinya' 4 | 5 | from bleu.bleu import Bleu 6 | from meteor.meteor import Meteor 7 | from rouge.rouge import Rouge 8 | from cider.cider import Cider 9 | from collections import defaultdict 10 | from argparse import ArgumentParser 11 | import string 12 | 13 | import sys 14 | reload(sys) 15 | sys.setdefaultencoding('utf-8') 16 | 17 | _tok_dict = {"(": "-lrb-", ")": "-rrb-", 18 | "[": "-lsb-", "]": "-rsb-", 19 | "{": "-lcb-", "}": "-rcb-", 20 | "[UNK]": "UNK", '&': '&', '<': '<', '>': '>'} 21 | 22 | 23 | def _is_digit(w): 24 | for ch in w: 25 | if not(ch.isdigit() or ch == ','): 26 | return False 27 | return True 28 | 29 | 30 | def detokenize(tk_list): 31 | r_list = [] 32 | for tk in tk_list: 33 | if tk.startswith('##') and len(r_list) > 0: 34 | r_list[-1] = r_list[-1] + tk[2:] 35 | else: 36 | r_list.append(tk) 37 | return r_list 38 | 39 | 40 | def fix_tokenization(text): 41 | input_tokens = text.split() 42 | output_tokens = [] 43 | has_left_quote = False 44 | has_left_single_quote = False 45 | 46 | i = 0 47 | prev_dash = False 48 | while i < len(input_tokens): 49 | tok = input_tokens[i] 50 | flag_prev_dash = False 51 | if tok in _tok_dict.keys(): 52 | output_tokens.append(_tok_dict[tok]) 53 | i += 1 54 | elif tok == "\"": 55 | if has_left_quote: 56 | output_tokens.append("''") 57 | else: 58 | output_tokens.append("``") 59 | has_left_quote = not has_left_quote 60 | i += 1 61 | elif tok == "'" and len(output_tokens) > 0 and output_tokens[-1].endswith("n") and i < len(input_tokens) - 1 and input_tokens[i + 1] == "t": 62 | output_tokens[-1] = output_tokens[-1][:-1] 63 | output_tokens.append("n't") 64 | i += 2 65 | elif tok == "'" and i < len(input_tokens) - 1 and input_tokens[i + 1] in ("s", "d", "ll"): 66 | output_tokens.append("'"+input_tokens[i + 1]) 67 | i += 2 68 | elif tok == "'": 69 | if has_left_single_quote: 70 | output_tokens.append("'") 71 | else: 72 | output_tokens.append("`") 73 | has_left_single_quote = not has_left_single_quote 74 | i += 1 75 | elif tok == "." and i < len(input_tokens) - 2 and input_tokens[i + 1] == "." and input_tokens[i + 2] == ".": 76 | output_tokens.append("...") 77 | i += 3 78 | elif tok == "," and len(output_tokens) > 0 and _is_digit(output_tokens[-1]) and i < len(input_tokens) - 1 and _is_digit(input_tokens[i + 1]): 79 | # $ 3 , 000 -> $ 3,000 80 | output_tokens[-1] += ','+input_tokens[i + 1] 81 | i += 2 82 | elif tok == "." and len(output_tokens) > 0 and output_tokens[-1].isdigit() and i < len(input_tokens) - 1 and input_tokens[i + 1].isdigit(): 83 | # 3 . 03 -> $ 3.03 84 | output_tokens[-1] += '.'+input_tokens[i + 1] 85 | i += 2 86 | elif tok == "." and len(output_tokens) > 0 and len(output_tokens[-1]) == 1 and output_tokens[-1].isupper() and i < len(input_tokens) - 2 and len(input_tokens[i + 1]) == 1 and input_tokens[i + 1].isupper() and input_tokens[i + 2] == '.': 87 | # U . N . -> U.N. 88 | k = i+3 89 | while k+2 < len(input_tokens): 90 | if len(input_tokens[k + 1]) == 1 and input_tokens[k + 1].isupper() and input_tokens[k + 2] == '.': 91 | k += 2 92 | else: 93 | break 94 | output_tokens[-1] += ''.join(input_tokens[i:k]) 95 | i += 2 96 | elif tok == "-": 97 | if i < len(input_tokens) - 1 and input_tokens[i + 1] == "-": 98 | output_tokens.append("--") 99 | i += 2 100 | elif i == len(input_tokens) - 1 or i == 0: 101 | output_tokens.append("-") 102 | i += 1 103 | elif output_tokens[-1] not in string.punctuation and input_tokens[i + 1][0] not in string.punctuation: 104 | output_tokens[-1] += "-" 105 | i += 1 106 | flag_prev_dash = True 107 | else: 108 | output_tokens.append("-") 109 | i += 1 110 | elif prev_dash and len(output_tokens) > 0 and tok[0] not in string.punctuation: 111 | output_tokens[-1] += tok 112 | i += 1 113 | else: 114 | output_tokens.append(tok) 115 | i += 1 116 | prev_dash = flag_prev_dash 117 | return " ".join(output_tokens) 118 | 119 | 120 | class QGEvalCap: 121 | def __init__(self, gts, res): 122 | self.gts = gts 123 | self.res = res 124 | 125 | def evaluate(self): 126 | output = [] 127 | scorers = [ 128 | (Bleu(4), ["Bleu_1", "Bleu_2", "Bleu_3", "Bleu_4"]), 129 | (Meteor(), "METEOR"), 130 | (Rouge(), "ROUGE_L"), 131 | # (Cider(), "CIDEr") 132 | ] 133 | 134 | # ================================================= 135 | # Compute scores 136 | # ================================================= 137 | for scorer, method in scorers: 138 | # print 'computing %s score...'%(scorer.method()) 139 | score, scores = scorer.compute_score(self.gts, self.res) 140 | if type(method) == list: 141 | for sc, scs, m in zip(score, scores, method): 142 | print("%s: %0.5f" % (m, sc)) 143 | output.append(sc) 144 | else: 145 | print("%s: %0.5f" % (method, score)) 146 | output.append(score) 147 | return output 148 | 149 | 150 | def eval(out_file, src_file, tgt_file, isDIn=False, num_pairs=500): 151 | """ 152 | Given a filename, calculate the metric scores for that prediction file 153 | 154 | isDin: boolean value to check whether input file is DirectIn.txt 155 | """ 156 | 157 | pairs = [] 158 | with open(src_file, 'r') as infile: 159 | for line in infile: 160 | pair = {} 161 | pair['tokenized_sentence'] = line[:-1].strip().lower() 162 | pairs.append(pair) 163 | 164 | with open(tgt_file, "r") as infile: 165 | cnt = 0 166 | for line in infile: 167 | pairs[cnt]['tokenized_question'] = line[:-1].strip() 168 | cnt += 1 169 | 170 | output = [] 171 | with open(out_file, 'r') as infile: 172 | for line in infile: 173 | line = fix_tokenization(line[:-1].strip()).lower() 174 | output.append(line) 175 | 176 | for idx, pair in enumerate(pairs): 177 | pair['prediction'] = output[idx] 178 | 179 | # eval 180 | from eval import QGEvalCap 181 | import json 182 | from json import encoder 183 | encoder.FLOAT_REPR = lambda o: format(o, '.4f') 184 | 185 | res = defaultdict(lambda: []) 186 | gts = defaultdict(lambda: []) 187 | 188 | for pair in pairs[:]: 189 | key = pair['tokenized_sentence'] 190 | res[key] = [pair['prediction'].encode('utf-8')] 191 | 192 | # gts 193 | gts[key].append(pair['tokenized_question'].encode('utf-8')) 194 | 195 | QGEval = QGEvalCap(gts, res) 196 | return QGEval.evaluate() 197 | 198 | 199 | if __name__ == "__main__": 200 | parser = ArgumentParser() 201 | parser.add_argument("-out", "--out_file", dest="out_file", 202 | default="./output/pred.txt", help="output file to compare") 203 | parser.add_argument("-src", "--src_file", dest="src_file", 204 | default="./qg_data/test/test.pa.txt", help="src file") 205 | parser.add_argument("-tgt", "--tgt_file", dest="tgt_file", 206 | default="./qg_data/nqg_processed_data/tgt-test.txt", help="target file") 207 | args = parser.parse_args() 208 | 209 | print("scores: \n") 210 | eval(args.out_file, args.src_file, args.tgt_file) 211 | -------------------------------------------------------------------------------- /qg/eval_on_unilm_tokenized_ref.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | from __future__ import print_function 3 | __author__ = 'xinya' 4 | 5 | from bleu.bleu import Bleu 6 | from meteor.meteor import Meteor 7 | from rouge.rouge import Rouge 8 | from cider.cider import Cider 9 | from collections import defaultdict 10 | from argparse import ArgumentParser 11 | import string 12 | 13 | import sys 14 | reload(sys) 15 | sys.setdefaultencoding('utf-8') 16 | 17 | _tok_dict = {"(": "-lrb-", ")": "-rrb-", 18 | "[": "-lsb-", "]": "-rsb-", 19 | "{": "-lcb-", "}": "-rcb-", 20 | "[UNK]": "UNK", '&': '&', '<': '<', '>': '>'} 21 | 22 | 23 | def _is_digit(w): 24 | for ch in w: 25 | if not(ch.isdigit() or ch == ','): 26 | return False 27 | return True 28 | 29 | 30 | def detokenize(tk_list): 31 | r_list = [] 32 | for tk in tk_list: 33 | if tk.startswith('##') and len(r_list) > 0: 34 | r_list[-1] = r_list[-1] + tk[2:] 35 | else: 36 | r_list.append(tk) 37 | return r_list 38 | 39 | 40 | def fix_tokenization(text): 41 | input_tokens = text.split() 42 | output_tokens = [] 43 | has_left_quote = False 44 | has_left_single_quote = False 45 | 46 | i = 0 47 | prev_dash = False 48 | while i < len(input_tokens): 49 | tok = input_tokens[i] 50 | flag_prev_dash = False 51 | if tok in _tok_dict.keys(): 52 | output_tokens.append(_tok_dict[tok]) 53 | i += 1 54 | elif tok == "\"": 55 | if has_left_quote: 56 | output_tokens.append("''") 57 | else: 58 | output_tokens.append("``") 59 | has_left_quote = not has_left_quote 60 | i += 1 61 | elif tok == "'" and len(output_tokens) > 0 and output_tokens[-1].endswith("n") and i < len(input_tokens) - 1 and input_tokens[i + 1] == "t": 62 | output_tokens[-1] = output_tokens[-1][:-1] 63 | output_tokens.append("n't") 64 | i += 2 65 | elif tok == "'" and i < len(input_tokens) - 1 and input_tokens[i + 1] in ("s", "d", "ll"): 66 | output_tokens.append("'"+input_tokens[i + 1]) 67 | i += 2 68 | elif tok == "'": 69 | if has_left_single_quote: 70 | output_tokens.append("'") 71 | else: 72 | output_tokens.append("`") 73 | has_left_single_quote = not has_left_single_quote 74 | i += 1 75 | elif tok == "." and i < len(input_tokens) - 2 and input_tokens[i + 1] == "." and input_tokens[i + 2] == ".": 76 | output_tokens.append("...") 77 | i += 3 78 | elif tok == "," and len(output_tokens) > 0 and _is_digit(output_tokens[-1]) and i < len(input_tokens) - 1 and _is_digit(input_tokens[i + 1]): 79 | # $ 3 , 000 -> $ 3,000 80 | output_tokens[-1] += ','+input_tokens[i + 1] 81 | i += 2 82 | elif tok == "." and len(output_tokens) > 0 and output_tokens[-1].isdigit() and i < len(input_tokens) - 1 and input_tokens[i + 1].isdigit(): 83 | # 3 . 03 -> $ 3.03 84 | output_tokens[-1] += '.'+input_tokens[i + 1] 85 | i += 2 86 | elif tok == "." and len(output_tokens) > 0 and len(output_tokens[-1]) == 1 and output_tokens[-1].isupper() and i < len(input_tokens) - 2 and len(input_tokens[i + 1]) == 1 and input_tokens[i + 1].isupper() and input_tokens[i + 2] == '.': 87 | # U . N . -> U.N. 88 | k = i+3 89 | while k+2 < len(input_tokens): 90 | if len(input_tokens[k + 1]) == 1 and input_tokens[k + 1].isupper() and input_tokens[k + 2] == '.': 91 | k += 2 92 | else: 93 | break 94 | output_tokens[-1] += ''.join(input_tokens[i:k]) 95 | i += 2 96 | elif tok == "-": 97 | if i < len(input_tokens) - 1 and input_tokens[i + 1] == "-": 98 | output_tokens.append("--") 99 | i += 2 100 | elif i == len(input_tokens) - 1 or i == 0: 101 | output_tokens.append("-") 102 | i += 1 103 | elif output_tokens[-1] not in string.punctuation and input_tokens[i + 1][0] not in string.punctuation: 104 | output_tokens[-1] += "-" 105 | i += 1 106 | flag_prev_dash = True 107 | else: 108 | output_tokens.append("-") 109 | i += 1 110 | elif prev_dash and len(output_tokens) > 0 and tok[0] not in string.punctuation: 111 | output_tokens[-1] += tok 112 | i += 1 113 | else: 114 | output_tokens.append(tok) 115 | i += 1 116 | prev_dash = flag_prev_dash 117 | return " ".join(output_tokens) 118 | 119 | 120 | class QGEvalCap: 121 | def __init__(self, gts, res): 122 | self.gts = gts 123 | self.res = res 124 | 125 | def evaluate(self): 126 | output = [] 127 | scorers = [ 128 | (Bleu(4), ["Bleu_1", "Bleu_2", "Bleu_3", "Bleu_4"]), 129 | (Meteor(), "METEOR"), 130 | (Rouge(), "ROUGE_L"), 131 | # (Cider(), "CIDEr") 132 | ] 133 | 134 | # ================================================= 135 | # Compute scores 136 | # ================================================= 137 | for scorer, method in scorers: 138 | # print 'computing %s score...'%(scorer.method()) 139 | score, scores = scorer.compute_score(self.gts, self.res) 140 | if type(method) == list: 141 | for sc, scs, m in zip(score, scores, method): 142 | print("%s: %0.5f" % (m, sc)) 143 | output.append(sc) 144 | else: 145 | print("%s: %0.5f" % (method, score)) 146 | output.append(score) 147 | return output 148 | 149 | 150 | def eval(out_file, src_file, tgt_file, isDIn=False, num_pairs=500): 151 | """ 152 | Given a filename, calculate the metric scores for that prediction file 153 | 154 | isDin: boolean value to check whether input file is DirectIn.txt 155 | """ 156 | 157 | pairs = [] 158 | with open(src_file, 'r') as infile: 159 | for line in infile: 160 | pair = {} 161 | pair['tokenized_sentence'] = line[:-1].strip().lower() 162 | pairs.append(pair) 163 | 164 | with open(tgt_file, "r") as infile: 165 | cnt = 0 166 | for line in infile: 167 | pairs[cnt]['tokenized_question'] = " ".join( 168 | detokenize(line[:-1].strip().split())).lower() 169 | cnt += 1 170 | 171 | output = [] 172 | with open(out_file, 'r') as infile: 173 | for line in infile: 174 | line = line[:-1].strip().lower() 175 | output.append(line) 176 | 177 | for idx, pair in enumerate(pairs): 178 | pair['prediction'] = output[idx] 179 | 180 | # eval 181 | from eval import QGEvalCap 182 | import json 183 | from json import encoder 184 | encoder.FLOAT_REPR = lambda o: format(o, '.4f') 185 | 186 | res = defaultdict(lambda: []) 187 | gts = defaultdict(lambda: []) 188 | 189 | for pair in pairs[:]: 190 | key = pair['tokenized_sentence'] 191 | res[key] = [pair['prediction'].encode('utf-8')] 192 | 193 | # gts 194 | gts[key].append(pair['tokenized_question'].encode('utf-8')) 195 | 196 | QGEval = QGEvalCap(gts, res) 197 | return QGEval.evaluate() 198 | 199 | 200 | if __name__ == "__main__": 201 | parser = ArgumentParser() 202 | parser.add_argument("-out", "--out_file", dest="out_file", 203 | default="./output/pred.txt", help="output file to compare") 204 | parser.add_argument("-src", "--src_file", dest="src_file", 205 | default="./qg_data/test/test.pa.txt", help="src file") 206 | parser.add_argument("-tgt", "--tgt_file", dest="tgt_file", 207 | default="./qg_data/test/test.q.tok.txt", help="target file") 208 | args = parser.parse_args() 209 | 210 | print("scores: \n") 211 | eval(args.out_file, args.src_file, args.tgt_file) 212 | -------------------------------------------------------------------------------- /pytorch_pretrained_bert/file_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utilities for working with the local dataset cache. 3 | This file is adapted from the AllenNLP library at https://github.com/allenai/allennlp 4 | Copyright by the AllenNLP authors. 5 | """ 6 | 7 | import os 8 | import logging 9 | import shutil 10 | import tempfile 11 | import json 12 | from urllib.parse import urlparse 13 | from pathlib import Path 14 | from typing import Optional, Tuple, Union, IO, Callable, Set 15 | from hashlib import sha256 16 | from functools import wraps 17 | 18 | from tqdm import tqdm 19 | 20 | import boto3 21 | from botocore.exceptions import ClientError 22 | import requests 23 | 24 | logger = logging.getLogger(__name__) # pylint: disable=invalid-name 25 | 26 | PYTORCH_PRETRAINED_BERT_CACHE = Path(os.getenv('PYTORCH_PRETRAINED_BERT_CACHE', 27 | Path.home() / '.pytorch_pretrained_bert')) 28 | 29 | 30 | def url_to_filename(url: str, etag: str = None) -> str: 31 | """ 32 | Convert `url` into a hashed filename in a repeatable way. 33 | If `etag` is specified, append its hash to the url's, delimited 34 | by a period. 35 | """ 36 | url_bytes = url.encode('utf-8') 37 | url_hash = sha256(url_bytes) 38 | filename = url_hash.hexdigest() 39 | 40 | if etag: 41 | etag_bytes = etag.encode('utf-8') 42 | etag_hash = sha256(etag_bytes) 43 | filename += '.' + etag_hash.hexdigest() 44 | 45 | return filename 46 | 47 | 48 | def filename_to_url(filename: str, cache_dir: Union[str, Path] = None) -> Tuple[str, str]: 49 | """ 50 | Return the url and etag (which may be ``None``) stored for `filename`. 51 | Raise ``FileNotFoundError`` if `filename` or its stored metadata do not exist. 52 | """ 53 | if cache_dir is None: 54 | cache_dir = PYTORCH_PRETRAINED_BERT_CACHE 55 | if isinstance(cache_dir, Path): 56 | cache_dir = str(cache_dir) 57 | 58 | cache_path = os.path.join(cache_dir, filename) 59 | if not os.path.exists(cache_path): 60 | raise FileNotFoundError("file {} not found".format(cache_path)) 61 | 62 | meta_path = cache_path + '.json' 63 | if not os.path.exists(meta_path): 64 | raise FileNotFoundError("file {} not found".format(meta_path)) 65 | 66 | with open(meta_path) as meta_file: 67 | metadata = json.load(meta_file) 68 | url = metadata['url'] 69 | etag = metadata['etag'] 70 | 71 | return url, etag 72 | 73 | 74 | def cached_path(url_or_filename: Union[str, Path], cache_dir: Union[str, Path] = None) -> str: 75 | """ 76 | Given something that might be a URL (or might be a local path), 77 | determine which. If it's a URL, download the file and cache it, and 78 | return the path to the cached file. If it's already a local path, 79 | make sure the file exists and then return the path. 80 | """ 81 | if cache_dir is None: 82 | cache_dir = PYTORCH_PRETRAINED_BERT_CACHE 83 | if isinstance(url_or_filename, Path): 84 | url_or_filename = str(url_or_filename) 85 | if isinstance(cache_dir, Path): 86 | cache_dir = str(cache_dir) 87 | 88 | parsed = urlparse(url_or_filename) 89 | 90 | if parsed.scheme in ('http', 'https', 's3'): 91 | # URL, so get it from the cache (downloading if necessary) 92 | return get_from_cache(url_or_filename, cache_dir) 93 | elif os.path.exists(url_or_filename): 94 | # File, and it exists. 95 | return url_or_filename 96 | elif parsed.scheme == '': 97 | # File, but it doesn't exist. 98 | raise FileNotFoundError("file {} not found".format(url_or_filename)) 99 | else: 100 | # Something unknown 101 | raise ValueError("unable to parse {} as a URL or as a local path".format(url_or_filename)) 102 | 103 | 104 | def split_s3_path(url: str) -> Tuple[str, str]: 105 | """Split a full s3 path into the bucket name and path.""" 106 | parsed = urlparse(url) 107 | if not parsed.netloc or not parsed.path: 108 | raise ValueError("bad s3 path {}".format(url)) 109 | bucket_name = parsed.netloc 110 | s3_path = parsed.path 111 | # Remove '/' at beginning of path. 112 | if s3_path.startswith("/"): 113 | s3_path = s3_path[1:] 114 | return bucket_name, s3_path 115 | 116 | 117 | def s3_request(func: Callable): 118 | """ 119 | Wrapper function for s3 requests in order to create more helpful error 120 | messages. 121 | """ 122 | 123 | @wraps(func) 124 | def wrapper(url: str, *args, **kwargs): 125 | try: 126 | return func(url, *args, **kwargs) 127 | except ClientError as exc: 128 | if int(exc.response["Error"]["Code"]) == 404: 129 | raise FileNotFoundError("file {} not found".format(url)) 130 | else: 131 | raise 132 | 133 | return wrapper 134 | 135 | 136 | @s3_request 137 | def s3_etag(url: str) -> Optional[str]: 138 | """Check ETag on S3 object.""" 139 | s3_resource = boto3.resource("s3") 140 | bucket_name, s3_path = split_s3_path(url) 141 | s3_object = s3_resource.Object(bucket_name, s3_path) 142 | return s3_object.e_tag 143 | 144 | 145 | @s3_request 146 | def s3_get(url: str, temp_file: IO) -> None: 147 | """Pull a file directly from S3.""" 148 | s3_resource = boto3.resource("s3") 149 | bucket_name, s3_path = split_s3_path(url) 150 | s3_resource.Bucket(bucket_name).download_fileobj(s3_path, temp_file) 151 | 152 | 153 | def http_get(url: str, temp_file: IO) -> None: 154 | req = requests.get(url, stream=True) 155 | content_length = req.headers.get('Content-Length') 156 | total = int(content_length) if content_length is not None else None 157 | progress = tqdm(unit="B", total=total) 158 | for chunk in req.iter_content(chunk_size=1024): 159 | if chunk: # filter out keep-alive new chunks 160 | progress.update(len(chunk)) 161 | temp_file.write(chunk) 162 | progress.close() 163 | 164 | 165 | def get_from_cache(url: str, cache_dir: Union[str, Path] = None) -> str: 166 | """ 167 | Given a URL, look for the corresponding dataset in the local cache. 168 | If it's not there, download it. Then return the path to the cached file. 169 | """ 170 | if cache_dir is None: 171 | cache_dir = PYTORCH_PRETRAINED_BERT_CACHE 172 | if isinstance(cache_dir, Path): 173 | cache_dir = str(cache_dir) 174 | 175 | os.makedirs(cache_dir, exist_ok=True) 176 | 177 | # Get eTag to add to filename, if it exists. 178 | if url.startswith("s3://"): 179 | etag = s3_etag(url) 180 | else: 181 | response = requests.head(url, allow_redirects=True) 182 | if response.status_code != 200: 183 | raise IOError("HEAD request failed for url {} with status code {}" 184 | .format(url, response.status_code)) 185 | etag = response.headers.get("ETag") 186 | 187 | filename = url_to_filename(url, etag) 188 | 189 | # get cache path to put the file 190 | cache_path = os.path.join(cache_dir, filename) 191 | # *ZY* 192 | # cache_path = "cache_tmp/bert-base-uncased-pretrained-cache/9c41111e2de84547a463fd39217199738d1e3deb72d4fec4399e6e241983c6f0.ae3cef932725ca7a30cdcb93fc6e09150a55e2a130ec7af63975a16c153ae2ba" 193 | 194 | if not os.path.exists(cache_path): 195 | # Download to temporary file, then copy to cache dir once finished. 196 | # Otherwise you get corrupt cache entries if the download gets interrupted. 197 | with tempfile.NamedTemporaryFile() as temp_file: 198 | logger.info("%s not found in cache, downloading to %s", url, temp_file.name) 199 | 200 | # GET file object 201 | if url.startswith("s3://"): 202 | s3_get(url, temp_file) 203 | else: 204 | http_get(url, temp_file) 205 | 206 | # we are copying the file before closing it, so flush to avoid truncation 207 | temp_file.flush() 208 | # shutil.copyfileobj() starts at the current position, so go to the start 209 | temp_file.seek(0) 210 | 211 | logger.info("copying %s to cache at %s", temp_file.name, cache_path) 212 | with open(cache_path, 'wb') as cache_file: 213 | shutil.copyfileobj(temp_file, cache_file) 214 | 215 | logger.info("creating metadata file for %s", cache_path) 216 | meta = {'url': url, 'etag': etag} 217 | meta_path = cache_path + '.json' 218 | with open(meta_path, 'w') as meta_file: 219 | json.dump(meta, meta_file) 220 | 221 | logger.info("removing temp file %s", temp_file.name) 222 | 223 | return cache_path 224 | 225 | 226 | def read_set_from_file(filename: str) -> Set[str]: 227 | ''' 228 | Extract a de-duped collection (set) of text from a file. 229 | Expected file format is one item per line. 230 | ''' 231 | collection = set() 232 | with open(filename, 'r', encoding='utf-8') as file_: 233 | for line in file_: 234 | collection.add(line.rstrip()) 235 | return collection 236 | 237 | 238 | def get_file_extension(path: str, dot=True, lower: bool = True): 239 | ext = os.path.splitext(path)[1] 240 | ext = ext if dot else ext[1:] 241 | return ext.lower() if lower else ext 242 | -------------------------------------------------------------------------------- /biunilm/loader_utils.py: -------------------------------------------------------------------------------- 1 | from random import randint, shuffle 2 | from random import random as rand 3 | import numpy as np 4 | 5 | import torch 6 | import torch.utils.data 7 | 8 | 9 | def get_random_word(vocab_words): 10 | i = randint(0, len(vocab_words)-1) 11 | return vocab_words[i] 12 | 13 | 14 | def batch_list_to_batch_tensors(batch): 15 | batch_tensors = [] 16 | for x in zip(*batch): 17 | # TODO: e.g. 18 | # batch = [(1,2,3), (4,5,6)] 19 | # x will be (1,4), then (2,5), then (3,6). 20 | if x[0] is None: 21 | batch_tensors.append(None) 22 | elif isinstance(x[0], torch.Tensor): 23 | batch_tensors.append(torch.stack(x)) 24 | else: 25 | batch_tensors.append(torch.tensor(x, dtype=torch.long)) 26 | 27 | return batch_tensors 28 | 29 | 30 | class TrieNode(object): 31 | def __init__(self): 32 | self.children = {} 33 | self.is_leaf = False 34 | 35 | def try_get_children(self, key): 36 | if key not in self.children: 37 | self.children[key] = TrieNode() 38 | return self.children[key] 39 | 40 | 41 | class TrieTree(object): 42 | def __init__(self): 43 | self.root = TrieNode() 44 | 45 | def add(self, tokens): 46 | r = self.root 47 | for token in tokens: 48 | r = r.try_get_children(token) 49 | r.is_leaf = True 50 | 51 | def get_pieces(self, tokens, offset): 52 | pieces = [] 53 | r = self.root 54 | token_id = 0 55 | last_valid = 0 56 | match_count = 0 57 | while last_valid < len(tokens): 58 | if token_id < len(tokens) and tokens[token_id] in r.children: 59 | r = r.children[tokens[token_id]] 60 | match_count += 1 61 | if r.is_leaf: 62 | last_valid = token_id 63 | token_id += 1 64 | else: 65 | pieces.append( 66 | list(range(token_id - match_count + offset, last_valid + 1 + offset))) 67 | last_valid += 1 68 | token_id = last_valid 69 | r = self.root 70 | match_count = 0 71 | 72 | return pieces 73 | 74 | 75 | def _get_word_split_index(tokens, st, end): 76 | split_idx = [] 77 | i = st 78 | while i < end: 79 | if (not tokens[i].startswith('##')) or (i == st): 80 | split_idx.append(i) 81 | i += 1 82 | split_idx.append(end) 83 | return split_idx 84 | 85 | 86 | def _expand_whole_word(tokens, st, end): 87 | new_st, new_end = st, end 88 | while (new_st >= 0) and tokens[new_st].startswith('##'): 89 | new_st -= 1 90 | while (new_end < len(tokens)) and tokens[new_end].startswith('##'): 91 | new_end += 1 92 | return new_st, new_end 93 | 94 | 95 | class Pipeline(object): 96 | """ Pre-process Pipeline Class : callable """ 97 | 98 | def __init__(self): 99 | super(Pipeline).__init__() 100 | self.skipgram_prb = None 101 | self.skipgram_size = None 102 | self.pre_whole_word = None 103 | self.mask_whole_word = None 104 | self.word_subsample_prb = None 105 | self.sp_prob = None 106 | self.pieces_dir = None 107 | self.vocab_words = None 108 | self.pieces_threshold = 10 109 | self.trie = None 110 | self.call_count = 0 111 | self.offline_mode = False 112 | self.skipgram_size_geo_list = None 113 | self.span_same_mask = False 114 | 115 | def init_skipgram_size_geo_list(self, p): 116 | if p > 0: 117 | g_list = [] 118 | t = p 119 | for _ in range(self.skipgram_size): 120 | g_list.append(t) 121 | t *= (1-p) 122 | s = sum(g_list) 123 | self.skipgram_size_geo_list = [x/s for x in g_list] 124 | 125 | def create_trie_tree(self, pieces_dir): 126 | print("sp_prob = {}".format(self.sp_prob)) 127 | print("pieces_threshold = {}".format(self.pieces_threshold)) 128 | if pieces_dir is not None: 129 | self.trie = TrieTree() 130 | pieces_files = [pieces_dir] 131 | for token in self.vocab_words: 132 | self.trie.add([token]) 133 | for piece_file in pieces_files: 134 | print("Load piece file: {}".format(piece_file)) 135 | with open(piece_file, mode='r', encoding='utf-8') as reader: 136 | for line in reader: 137 | parts = line.split('\t') 138 | if int(parts[-1]) < self.pieces_threshold: 139 | pass 140 | tokens = [] 141 | for part in parts[:-1]: 142 | tokens.extend(part.split(' ')) 143 | self.trie.add(tokens) 144 | 145 | def __call__(self, instance): 146 | raise NotImplementedError 147 | 148 | # pre_whole_word: tokenize to words before masking 149 | # post whole word (--mask_whole_word): expand to words after masking 150 | def get_masked_pos(self, tokens, n_pred, add_skipgram=False, mask_segment=None, protect_range=None): 151 | if self.pieces_dir is not None and self.trie is None: 152 | self.create_trie_tree(self.pieces_dir) 153 | if self.pre_whole_word: 154 | if self.trie is not None: 155 | pieces = self.trie.get_pieces(tokens, 0) 156 | 157 | new_pieces = [] 158 | for piece in pieces: 159 | if len(new_pieces) > 0 and tokens[piece[0]].startswith("##"): 160 | new_pieces[-1].extend(piece) 161 | else: 162 | new_pieces.append(piece) 163 | del pieces 164 | pieces = new_pieces 165 | 166 | pre_word_split = list(_[-1] for _ in pieces) 167 | pre_word_split.append(len(tokens)) 168 | else: 169 | pre_word_split = _get_word_split_index(tokens, 0, len(tokens)) 170 | index2piece = None 171 | else: 172 | pre_word_split = list(range(0, len(tokens)+1)) 173 | 174 | if self.trie is not None: 175 | pieces = self.trie.get_pieces(tokens, 0) 176 | 177 | index2piece = {} 178 | for piece in pieces: 179 | for index in piece: 180 | index2piece[index] = (piece[0], piece[-1]) 181 | else: 182 | index2piece = None 183 | 184 | span_list = list(zip(pre_word_split[:-1], pre_word_split[1:])) 185 | 186 | # candidate positions of masked tokens 187 | cand_pos = [] 188 | special_pos = set() 189 | if mask_segment: 190 | for i, sp in enumerate(span_list): 191 | sp_st, sp_end = sp 192 | if (sp_end-sp_st == 1) and tokens[sp_st].endswith('SEP]'): 193 | segment_index = i 194 | break 195 | for i, sp in enumerate(span_list): 196 | sp_st, sp_end = sp 197 | if (sp_end-sp_st == 1) and (tokens[sp_st].endswith('CLS]') or tokens[sp_st].endswith('SEP]')): 198 | special_pos.add(i) 199 | else: 200 | if mask_segment: 201 | if ((i < segment_index) and ('a' in mask_segment)) or ((i > segment_index) and ('b' in mask_segment)): 202 | cand_pos.append(i) 203 | else: 204 | cand_pos.append(i) 205 | shuffle(cand_pos) 206 | 207 | masked_pos = set() 208 | for i_span in cand_pos: 209 | if len(masked_pos) >= n_pred: 210 | break 211 | cand_st, cand_end = span_list[i_span] 212 | if len(masked_pos)+cand_end-cand_st > n_pred: 213 | continue 214 | if any(p in masked_pos for p in range(cand_st, cand_end)): 215 | continue 216 | 217 | n_span = 1 218 | if index2piece is not None: 219 | p_start, p_end = index2piece[i_span] 220 | if p_start < p_end and (rand() < self.sp_prob): 221 | # n_span = p_end - p_start + 1 222 | st_span, end_span = p_start, p_end + 1 223 | else: 224 | st_span, end_span = i_span, i_span + 1 225 | else: 226 | rand_skipgram_size = 0 227 | # ngram 228 | if self.skipgram_size_geo_list: 229 | # sampling ngram size from geometric distribution 230 | rand_skipgram_size = np.random.choice( 231 | len(self.skipgram_size_geo_list), 1, p=self.skipgram_size_geo_list)[0] + 1 232 | else: 233 | if add_skipgram and (self.skipgram_prb > 0) and (self.skipgram_size >= 2) and (rand() < self.skipgram_prb): 234 | rand_skipgram_size = min( 235 | randint(2, self.skipgram_size), len(span_list)-i_span) 236 | for n in range(2, rand_skipgram_size+1): 237 | tail_st, tail_end = span_list[i_span+n-1] 238 | if (tail_end-tail_st == 1) and (tail_st in special_pos): 239 | break 240 | if len(masked_pos)+tail_end-cand_st > n_pred: 241 | break 242 | n_span = n 243 | st_span, end_span = i_span, i_span + n_span 244 | 245 | if self.mask_whole_word: 246 | # pre_whole_word==False: position index of span_list is the same as tokens 247 | st_span, end_span = _expand_whole_word( 248 | tokens, st_span, end_span) 249 | 250 | # subsampling according to frequency 251 | if self.word_subsample_prb: 252 | skip_pos = set() 253 | if self.pre_whole_word: 254 | w_span_list = span_list[st_span:end_span] 255 | else: 256 | split_idx = _get_word_split_index( 257 | tokens, st_span, end_span) 258 | w_span_list = list( 259 | zip(split_idx[:-1], split_idx[1:])) 260 | for i, sp in enumerate(w_span_list): 261 | sp_st, sp_end = sp 262 | if sp_end-sp_st == 1: 263 | w_cat = tokens[sp_st] 264 | else: 265 | w_cat = ''.join(tokens[sp_st:sp_end]) 266 | if (w_cat in self.word_subsample_prb) and (rand() < self.word_subsample_prb[w_cat]): 267 | for k in range(sp_st, sp_end): 268 | skip_pos.add(k) 269 | else: 270 | skip_pos = None 271 | 272 | for sp in range(st_span, end_span): 273 | for mp in range(span_list[sp][0], span_list[sp][1]): 274 | if not(skip_pos and (mp in skip_pos)) and (mp not in special_pos) and not(protect_range and (protect_range[0] <= mp < protect_range[1])): 275 | masked_pos.add(mp) 276 | 277 | if len(masked_pos) < n_pred: 278 | shuffle(cand_pos) 279 | for pos in cand_pos: 280 | if len(masked_pos) >= n_pred: 281 | break 282 | if pos not in masked_pos: 283 | masked_pos.add(pos) 284 | masked_pos = list(masked_pos) 285 | if len(masked_pos) > n_pred: 286 | # shuffle(masked_pos) 287 | masked_pos = masked_pos[:n_pred] 288 | return masked_pos 289 | 290 | def replace_masked_tokens(self, tokens, masked_pos): 291 | if self.span_same_mask: 292 | masked_pos = sorted(list(masked_pos)) 293 | prev_pos, prev_rand = None, None 294 | for pos in masked_pos: 295 | if self.span_same_mask and (pos-1 == prev_pos): 296 | t_rand = prev_rand 297 | else: 298 | t_rand = rand() 299 | if t_rand < 0.8: # 80% 300 | tokens[pos] = '[MASK]' 301 | elif t_rand < 0.9: # 10% 302 | tokens[pos] = get_random_word(self.vocab_words) 303 | prev_pos, prev_rand = pos, t_rand 304 | -------------------------------------------------------------------------------- /biunilm/decode_seq2seq.py: -------------------------------------------------------------------------------- 1 | """BERT finetuning runner.""" 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | 6 | import os 7 | import sys 8 | import logging 9 | import glob 10 | import argparse 11 | import math 12 | from tqdm import tqdm, trange 13 | import numpy as np 14 | import torch 15 | from torch.utils.data import DataLoader, RandomSampler 16 | from torch.utils.data.distributed import DistributedSampler 17 | import random 18 | import pickle 19 | 20 | from pytorch_pretrained_bert.tokenization import BertTokenizer, WhitespaceTokenizer 21 | from pytorch_pretrained_bert.modeling import BertForSeq2SeqDecoder 22 | from pytorch_pretrained_bert.optimization import BertAdam, warmup_linear 23 | 24 | from nn.data_parallel import DataParallelImbalance 25 | import biunilm.seq2seq_loader as seq2seq_loader 26 | 27 | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', 28 | datefmt='%m/%d/%Y %H:%M:%S', 29 | level=logging.INFO) 30 | logger = logging.getLogger(__name__) 31 | 32 | 33 | def detokenize(tk_list): 34 | r_list = [] 35 | for tk in tk_list: 36 | if tk.startswith('##') and len(r_list) > 0: 37 | r_list[-1] = r_list[-1] + tk[2:] 38 | else: 39 | r_list.append(tk) 40 | return r_list 41 | 42 | 43 | def ascii_print(text): 44 | text = text.encode("ascii", "ignore") 45 | print(text) 46 | 47 | 48 | def main(): 49 | parser = argparse.ArgumentParser() 50 | 51 | parser.add_argument('--gate', type=str, default="attn", 52 | help="gate method: [attn|gate|gate_x2] ") 53 | 54 | parser.add_argument('--n_clayer', type=int, required=True, 55 | help="n conditional layer") 56 | 57 | parser.add_argument("--data_dir", 58 | default=None, 59 | type=str, 60 | required=True, 61 | help="The input data dir. Should contain the .tsv files (or other data files) for the task.") 62 | 63 | parser.add_argument("--input_file", type=str, help="Input file", required=True) 64 | 65 | parser.add_argument("--output_file", type=str, required=True, help="output file") 66 | 67 | # Required parameters 68 | parser.add_argument("--bert_model", default=None, type=str, required=True, 69 | help="Bert pre-trained model selected in the list: bert-base-uncased, " 70 | "bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese.") 71 | parser.add_argument("--model_recover_path", default=None, type=str, 72 | help="The file of fine-tuned pretraining model.") 73 | parser.add_argument("--max_seq_length", default=512, type=int, 74 | help="The maximum total input sequence length after WordPiece tokenization. \n" 75 | "Sequences longer than this will be truncated, and sequences shorter \n" 76 | "than this will be padded.") 77 | parser.add_argument('--ffn_type', default=0, type=int, 78 | help="0: default mlp; 1: W((Wx+b) elem_prod x);") 79 | parser.add_argument('--num_qkv', default=0, type=int, 80 | help="Number of different .") 81 | parser.add_argument('--seg_emb', action='store_true', 82 | help="Using segment embedding for self-attention.") 83 | 84 | # decoding parameters 85 | parser.add_argument('--fp16', action='store_true', 86 | help="Whether to use 16-bit float precision instead of 32-bit") 87 | parser.add_argument('--amp', action='store_true', 88 | help="Whether to use amp for fp16") 89 | 90 | parser.add_argument('--subset', type=int, default=0, 91 | help="Decode a subset of the input dataset.") 92 | parser.add_argument("--split", type=str, default="", 93 | help="Data split (train/val/test).") 94 | parser.add_argument('--tokenized_input', action='store_true', 95 | help="Whether the input is tokenized.") 96 | parser.add_argument('--seed', type=int, default=123, 97 | help="random seed for initialization") 98 | parser.add_argument("--do_lower_case", action='store_true', 99 | help="Set this flag if you are using an uncased model.") 100 | parser.add_argument('--new_segment_ids', action='store_true', 101 | help="Use new segment ids for bi-uni-directional LM.") 102 | parser.add_argument('--new_pos_ids', action='store_true', 103 | help="Use new position ids for LMs.") 104 | parser.add_argument('--batch_size', type=int, default=4, 105 | help="Batch size for decoding.") 106 | parser.add_argument('--beam_size', type=int, default=1, 107 | help="Beam size for searching") 108 | parser.add_argument('--length_penalty', type=float, default=0, 109 | help="Length penalty for beam search") 110 | 111 | parser.add_argument('--forbid_duplicate_ngrams', action='store_true') 112 | parser.add_argument('--forbid_ignore_word', type=str, default=None, 113 | help="Ignore the word during forbid_duplicate_ngrams") 114 | parser.add_argument("--min_len", default=None, type=int) 115 | parser.add_argument('--need_score_traces', action='store_true') 116 | parser.add_argument('--ngram_size', type=int, default=3) 117 | parser.add_argument('--mode', default="s2s", 118 | choices=["s2s", "l2r", "both"]) 119 | parser.add_argument('--max_tgt_length', type=int, default=128, 120 | help="maximum length of target sequence") 121 | parser.add_argument('--s2s_special_token', action='store_true', 122 | help="New special tokens ([S2S_SEP]/[S2S_CLS]) of S2S.") 123 | parser.add_argument('--s2s_add_segment', action='store_true', 124 | help="Additional segmental for the encoder of S2S.") 125 | parser.add_argument('--s2s_share_segment', action='store_true', 126 | help="Sharing segment embeddings for the encoder of S2S (used with --s2s_add_segment).") 127 | parser.add_argument('--pos_shift', action='store_true', 128 | help="Using position shift for fine-tuning.") 129 | parser.add_argument('--not_predict_token', type=str, default=None, 130 | help="Do not predict the tokens during decoding.") 131 | 132 | args = parser.parse_args() 133 | 134 | if args.need_score_traces and args.beam_size <= 1: 135 | raise ValueError( 136 | "Score trace is only available for beam search with beam size > 1.") 137 | if args.max_tgt_length >= args.max_seq_length - 2: 138 | raise ValueError("Maximum tgt length exceeds max seq length - 2.") 139 | 140 | device = torch.device( 141 | "cuda" if torch.cuda.is_available() else "cpu") 142 | n_gpu = torch.cuda.device_count() 143 | 144 | random.seed(args.seed) 145 | np.random.seed(args.seed) 146 | torch.manual_seed(args.seed) 147 | if n_gpu > 0: 148 | torch.cuda.manual_seed_all(args.seed) 149 | 150 | tokenizer = BertTokenizer.from_pretrained( 151 | args.bert_model, do_lower_case=args.do_lower_case) 152 | 153 | tokenizer.max_len = args.max_seq_length 154 | 155 | c_indexer = torch.load(os.path.join(args.data_dir, 'c_indexer.pt')) 156 | logger.info("{:} conditions.".format(len(c_indexer))) 157 | 158 | pair_num_relation = 0 159 | bi_uni_pipeline = [ 160 | seq2seq_loader.Preprocess4Decoder(list(tokenizer.vocab.keys()), tokenizer.convert_tokens_to_ids, 161 | args.max_seq_length, max_tgt_length=args.max_tgt_length, 162 | new_segment_ids=args.new_segment_ids, 163 | mode="s2s", num_qkv=args.num_qkv, 164 | s2s_special_token=args.s2s_special_token, 165 | s2s_add_segment=args.s2s_add_segment, 166 | s2s_share_segment=args.s2s_share_segment, pos_shift=args.pos_shift, 167 | c_indexer=c_indexer)] 168 | 169 | logger.info("### Some c_indexer ###") 170 | tmp = sorted(bi_uni_pipeline[0].c_indexer.items(), key=lambda p: p[1]) 171 | print(tmp[:10]) 172 | sys.stdout.flush() 173 | 174 | amp_handle = None 175 | if args.fp16 and args.amp: 176 | raise NotImplementedError 177 | # from apex import amp 178 | # amp_handle = amp.init(enable_caching=True) 179 | # logger.info("enable fp16 with amp") 180 | 181 | # Prepare model 182 | cls_num_labels = 2 183 | type_vocab_size = 2 184 | mask_word_id, eos_word_ids, sos_word_id = tokenizer.convert_tokens_to_ids( 185 | ["[MASK]", "[SEP]", "[S2S_SOS]"]) 186 | 187 | def _get_token_id_set(s): 188 | r = None 189 | if s: 190 | w_list = [] 191 | for w in s.split('|'): 192 | if w.startswith('[') and w.endswith(']'): 193 | w_list.append(w.upper()) 194 | else: 195 | w_list.append(w) 196 | r = set(tokenizer.convert_tokens_to_ids(w_list)) 197 | return r 198 | 199 | forbid_ignore_set = _get_token_id_set(args.forbid_ignore_word) 200 | not_predict_set = _get_token_id_set(args.not_predict_token) 201 | print(args.model_recover_path) 202 | for model_recover_path in glob.glob(args.model_recover_path.strip()): 203 | logger.info("***** Recover model: %s *****", model_recover_path) 204 | model_recover = torch.load(model_recover_path) 205 | 206 | if '' not in c_indexer.keys(): 207 | n_condition = len(c_indexer) + 1 208 | else: 209 | n_condition = len(c_indexer) 210 | assert c_indexer[''] == 0 # Check 211 | 212 | n_dial = 10 # Fake 213 | model = BertForSeq2SeqDecoder.from_pretrained(args.bert_model, state_dict=model_recover, 214 | num_labels=cls_num_labels, num_rel=pair_num_relation, 215 | type_vocab_size=type_vocab_size, task_idx=3, 216 | mask_word_id=mask_word_id, search_beam_size=args.beam_size, 217 | length_penalty=args.length_penalty, eos_id=eos_word_ids, 218 | sos_id=sos_word_id, 219 | forbid_duplicate_ngrams=args.forbid_duplicate_ngrams, 220 | forbid_ignore_set=forbid_ignore_set, 221 | not_predict_set=not_predict_set, ngram_size=args.ngram_size, 222 | min_len=args.min_len, mode=args.mode, 223 | max_position_embeddings=args.max_seq_length, 224 | ffn_type=args.ffn_type, num_qkv=args.num_qkv, 225 | seg_emb=args.seg_emb, pos_shift=args.pos_shift, 226 | n_condition=n_condition, n_dial=n_dial, n_clayer=args.n_clayer, gate=args.gate) 227 | 228 | del model_recover 229 | 230 | model.to(device) 231 | if n_gpu > 1: 232 | model = torch.nn.DataParallel(model) 233 | 234 | torch.cuda.empty_cache() 235 | model.eval() 236 | next_i = 0 237 | max_src_length = args.max_seq_length - 2 - args.max_tgt_length 238 | 239 | with open(os.path.join(args.data_dir, args.input_file), encoding="utf-8") as fin: 240 | # *ZY* 241 | input_lines = [line.strip().split('\t')[:2] for line in fin.readlines()] 242 | if args.subset > 0: 243 | logger.info("Decoding subset: %d", args.subset) 244 | input_lines = input_lines[:args.subset] 245 | 246 | data_tokenizer = WhitespaceTokenizer() if args.tokenized_input else tokenizer 247 | 248 | input_lines = [[data_tokenizer.tokenize( 249 | src)[:max_src_length], uid] for src, uid in input_lines] 250 | 251 | input_lines = sorted(list(enumerate(input_lines)), 252 | key=lambda x: -len(x[1][0])) 253 | 254 | output_lines = [""] * len(input_lines) 255 | score_trace_list = [None] * len(input_lines) 256 | total_batch = math.ceil(len(input_lines) / args.batch_size) 257 | 258 | with tqdm(total=total_batch) as pbar: 259 | while next_i < len(input_lines): 260 | _chunk = input_lines[next_i:next_i + args.batch_size] 261 | buf_id = [x[0] for x in _chunk] 262 | buf = [x[1] for x in _chunk] 263 | next_i += args.batch_size 264 | max_a_len = max([len(x[0]) for x in buf]) 265 | instances = [] 266 | for instance in [(x[0], x[1], max_a_len) for x in buf]: 267 | for proc in bi_uni_pipeline: 268 | instances.append(proc(instance)) 269 | 270 | with torch.no_grad(): 271 | batch = seq2seq_loader.batch_list_to_batch_tensors( 272 | instances) 273 | batch = [ 274 | t.to(device) if t is not None else None for t in batch] 275 | 276 | input_ids, usrid_ids, token_type_ids, position_ids, input_mask, mask_qkv, task_idx = batch 277 | traces = model(input_ids, usrid_ids, token_type_ids, position_ids, input_mask, 278 | task_idx=task_idx, mask_qkv=mask_qkv) 279 | 280 | if args.beam_size > 1: 281 | traces = {k: v.tolist() for k, v in traces.items()} 282 | output_ids = traces['pred_seq'] 283 | # print(output_ids) # Debug 284 | else: 285 | output_ids = traces.tolist() 286 | for i in range(len(buf)): 287 | w_ids = output_ids[i] 288 | output_buf = tokenizer.convert_ids_to_tokens(w_ids) 289 | output_tokens = [] 290 | for t in output_buf: 291 | if t in ("[SEP]", "[PAD]"): 292 | break 293 | output_tokens.append(t) 294 | output_sequence = ' '.join(detokenize(output_tokens)) 295 | output_lines[buf_id[i]] = output_sequence 296 | if args.need_score_traces: 297 | score_trace_list[buf_id[i]] = { 298 | 'scores': traces['scores'][i], 'wids': traces['wids'][i], 'ptrs': traces['ptrs'][i]} 299 | pbar.update(1) 300 | 301 | if args.output_file: 302 | fn_out = args.output_file 303 | else: 304 | fn_out = model_recover_path + '.' + args.split 305 | 306 | len_list = [] 307 | with open(fn_out, "w", encoding="utf-8") as fout: 308 | for l in output_lines: 309 | fout.write(l) 310 | fout.write("\n") 311 | 312 | len_list.append(len(l.strip().split(' '))) 313 | 314 | print("### average len: {:}".format(np.mean(len_list))) 315 | 316 | if args.need_score_traces: 317 | with open(fn_out + ".trace.pickle", "wb") as fout_trace: 318 | pickle.dump( 319 | {"version": 0.0, "num_samples": len(input_lines)}, fout_trace) 320 | for x in score_trace_list: 321 | pickle.dump(x, fout_trace) 322 | 323 | 324 | if __name__ == "__main__": 325 | main() 326 | -------------------------------------------------------------------------------- /pytorch_pretrained_bert/tokenization.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Tokenization classes.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import collections 22 | import unicodedata 23 | import os 24 | import logging 25 | 26 | from .file_utils import cached_path 27 | 28 | logger = logging.getLogger(__name__) 29 | 30 | PRETRAINED_VOCAB_ARCHIVE_MAP = { 31 | 'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt", 32 | 'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-vocab.txt", 33 | 'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-vocab.txt", 34 | 'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-vocab.txt", 35 | 'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-vocab.txt", 36 | 'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-vocab.txt", 37 | 'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-vocab.txt", 38 | } 39 | PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = { 40 | 'bert-base-uncased': 512, 41 | 'bert-large-uncased': 512, 42 | 'bert-base-cased': 512, 43 | 'bert-large-cased': 512, 44 | 'bert-base-multilingual-uncased': 512, 45 | 'bert-base-multilingual-cased': 512, 46 | 'bert-base-chinese': 512, 47 | } 48 | VOCAB_NAME = 'vocab.txt' 49 | 50 | 51 | def load_vocab(vocab_file): 52 | """Loads a vocabulary file into a dictionary.""" 53 | # mapping unused tokens to special tokens 54 | extra_map = {} 55 | extra_map['[unused1]'] = '[X_SEP]' 56 | for i in range(10): 57 | extra_map['[unused{}]'.format(i+2)] = '[SEP_{}]'.format(i) 58 | extra_map['[unused12]'] = '[S2S_SEP]' 59 | extra_map['[unused13]'] = '[S2S_CLS]' 60 | extra_map['[unused14]'] = '[L2R_SEP]' 61 | extra_map['[unused15]'] = '[L2R_CLS]' 62 | extra_map['[unused16]'] = '[R2L_SEP]' 63 | extra_map['[unused17]'] = '[R2L_CLS]' 64 | extra_map['[unused18]'] = '[S2S_SOS]' 65 | 66 | vocab = collections.OrderedDict() 67 | index = 0 68 | with open(vocab_file, "r", encoding="utf-8") as reader: 69 | while True: 70 | token = reader.readline() 71 | if not token: 72 | break 73 | token = token.strip() 74 | if token in extra_map: 75 | token = extra_map[token] 76 | vocab[token] = index 77 | index += 1 78 | return vocab 79 | 80 | 81 | def whitespace_tokenize(text): 82 | """Runs basic whitespace cleaning and splitting on a peice of text.""" 83 | text = text.strip() 84 | if not text: 85 | return [] 86 | tokens = text.split() 87 | return tokens 88 | 89 | 90 | class BertTokenizer(object): 91 | """Runs end-to-end tokenization: punctuation splitting + wordpiece""" 92 | 93 | def __init__(self, vocab_file, do_lower_case=True, max_len=None, never_split=("[UNK]", "[SEP]", "[X_SEP]", "[PAD]", "[CLS]", "[MASK]")): 94 | if not os.path.isfile(vocab_file): 95 | raise ValueError( 96 | "Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained " 97 | "model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`".format(vocab_file)) 98 | self.vocab = load_vocab(vocab_file) 99 | # *ZY* 100 | self.my_special = list(never_split) + ['', '', '', '', '[SRC_SEP]'] 101 | self.vocab[''] = 20 102 | self.vocab[''] = 21 103 | self.vocab[''] = 22 104 | self.vocab[''] = 23 105 | self.vocab['[SRC_SEP]'] = 24 106 | 107 | del self.vocab['[unused19]'] 108 | del self.vocab['[unused20]'] 109 | del self.vocab['[unused21]'] 110 | del self.vocab['[unused22]'] 111 | del self.vocab['[unused23]'] 112 | 113 | self.ids_to_tokens = collections.OrderedDict( 114 | [(ids, tok) for tok, ids in self.vocab.items()]) 115 | self.basic_tokenizer = BasicTokenizer( 116 | do_lower_case=do_lower_case, never_split=self.my_special) # *ZY* 117 | self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab) 118 | self.max_len = max_len if max_len is not None else int(1e12) 119 | 120 | def tokenize(self, text): 121 | split_tokens = [] 122 | for token in self.basic_tokenizer.tokenize(text): 123 | for sub_token in self.wordpiece_tokenizer.tokenize(token): 124 | split_tokens.append(sub_token) 125 | return split_tokens 126 | 127 | def convert_tokens_to_ids(self, tokens): 128 | """Converts a sequence of tokens into ids using the vocab.""" 129 | ids = [] 130 | for token in tokens: 131 | ids.append(self.vocab[token]) 132 | if len(ids) > self.max_len: 133 | raise ValueError( 134 | "Token indices sequence length is longer than the specified maximum " 135 | " sequence length for this BERT model ({} > {}). Running this" 136 | " sequence through BERT will result in indexing errors".format( 137 | len(ids), self.max_len) 138 | ) 139 | return ids 140 | 141 | def convert_tokens_to_ids_FGfree(self, tokens, ret_ids_only=True): 142 | """Converts a sequence of tokens into ids using the vocab.""" 143 | ids = [] 144 | position_ids = [] 145 | mask_pos_idx_map = {} 146 | idx_counter = 0 147 | for pos, token in enumerate(tokens): 148 | if isinstance(token, str): 149 | ids.append(self.vocab[token]) 150 | position_ids.append(pos) 151 | idx_counter += 1 152 | else: # TODO: masked tokens -- here is tuple 153 | mask_pos_idx_map[pos] = idx_counter 154 | assert len(token) == 2 155 | for t in token: # here is tuple, and the first position is [MASK] 156 | ids.append(self.vocab[t]) 157 | position_ids.append(pos) 158 | idx_counter += 1 159 | 160 | if len(ids) > self.max_len: 161 | raise ValueError( 162 | "Token indices sequence length is longer than the specified maximum " 163 | " sequence length for this BERT model ({} > {}). Running this" 164 | " sequence through BERT will result in indexing errors".format( 165 | len(ids), self.max_len) 166 | ) 167 | if ret_ids_only: 168 | return ids 169 | 170 | return ids, position_ids, mask_pos_idx_map 171 | 172 | def convert_ids_to_tokens(self, ids): 173 | """Converts a sequence of ids in wordpiece tokens using the vocab.""" 174 | tokens = [] 175 | for i in ids: 176 | tokens.append(self.ids_to_tokens[i]) 177 | return tokens 178 | 179 | @classmethod 180 | def from_pretrained(cls, pretrained_model_name, cache_dir=None, *inputs, **kwargs): 181 | """ 182 | Instantiate a PreTrainedBertModel from a pre-trained model file. 183 | Download and cache the pre-trained model file if needed. 184 | """ 185 | if pretrained_model_name in PRETRAINED_VOCAB_ARCHIVE_MAP: 186 | vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name] 187 | else: 188 | vocab_file = pretrained_model_name 189 | if os.path.isdir(vocab_file): 190 | vocab_file = os.path.join(vocab_file, VOCAB_NAME) 191 | # redirect to the cache, if necessary 192 | try: 193 | resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir) 194 | except FileNotFoundError: 195 | logger.error( 196 | "Model name '{}' was not found in model name list ({}). " 197 | "We assumed '{}' was a path or url but couldn't find any file " 198 | "associated to this path or url.".format( 199 | pretrained_model_name, 200 | ', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()), 201 | vocab_file)) 202 | return None 203 | if resolved_vocab_file == vocab_file: 204 | logger.info("loading vocabulary file {}".format(vocab_file)) 205 | else: 206 | logger.info("loading vocabulary file {} from cache at {}".format( 207 | vocab_file, resolved_vocab_file)) 208 | if pretrained_model_name in PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP: 209 | # if we're using a pretrained model, ensure the tokenizer wont index sequences longer 210 | # than the number of positional embeddings 211 | max_len = PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP[pretrained_model_name] 212 | kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len) 213 | # Instantiate tokenizer. 214 | tokenizer = cls(resolved_vocab_file, *inputs, **kwargs) 215 | return tokenizer 216 | 217 | 218 | class WhitespaceTokenizer(object): 219 | def tokenize(self, text): 220 | return whitespace_tokenize(text) 221 | 222 | 223 | class BasicTokenizer(object): 224 | """Runs basic tokenization (punctuation splitting, lower casing, etc.).""" 225 | 226 | def __init__(self, do_lower_case=True, never_split=("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]")): 227 | """Constructs a BasicTokenizer. 228 | 229 | Args: 230 | do_lower_case: Whether to lower case the input. 231 | """ 232 | self.do_lower_case = do_lower_case 233 | self.never_split = never_split 234 | 235 | def tokenize(self, text): 236 | """Tokenizes a piece of text.""" 237 | text = self._clean_text(text) 238 | # This was added on November 1st, 2018 for the multilingual and Chinese 239 | # models. This is also applied to the English models now, but it doesn't 240 | # matter since the English models were not trained on any Chinese data 241 | # and generally don't have any Chinese data in them (there are Chinese 242 | # characters in the vocabulary because Wikipedia does have some Chinese 243 | # words in the English Wikipedia.). 244 | text = self._tokenize_chinese_chars(text) 245 | orig_tokens = whitespace_tokenize(text) 246 | split_tokens = [] 247 | for token in orig_tokens: 248 | if self.do_lower_case and token not in self.never_split: 249 | token = token.lower() 250 | token = self._run_strip_accents(token) 251 | split_tokens.extend(self._run_split_on_punc(token)) 252 | 253 | output_tokens = whitespace_tokenize(" ".join(split_tokens)) 254 | return output_tokens 255 | 256 | def _run_strip_accents(self, text): 257 | """Strips accents from a piece of text.""" 258 | text = unicodedata.normalize("NFD", text) 259 | output = [] 260 | for char in text: 261 | cat = unicodedata.category(char) 262 | if cat == "Mn": 263 | continue 264 | output.append(char) 265 | return "".join(output) 266 | 267 | def _run_split_on_punc(self, text): 268 | """Splits punctuation on a piece of text.""" 269 | if text in self.never_split: 270 | return [text] 271 | chars = list(text) 272 | i = 0 273 | start_new_word = True 274 | output = [] 275 | while i < len(chars): 276 | char = chars[i] 277 | if _is_punctuation(char): 278 | output.append([char]) 279 | start_new_word = True 280 | else: 281 | if start_new_word: 282 | output.append([]) 283 | start_new_word = False 284 | output[-1].append(char) 285 | i += 1 286 | 287 | return ["".join(x) for x in output] 288 | 289 | def _tokenize_chinese_chars(self, text): 290 | """Adds whitespace around any CJK character.""" 291 | output = [] 292 | for char in text: 293 | cp = ord(char) 294 | if self._is_chinese_char(cp): 295 | output.append(" ") 296 | output.append(char) 297 | output.append(" ") 298 | else: 299 | output.append(char) 300 | return "".join(output) 301 | 302 | def _is_chinese_char(self, cp): 303 | """Checks whether CP is the codepoint of a CJK character.""" 304 | # This defines a "chinese character" as anything in the CJK Unicode block: 305 | # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) 306 | # 307 | # Note that the CJK Unicode block is NOT all Japanese and Korean characters, 308 | # despite its name. The modern Korean Hangul alphabet is a different block, 309 | # as is Japanese Hiragana and Katakana. Those alphabets are used to write 310 | # space-separated words, so they are not treated specially and handled 311 | # like the all of the other languages. 312 | if ((cp >= 0x4E00 and cp <= 0x9FFF) or # 313 | (cp >= 0x3400 and cp <= 0x4DBF) or # 314 | (cp >= 0x20000 and cp <= 0x2A6DF) or # 315 | (cp >= 0x2A700 and cp <= 0x2B73F) or # 316 | (cp >= 0x2B740 and cp <= 0x2B81F) or # 317 | (cp >= 0x2B820 and cp <= 0x2CEAF) or 318 | (cp >= 0xF900 and cp <= 0xFAFF) or # 319 | (cp >= 0x2F800 and cp <= 0x2FA1F)): # 320 | return True 321 | 322 | return False 323 | 324 | def _clean_text(self, text): 325 | """Performs invalid character removal and whitespace cleanup on text.""" 326 | output = [] 327 | for char in text: 328 | cp = ord(char) 329 | if cp == 0 or cp == 0xfffd or _is_control(char): 330 | continue 331 | if _is_whitespace(char): 332 | output.append(" ") 333 | else: 334 | output.append(char) 335 | return "".join(output) 336 | 337 | 338 | class WordpieceTokenizer(object): 339 | """Runs WordPiece tokenization.""" 340 | 341 | def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=100): 342 | self.vocab = vocab 343 | self.unk_token = unk_token 344 | self.max_input_chars_per_word = max_input_chars_per_word 345 | 346 | def tokenize(self, text): 347 | """Tokenizes a piece of text into its word pieces. 348 | 349 | This uses a greedy longest-match-first algorithm to perform tokenization 350 | using the given vocabulary. 351 | 352 | For example: 353 | input = "unaffable" 354 | output = ["un", "##aff", "##able"] 355 | 356 | Args: 357 | text: A single token or whitespace separated tokens. This should have 358 | already been passed through `BasicTokenizer`. 359 | 360 | Returns: 361 | A list of wordpiece tokens. 362 | """ 363 | 364 | output_tokens = [] 365 | for token in whitespace_tokenize(text): 366 | chars = list(token) 367 | if len(chars) > self.max_input_chars_per_word: 368 | output_tokens.append(self.unk_token) 369 | continue 370 | 371 | is_bad = False 372 | start = 0 373 | sub_tokens = [] 374 | while start < len(chars): 375 | end = len(chars) 376 | cur_substr = None 377 | while start < end: 378 | substr = "".join(chars[start:end]) 379 | if start > 0: 380 | substr = "##" + substr 381 | if substr in self.vocab: 382 | cur_substr = substr 383 | break 384 | end -= 1 385 | if cur_substr is None: 386 | is_bad = True 387 | break 388 | sub_tokens.append(cur_substr) 389 | start = end 390 | 391 | if is_bad: 392 | output_tokens.append(self.unk_token) 393 | else: 394 | output_tokens.extend(sub_tokens) 395 | return output_tokens 396 | 397 | 398 | def _is_whitespace(char): 399 | """Checks whether `chars` is a whitespace character.""" 400 | # \t, \n, and \r are technically contorl characters but we treat them 401 | # as whitespace since they are generally considered as such. 402 | if char == " " or char == "\t" or char == "\n" or char == "\r": 403 | return True 404 | cat = unicodedata.category(char) 405 | if cat == "Zs": 406 | return True 407 | return False 408 | 409 | 410 | def _is_control(char): 411 | """Checks whether `chars` is a control character.""" 412 | # These are technically control characters but we count them as whitespace 413 | # characters. 414 | if char == "\t" or char == "\n" or char == "\r": 415 | return False 416 | cat = unicodedata.category(char) 417 | if cat.startswith("C"): 418 | return True 419 | return False 420 | 421 | 422 | def _is_punctuation(char): 423 | """Checks whether `chars` is a punctuation character.""" 424 | cp = ord(char) 425 | # We treat all non-letter/number ASCII as punctuation. 426 | # Characters such as "^", "$", and "`" are not in the Unicode 427 | # Punctuation class but we treat them as punctuation anyways, for 428 | # consistency. 429 | if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or 430 | (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)): 431 | return True 432 | cat = unicodedata.category(char) 433 | if cat.startswith("P"): 434 | return True 435 | return False 436 | -------------------------------------------------------------------------------- /pytorch_pretrained_bert/optimization.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """PyTorch optimization for BERT model.""" 16 | 17 | import math 18 | import torch 19 | from torch.optim import Optimizer 20 | from torch.optim.optimizer import required 21 | from torch.nn.utils import clip_grad_norm_ 22 | 23 | from collections import defaultdict 24 | from torch._six import container_abcs 25 | from copy import deepcopy 26 | from itertools import chain 27 | 28 | 29 | def warmup_cosine(x, warmup=0.002): 30 | if x < warmup: 31 | return x/warmup 32 | return 0.5 * (1.0 + torch.cos(math.pi * x)) 33 | 34 | 35 | def warmup_constant(x, warmup=0.002): 36 | if x < warmup: 37 | return x/warmup 38 | return 1.0 39 | 40 | 41 | def warmup_linear(x, warmup=0.002): 42 | if x < warmup: 43 | return x/warmup 44 | return max((x-1.)/(warmup-1.), 0) 45 | 46 | 47 | SCHEDULES = { 48 | 'warmup_cosine': warmup_cosine, 49 | 'warmup_constant': warmup_constant, 50 | 'warmup_linear': warmup_linear, 51 | } 52 | 53 | 54 | class BertAdam(Optimizer): 55 | """Implements BERT version of Adam algorithm with weight decay fix. 56 | Params: 57 | lr: learning rate 58 | warmup: portion of t_total for the warmup, -1 means no warmup. Default: -1 59 | t_total: total number of training steps for the learning 60 | rate schedule, -1 means constant learning rate. Default: -1 61 | schedule: schedule to use for the warmup (see above). Default: 'warmup_linear' 62 | b1: Adams b1. Default: 0.9 63 | b2: Adams b2. Default: 0.999 64 | e: Adams epsilon. Default: 1e-6 65 | weight_decay: Weight decay. Default: 0.01 66 | max_grad_norm: Maximum norm for the gradients (-1 means no clipping). Default: 1.0 67 | """ 68 | 69 | def __init__(self, params, lr=required, warmup=-1, t_total=-1, schedule='warmup_linear', b1=0.9, b2=0.999, e=1e-6, weight_decay=0.01, max_grad_norm=1.0): 70 | if lr is not required and lr < 0.0: 71 | raise ValueError( 72 | "Invalid learning rate: {} - should be >= 0.0".format(lr)) 73 | if schedule not in SCHEDULES: 74 | raise ValueError("Invalid schedule parameter: {}".format(schedule)) 75 | if not 0.0 <= warmup < 1.0 and not warmup == -1: 76 | raise ValueError( 77 | "Invalid warmup: {} - should be in [0.0, 1.0[ or -1".format(warmup)) 78 | if not 0.0 <= b1 < 1.0: 79 | raise ValueError( 80 | "Invalid b1 parameter: {} - should be in [0.0, 1.0[".format(b1)) 81 | if not 0.0 <= b2 < 1.0: 82 | raise ValueError( 83 | "Invalid b2 parameter: {} - should be in [0.0, 1.0[".format(b2)) 84 | if not e >= 0.0: 85 | raise ValueError( 86 | "Invalid epsilon value: {} - should be >= 0.0".format(e)) 87 | defaults = dict(lr=lr, schedule=schedule, warmup=warmup, t_total=t_total, 88 | b1=b1, b2=b2, e=e, weight_decay=weight_decay, 89 | max_grad_norm=max_grad_norm) 90 | super(BertAdam, self).__init__(params, defaults) 91 | 92 | def get_lr(self): 93 | lr = [] 94 | for group in self.param_groups: 95 | for p in group['params']: 96 | state = self.state[p] 97 | if len(state) == 0: 98 | return [0] 99 | if group['t_total'] != -1: 100 | schedule_fct = SCHEDULES[group['schedule']] 101 | lr_scheduled = group['lr'] * schedule_fct( 102 | state['step']/group['t_total'], group['warmup']) 103 | else: 104 | lr_scheduled = group['lr'] 105 | lr.append(lr_scheduled) 106 | return lr 107 | 108 | def step(self, closure=None): 109 | """Performs a single optimization step. 110 | 111 | Arguments: 112 | closure (callable, optional): A closure that reevaluates the model 113 | and returns the loss. 114 | """ 115 | loss = None 116 | if closure is not None: 117 | loss = closure() 118 | 119 | for group in self.param_groups: 120 | for p in group['params']: 121 | if p.grad is None: 122 | continue 123 | grad = p.grad.data 124 | if grad.is_sparse: 125 | raise RuntimeError( 126 | 'Adam does not support sparse gradients, please consider SparseAdam instead') 127 | 128 | state = self.state[p] 129 | 130 | # State initialization 131 | if len(state) == 0: 132 | state['step'] = 0 133 | # Exponential moving average of gradient values 134 | state['next_m'] = torch.zeros_like(p.data) 135 | # Exponential moving average of squared gradient values 136 | state['next_v'] = torch.zeros_like(p.data) 137 | 138 | next_m, next_v = state['next_m'], state['next_v'] 139 | beta1, beta2 = group['b1'], group['b2'] 140 | 141 | # Add grad clipping 142 | if group['max_grad_norm'] > 0: 143 | clip_grad_norm_(p, group['max_grad_norm']) 144 | 145 | # Decay the first and second moment running average coefficient 146 | # In-place operations to update the averages at the same time 147 | next_m.mul_(beta1).add_(1 - beta1, grad) 148 | next_v.mul_(beta2).addcmul_(1 - beta2, grad, grad) 149 | update = next_m / (next_v.sqrt() + group['e']) 150 | 151 | # Just adding the square of the weights to the loss function is *not* 152 | # the correct way of using L2 regularization/weight decay with Adam, 153 | # since that will interact with the m and v parameters in strange ways. 154 | # 155 | # Instead we want to decay the weights in a manner that doesn't interact 156 | # with the m/v parameters. This is equivalent to adding the square 157 | # of the weights to the loss with plain (non-momentum) SGD. 158 | if group['weight_decay'] > 0.0: 159 | update += group['weight_decay'] * p.data 160 | 161 | if group['t_total'] != -1: 162 | schedule_fct = SCHEDULES[group['schedule']] 163 | lr_scheduled = group['lr'] * schedule_fct( 164 | state['step']/group['t_total'], group['warmup']) 165 | else: 166 | lr_scheduled = group['lr'] 167 | 168 | update_with_lr = lr_scheduled * update 169 | p.data.add_(-update_with_lr) 170 | 171 | state['step'] += 1 172 | 173 | # step_size = lr_scheduled * math.sqrt(bias_correction2) / bias_correction1 174 | # No bias correction 175 | # bias_correction1 = 1 - beta1 ** state['step'] 176 | # bias_correction2 = 1 - beta2 ** state['step'] 177 | 178 | return loss 179 | 180 | 181 | class BertAdamFineTune(BertAdam): 182 | def __init__(self, params, lr=required, warmup=-1, t_total=-1, schedule='warmup_linear', b1=0.9, b2=0.999, e=1e-6, weight_decay=0.01, max_grad_norm=1.0): 183 | self.init_param_group = [] 184 | super(BertAdamFineTune, self).__init__(params, lr, warmup, 185 | t_total, schedule, b1, b2, e, weight_decay, max_grad_norm) 186 | 187 | def save_init_param_group(self, param_groups, name_groups, missing_keys): 188 | self.init_param_group = [] 189 | for group, name in zip(param_groups, name_groups): 190 | if group['weight_decay'] > 0.0: 191 | init_p_list = [] 192 | for p, n in zip(group['params'], name): 193 | init_p = p.data.clone().detach() 194 | if any(mk in n for mk in missing_keys): 195 | print("[no finetuning weight decay]", n) 196 | # should use the original weight decay 197 | init_p.zero_() 198 | init_p_list.append(init_p) 199 | self.init_param_group.append(init_p_list) 200 | else: 201 | # placeholder 202 | self.init_param_group.append([]) 203 | 204 | def step(self, closure=None): 205 | """Performs a single optimization step. 206 | 207 | Arguments: 208 | closure (callable, optional): A closure that reevaluates the model 209 | and returns the loss. 210 | """ 211 | loss = None 212 | if closure is not None: 213 | loss = closure() 214 | 215 | for i_group, group in enumerate(self.param_groups): 216 | for i_p, p in enumerate(group['params']): 217 | if p.grad is None: 218 | continue 219 | grad = p.grad.data 220 | if grad.is_sparse: 221 | raise RuntimeError( 222 | 'Adam does not support sparse gradients, please consider SparseAdam instead') 223 | 224 | state = self.state[p] 225 | 226 | # State initialization 227 | if len(state) == 0: 228 | state['step'] = 0 229 | # Exponential moving average of gradient values 230 | state['next_m'] = torch.zeros_like(p.data) 231 | # Exponential moving average of squared gradient values 232 | state['next_v'] = torch.zeros_like(p.data) 233 | 234 | next_m, next_v = state['next_m'], state['next_v'] 235 | beta1, beta2 = group['b1'], group['b2'] 236 | 237 | # Add grad clipping 238 | if group['max_grad_norm'] > 0: 239 | clip_grad_norm_(p, group['max_grad_norm']) 240 | 241 | # Decay the first and second moment running average coefficient 242 | # In-place operations to update the averages at the same time 243 | next_m.mul_(beta1).add_(1 - beta1, grad) 244 | next_v.mul_(beta2).addcmul_(1 - beta2, grad, grad) 245 | update = next_m / (next_v.sqrt() + group['e']) 246 | 247 | # Just adding the square of the weights to the loss function is *not* 248 | # the correct way of using L2 regularization/weight decay with Adam, 249 | # since that will interact with the m and v parameters in strange ways. 250 | # 251 | # Instead we want to decay the weights in a manner that doesn't interact 252 | # with the m/v parameters. This is equivalent to adding the square 253 | # of the weights to the loss with plain (non-momentum) SGD. 254 | if group['weight_decay'] > 0.0: 255 | if self.init_param_group: 256 | update += group['weight_decay'] * \ 257 | (2.0 * p.data - 258 | self.init_param_group[i_group][i_p]) 259 | else: 260 | update += group['weight_decay'] * p.data 261 | 262 | if group['t_total'] != -1: 263 | schedule_fct = SCHEDULES[group['schedule']] 264 | lr_scheduled = group['lr'] * schedule_fct( 265 | state['step']/group['t_total'], group['warmup']) 266 | else: 267 | lr_scheduled = group['lr'] 268 | 269 | update_with_lr = lr_scheduled * update 270 | p.data.add_(-update_with_lr) 271 | 272 | state['step'] += 1 273 | 274 | # step_size = lr_scheduled * math.sqrt(bias_correction2) / bias_correction1 275 | # No bias correction 276 | # bias_correction1 = 1 - beta1 ** state['step'] 277 | # bias_correction2 = 1 - beta2 ** state['step'] 278 | 279 | return loss 280 | 281 | def load_state_dict_subset_finetune(self, state_dict, num_load_group): 282 | r"""Loads the optimizer state. 283 | 284 | Arguments: 285 | state_dict (dict): optimizer state. Should be an object returned 286 | from a call to :meth:`state_dict`. 287 | """ 288 | # deepcopy, to be consistent with module API 289 | state_dict = deepcopy(state_dict) 290 | # Validate the state_dict 291 | groups = self.param_groups 292 | saved_groups = state_dict['param_groups'] 293 | 294 | if len(groups) < num_load_group or len(saved_groups) < num_load_group: 295 | raise ValueError("loaded state dict has a different number of " 296 | "parameter groups") 297 | param_lens = (len(g['params']) for g in groups[:num_load_group]) 298 | saved_lens = (len(g['params']) for g in saved_groups[:num_load_group]) 299 | if any(p_len != s_len for p_len, s_len in zip(param_lens, saved_lens)): 300 | raise ValueError("loaded state dict contains a parameter group " 301 | "that doesn't match the size of optimizer's group") 302 | 303 | # Update the state 304 | id_map = {old_id: p for old_id, p in 305 | zip(chain(*(g['params'] for g in saved_groups[:num_load_group])), 306 | chain(*(g['params'] for g in groups[:num_load_group])))} 307 | 308 | def cast(param, value): 309 | r"""Make a deep copy of value, casting all tensors to device of param.""" 310 | if isinstance(value, torch.Tensor): 311 | # Floating-point types are a bit special here. They are the only ones 312 | # that are assumed to always match the type of params. 313 | if param.is_floating_point(): 314 | value = value.to(param.dtype) 315 | value = value.to(param.device) 316 | return value 317 | elif isinstance(value, dict): 318 | return {k: cast(param, v) for k, v in value.items()} 319 | elif isinstance(value, container_abcs.Iterable): 320 | return type(value)(cast(param, v) for v in value) 321 | else: 322 | return value 323 | 324 | # Copy state assigned to params (and cast tensors to appropriate types). 325 | # State that is not assigned to params is copied as is (needed for 326 | # backward compatibility). 327 | state = defaultdict(dict) 328 | for k, v in state_dict['state'].items(): 329 | if k in id_map: 330 | param = id_map[k] 331 | state[param] = cast(param, v) 332 | else: 333 | state[k] = v 334 | # handle additional params 335 | for k, v in self.state: 336 | if k not in state: 337 | state[k] = v 338 | 339 | # do not change groups: {'weight_decay': 0.01, 'lr': 9.995e-06, 'schedule': 'warmup_linear', 'warmup': 0.1, 't_total': 400000, 'b1': 0.9, 'b2': 0.999, 'e': 1e-06, 'max_grad_norm': 1.0, 'params': [...]} 340 | # # Update parameter groups, setting their 'params' value 341 | # def update_group(group, new_group): 342 | # new_group['params'] = group['params'] 343 | # return new_group 344 | # param_groups = [ 345 | # update_group(g, ng) for g, ng in zip(groups[:num_load_group], saved_groups[:num_load_group])] 346 | # # handle additional params 347 | # param_groups.extend(groups[num_load_group:]) 348 | 349 | self.__setstate__({'state': state, 'param_groups': groups}) 350 | 351 | 352 | def find_state_dict_subset_finetune(org_state_dict, org_name_list, no_decay, param_optimizer): 353 | # only use the bert encoder and embeddings 354 | want_name_set = set() 355 | for n in org_name_list: 356 | if ('bert.encoder' in n) or ('bert.embeddings' in n): 357 | want_name_set.add(n) 358 | # original: name to pid, pid to name 359 | org_grouped_names = [[n for n in org_name_list if not any(nd in n for nd in no_decay)], 360 | [n for n in org_name_list if any(nd in n for nd in no_decay)]] 361 | org_n2id, org_id2n = {}, {} 362 | for ng, pg in zip(org_grouped_names, org_state_dict['param_groups']): 363 | for n, pid in zip(ng, pg['params']): 364 | org_n2id[n] = pid 365 | org_id2n[pid] = n 366 | # group by: whether pretrained; whether weight decay 367 | g_np_list = [ 368 | [(n, p) for n, p in param_optimizer if n in want_name_set and not any( 369 | nd in n for nd in no_decay)], 370 | [(n, p) for n, p in param_optimizer if n in want_name_set and any( 371 | nd in n for nd in no_decay)], 372 | [(n, p) for n, p in param_optimizer if n not in want_name_set and not any( 373 | nd in n for nd in no_decay)], 374 | [(n, p) for n, p in param_optimizer if n not in want_name_set and any( 375 | nd in n for nd in no_decay)], 376 | ] 377 | optimizer_grouped_parameters = [ 378 | {'params': [p for n, p in g_np_list[0]], 'weight_decay': 0.01}, 379 | {'params': [p for n, p in g_np_list[1]], 'weight_decay': 0.0}, 380 | {'params': [p for n, p in g_np_list[2]], 'weight_decay': 0.01}, 381 | {'params': [p for n, p in g_np_list[3]], 'weight_decay': 0.0} 382 | ] 383 | new_state_dict = {} 384 | # regroup the original state_dict 385 | new_state_dict['state'] = {pid: v for pid, v in org_state_dict['state'].items( 386 | ) if pid not in org_id2n or org_id2n[pid] in want_name_set} 387 | # reset step count to 0 388 | for pid, st in new_state_dict['state'].items(): 389 | st['step'] = 0 390 | 391 | def _filter_group(group, g_np_list, i, org_n2id): 392 | packed = {k: v for k, v in group.items() if k != 'params'} 393 | packed['params'] = [pid for pid in group['params'] 394 | if pid in org_id2n and org_id2n[pid] in want_name_set] 395 | assert len(g_np_list[i]) == len(packed['params']) 396 | # keep them the same order 397 | packed['params'] = [org_n2id[n] for n, p in g_np_list[i]] 398 | return packed 399 | new_state_dict['param_groups'] = [_filter_group( 400 | g, g_np_list, i, org_n2id) for i, g in enumerate(org_state_dict['param_groups'])] 401 | return new_state_dict, optimizer_grouped_parameters 402 | -------------------------------------------------------------------------------- /biunilm/run_ppl.py: -------------------------------------------------------------------------------- 1 | """BERT finetuning runner.""" 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | 7 | import os 8 | import sys 9 | import copy 10 | import logging 11 | import glob 12 | import math 13 | import json 14 | import argparse 15 | import random 16 | import pickle 17 | from pathlib import Path 18 | from tqdm import tqdm, trange 19 | import numpy as np 20 | import pandas as pd 21 | import torch 22 | from torch.utils.data import RandomSampler 23 | from torch.utils.data.distributed import DistributedSampler 24 | 25 | from pytorch_pretrained_bert.tokenization import BertTokenizer, WhitespaceTokenizer 26 | from pytorch_pretrained_bert.modeling import BertForPreTrainingLossMask 27 | from pytorch_pretrained_bert.optimization import BertAdam, warmup_linear 28 | 29 | from nn.data_parallel import DataParallelImbalance 30 | import biunilm.seq2seq_loader as seq2seq_loader 31 | import torch.distributed as dist 32 | 33 | 34 | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', 35 | datefmt='%m/%d/%Y %H:%M:%S', 36 | level=logging.INFO) 37 | logger = logging.getLogger(__name__) 38 | 39 | 40 | def _get_max_epoch_model(output_dir): 41 | fn_model_list = glob.glob(os.path.join(output_dir, "model.*.bin")) 42 | # fn_optim_list = glob.glob(os.path.join(output_dir, "optim.*.bin")) 43 | # if (not fn_model_list) or (not fn_optim_list): 44 | # return None 45 | 46 | if not fn_model_list: 47 | return None 48 | 49 | both_set = set([int(Path(fn).stem.split('.')[-1]) for fn in fn_model_list] 50 | ) 51 | if both_set: 52 | # *ZY* 53 | global_step = str(max(both_set)) 54 | fn_model = [s for s in fn_model_list if global_step in s] 55 | assert len(fn_model) == 1 56 | # fn_optim = [s for s in fn_optim_list if global_step in s] 57 | # assert len(fn_optim) == 1 58 | 59 | tmp = Path(fn_model[0]).stem.split('.')[-2].strip().split('_') 60 | n_epoch = int(tmp[0].strip('e').strip()) 61 | n_step = int(tmp[1].strip('s').strip()) 62 | return [fn_model[0], None, int(global_step), n_epoch, n_step] 63 | else: 64 | return None 65 | 66 | 67 | def pre_preprocess(train_flag, args, data_tokenizer, bi_uni_pipeline): 68 | train_flag = 'test' 69 | 70 | # TODO: PPL 71 | dial_src = os.path.join(args.data_dir, "dial.{:}".format(train_flag)) 72 | dial_ppl_src = os.path.join(args.data_dir, "dial.{:}.ppl".format(train_flag)) 73 | if not os.path.exists(dial_ppl_src): 74 | n_write = 0 75 | with open(dial_ppl_src, 'wt') as wf: 76 | with open(dial_src, 'rt') as rf: 77 | for line in rf: 78 | src, usrid, tgt, data_type = line.strip().split('\t')[:4] 79 | elems = tgt.strip().split(' ') 80 | 81 | for idx in range(len(elems)): 82 | word = elems[idx].strip() 83 | if len(word): 84 | wf.write('\t'.join([src, usrid, ' '.join(elems[:idx+1]), data_type])+'\n') 85 | n_write += 1 86 | 87 | logger.info("Write {:} samples for perplexity calculation to {:}".format(n_write, dial_ppl_src)) 88 | else: 89 | logger.info("Read ppl test file: {:}".format(dial_ppl_src)) 90 | 91 | dataset = seq2seq_loader.MyDataset( 92 | [dial_ppl_src], args.eval_batch_size, data_tokenizer, 93 | args.max_seq_length, preprocess=bi_uni_pipeline, accept_dtypes=['dial']) 94 | 95 | return dataset 96 | 97 | 98 | def validate(model, valid_dataloader, device, n_gpu): 99 | valid_ppl = 0 100 | n_samples = 0 101 | n_tokens = 0 102 | 103 | batch_size_gpu = int(valid_dataloader.batch_size / n_gpu) 104 | 105 | iter_bar = tqdm(valid_dataloader, desc='Iter (loss=X.XXX)') 106 | 107 | with torch.no_grad(): 108 | for step, batch in enumerate(iter_bar): 109 | batch = [ 110 | t.to(device) if t is not None else None for t in batch] 111 | 112 | num_tokens_a, num_tokens_b, input_ids, usrid_ids, segment_ids, input_mask, mask_qkv, lm_label_ids, masked_pos, masked_weights, is_next, task_idx = batch 113 | oracle_pos, oracle_weights, oracle_labels = None, None, None 114 | input_mask = None 115 | assert segment_ids is not None 116 | 117 | loss_tuple = model(input_ids, usrid_ids, segment_ids, input_mask, lm_label_ids, is_next, 118 | masked_pos=masked_pos, masked_weights=masked_weights, task_idx=task_idx, 119 | num_tokens_a=num_tokens_a, num_tokens_b=num_tokens_b, 120 | masked_pos_2=oracle_pos, masked_weights_2=oracle_weights, 121 | masked_labels_2=oracle_labels, mask_qkv=mask_qkv, is_ppl_eval=True) 122 | 123 | masked_lm_loss, next_sentence_loss, ppl = loss_tuple 124 | 125 | if n_gpu > 1: # mean() to average on multi-gpu. 126 | # loss = loss.mean() 127 | masked_lm_loss = masked_lm_loss.mean() 128 | # next_sentence_loss = next_sentence_loss.mean() 129 | ppl = ppl.sum() 130 | 131 | # loss = masked_lm_loss + next_sentence_loss 132 | loss = masked_lm_loss 133 | iter_bar.set_description('Iter (loss=%5.3f)' % loss.item()) 134 | 135 | valid_ppl += ppl.item() 136 | n_tokens += masked_weights.sum().item() 137 | n_samples += len(task_idx) 138 | 139 | # ppl = np.exp(valid_ppl / n_samples) 140 | ppl = np.exp(valid_ppl / n_tokens) # n_tokens == n_samples, I masked one token per sample 141 | 142 | return ppl 143 | 144 | 145 | def save(model, optimizer, args, i_epoch, i_step, global_step): 146 | model_to_save = model.module if hasattr( 147 | model, 'module') else model # Only save the model it-self 148 | output_model_file = os.path.join( 149 | args.output_dir, "model.e{:}_s{:}.{:}.bin".format(i_epoch, i_step, global_step)) 150 | torch.save(model_to_save.state_dict(), output_model_file) 151 | output_optim_file = os.path.join( 152 | args.output_dir, "optim.e{:}_s{:}.{:}.bin".format(i_epoch, i_step, global_step)) 153 | torch.save(optimizer.state_dict(), output_optim_file) 154 | 155 | 156 | def main(): 157 | parser = argparse.ArgumentParser() 158 | 159 | parser.add_argument('--n_clayer', type=int, required=True, 160 | help="n conditional layer") 161 | 162 | parser.add_argument('--gate', type=str, default="attn", 163 | help="gate method: [attn|gate|gate_x2] ") 164 | 165 | # Required parameters 166 | parser.add_argument("--data_dir", 167 | default=None, 168 | type=str, 169 | required=True, 170 | help="The input data dir. Should contain the .tsv files (or other data files) for the task.") 171 | 172 | parser.add_argument("--c_tfidf_map", type=str, required=True, 173 | help="e.g. c_tfidf_map.pkl in args.data_dir") 174 | 175 | # parser.add_argument("--tgt_file", default=None, type=str, 176 | # help="The output data file name.") 177 | 178 | parser.add_argument("--bert_model", default=None, type=str, required=True, 179 | help="Bert pre-trained model selected in the list: bert-base-uncased, " 180 | "bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese.") 181 | parser.add_argument("--config_path", default=None, type=str, 182 | help="Bert config file path.") 183 | parser.add_argument("--output_dir", 184 | default=None, 185 | type=str, 186 | required=True, 187 | help="The output directory where the model predictions and checkpoints will be written.") 188 | parser.add_argument("--log_dir", 189 | default='', 190 | type=str, 191 | required=True, 192 | help="The output directory where the log will be written.") 193 | parser.add_argument("--model_recover_path", 194 | default=None, 195 | type=str, 196 | help="The file of fine-tuned pretraining model.") 197 | parser.add_argument("--optim_recover_path", 198 | default=None, 199 | type=str, 200 | help="The file of pretraining optimizer.") 201 | # Other parameters 202 | parser.add_argument("--max_seq_length", 203 | default=128, 204 | type=int, 205 | help="The maximum total input sequence length after WordPiece tokenization. \n" 206 | "Sequences longer than this will be truncated, and sequences shorter \n" 207 | "than this will be padded.") 208 | 209 | parser.add_argument("--do_lower_case", 210 | action='store_true', 211 | help="Set this flag if you are using an uncased model.") 212 | parser.add_argument("--train_batch_size", 213 | default=32, 214 | type=int, 215 | help="Total batch size for training.") 216 | parser.add_argument("--eval_batch_size", 217 | default=64, 218 | type=int, 219 | help="Total batch size for eval.") 220 | parser.add_argument("--valid_steps", 221 | default=8192, 222 | type=int) 223 | 224 | parser.add_argument("--learning_rate", default=3e-5, type=float, 225 | help="The initial learning rate for Adam.") 226 | parser.add_argument("--label_smoothing", default=0, type=float, 227 | help="The initial learning rate for Adam.") 228 | parser.add_argument("--weight_decay", 229 | default=0.01, 230 | type=float, 231 | help="The weight decay rate for Adam.") 232 | parser.add_argument("--finetune_decay", 233 | action='store_true', 234 | help="Weight decay to the original weights.") 235 | 236 | parser.add_argument("--warmup_proportion", 237 | default=0.1, 238 | type=float, 239 | help="Proportion of training to perform linear learning rate warmup for. " 240 | "E.g., 0.1 = 10%% of training.") 241 | parser.add_argument("--hidden_dropout_prob", default=0.1, type=float, 242 | help="Dropout rate for hidden states.") 243 | parser.add_argument("--attention_probs_dropout_prob", default=0.1, type=float, 244 | help="Dropout rate for attention probabilities.") 245 | parser.add_argument("--no_cuda", 246 | action='store_true', 247 | help="Whether not to use CUDA when available") 248 | parser.add_argument("--local_rank", 249 | type=int, 250 | default=-1, 251 | help="local_rank for distributed training on gpus") 252 | parser.add_argument('--seed', 253 | type=int, 254 | default=42, 255 | help="random seed for initialization") 256 | parser.add_argument('--gradient_accumulation_steps', 257 | type=int, 258 | default=1, 259 | help="Number of updates steps to accumulate before performing a backward/update pass.") 260 | parser.add_argument('--fp16', action='store_true', 261 | help="Whether to use 16-bit float precision instead of 32-bit") 262 | parser.add_argument('--fp32_embedding', action='store_true', 263 | help="Whether to use 32-bit float precision instead of 16-bit for embeddings") 264 | parser.add_argument('--loss_scale', type=float, default=0, 265 | help="Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n" 266 | "0 (default value): dynamic loss scaling.\n" 267 | "Positive power of 2: static loss scaling value.\n") 268 | parser.add_argument('--amp', action='store_true', 269 | help="Whether to use amp for fp16") 270 | parser.add_argument('--from_scratch', action='store_true', 271 | help="Initialize parameters with random values (i.e., training from scratch).") 272 | parser.add_argument('--new_segment_ids', action='store_true', 273 | help="Use new segment ids for bi-uni-directional LM.") 274 | parser.add_argument('--new_pos_ids', action='store_true', 275 | help="Use new position ids for LMs.") 276 | 277 | parser.add_argument('--tokenized_input', action='store_true', 278 | help="Whether the input is tokenized.") 279 | 280 | parser.add_argument('--max_len_a', type=int, default=0, 281 | help="Truncate_config: maximum length of segment A.") 282 | parser.add_argument('--max_len_b', type=int, default=0, 283 | help="Truncate_config: maximum length of segment B.") 284 | parser.add_argument('--trunc_seg', default='', 285 | help="Truncate_config: first truncate segment A/B (option: a, b).") 286 | parser.add_argument('--always_truncate_tail', action='store_true', 287 | help="Truncate_config: Whether we should always truncate tail.") 288 | parser.add_argument("--mask_prob", default=0.15, type=float, 289 | help="Number of prediction is sometimes less than max_pred when sequence is short.") 290 | parser.add_argument("--mask_prob_eos", default=0, type=float, 291 | help="Number of prediction is sometimes less than max_pred when sequence is short.") 292 | parser.add_argument('--max_pred', type=int, default=20, 293 | help="Max tokens of prediction.") 294 | parser.add_argument("--num_workers", default=0, type=int, 295 | help="Number of workers for the data loader.") 296 | 297 | parser.add_argument('--mask_source_words', action='store_true', 298 | help="Whether to mask source words for training") 299 | parser.add_argument('--skipgram_prb', type=float, default=0.0, 300 | help='prob of ngram mask') 301 | parser.add_argument('--skipgram_size', type=int, default=1, 302 | help='the max size of ngram mask') 303 | parser.add_argument('--mask_whole_word', action='store_true', 304 | help="Whether masking a whole word.") 305 | parser.add_argument('--do_l2r_training', action='store_true', 306 | help="Whether to do left to right training") 307 | parser.add_argument('--has_sentence_oracle', action='store_true', 308 | help="Whether to have sentence level oracle for training. " 309 | "Only useful for summary generation") 310 | parser.add_argument('--max_position_embeddings', type=int, default=None, 311 | help="max position embeddings") 312 | parser.add_argument('--relax_projection', action='store_true', 313 | help="Use different projection layers for tasks.") 314 | parser.add_argument('--ffn_type', default=0, type=int, 315 | help="0: default mlp; 1: W((Wx+b) elem_prod x);") 316 | parser.add_argument('--num_qkv', default=0, type=int, 317 | help="Number of different .") 318 | parser.add_argument('--seg_emb', action='store_true', 319 | help="Using segment embedding for self-attention.") 320 | parser.add_argument('--s2s_special_token', action='store_true', 321 | help="New special tokens ([S2S_SEP]/[S2S_CLS]) of S2S.") 322 | parser.add_argument('--s2s_add_segment', action='store_true', 323 | help="Additional segmental for the encoder of S2S.") 324 | parser.add_argument('--s2s_share_segment', action='store_true', 325 | help="Sharing segment embeddings for the encoder of S2S (used with --s2s_add_segment).") 326 | parser.add_argument('--pos_shift', action='store_true', 327 | help="Using position shift for fine-tuning.") 328 | 329 | args = parser.parse_args() 330 | 331 | # Fine-tune use 332 | # assert Path(args.model_recover_path).exists( 333 | # ), "--model_recover_path doesn't exist" 334 | 335 | args.output_dir = args.output_dir.replace( 336 | '[PT_OUTPUT_DIR]', os.getenv('PT_OUTPUT_DIR', '')) 337 | args.log_dir = args.log_dir.replace( 338 | '[PT_OUTPUT_DIR]', os.getenv('PT_OUTPUT_DIR', '')) 339 | 340 | os.makedirs(args.output_dir, exist_ok=True) 341 | os.makedirs(args.log_dir, exist_ok=True) 342 | json.dump(args.__dict__, open(os.path.join( 343 | args.output_dir, 'opt.json'), 'w'), sort_keys=True, indent=2) 344 | 345 | if args.local_rank == -1 or args.no_cuda: 346 | device = torch.device( 347 | "cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") 348 | n_gpu = torch.cuda.device_count() 349 | 350 | else: 351 | torch.cuda.set_device(args.local_rank) 352 | device = torch.device("cuda", args.local_rank) 353 | n_gpu = 1 354 | # Initializes the distributed backend which will take care of sychronizing nodes/GPUs 355 | dist.init_process_group(backend='nccl') 356 | 357 | logger.info("device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".format( 358 | device, n_gpu, bool(args.local_rank != -1), args.fp16)) 359 | 360 | if args.gradient_accumulation_steps < 1: 361 | raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format( 362 | args.gradient_accumulation_steps)) 363 | 364 | args.train_batch_size = int( 365 | args.train_batch_size / args.gradient_accumulation_steps) 366 | 367 | random.seed(args.seed) 368 | np.random.seed(args.seed) 369 | torch.manual_seed(args.seed) 370 | if n_gpu > 0: 371 | torch.cuda.manual_seed_all(args.seed) 372 | 373 | if args.local_rank not in (-1, 0): 374 | # Make sure only the first process in distributed training will download model & vocab 375 | dist.barrier() 376 | if args.local_rank == 0: 377 | dist.barrier() 378 | 379 | ################################### 380 | # *ZY* 381 | # Load User Mask 382 | with open(os.path.join(args.data_dir, args.c_tfidf_map), 'rb') as f: 383 | c_tfidf_map = pickle.load(f) 384 | 385 | # Get User Indexer 386 | c_indexer = {cid: index for index, cid in enumerate(sorted(list(c_tfidf_map.keys())))} 387 | logger.info("{:} conditions.".format(len(c_indexer))) 388 | 389 | tokenizer = BertTokenizer.from_pretrained( 390 | args.bert_model, do_lower_case=args.do_lower_case) 391 | if args.max_position_embeddings: 392 | tokenizer.max_len = args.max_position_embeddings 393 | data_tokenizer = WhitespaceTokenizer() if args.tokenized_input else tokenizer 394 | 395 | if not args.tokenized_input: 396 | logger.warning("Strongly recommend using BertTokenizer(# Slow) before.") 397 | 398 | bi_uni_pipeline = [seq2seq_loader.Preprocess4Seq2seq(args.max_pred, args.mask_prob, list(tokenizer.vocab.keys( 399 | )), tokenizer.convert_tokens_to_ids, args.max_seq_length, new_segment_ids=args.new_segment_ids, 400 | truncate_config={'max_len_a': args.max_len_a, 401 | 'max_len_b': args.max_len_b, 402 | 'trunc_seg': args.trunc_seg, 403 | 'always_truncate_tail': args.always_truncate_tail}, 404 | mask_source_words=args.mask_source_words, 405 | skipgram_prb=args.skipgram_prb, 406 | skipgram_size=args.skipgram_size, 407 | mask_whole_word=args.mask_whole_word, mode="s2s", 408 | has_oracle=args.has_sentence_oracle, num_qkv=args.num_qkv, 409 | s2s_special_token=args.s2s_special_token, 410 | s2s_add_segment=args.s2s_add_segment, 411 | s2s_share_segment=args.s2s_share_segment, 412 | pos_shift=args.pos_shift, c_indexer=c_indexer, 413 | c_tfidf_map=c_tfidf_map, only_mask_last=True)] 414 | 415 | logger.info("Preprocess Test Set...") 416 | valid_dataset = pre_preprocess('test', args, data_tokenizer, bi_uni_pipeline) 417 | 418 | valid_dataloader = torch.utils.data.DataLoader(valid_dataset, batch_size=args.eval_batch_size, 419 | num_workers=args.num_workers, shuffle=False, 420 | collate_fn=seq2seq_loader.batch_list_to_batch_tensors, pin_memory=False) 421 | 422 | special_num_here = 2048 423 | recover_step = _get_max_epoch_model(args.output_dir) 424 | # (fn_model[0], fn_optim[0], int(global_step), n_epoch, n_step) 425 | 426 | if recover_step: 427 | if recover_step[-1] % special_num_here == 0: 428 | n_finished_epoch = recover_step[-2] - 1 429 | else: 430 | n_finished_epoch = recover_step[-2] 431 | recover_step[-1] = 0 # step in an epoch 432 | else: 433 | n_finished_epoch = 0 434 | 435 | logger.info("### Finished {:} Epoch(s) ###".format(n_finished_epoch)) 436 | 437 | amp_handle = None 438 | if args.fp16 and args.amp: 439 | raise NotImplementedError 440 | # from apex import amp 441 | # amp_handle = amp.init(enable_caching=True) 442 | # logger.info("enable fp16 with amp") 443 | 444 | # Prepare model 445 | cls_num_labels = 2 446 | 447 | type_vocab_size = 2 # V2 448 | 449 | c_indexer = torch.load(os.path.join(args.data_dir, 'c_indexer.pt')) 450 | n_condition = len(c_indexer) 451 | if '' not in c_indexer.keys(): 452 | n_condition += 1 453 | 454 | num_sentlvl_labels = 2 if args.has_sentence_oracle else 0 455 | relax_projection = 4 if args.relax_projection else 0 456 | if args.local_rank not in (-1, 0): 457 | # Make sure only the first process in distributed training will download model & vocab 458 | dist.barrier() 459 | if (recover_step is None) and (args.model_recover_path is None): 460 | raise ValueError 461 | 462 | else: 463 | if recover_step: 464 | assert args.model_recover_path is None # TODO: automatically recover to most recent model 465 | logger.info("***** Recover model: {:} *****".format(recover_step[0])) 466 | model_recover = torch.load(recover_step[0], map_location='cpu') 467 | # recover_step == number of epochs 468 | assert isinstance(recover_step[2], int) 469 | global_step = recover_step[2] 470 | elif args.model_recover_path: 471 | logger.info("***** (ONLY)Recover model: %s *****", 472 | args.model_recover_path) 473 | model_recover = torch.load( 474 | args.model_recover_path, map_location='cpu') 475 | global_step = 0 476 | 477 | n_dial = 10 # FAKE 478 | model = BertForPreTrainingLossMask.from_pretrained( 479 | args.bert_model, state_dict=model_recover, num_labels=cls_num_labels, num_rel=0, 480 | type_vocab_size=type_vocab_size, config_path=args.config_path, task_idx=3, 481 | num_sentlvl_labels=num_sentlvl_labels, max_position_embeddings=args.max_position_embeddings, 482 | label_smoothing=args.label_smoothing, fp32_embedding=args.fp32_embedding, relax_projection=relax_projection, 483 | new_pos_ids=args.new_pos_ids, ffn_type=args.ffn_type, hidden_dropout_prob=args.hidden_dropout_prob, 484 | attention_probs_dropout_prob=args.attention_probs_dropout_prob, num_qkv=args.num_qkv, seg_emb=args.seg_emb, 485 | n_condition=n_condition, n_dial=n_dial, n_clayer=args.n_clayer, gate=args.gate) 486 | 487 | if args.local_rank == 0: 488 | dist.barrier() 489 | 490 | if args.fp16: 491 | model.half() 492 | if args.fp32_embedding: 493 | model.bert.embeddings.word_embeddings.float() 494 | model.bert.embeddings.position_embeddings.float() 495 | model.bert.embeddings.token_type_embeddings.float() 496 | 497 | model.to(device) 498 | if args.local_rank != -1: 499 | try: 500 | from torch.nn.parallel import DistributedDataParallel as DDP 501 | except ImportError: 502 | raise ImportError("DistributedDataParallel") 503 | model = DDP(model, device_ids=[ 504 | args.local_rank], output_device=args.local_rank, find_unused_parameters=True) 505 | elif n_gpu > 1: 506 | # model = torch.nn.DataParallel(model) 507 | model = DataParallelImbalance(model) 508 | 509 | logger.info("***** CUDA.empty_cache() *****") 510 | torch.cuda.empty_cache() 511 | 512 | model.eval() 513 | # logger.info("### First Valid") 514 | valid_loss = validate(model, valid_dataloader, device, n_gpu) 515 | logger.info("### PPL {:.3f}".format(valid_loss)) 516 | 517 | 518 | if __name__ == "__main__": 519 | main() 520 | -------------------------------------------------------------------------------- /biunilm/seq2seq_loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import math 4 | from random import randint, shuffle, uniform 5 | from random import random as rand 6 | from random import sample as sample_func 7 | 8 | from numpy import array 9 | from numpy.random import choice 10 | 11 | import torch 12 | 13 | from biunilm.loader_utils import get_random_word, batch_list_to_batch_tensors, Pipeline 14 | 15 | # Input file format : 16 | # 1. One sentence per line. These should ideally be actual sentences, 17 | # not entire paragraphs or arbitrary spans of text. (Because we use 18 | # the sentence boundaries for the "next sentence prediction" task). 19 | # 2. Blank lines between documents. Document boundaries are needed 20 | # so that the "next sentence prediction" task doesn't span between documents. 21 | 22 | 23 | def truncate_tokens_pair(tokens_a, tokens_b, max_len, max_len_a=0, max_len_b=0, trunc_seg=None, always_truncate_tail=False): 24 | num_truncated_a = [0, 0] 25 | num_truncated_b = [0, 0] 26 | while True: 27 | if len(tokens_a) + len(tokens_b) <= max_len: 28 | break 29 | if (max_len_a > 0) and len(tokens_a) > max_len_a: 30 | trunc_tokens = tokens_a 31 | num_truncated = num_truncated_a 32 | elif (max_len_b > 0) and len(tokens_b) > max_len_b: 33 | trunc_tokens = tokens_b 34 | num_truncated = num_truncated_b 35 | elif trunc_seg: 36 | # truncate the specified segment 37 | if trunc_seg == 'a': 38 | trunc_tokens = tokens_a 39 | num_truncated = num_truncated_a 40 | else: 41 | trunc_tokens = tokens_b 42 | num_truncated = num_truncated_b 43 | else: 44 | # truncate the longer segment 45 | if len(tokens_a) > len(tokens_b): 46 | trunc_tokens = tokens_a 47 | num_truncated = num_truncated_a 48 | else: 49 | trunc_tokens = tokens_b 50 | num_truncated = num_truncated_b 51 | # whether always truncate source sequences 52 | if (not always_truncate_tail) and (rand() < 0.5): 53 | del trunc_tokens[0] 54 | num_truncated[0] += 1 55 | else: 56 | trunc_tokens.pop() 57 | num_truncated[1] += 1 58 | return num_truncated_a, num_truncated_b 59 | 60 | 61 | class MyDataset(torch.utils.data.Dataset): 62 | def __init__(self, file_src_list, batch_size, tokenizer, max_len, file_oracle=None, short_sampling_prob=0.1, sent_reverse_order=False, preprocess=[], 63 | n_dial=-1, n_text=-1, accept_dtypes=[]): 64 | super(MyDataset).__init__() 65 | self.tokenizer = tokenizer # tokenize function 66 | 67 | print("### I set minimum source length to 4.") 68 | self.min_src_len = 4 # TODO: !!! 69 | 70 | self.max_len = max_len # maximum length of tokens 71 | self.short_sampling_prob = short_sampling_prob 72 | assert isinstance(preprocess, list) 73 | assert len(preprocess) == 1 74 | self.preprocess = preprocess 75 | self.n_condition = len(self.preprocess[0].c_indexer) 76 | 77 | self.batch_size = batch_size 78 | self.sent_reverse_order = sent_reverse_order 79 | 80 | assert file_oracle is None 81 | 82 | assert len(accept_dtypes) > 0 83 | 84 | # read the file into memory 85 | self.is_pretrain = True 86 | dial = [] 87 | non_text = [] 88 | c_text = [] 89 | 90 | assert isinstance(file_src_list, list) 91 | for file_src in file_src_list: 92 | with open(file_src, "r", encoding='utf-8') as f: 93 | for index, line in enumerate(f): 94 | if index % 500000 == 0: 95 | print('Preprocess the {:}th line...'.format(index)) 96 | sys.stdout.flush() 97 | 98 | src, cond, tgt, data_type = line.strip('\n').split('\t')[:4] 99 | src_tk = tokenizer.tokenize(src.strip()) 100 | tgt_tk = tokenizer.tokenize(tgt.strip()) 101 | cond = cond.strip() 102 | 103 | if len(src_tk) < self.min_src_len: 104 | src_tk = [] # TODO: !!! 105 | 106 | sample = (src_tk, tgt_tk, cond, data_type) 107 | 108 | if len(tgt_tk) > 0 and len(cond) > 0: 109 | if data_type in accept_dtypes: 110 | if data_type == 'dial': 111 | if len(src_tk): 112 | dial.append(sample) 113 | elif data_type == 'mono': 114 | if cond == '': 115 | non_text.append(sample) 116 | else: 117 | c_text.append(sample) 118 | else: 119 | raise ValueError 120 | 121 | if 0 <= n_dial < len(dial): 122 | dial = sample_func(dial, n_dial) 123 | 124 | if 0 <= n_text < len(c_text): 125 | c_text = sample_func(c_text, n_text) 126 | 127 | print('Load {:} labeled dial samples.'.format(len(dial))) 128 | print('Load {:} labeled text samples.'.format(len(c_text))) 129 | print('Load {:} text samples.'.format(len(non_text))) 130 | 131 | if len(non_text): 132 | raise NotImplementedError # I have not checked it. 133 | 134 | self.n_samples = len(dial) + len(c_text) + len(non_text) 135 | self.n_dial_samples = len(dial) # 0215 136 | self.ex_list = [dial, c_text, non_text] 137 | 138 | self.index_map = {} 139 | index = 0 140 | for idx, _ in enumerate(dial): 141 | assert index not in self.index_map.keys() 142 | self.index_map[index] = (0, idx) 143 | index += 1 144 | 145 | for idx, _ in enumerate(c_text): 146 | assert index not in self.index_map.keys() 147 | self.index_map[index] = (1, idx) 148 | index += 1 149 | 150 | for idx, _ in enumerate(non_text): 151 | assert index not in self.index_map.keys() 152 | self.index_map[index] = (2, idx) 153 | index += 1 154 | 155 | assert list(self.index_map.keys()) == list(range(self.n_samples)) 156 | 157 | def __len__(self): 158 | return self.n_samples 159 | 160 | def __getitem__(self, index): 161 | data_type, idx = self.index_map[index] 162 | instance = self.preprocess[0](self.ex_list[data_type][idx]) 163 | return instance 164 | 165 | 166 | class MySampler(torch.utils.data.Sampler): 167 | def __init__(self, my_dataset, batch_size, n_gpu, n_ctext=-1, equal_sample=False): 168 | assert isinstance(my_dataset, MyDataset) 169 | assert batch_size % n_gpu == 0 170 | 171 | self.batch_size = batch_size 172 | self.n_gpu = n_gpu 173 | self.batch_size_gpu = int(self.batch_size / self.n_gpu) 174 | 175 | self.n_samples = my_dataset.n_samples 176 | 177 | self.dial_index = [] 178 | self.ctext_index = [] 179 | self.non_index = [] 180 | for index, p in my_dataset.index_map.items(): 181 | if p[0] == 0: 182 | self.dial_index.append(index) 183 | elif p[0] == 1: 184 | self.ctext_index.append(index) 185 | elif p[0] == 2: 186 | self.non_index.append(index) 187 | else: 188 | raise ValueError 189 | 190 | print("### Train Set: dial {:}, ctext {:}, non-text {:}".format(len(self.dial_index), 191 | len(self.ctext_index), 192 | len(self.non_index))) 193 | 194 | if n_ctext > 0: 195 | self.n_ctext = min(n_ctext, self.batch_size_gpu) 196 | self.n_non = 0 197 | self.n_dial = 0 198 | 199 | else: 200 | self.n_non = 0 201 | if equal_sample: 202 | self.n_ctext = round(self.batch_size_gpu * 1 / 2) if len(self.ctext_index) else 0 203 | else: 204 | self.n_ctext = round(self.batch_size_gpu * 1 / 4) if len(self.ctext_index) else 0 205 | 206 | self.n_dial = self.batch_size_gpu - self.n_ctext - self.n_non 207 | 208 | print("### Sampler: dial {:}, ctext {:}, non-text {:}".format(self.n_dial, self.n_ctext, self.n_non)) 209 | assert self.n_dial >= 0 210 | 211 | self.dial_gen = self.get_batch_index_generator(self.dial_index, self.n_dial) 212 | self.ctext_gen = self.get_batch_index_generator(self.ctext_index, self.n_ctext) 213 | self.non_gen = self.get_batch_index_generator(self.non_index, self.n_non) 214 | 215 | def __len__(self): 216 | # return math.ceil(self.n_samples / float(self.batch_size)) 217 | return self.n_samples 218 | 219 | def __iter__(self): # iterator to load data 220 | for __ in range(math.ceil(self.n_samples / float(self.batch_size))): 221 | batch_index = [] 222 | 223 | for i in range(self.n_gpu): 224 | batch_index_gpu = self.get_batch() 225 | batch_index.extend(batch_index_gpu) 226 | 227 | for index in batch_index: 228 | yield index 229 | 230 | def get_batch(self): 231 | batch_index = [] 232 | if self.n_dial > 0: 233 | try: 234 | batch_index.extend(next(self.dial_gen)) 235 | except StopIteration: 236 | self.dial_gen = self.get_batch_index_generator(self.dial_index, self.n_dial) 237 | batch_index.extend(next(self.dial_gen)) 238 | 239 | if self.n_ctext > 0: 240 | try: 241 | batch_index.extend(next(self.ctext_gen)) 242 | except StopIteration: 243 | self.ctext_gen = self.get_batch_index_generator(self.ctext_index, self.n_ctext) 244 | batch_index.extend(next(self.ctext_gen)) 245 | 246 | if self.n_non > 0: 247 | try: 248 | batch_index.extend(next(self.non_gen)) 249 | except StopIteration: 250 | self.non_gen = self.get_batch_index_generator(self.non_index, self.n_non) 251 | batch_index.extend(next(self.non_gen)) 252 | 253 | return batch_index 254 | 255 | def get_batch_index_generator(self, a_list, batch_size): 256 | def get_batch_index(a_list, batch_size): 257 | assert isinstance(a_list, list) 258 | for start in range(0, len(a_list), batch_size): 259 | yield a_list[start:start + batch_size] 260 | 261 | assert isinstance(a_list, list) 262 | a_list = sample_func(a_list, len(a_list)) 263 | generator = get_batch_index(a_list, batch_size) 264 | return generator 265 | 266 | 267 | class Preprocess4Seq2seq(Pipeline): 268 | """ Pre-processing steps for pretraining transformer """ 269 | 270 | def __init__(self, max_pred, mask_prob, vocab_words, indexer, max_len=512, skipgram_prb=0, skipgram_size=0, 271 | block_mask=False, mask_whole_word=False, new_segment_ids=False, truncate_config={}, mask_source_words=False, 272 | mode="s2s", has_oracle=False, num_qkv=0, s2s_special_token=False, s2s_add_segment=False, 273 | s2s_share_segment=False, pos_shift=False, 274 | c_indexer=None, c_tfidf_map=None, tfidf_eps=1e-8, dial_mask_rate=0, only_mask_last=False, FGfree_indexer=None): 275 | super().__init__() 276 | self.max_len = max_len 277 | self.max_pred = max_pred # max tokens of prediction 278 | self.mask_prob = mask_prob # masking probability 279 | self.vocab_words = vocab_words # vocabulary (sub)words 280 | self.indexer = indexer # function from token to token index 281 | self.FGfree_indexer = FGfree_indexer 282 | 283 | # *ZY* 284 | self.dial_mask_rate = dial_mask_rate 285 | self.only_mask_last = only_mask_last # TODO: to calculate perplexity 286 | 287 | assert isinstance(c_tfidf_map, dict) 288 | self.c_tfidf_map = c_tfidf_map 289 | self.tfidf_eps = tfidf_eps 290 | 291 | self.nan_cond = '' 292 | assert isinstance(c_indexer, dict) 293 | if self.nan_cond not in c_indexer.keys(): 294 | print('#'*10+'To add condition, we re-arranged c_indexer (+1)'+'#'*10) 295 | sys.stdout.flush() 296 | self.c_indexer = {self.nan_cond: 0} 297 | for i, u in enumerate(c_indexer.keys()): 298 | self.c_indexer[u] = i + 1 299 | else: 300 | self.c_indexer = c_indexer 301 | 302 | # Check 303 | assert sorted(list(self.c_indexer.values())) == list(range(len(self.c_indexer))) 304 | 305 | self.max_len = max_len 306 | self._tril_matrix = torch.tril(torch.ones( 307 | (max_len, max_len), dtype=torch.long)) 308 | self.skipgram_prb = skipgram_prb 309 | self.skipgram_size = skipgram_size 310 | self.mask_whole_word = mask_whole_word 311 | self.new_segment_ids = new_segment_ids 312 | self.always_truncate_tail = truncate_config.get( 313 | 'always_truncate_tail', False) 314 | self.max_len_a = truncate_config.get('max_len_a', None) 315 | self.max_len_b = truncate_config.get('max_len_b', None) 316 | self.trunc_seg = truncate_config.get('trunc_seg', None) 317 | self.task_idx = 3 # relax projection layer for different tasks 318 | self.mask_source_words = mask_source_words 319 | assert mode in ("s2s", "l2r") 320 | self.mode = mode 321 | self.has_oracle = has_oracle 322 | self.num_qkv = num_qkv 323 | self.s2s_special_token = s2s_special_token 324 | self.s2s_add_segment = s2s_add_segment 325 | self.s2s_share_segment = s2s_share_segment 326 | self.pos_shift = pos_shift 327 | 328 | assert self.has_oracle is False 329 | assert self.pos_shift is False # I did not check this option 330 | assert self.num_qkv == 0 331 | 332 | def tfidf_mask(self, cid, cand_pos_tk, n_sample): 333 | tk_tfidf = [] 334 | for _, tk in cand_pos_tk: 335 | try: 336 | tk_tfidf.append(max(self.c_tfidf_map[cid][tk], self.tfidf_eps)) 337 | except KeyError: 338 | tk_tfidf.append(self.tfidf_eps) 339 | 340 | tk_tfidf = array(tk_tfidf) 341 | tk_tfidf = tk_tfidf / tk_tfidf.sum() 342 | 343 | tk_index = choice(range(len(tk_tfidf)), size=n_sample, replace=False, p=tk_tfidf).tolist() 344 | 345 | return [cand_pos_tk[idx][0] for idx in tk_index] 346 | 347 | def preprocess(self, tokens_a, tokens_b, cond, task_idx): 348 | 349 | # tokens_a = ['i', 'love', 'you'] 350 | # tokens_b = ['you', 'like', 'me'] 351 | # cond = '' 352 | # task_idx = 3 353 | 354 | try: 355 | cid = self.c_indexer[cond] 356 | except KeyError: 357 | print("Warning: {:} not in c_indexer".format(cond)) 358 | cid = self.c_indexer[self.nan_cond] 359 | 360 | # -3 for special tokens [CLS], [SEP], [SEP] 361 | num_truncated_a, _ = truncate_tokens_pair(tokens_a, tokens_b, self.max_len - 3, max_len_a=self.max_len_a, 362 | max_len_b=self.max_len_b, trunc_seg=self.trunc_seg, always_truncate_tail=self.always_truncate_tail) 363 | 364 | # Add Special Tokens 365 | if len(tokens_a) > 0: 366 | if (task_idx == 3) and self.s2s_special_token: # dial 367 | tokens = ['[S2S_CLS]'] + tokens_a + ['[S2S_SEP]'] + tokens_b + ['[SEP]'] 368 | else: 369 | tokens = ['[CLS]'] + tokens_a + ['[SEP]'] + tokens_b + ['[SEP]'] 370 | 371 | num_tokens_a = len(tokens_a) + 2 372 | num_tokens_b = len(tokens_b) + 1 373 | 374 | else: # text 375 | tokens = ['[CLS]'] + tokens_b + ['[SEP]'] 376 | num_tokens_a = 0 377 | num_tokens_b = len(tokens_b) + 2 378 | 379 | effective_length = len(tokens_b) 380 | # if (task_idx != 3) and self.mask_source_words: 381 | # effective_length += len(tokens_a) 382 | n_pred = min(self.max_pred, max( 383 | 1, int(round(effective_length*self.mask_prob)))) 384 | # candidate positions of masked tokens 385 | 386 | cand_pos_tk = [] 387 | special_pos = set() # will not be masked 388 | for i, tk in enumerate(tokens): 389 | if len(tokens_a) and (i >= len(tokens_a)+2) and (tk != '[CLS]'): # TODO: mask tokens_b (target sequence) 390 | # we will mask [SEP] as an ending symbol 391 | cand_pos_tk.append((i, tk)) 392 | 393 | elif (len(tokens_a) == 0) and (i >= 1) and (tk != '[CLS]') and (not tk.startswith('[SEP')): 394 | cand_pos_tk.append((i, tk)) 395 | 396 | else: 397 | special_pos.add(i) 398 | 399 | if self.only_mask_last: 400 | cand_pos_tk = [(len(tokens)-2, tokens[-2])] 401 | 402 | # *ZY* 403 | if cond != self.nan_cond: 404 | if task_idx == 1: 405 | cand_pos = self.tfidf_mask(cond, cand_pos_tk, n_pred) 406 | elif (task_idx == 3) and (self.dial_mask_rate > 0.01) and (rand() < self.dial_mask_rate): 407 | cand_pos = self.tfidf_mask(cond, cand_pos_tk, n_pred) 408 | else: 409 | cand_pos = [p[0] for p in cand_pos_tk] 410 | else: 411 | cand_pos = [p[0] for p in cand_pos_tk] 412 | 413 | if self.only_mask_last: 414 | masked_pos = [len(tokens) - 2] 415 | n_real_pred = 1 416 | else: 417 | shuffle(cand_pos) 418 | masked_pos = set() 419 | max_cand_pos = max(cand_pos) 420 | 421 | for pos in cand_pos: # Uniform Distribution Here 422 | if len(masked_pos) >= n_pred: 423 | break 424 | if pos in masked_pos: # Avoid Overlapping 425 | continue 426 | 427 | def _expand_whole_word(st, end): 428 | # because of using WordPiece 429 | new_st, new_end = st, end 430 | while (new_st >= 0) and tokens[new_st].startswith('##'): 431 | new_st -= 1 432 | while (new_end < len(tokens)) and tokens[new_end].startswith('##'): 433 | new_end += 1 434 | return new_st, new_end 435 | 436 | if (self.skipgram_prb > 0) and (self.skipgram_size >= 2) and (rand() < self.skipgram_prb): 437 | # ngram 438 | cur_skipgram_size = randint(2, self.skipgram_size) 439 | if self.mask_whole_word: 440 | st_pos, end_pos = _expand_whole_word( 441 | pos, pos + cur_skipgram_size) 442 | else: 443 | st_pos, end_pos = pos, pos + cur_skipgram_size 444 | else: 445 | # directly mask 446 | if self.mask_whole_word: 447 | st_pos, end_pos = _expand_whole_word(pos, pos + 1) 448 | else: 449 | st_pos, end_pos = pos, pos + 1 450 | 451 | for mp in range(st_pos, end_pos): 452 | if (0 < mp <= max_cand_pos) and (mp not in special_pos): 453 | masked_pos.add(mp) 454 | else: 455 | break 456 | 457 | masked_pos = list(masked_pos) 458 | n_real_pred = len(masked_pos) 459 | if n_real_pred > n_pred: 460 | shuffle(masked_pos) 461 | masked_pos = masked_pos[:n_pred] 462 | n_real_pred = n_pred 463 | 464 | masked_tokens = [tokens[pos] for pos in masked_pos] 465 | 466 | for pos in masked_pos: 467 | if self.only_mask_last or rand() < 0.8: # 80% 468 | tokens[pos] = '[MASK]' 469 | elif rand() < 0.5: # 10% 470 | tokens[pos] = get_random_word(self.vocab_words) 471 | 472 | # when n_pred < max_pred, we only calculate loss within n_pred 473 | masked_weights = [1]*len(masked_tokens) 474 | 475 | # Token Indexing 476 | masked_ids = self.indexer(masked_tokens) 477 | 478 | # Token Indexing 479 | input_ids = self.indexer(tokens) 480 | 481 | # Zero Padding 482 | n_pad = self.max_len - len(input_ids) 483 | input_ids.extend([0]*n_pad) 484 | 485 | mask_qkv = None 486 | 487 | is_next = 1 488 | 489 | if task_idx == 3: 490 | segment_ids = [0] * num_tokens_a + [1] * num_tokens_b 491 | input_mask = torch.zeros(self.max_len, self.max_len, dtype=torch.long) 492 | input_mask[:num_tokens_a, :num_tokens_a].fill_(1) 493 | tril = torch.tril(torch.ones((self.max_len, self.max_len), dtype=torch.long)) 494 | input_mask[num_tokens_a:, :] = tril[num_tokens_a:, :] 495 | 496 | elif task_idx == 1: # left-to-right 497 | segment_ids = [1] * len(tokens) 498 | input_mask = torch.tril(torch.ones((self.max_len, self.max_len), dtype=torch.long)) 499 | 500 | elif task_idx == 0: # bi-attn 501 | segment_ids = [0] * len(tokens) 502 | input_mask = torch.ones((self.max_len, self.max_len), dtype=torch.long) 503 | 504 | else: 505 | raise ValueError 506 | 507 | segment_ids.extend([0]*n_pad) 508 | 509 | # Zero Padding for masked target 510 | if self.max_pred > n_real_pred: 511 | n_pad = self.max_pred - n_real_pred 512 | if masked_ids is not None: 513 | masked_ids.extend([0]*n_pad) 514 | if masked_pos is not None: 515 | masked_pos.extend([0]*n_pad) 516 | if masked_weights is not None: 517 | masked_weights.extend([0]*n_pad) 518 | 519 | # print("tokens, ", tokens) 520 | # print("input_ids, ", input_ids) 521 | # print("segment_ids, ", segment_ids) 522 | # print("masked_ids, ", masked_ids) 523 | # print("masked_pos, ", masked_pos) 524 | # print("input_mask, ", input_mask) 525 | # exit() 526 | 527 | return (num_tokens_a, num_tokens_b, input_ids, cid, segment_ids, input_mask, mask_qkv, masked_ids, masked_pos, masked_weights, is_next, task_idx) 528 | 529 | def preprocess_FGfree(self, tokens_a, tokens_b, cond, task_idx): 530 | def _get_attn_mask(n_words, num_tokens_a, 531 | mask_pos_idx_map_sorted, 532 | task_idx): 533 | 534 | if task_idx == 3: 535 | input_mask = torch.zeros(self.max_len, self.max_len, dtype=torch.long) 536 | # Source 537 | input_mask[:num_tokens_a, :num_tokens_a].fill_(1) 538 | 539 | # Target 540 | tril = torch.tril(torch.ones((self.max_len, self.max_len), dtype=torch.long)) 541 | input_mask[num_tokens_a:, :] = tril[num_tokens_a:, :] 542 | 543 | elif task_idx == 1: 544 | input_mask = torch.tril(torch.ones((self.max_len, self.max_len), dtype=torch.long)) 545 | 546 | else: 547 | raise ValueError("do not support task_idx {:}".format(task_idx)) 548 | 549 | for i, (pos, idx) in enumerate(mask_pos_idx_map_sorted): 550 | input_mask[:, idx].fill_(0) 551 | input_mask[idx, idx].fill_(1) 552 | 553 | input_mask[n_words:, :].fill_(0) 554 | return input_mask 555 | 556 | # tokens_a = ['i', 'love', 'you'] 557 | # tokens_b = ['you', 'like', 'me'] 558 | # cond = '' 559 | # task_idx = 3 560 | 561 | try: 562 | cid = self.c_indexer[cond] 563 | except KeyError: 564 | print("Warning: {:} not in c_indexer".format(cond)) 565 | cid = self.c_indexer[self.nan_cond] 566 | 567 | effective_length = len(tokens_b) 568 | # if (task_idx != 3) and self.mask_source_words: 569 | # effective_length += len(tokens_a) 570 | n_pred = min(self.max_pred, max( 571 | 1, int(round(effective_length*self.mask_prob)))) 572 | # candidate positions of masked tokens 573 | 574 | # -3 for special tokens [CLS], [SEP], [SEP] 575 | num_truncated_a, _ = truncate_tokens_pair(tokens_a, tokens_b, self.max_len - 3 - n_pred, max_len_a=self.max_len_a, 576 | max_len_b=self.max_len_b, trunc_seg=self.trunc_seg, always_truncate_tail=self.always_truncate_tail) 577 | 578 | # Add Special Tokens 579 | if len(tokens_a) > 0: 580 | if (task_idx == 3) and self.s2s_special_token: # dial 581 | tokens = ['[S2S_CLS]'] + tokens_a + ['[S2S_SEP]'] + tokens_b + ['[SEP]'] 582 | else: # text 583 | tokens = ['[CLS]'] + tokens_a + ['[SEP]'] + tokens_b + ['[SEP]'] 584 | 585 | num_tokens_a = len(tokens_a) + 2 586 | num_tokens_b = len(tokens_b) + 1 587 | 588 | else: # text 589 | tokens = ['[CLS]'] + tokens_b + ['[SEP]'] 590 | num_tokens_a = 0 591 | num_tokens_b = len(tokens_b) + 2 592 | 593 | cand_pos_tk = [] 594 | special_pos = set() # will not be masked 595 | for i, tk in enumerate(tokens): 596 | if len(tokens_a) and (i >= len(tokens_a)+2) and (tk != '[CLS]'): # TODO: mask tokens_b (target sequence) 597 | # we will mask [SEP] as an ending symbol 598 | cand_pos_tk.append((i, tk)) 599 | 600 | elif (len(tokens_a) == 0) and (i >= 1) and (tk != '[CLS]') and (not tk.startswith('[SEP')): 601 | cand_pos_tk.append((i, tk)) 602 | 603 | else: 604 | special_pos.add(i) 605 | 606 | if self.only_mask_last: 607 | cand_pos_tk = [(len(tokens)-2, tokens[-2])] 608 | 609 | # *ZY* 610 | if cond != self.nan_cond: 611 | if task_idx == 1: 612 | cand_pos = self.tfidf_mask(cond, cand_pos_tk, n_pred) 613 | elif (task_idx == 3) and (self.dial_mask_rate > 0.01) and (rand() < self.dial_mask_rate): 614 | cand_pos = self.tfidf_mask(cond, cand_pos_tk, n_pred) 615 | else: 616 | cand_pos = [p[0] for p in cand_pos_tk] 617 | else: 618 | cand_pos = [p[0] for p in cand_pos_tk] 619 | 620 | shuffle(cand_pos) 621 | masked_pos = set() 622 | max_cand_pos = max(cand_pos) 623 | 624 | for pos in cand_pos: # Uniform Distribution Here 625 | if len(masked_pos) >= n_pred: 626 | break 627 | if pos in masked_pos: # Avoid Overlapping 628 | continue 629 | 630 | def _expand_whole_word(st, end): 631 | # because of using WordPiece 632 | new_st, new_end = st, end 633 | while (new_st >= 0) and tokens[new_st].startswith('##'): 634 | new_st -= 1 635 | while (new_end < len(tokens)) and tokens[new_end].startswith('##'): 636 | new_end += 1 637 | return new_st, new_end 638 | 639 | if (self.skipgram_prb > 0) and (self.skipgram_size >= 2) and (rand() < self.skipgram_prb): 640 | # ngram 641 | cur_skipgram_size = randint(2, self.skipgram_size) 642 | if self.mask_whole_word: 643 | st_pos, end_pos = _expand_whole_word( 644 | pos, pos + cur_skipgram_size) 645 | else: 646 | st_pos, end_pos = pos, pos + cur_skipgram_size 647 | else: 648 | # directly mask 649 | if self.mask_whole_word: 650 | st_pos, end_pos = _expand_whole_word(pos, pos + 1) 651 | else: 652 | st_pos, end_pos = pos, pos + 1 653 | 654 | for mp in range(st_pos, end_pos): 655 | if (0 < mp <= max_cand_pos) and (mp not in special_pos): 656 | masked_pos.add(mp) 657 | else: 658 | break 659 | 660 | masked_pos = list(masked_pos) 661 | n_real_pred = len(masked_pos) 662 | if n_real_pred > n_pred: 663 | shuffle(masked_pos) 664 | masked_pos = masked_pos[:n_pred] 665 | n_real_pred = n_pred 666 | 667 | masked_tokens = [tokens[pos] for pos in masked_pos] 668 | 669 | for pos in masked_pos: 670 | if rand() < 0.8: # 80% 671 | tokens[pos] = ('[MASK]', tokens[pos]) 672 | elif rand() < 0.5: # 10% 673 | tokens[pos] = (get_random_word(self.vocab_words), tokens[pos]) 674 | else: 675 | tokens[pos] = (tokens[pos], tokens[pos]) 676 | 677 | # when n_pred < max_pred, we only calculate loss within n_pred 678 | masked_weights = [1]*len(masked_tokens) 679 | 680 | # Token Indexing 681 | masked_ids = self.FGfree_indexer(masked_tokens) 682 | 683 | # Token Indexing 684 | # input_ids = self.indexer(tokens) 685 | input_ids, position_ids, mask_pos_idx_map = self.FGfree_indexer(tokens, ret_ids_only=False) 686 | mask_pos_idx_map_sorted = sorted(mask_pos_idx_map.items(), key=lambda p: p[1]) 687 | 688 | num_tokens_b += n_real_pred 689 | 690 | is_next = 1 691 | mask_qkv = None 692 | 693 | if task_idx == 3: 694 | segment_ids = [0] * num_tokens_a + [1] * num_tokens_b 695 | 696 | elif task_idx == 1: 697 | segment_ids = [1] * (num_tokens_a + num_tokens_b) 698 | 699 | elif task_idx == 0: 700 | segment_ids = [0] * (num_tokens_a + num_tokens_b) 701 | 702 | else: 703 | raise ValueError 704 | 705 | assert len(input_ids) == len(position_ids) 706 | assert len(input_ids) == len(segment_ids) 707 | 708 | n_words = len(input_ids) 709 | n_pad = self.max_len - n_words 710 | end_at = position_ids[-1] + 1 711 | 712 | # Zero Padding 713 | input_ids.extend([0]*n_pad) 714 | segment_ids.extend([0]*n_pad) 715 | position_ids.extend(list(range(end_at, end_at+n_pad))) 716 | 717 | assert len(input_ids) == len(position_ids) 718 | 719 | input_mask = _get_attn_mask(n_words, num_tokens_a, mask_pos_idx_map_sorted, task_idx) 720 | 721 | masked_pos = [mask_pos_idx_map[pos] for pos in masked_pos] 722 | 723 | # Zero Padding for masked target 724 | if self.max_pred > n_real_pred: 725 | n_pad = self.max_pred - n_real_pred 726 | if masked_ids is not None: 727 | masked_ids.extend([0]*n_pad) 728 | if masked_pos is not None: 729 | masked_pos.extend([0]*n_pad) 730 | if masked_weights is not None: 731 | masked_weights.extend([0]*n_pad) 732 | 733 | # print("tokens, ", tokens) 734 | # print("input_ids, ", input_ids) 735 | # print("segment_ids, ", segment_ids) 736 | # print("position_ids, ", position_ids) 737 | # print("masked_ids, ", masked_ids) 738 | # print("masked_pos, ", masked_pos) 739 | # print("input_mask, ", input_mask[:n_words+2, :n_words+2]) 740 | # exit() 741 | 742 | return (num_tokens_a, num_tokens_b, input_ids, cid, segment_ids, input_mask, mask_qkv, masked_ids, masked_pos, masked_weights, is_next, task_idx) 743 | 744 | def __call__(self, instance): 745 | tokens_a, tokens_b, cond, data_type = instance 746 | 747 | # print("instance: ", instance) 748 | 749 | if data_type == 'dial': 750 | task_idx = 3 # seq2seq 751 | elif data_type == 'mono': 752 | 753 | if len(tokens_a): # TODO: Notice Here! 754 | tokens_b = tokens_a + ['[SEP]'] + tokens_b 755 | tokens_a = [] 756 | 757 | if (rand() < 0.5) or (cond == ''): 758 | task_idx = 1 # generation 759 | else: 760 | task_idx = 0 # bi-attn, encoding 761 | else: 762 | raise ValueError 763 | 764 | if (self.FGfree_indexer is None) or (task_idx == 0): 765 | return self.preprocess(tokens_a, tokens_b, cond, task_idx) 766 | else: 767 | return self.preprocess_FGfree(tokens_a, tokens_b, cond, task_idx) 768 | 769 | 770 | class Preprocess4Decoder(Pipeline): 771 | """ Pre-processing steps for pretraining transformer """ 772 | def __init__(self, vocab_words, indexer, max_len=512, max_tgt_length=128, new_segment_ids=False, mode="s2s", 773 | num_qkv=0, s2s_special_token=False, s2s_add_segment=False, s2s_share_segment=False, pos_shift=False, 774 | c_indexer=None): 775 | super().__init__() 776 | self.max_len = max_len 777 | self.vocab_words = vocab_words # vocabulary (sub)words 778 | self.indexer = indexer # function from token to token index 779 | self.max_len = max_len 780 | self._tril_matrix = torch.tril(torch.ones( 781 | (max_len, max_len), dtype=torch.long)) 782 | self.new_segment_ids = new_segment_ids 783 | self.task_idx = 3 # relax projection layer for different tasks 784 | assert mode in ("s2s", "l2r") 785 | self.mode = mode 786 | self.max_tgt_length = max_tgt_length 787 | self.num_qkv = num_qkv 788 | self.s2s_special_token = s2s_special_token 789 | self.s2s_add_segment = s2s_add_segment 790 | self.s2s_share_segment = s2s_share_segment 791 | self.pos_shift = pos_shift 792 | 793 | # *ZY* 794 | self.nan_cond = '' 795 | assert isinstance(c_indexer, dict) 796 | if self.nan_cond not in c_indexer.keys(): 797 | print('#'*10+'To add user, we re-arranged c_indexer (+1)'+'#'*10) 798 | sys.stdout.flush() 799 | self.c_indexer = {self.nan_cond: 0} 800 | for i, u in enumerate(c_indexer.keys()): 801 | self.c_indexer[u] = i + 1 802 | # Check 803 | assert sorted(list(self.c_indexer.values())) == list(range(len(self.c_indexer))) 804 | 805 | def __call__(self, instance): 806 | tokens_a, usrid, max_a_len = instance 807 | 808 | try: 809 | cid = self.c_indexer[usrid] 810 | except KeyError: 811 | print("Warning: {:} not in c_indexer".format(usrid)) 812 | cid = self.c_indexer[self.nan_cond] 813 | 814 | # Add Special Tokens 815 | if self.s2s_special_token: 816 | padded_tokens_a = ['[S2S_CLS]'] + tokens_a + ['[S2S_SEP]'] 817 | else: 818 | padded_tokens_a = ['[CLS]'] + tokens_a + ['[SEP]'] 819 | assert len(padded_tokens_a) <= max_a_len + 2 820 | if max_a_len + 2 > len(padded_tokens_a): 821 | padded_tokens_a += ['[PAD]'] * \ 822 | (max_a_len + 2 - len(padded_tokens_a)) 823 | assert len(padded_tokens_a) == max_a_len + 2 824 | max_len_in_batch = min(self.max_tgt_length + 825 | max_a_len + 2, self.max_len) 826 | tokens = padded_tokens_a 827 | 828 | segment_ids = [0]*(len(padded_tokens_a)) \ 829 | + [1]*(max_len_in_batch - len(padded_tokens_a)) 830 | 831 | if self.num_qkv > 1: 832 | mask_qkv = [0]*(len(padded_tokens_a)) + [1] * \ 833 | (max_len_in_batch - len(padded_tokens_a)) 834 | else: 835 | mask_qkv = None 836 | 837 | position_ids = [] 838 | for i in range(len(tokens_a) + 2): 839 | position_ids.append(i) 840 | for i in range(len(tokens_a) + 2, max_a_len + 2): 841 | position_ids.append(0) 842 | for i in range(max_a_len + 2, max_len_in_batch): 843 | position_ids.append(i - (max_a_len + 2) + len(tokens_a) + 2) 844 | 845 | # Token Indexing 846 | input_ids = self.indexer(tokens) 847 | 848 | # Zero Padding 849 | input_mask = torch.zeros( 850 | max_len_in_batch, max_len_in_batch, dtype=torch.long) 851 | if self.mode == "s2s": 852 | input_mask[:, :len(tokens_a)+2].fill_(1) 853 | else: 854 | st, end = 0, len(tokens_a) + 2 855 | input_mask[st:end, st:end].copy_( 856 | self._tril_matrix[:end, :end]) 857 | input_mask[end:, :len(tokens_a)+2].fill_(1) 858 | second_st, second_end = len(padded_tokens_a), max_len_in_batch 859 | 860 | input_mask[second_st:second_end, second_st:second_end].copy_( 861 | self._tril_matrix[:second_end-second_st, :second_end-second_st]) 862 | 863 | return (input_ids, cid, segment_ids, position_ids, input_mask, mask_qkv, self.task_idx) 864 | --------------------------------------------------------------------------------