├── nn
├── __init__.py
└── data_parallel.py
├── biunilm
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-37.pyc
│ ├── loader_utils.cpython-37.pyc
│ └── seq2seq_loader.cpython-37.pyc
├── loader_utils.py
├── decode_seq2seq.py
├── run_ppl.py
└── seq2seq_loader.py
├── img
└── model.png
├── pytorch_pretrained_bert
├── __pycache__
│ ├── __init__.cpython-37.pyc
│ ├── file_utils.cpython-37.pyc
│ └── tokenization.cpython-37.pyc
├── __init__.py
├── __main__.py
├── loss.py
├── optimization_fp16.py
├── file_utils.py
├── tokenization.py
└── optimization.py
├── .idea
├── vcs.xml
├── inspectionProfiles
│ └── profiles_settings.xml
├── misc.xml
├── modules.xml
├── deployment.xml
├── MultiT-C-Dialog-git.iml
├── webServers.xml
└── workspace.xml
├── utils.py
├── run_eval.sh
├── run_ppl.sh
├── run_train.sh
├── run_pretrain.sh
├── run_2step_pre.sh
├── run_2step_ft.sh
├── run_sequential_train.py
├── setup.py
├── get_tfidf.py
├── pre_tokenize.py
├── README.md
├── eval.py
└── qg
├── eval.py
└── eval_on_unilm_tokenized_ref.py
/nn/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/biunilm/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/img/model.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zengyan-97/MultiT-C-Dialog/HEAD/img/model.png
--------------------------------------------------------------------------------
/biunilm/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zengyan-97/MultiT-C-Dialog/HEAD/biunilm/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/biunilm/__pycache__/loader_utils.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zengyan-97/MultiT-C-Dialog/HEAD/biunilm/__pycache__/loader_utils.cpython-37.pyc
--------------------------------------------------------------------------------
/biunilm/__pycache__/seq2seq_loader.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zengyan-97/MultiT-C-Dialog/HEAD/biunilm/__pycache__/seq2seq_loader.cpython-37.pyc
--------------------------------------------------------------------------------
/pytorch_pretrained_bert/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zengyan-97/MultiT-C-Dialog/HEAD/pytorch_pretrained_bert/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/pytorch_pretrained_bert/__pycache__/file_utils.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zengyan-97/MultiT-C-Dialog/HEAD/pytorch_pretrained_bert/__pycache__/file_utils.cpython-37.pyc
--------------------------------------------------------------------------------
/pytorch_pretrained_bert/__pycache__/tokenization.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zengyan-97/MultiT-C-Dialog/HEAD/pytorch_pretrained_bert/__pycache__/tokenization.cpython-37.pyc
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
5 |
7 |
16 |
18 |
42 |
49 |
5 |
6 |

.") 81 | parser.add_argument('--seg_emb', action='store_true', 82 | help="Using segment embedding for self-attention.") 83 | 84 | # decoding parameters 85 | parser.add_argument('--fp16', action='store_true', 86 | help="Whether to use 16-bit float precision instead of 32-bit") 87 | parser.add_argument('--amp', action='store_true', 88 | help="Whether to use amp for fp16") 89 | 90 | parser.add_argument('--subset', type=int, default=0, 91 | help="Decode a subset of the input dataset.") 92 | parser.add_argument("--split", type=str, default="", 93 | help="Data split (train/val/test).") 94 | parser.add_argument('--tokenized_input', action='store_true', 95 | help="Whether the input is tokenized.") 96 | parser.add_argument('--seed', type=int, default=123, 97 | help="random seed for initialization") 98 | parser.add_argument("--do_lower_case", action='store_true', 99 | help="Set this flag if you are using an uncased model.") 100 | parser.add_argument('--new_segment_ids', action='store_true', 101 | help="Use new segment ids for bi-uni-directional LM.") 102 | parser.add_argument('--new_pos_ids', action='store_true', 103 | help="Use new position ids for LMs.") 104 | parser.add_argument('--batch_size', type=int, default=4, 105 | help="Batch size for decoding.") 106 | parser.add_argument('--beam_size', type=int, default=1, 107 | help="Beam size for searching") 108 | parser.add_argument('--length_penalty', type=float, default=0, 109 | help="Length penalty for beam search") 110 | 111 | parser.add_argument('--forbid_duplicate_ngrams', action='store_true') 112 | parser.add_argument('--forbid_ignore_word', type=str, default=None, 113 | help="Ignore the word during forbid_duplicate_ngrams") 114 | parser.add_argument("--min_len", default=None, type=int) 115 | parser.add_argument('--need_score_traces', action='store_true') 116 | parser.add_argument('--ngram_size', type=int, default=3) 117 | parser.add_argument('--mode', default="s2s", 118 | choices=["s2s", "l2r", "both"]) 119 | parser.add_argument('--max_tgt_length', type=int, default=128, 120 | help="maximum length of target sequence") 121 | parser.add_argument('--s2s_special_token', action='store_true', 122 | help="New special tokens ([S2S_SEP]/[S2S_CLS]) of S2S.") 123 | parser.add_argument('--s2s_add_segment', action='store_true', 124 | help="Additional segmental for the encoder of S2S.") 125 | parser.add_argument('--s2s_share_segment', action='store_true', 126 | help="Sharing segment embeddings for the encoder of S2S (used with --s2s_add_segment).") 127 | parser.add_argument('--pos_shift', action='store_true', 128 | help="Using position shift for fine-tuning.") 129 | parser.add_argument('--not_predict_token', type=str, default=None, 130 | help="Do not predict the tokens during decoding.") 131 | 132 | args = parser.parse_args() 133 | 134 | if args.need_score_traces and args.beam_size <= 1: 135 | raise ValueError( 136 | "Score trace is only available for beam search with beam size > 1.") 137 | if args.max_tgt_length >= args.max_seq_length - 2: 138 | raise ValueError("Maximum tgt length exceeds max seq length - 2.") 139 | 140 | device = torch.device( 141 | "cuda" if torch.cuda.is_available() else "cpu") 142 | n_gpu = torch.cuda.device_count() 143 | 144 | random.seed(args.seed) 145 | np.random.seed(args.seed) 146 | torch.manual_seed(args.seed) 147 | if n_gpu > 0: 148 | torch.cuda.manual_seed_all(args.seed) 149 | 150 | tokenizer = BertTokenizer.from_pretrained( 151 | args.bert_model, do_lower_case=args.do_lower_case) 152 | 153 | tokenizer.max_len = args.max_seq_length 154 | 155 | c_indexer = torch.load(os.path.join(args.data_dir, 'c_indexer.pt')) 156 | logger.info("{:} conditions.".format(len(c_indexer))) 157 | 158 | pair_num_relation = 0 159 | bi_uni_pipeline = [ 160 | seq2seq_loader.Preprocess4Decoder(list(tokenizer.vocab.keys()), tokenizer.convert_tokens_to_ids, 161 | args.max_seq_length, max_tgt_length=args.max_tgt_length, 162 | new_segment_ids=args.new_segment_ids, 163 | mode="s2s", num_qkv=args.num_qkv, 164 | s2s_special_token=args.s2s_special_token, 165 | s2s_add_segment=args.s2s_add_segment, 166 | s2s_share_segment=args.s2s_share_segment, pos_shift=args.pos_shift, 167 | c_indexer=c_indexer)] 168 | 169 | logger.info("### Some c_indexer ###") 170 | tmp = sorted(bi_uni_pipeline[0].c_indexer.items(), key=lambda p: p[1]) 171 | print(tmp[:10]) 172 | sys.stdout.flush() 173 | 174 | amp_handle = None 175 | if args.fp16 and args.amp: 176 | raise NotImplementedError 177 | # from apex import amp 178 | # amp_handle = amp.init(enable_caching=True) 179 | # logger.info("enable fp16 with amp") 180 | 181 | # Prepare model 182 | cls_num_labels = 2 183 | type_vocab_size = 2 184 | mask_word_id, eos_word_ids, sos_word_id = tokenizer.convert_tokens_to_ids( 185 | ["[MASK]", "[SEP]", "[S2S_SOS]"]) 186 | 187 | def _get_token_id_set(s): 188 | r = None 189 | if s: 190 | w_list = [] 191 | for w in s.split('|'): 192 | if w.startswith('[') and w.endswith(']'): 193 | w_list.append(w.upper()) 194 | else: 195 | w_list.append(w) 196 | r = set(tokenizer.convert_tokens_to_ids(w_list)) 197 | return r 198 | 199 | forbid_ignore_set = _get_token_id_set(args.forbid_ignore_word) 200 | not_predict_set = _get_token_id_set(args.not_predict_token) 201 | print(args.model_recover_path) 202 | for model_recover_path in glob.glob(args.model_recover_path.strip()): 203 | logger.info("***** Recover model: %s *****", model_recover_path) 204 | model_recover = torch.load(model_recover_path) 205 | 206 | if '' not in c_indexer.keys(): 207 | n_condition = len(c_indexer) + 1 208 | else: 209 | n_condition = len(c_indexer) 210 | assert c_indexer[' '] == 0 # Check 211 | 212 | n_dial = 10 # Fake 213 | model = BertForSeq2SeqDecoder.from_pretrained(args.bert_model, state_dict=model_recover, 214 | num_labels=cls_num_labels, num_rel=pair_num_relation, 215 | type_vocab_size=type_vocab_size, task_idx=3, 216 | mask_word_id=mask_word_id, search_beam_size=args.beam_size, 217 | length_penalty=args.length_penalty, eos_id=eos_word_ids, 218 | sos_id=sos_word_id, 219 | forbid_duplicate_ngrams=args.forbid_duplicate_ngrams, 220 | forbid_ignore_set=forbid_ignore_set, 221 | not_predict_set=not_predict_set, ngram_size=args.ngram_size, 222 | min_len=args.min_len, mode=args.mode, 223 | max_position_embeddings=args.max_seq_length, 224 | ffn_type=args.ffn_type, num_qkv=args.num_qkv, 225 | seg_emb=args.seg_emb, pos_shift=args.pos_shift, 226 | n_condition=n_condition, n_dial=n_dial, n_clayer=args.n_clayer, gate=args.gate) 227 | 228 | del model_recover 229 | 230 | model.to(device) 231 | if n_gpu > 1: 232 | model = torch.nn.DataParallel(model) 233 | 234 | torch.cuda.empty_cache() 235 | model.eval() 236 | next_i = 0 237 | max_src_length = args.max_seq_length - 2 - args.max_tgt_length 238 | 239 | with open(os.path.join(args.data_dir, args.input_file), encoding="utf-8") as fin: 240 | # *ZY* 241 | input_lines = [line.strip().split('\t')[:2] for line in fin.readlines()] 242 | if args.subset > 0: 243 | logger.info("Decoding subset: %d", args.subset) 244 | input_lines = input_lines[:args.subset] 245 | 246 | data_tokenizer = WhitespaceTokenizer() if args.tokenized_input else tokenizer 247 | 248 | input_lines = [[data_tokenizer.tokenize( 249 | src)[:max_src_length], uid] for src, uid in input_lines] 250 | 251 | input_lines = sorted(list(enumerate(input_lines)), 252 | key=lambda x: -len(x[1][0])) 253 | 254 | output_lines = [""] * len(input_lines) 255 | score_trace_list = [None] * len(input_lines) 256 | total_batch = math.ceil(len(input_lines) / args.batch_size) 257 | 258 | with tqdm(total=total_batch) as pbar: 259 | while next_i < len(input_lines): 260 | _chunk = input_lines[next_i:next_i + args.batch_size] 261 | buf_id = [x[0] for x in _chunk] 262 | buf = [x[1] for x in _chunk] 263 | next_i += args.batch_size 264 | max_a_len = max([len(x[0]) for x in buf]) 265 | instances = [] 266 | for instance in [(x[0], x[1], max_a_len) for x in buf]: 267 | for proc in bi_uni_pipeline: 268 | instances.append(proc(instance)) 269 | 270 | with torch.no_grad(): 271 | batch = seq2seq_loader.batch_list_to_batch_tensors( 272 | instances) 273 | batch = [ 274 | t.to(device) if t is not None else None for t in batch] 275 | 276 | input_ids, usrid_ids, token_type_ids, position_ids, input_mask, mask_qkv, task_idx = batch 277 | traces = model(input_ids, usrid_ids, token_type_ids, position_ids, input_mask, 278 | task_idx=task_idx, mask_qkv=mask_qkv) 279 | 280 | if args.beam_size > 1: 281 | traces = {k: v.tolist() for k, v in traces.items()} 282 | output_ids = traces['pred_seq'] 283 | # print(output_ids) # Debug 284 | else: 285 | output_ids = traces.tolist() 286 | for i in range(len(buf)): 287 | w_ids = output_ids[i] 288 | output_buf = tokenizer.convert_ids_to_tokens(w_ids) 289 | output_tokens = [] 290 | for t in output_buf: 291 | if t in ("[SEP]", "[PAD]"): 292 | break 293 | output_tokens.append(t) 294 | output_sequence = ' '.join(detokenize(output_tokens)) 295 | output_lines[buf_id[i]] = output_sequence 296 | if args.need_score_traces: 297 | score_trace_list[buf_id[i]] = { 298 | 'scores': traces['scores'][i], 'wids': traces['wids'][i], 'ptrs': traces['ptrs'][i]} 299 | pbar.update(1) 300 | 301 | if args.output_file: 302 | fn_out = args.output_file 303 | else: 304 | fn_out = model_recover_path + '.' + args.split 305 | 306 | len_list = [] 307 | with open(fn_out, "w", encoding="utf-8") as fout: 308 | for l in output_lines: 309 | fout.write(l) 310 | fout.write("\n") 311 | 312 | len_list.append(len(l.strip().split(' '))) 313 | 314 | print("### average len: {:}".format(np.mean(len_list))) 315 | 316 | if args.need_score_traces: 317 | with open(fn_out + ".trace.pickle", "wb") as fout_trace: 318 | pickle.dump( 319 | {"version": 0.0, "num_samples": len(input_lines)}, fout_trace) 320 | for x in score_trace_list: 321 | pickle.dump(x, fout_trace) 322 | 323 | 324 | if __name__ == "__main__": 325 | main() 326 | -------------------------------------------------------------------------------- /pytorch_pretrained_bert/tokenization.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Tokenization classes.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import collections 22 | import unicodedata 23 | import os 24 | import logging 25 | 26 | from .file_utils import cached_path 27 | 28 | logger = logging.getLogger(__name__) 29 | 30 | PRETRAINED_VOCAB_ARCHIVE_MAP = { 31 | 'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt", 32 | 'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-vocab.txt", 33 | 'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-vocab.txt", 34 | 'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-vocab.txt", 35 | 'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-vocab.txt", 36 | 'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-vocab.txt", 37 | 'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-vocab.txt", 38 | } 39 | PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = { 40 | 'bert-base-uncased': 512, 41 | 'bert-large-uncased': 512, 42 | 'bert-base-cased': 512, 43 | 'bert-large-cased': 512, 44 | 'bert-base-multilingual-uncased': 512, 45 | 'bert-base-multilingual-cased': 512, 46 | 'bert-base-chinese': 512, 47 | } 48 | VOCAB_NAME = 'vocab.txt' 49 | 50 | 51 | def load_vocab(vocab_file): 52 | """Loads a vocabulary file into a dictionary.""" 53 | # mapping unused tokens to special tokens 54 | extra_map = {} 55 | extra_map['[unused1]'] = '[X_SEP]' 56 | for i in range(10): 57 | extra_map['[unused{}]'.format(i+2)] = '[SEP_{}]'.format(i) 58 | extra_map['[unused12]'] = '[S2S_SEP]' 59 | extra_map['[unused13]'] = '[S2S_CLS]' 60 | extra_map['[unused14]'] = '[L2R_SEP]' 61 | extra_map['[unused15]'] = '[L2R_CLS]' 62 | extra_map['[unused16]'] = '[R2L_SEP]' 63 | extra_map['[unused17]'] = '[R2L_CLS]' 64 | extra_map['[unused18]'] = '[S2S_SOS]' 65 | 66 | vocab = collections.OrderedDict() 67 | index = 0 68 | with open(vocab_file, "r", encoding="utf-8") as reader: 69 | while True: 70 | token = reader.readline() 71 | if not token: 72 | break 73 | token = token.strip() 74 | if token in extra_map: 75 | token = extra_map[token] 76 | vocab[token] = index 77 | index += 1 78 | return vocab 79 | 80 | 81 | def whitespace_tokenize(text): 82 | """Runs basic whitespace cleaning and splitting on a peice of text.""" 83 | text = text.strip() 84 | if not text: 85 | return [] 86 | tokens = text.split() 87 | return tokens 88 | 89 | 90 | class BertTokenizer(object): 91 | """Runs end-to-end tokenization: punctuation splitting + wordpiece""" 92 | 93 | def __init__(self, vocab_file, do_lower_case=True, max_len=None, never_split=("[UNK]", "[SEP]", "[X_SEP]", "[PAD]", "[CLS]", "[MASK]")): 94 | if not os.path.isfile(vocab_file): 95 | raise ValueError( 96 | "Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained " 97 | "model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`".format(vocab_file)) 98 | self.vocab = load_vocab(vocab_file) 99 | # *ZY* 100 | self.my_special = list(never_split) + [' ', ' ', ' ', ' ', '[SRC_SEP]'] 101 | self.vocab[' '] = 20 102 | self.vocab[' '] = 21 103 | self.vocab[' '] = 22 104 | self.vocab[' '] = 23 105 | self.vocab['[SRC_SEP]'] = 24 106 | 107 | del self.vocab['[unused19]'] 108 | del self.vocab['[unused20]'] 109 | del self.vocab['[unused21]'] 110 | del self.vocab['[unused22]'] 111 | del self.vocab['[unused23]'] 112 | 113 | self.ids_to_tokens = collections.OrderedDict( 114 | [(ids, tok) for tok, ids in self.vocab.items()]) 115 | self.basic_tokenizer = BasicTokenizer( 116 | do_lower_case=do_lower_case, never_split=self.my_special) # *ZY* 117 | self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab) 118 | self.max_len = max_len if max_len is not None else int(1e12) 119 | 120 | def tokenize(self, text): 121 | split_tokens = [] 122 | for token in self.basic_tokenizer.tokenize(text): 123 | for sub_token in self.wordpiece_tokenizer.tokenize(token): 124 | split_tokens.append(sub_token) 125 | return split_tokens 126 | 127 | def convert_tokens_to_ids(self, tokens): 128 | """Converts a sequence of tokens into ids using the vocab.""" 129 | ids = [] 130 | for token in tokens: 131 | ids.append(self.vocab[token]) 132 | if len(ids) > self.max_len: 133 | raise ValueError( 134 | "Token indices sequence length is longer than the specified maximum " 135 | " sequence length for this BERT model ({} > {}). Running this" 136 | " sequence through BERT will result in indexing errors".format( 137 | len(ids), self.max_len) 138 | ) 139 | return ids 140 | 141 | def convert_tokens_to_ids_FGfree(self, tokens, ret_ids_only=True): 142 | """Converts a sequence of tokens into ids using the vocab.""" 143 | ids = [] 144 | position_ids = [] 145 | mask_pos_idx_map = {} 146 | idx_counter = 0 147 | for pos, token in enumerate(tokens): 148 | if isinstance(token, str): 149 | ids.append(self.vocab[token]) 150 | position_ids.append(pos) 151 | idx_counter += 1 152 | else: # TODO: masked tokens -- here is tuple 153 | mask_pos_idx_map[pos] = idx_counter 154 | assert len(token) == 2 155 | for t in token: # here is tuple, and the first position is [MASK] 156 | ids.append(self.vocab[t]) 157 | position_ids.append(pos) 158 | idx_counter += 1 159 | 160 | if len(ids) > self.max_len: 161 | raise ValueError( 162 | "Token indices sequence length is longer than the specified maximum " 163 | " sequence length for this BERT model ({} > {}). Running this" 164 | " sequence through BERT will result in indexing errors".format( 165 | len(ids), self.max_len) 166 | ) 167 | if ret_ids_only: 168 | return ids 169 | 170 | return ids, position_ids, mask_pos_idx_map 171 | 172 | def convert_ids_to_tokens(self, ids): 173 | """Converts a sequence of ids in wordpiece tokens using the vocab.""" 174 | tokens = [] 175 | for i in ids: 176 | tokens.append(self.ids_to_tokens[i]) 177 | return tokens 178 | 179 | @classmethod 180 | def from_pretrained(cls, pretrained_model_name, cache_dir=None, *inputs, **kwargs): 181 | """ 182 | Instantiate a PreTrainedBertModel from a pre-trained model file. 183 | Download and cache the pre-trained model file if needed. 184 | """ 185 | if pretrained_model_name in PRETRAINED_VOCAB_ARCHIVE_MAP: 186 | vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name] 187 | else: 188 | vocab_file = pretrained_model_name 189 | if os.path.isdir(vocab_file): 190 | vocab_file = os.path.join(vocab_file, VOCAB_NAME) 191 | # redirect to the cache, if necessary 192 | try: 193 | resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir) 194 | except FileNotFoundError: 195 | logger.error( 196 | "Model name '{}' was not found in model name list ({}). " 197 | "We assumed '{}' was a path or url but couldn't find any file " 198 | "associated to this path or url.".format( 199 | pretrained_model_name, 200 | ', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()), 201 | vocab_file)) 202 | return None 203 | if resolved_vocab_file == vocab_file: 204 | logger.info("loading vocabulary file {}".format(vocab_file)) 205 | else: 206 | logger.info("loading vocabulary file {} from cache at {}".format( 207 | vocab_file, resolved_vocab_file)) 208 | if pretrained_model_name in PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP: 209 | # if we're using a pretrained model, ensure the tokenizer wont index sequences longer 210 | # than the number of positional embeddings 211 | max_len = PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP[pretrained_model_name] 212 | kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len) 213 | # Instantiate tokenizer. 214 | tokenizer = cls(resolved_vocab_file, *inputs, **kwargs) 215 | return tokenizer 216 | 217 | 218 | class WhitespaceTokenizer(object): 219 | def tokenize(self, text): 220 | return whitespace_tokenize(text) 221 | 222 | 223 | class BasicTokenizer(object): 224 | """Runs basic tokenization (punctuation splitting, lower casing, etc.).""" 225 | 226 | def __init__(self, do_lower_case=True, never_split=("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]")): 227 | """Constructs a BasicTokenizer. 228 | 229 | Args: 230 | do_lower_case: Whether to lower case the input. 231 | """ 232 | self.do_lower_case = do_lower_case 233 | self.never_split = never_split 234 | 235 | def tokenize(self, text): 236 | """Tokenizes a piece of text.""" 237 | text = self._clean_text(text) 238 | # This was added on November 1st, 2018 for the multilingual and Chinese 239 | # models. This is also applied to the English models now, but it doesn't 240 | # matter since the English models were not trained on any Chinese data 241 | # and generally don't have any Chinese data in them (there are Chinese 242 | # characters in the vocabulary because Wikipedia does have some Chinese 243 | # words in the English Wikipedia.). 244 | text = self._tokenize_chinese_chars(text) 245 | orig_tokens = whitespace_tokenize(text) 246 | split_tokens = [] 247 | for token in orig_tokens: 248 | if self.do_lower_case and token not in self.never_split: 249 | token = token.lower() 250 | token = self._run_strip_accents(token) 251 | split_tokens.extend(self._run_split_on_punc(token)) 252 | 253 | output_tokens = whitespace_tokenize(" ".join(split_tokens)) 254 | return output_tokens 255 | 256 | def _run_strip_accents(self, text): 257 | """Strips accents from a piece of text.""" 258 | text = unicodedata.normalize("NFD", text) 259 | output = [] 260 | for char in text: 261 | cat = unicodedata.category(char) 262 | if cat == "Mn": 263 | continue 264 | output.append(char) 265 | return "".join(output) 266 | 267 | def _run_split_on_punc(self, text): 268 | """Splits punctuation on a piece of text.""" 269 | if text in self.never_split: 270 | return [text] 271 | chars = list(text) 272 | i = 0 273 | start_new_word = True 274 | output = [] 275 | while i < len(chars): 276 | char = chars[i] 277 | if _is_punctuation(char): 278 | output.append([char]) 279 | start_new_word = True 280 | else: 281 | if start_new_word: 282 | output.append([]) 283 | start_new_word = False 284 | output[-1].append(char) 285 | i += 1 286 | 287 | return ["".join(x) for x in output] 288 | 289 | def _tokenize_chinese_chars(self, text): 290 | """Adds whitespace around any CJK character.""" 291 | output = [] 292 | for char in text: 293 | cp = ord(char) 294 | if self._is_chinese_char(cp): 295 | output.append(" ") 296 | output.append(char) 297 | output.append(" ") 298 | else: 299 | output.append(char) 300 | return "".join(output) 301 | 302 | def _is_chinese_char(self, cp): 303 | """Checks whether CP is the codepoint of a CJK character.""" 304 | # This defines a "chinese character" as anything in the CJK Unicode block: 305 | # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) 306 | # 307 | # Note that the CJK Unicode block is NOT all Japanese and Korean characters, 308 | # despite its name. The modern Korean Hangul alphabet is a different block, 309 | # as is Japanese Hiragana and Katakana. Those alphabets are used to write 310 | # space-separated words, so they are not treated specially and handled 311 | # like the all of the other languages. 312 | if ((cp >= 0x4E00 and cp <= 0x9FFF) or # 313 | (cp >= 0x3400 and cp <= 0x4DBF) or # 314 | (cp >= 0x20000 and cp <= 0x2A6DF) or # 315 | (cp >= 0x2A700 and cp <= 0x2B73F) or # 316 | (cp >= 0x2B740 and cp <= 0x2B81F) or # 317 | (cp >= 0x2B820 and cp <= 0x2CEAF) or 318 | (cp >= 0xF900 and cp <= 0xFAFF) or # 319 | (cp >= 0x2F800 and cp <= 0x2FA1F)): # 320 | return True 321 | 322 | return False 323 | 324 | def _clean_text(self, text): 325 | """Performs invalid character removal and whitespace cleanup on text.""" 326 | output = [] 327 | for char in text: 328 | cp = ord(char) 329 | if cp == 0 or cp == 0xfffd or _is_control(char): 330 | continue 331 | if _is_whitespace(char): 332 | output.append(" ") 333 | else: 334 | output.append(char) 335 | return "".join(output) 336 | 337 | 338 | class WordpieceTokenizer(object): 339 | """Runs WordPiece tokenization.""" 340 | 341 | def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=100): 342 | self.vocab = vocab 343 | self.unk_token = unk_token 344 | self.max_input_chars_per_word = max_input_chars_per_word 345 | 346 | def tokenize(self, text): 347 | """Tokenizes a piece of text into its word pieces. 348 | 349 | This uses a greedy longest-match-first algorithm to perform tokenization 350 | using the given vocabulary. 351 | 352 | For example: 353 | input = "unaffable" 354 | output = ["un", "##aff", "##able"] 355 | 356 | Args: 357 | text: A single token or whitespace separated tokens. This should have 358 | already been passed through `BasicTokenizer`. 359 | 360 | Returns: 361 | A list of wordpiece tokens. 362 | """ 363 | 364 | output_tokens = [] 365 | for token in whitespace_tokenize(text): 366 | chars = list(token) 367 | if len(chars) > self.max_input_chars_per_word: 368 | output_tokens.append(self.unk_token) 369 | continue 370 | 371 | is_bad = False 372 | start = 0 373 | sub_tokens = [] 374 | while start < len(chars): 375 | end = len(chars) 376 | cur_substr = None 377 | while start < end: 378 | substr = "".join(chars[start:end]) 379 | if start > 0: 380 | substr = "##" + substr 381 | if substr in self.vocab: 382 | cur_substr = substr 383 | break 384 | end -= 1 385 | if cur_substr is None: 386 | is_bad = True 387 | break 388 | sub_tokens.append(cur_substr) 389 | start = end 390 | 391 | if is_bad: 392 | output_tokens.append(self.unk_token) 393 | else: 394 | output_tokens.extend(sub_tokens) 395 | return output_tokens 396 | 397 | 398 | def _is_whitespace(char): 399 | """Checks whether `chars` is a whitespace character.""" 400 | # \t, \n, and \r are technically contorl characters but we treat them 401 | # as whitespace since they are generally considered as such. 402 | if char == " " or char == "\t" or char == "\n" or char == "\r": 403 | return True 404 | cat = unicodedata.category(char) 405 | if cat == "Zs": 406 | return True 407 | return False 408 | 409 | 410 | def _is_control(char): 411 | """Checks whether `chars` is a control character.""" 412 | # These are technically control characters but we count them as whitespace 413 | # characters. 414 | if char == "\t" or char == "\n" or char == "\r": 415 | return False 416 | cat = unicodedata.category(char) 417 | if cat.startswith("C"): 418 | return True 419 | return False 420 | 421 | 422 | def _is_punctuation(char): 423 | """Checks whether `chars` is a punctuation character.""" 424 | cp = ord(char) 425 | # We treat all non-letter/number ASCII as punctuation. 426 | # Characters such as "^", "$", and "`" are not in the Unicode 427 | # Punctuation class but we treat them as punctuation anyways, for 428 | # consistency. 429 | if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or 430 | (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)): 431 | return True 432 | cat = unicodedata.category(char) 433 | if cat.startswith("P"): 434 | return True 435 | return False 436 | -------------------------------------------------------------------------------- /pytorch_pretrained_bert/optimization.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """PyTorch optimization for BERT model.""" 16 | 17 | import math 18 | import torch 19 | from torch.optim import Optimizer 20 | from torch.optim.optimizer import required 21 | from torch.nn.utils import clip_grad_norm_ 22 | 23 | from collections import defaultdict 24 | from torch._six import container_abcs 25 | from copy import deepcopy 26 | from itertools import chain 27 | 28 | 29 | def warmup_cosine(x, warmup=0.002): 30 | if x < warmup: 31 | return x/warmup 32 | return 0.5 * (1.0 + torch.cos(math.pi * x)) 33 | 34 | 35 | def warmup_constant(x, warmup=0.002): 36 | if x < warmup: 37 | return x/warmup 38 | return 1.0 39 | 40 | 41 | def warmup_linear(x, warmup=0.002): 42 | if x < warmup: 43 | return x/warmup 44 | return max((x-1.)/(warmup-1.), 0) 45 | 46 | 47 | SCHEDULES = { 48 | 'warmup_cosine': warmup_cosine, 49 | 'warmup_constant': warmup_constant, 50 | 'warmup_linear': warmup_linear, 51 | } 52 | 53 | 54 | class BertAdam(Optimizer): 55 | """Implements BERT version of Adam algorithm with weight decay fix. 56 | Params: 57 | lr: learning rate 58 | warmup: portion of t_total for the warmup, -1 means no warmup. Default: -1 59 | t_total: total number of training steps for the learning 60 | rate schedule, -1 means constant learning rate. Default: -1 61 | schedule: schedule to use for the warmup (see above). Default: 'warmup_linear' 62 | b1: Adams b1. Default: 0.9 63 | b2: Adams b2. Default: 0.999 64 | e: Adams epsilon. Default: 1e-6 65 | weight_decay: Weight decay. Default: 0.01 66 | max_grad_norm: Maximum norm for the gradients (-1 means no clipping). Default: 1.0 67 | """ 68 | 69 | def __init__(self, params, lr=required, warmup=-1, t_total=-1, schedule='warmup_linear', b1=0.9, b2=0.999, e=1e-6, weight_decay=0.01, max_grad_norm=1.0): 70 | if lr is not required and lr < 0.0: 71 | raise ValueError( 72 | "Invalid learning rate: {} - should be >= 0.0".format(lr)) 73 | if schedule not in SCHEDULES: 74 | raise ValueError("Invalid schedule parameter: {}".format(schedule)) 75 | if not 0.0 <= warmup < 1.0 and not warmup == -1: 76 | raise ValueError( 77 | "Invalid warmup: {} - should be in [0.0, 1.0[ or -1".format(warmup)) 78 | if not 0.0 <= b1 < 1.0: 79 | raise ValueError( 80 | "Invalid b1 parameter: {} - should be in [0.0, 1.0[".format(b1)) 81 | if not 0.0 <= b2 < 1.0: 82 | raise ValueError( 83 | "Invalid b2 parameter: {} - should be in [0.0, 1.0[".format(b2)) 84 | if not e >= 0.0: 85 | raise ValueError( 86 | "Invalid epsilon value: {} - should be >= 0.0".format(e)) 87 | defaults = dict(lr=lr, schedule=schedule, warmup=warmup, t_total=t_total, 88 | b1=b1, b2=b2, e=e, weight_decay=weight_decay, 89 | max_grad_norm=max_grad_norm) 90 | super(BertAdam, self).__init__(params, defaults) 91 | 92 | def get_lr(self): 93 | lr = [] 94 | for group in self.param_groups: 95 | for p in group['params']: 96 | state = self.state[p] 97 | if len(state) == 0: 98 | return [0] 99 | if group['t_total'] != -1: 100 | schedule_fct = SCHEDULES[group['schedule']] 101 | lr_scheduled = group['lr'] * schedule_fct( 102 | state['step']/group['t_total'], group['warmup']) 103 | else: 104 | lr_scheduled = group['lr'] 105 | lr.append(lr_scheduled) 106 | return lr 107 | 108 | def step(self, closure=None): 109 | """Performs a single optimization step. 110 | 111 | Arguments: 112 | closure (callable, optional): A closure that reevaluates the model 113 | and returns the loss. 114 | """ 115 | loss = None 116 | if closure is not None: 117 | loss = closure() 118 | 119 | for group in self.param_groups: 120 | for p in group['params']: 121 | if p.grad is None: 122 | continue 123 | grad = p.grad.data 124 | if grad.is_sparse: 125 | raise RuntimeError( 126 | 'Adam does not support sparse gradients, please consider SparseAdam instead') 127 | 128 | state = self.state[p] 129 | 130 | # State initialization 131 | if len(state) == 0: 132 | state['step'] = 0 133 | # Exponential moving average of gradient values 134 | state['next_m'] = torch.zeros_like(p.data) 135 | # Exponential moving average of squared gradient values 136 | state['next_v'] = torch.zeros_like(p.data) 137 | 138 | next_m, next_v = state['next_m'], state['next_v'] 139 | beta1, beta2 = group['b1'], group['b2'] 140 | 141 | # Add grad clipping 142 | if group['max_grad_norm'] > 0: 143 | clip_grad_norm_(p, group['max_grad_norm']) 144 | 145 | # Decay the first and second moment running average coefficient 146 | # In-place operations to update the averages at the same time 147 | next_m.mul_(beta1).add_(1 - beta1, grad) 148 | next_v.mul_(beta2).addcmul_(1 - beta2, grad, grad) 149 | update = next_m / (next_v.sqrt() + group['e']) 150 | 151 | # Just adding the square of the weights to the loss function is *not* 152 | # the correct way of using L2 regularization/weight decay with Adam, 153 | # since that will interact with the m and v parameters in strange ways. 154 | # 155 | # Instead we want to decay the weights in a manner that doesn't interact 156 | # with the m/v parameters. This is equivalent to adding the square 157 | # of the weights to the loss with plain (non-momentum) SGD. 158 | if group['weight_decay'] > 0.0: 159 | update += group['weight_decay'] * p.data 160 | 161 | if group['t_total'] != -1: 162 | schedule_fct = SCHEDULES[group['schedule']] 163 | lr_scheduled = group['lr'] * schedule_fct( 164 | state['step']/group['t_total'], group['warmup']) 165 | else: 166 | lr_scheduled = group['lr'] 167 | 168 | update_with_lr = lr_scheduled * update 169 | p.data.add_(-update_with_lr) 170 | 171 | state['step'] += 1 172 | 173 | # step_size = lr_scheduled * math.sqrt(bias_correction2) / bias_correction1 174 | # No bias correction 175 | # bias_correction1 = 1 - beta1 ** state['step'] 176 | # bias_correction2 = 1 - beta2 ** state['step'] 177 | 178 | return loss 179 | 180 | 181 | class BertAdamFineTune(BertAdam): 182 | def __init__(self, params, lr=required, warmup=-1, t_total=-1, schedule='warmup_linear', b1=0.9, b2=0.999, e=1e-6, weight_decay=0.01, max_grad_norm=1.0): 183 | self.init_param_group = [] 184 | super(BertAdamFineTune, self).__init__(params, lr, warmup, 185 | t_total, schedule, b1, b2, e, weight_decay, max_grad_norm) 186 | 187 | def save_init_param_group(self, param_groups, name_groups, missing_keys): 188 | self.init_param_group = [] 189 | for group, name in zip(param_groups, name_groups): 190 | if group['weight_decay'] > 0.0: 191 | init_p_list = [] 192 | for p, n in zip(group['params'], name): 193 | init_p = p.data.clone().detach() 194 | if any(mk in n for mk in missing_keys): 195 | print("[no finetuning weight decay]", n) 196 | # should use the original weight decay 197 | init_p.zero_() 198 | init_p_list.append(init_p) 199 | self.init_param_group.append(init_p_list) 200 | else: 201 | # placeholder 202 | self.init_param_group.append([]) 203 | 204 | def step(self, closure=None): 205 | """Performs a single optimization step. 206 | 207 | Arguments: 208 | closure (callable, optional): A closure that reevaluates the model 209 | and returns the loss. 210 | """ 211 | loss = None 212 | if closure is not None: 213 | loss = closure() 214 | 215 | for i_group, group in enumerate(self.param_groups): 216 | for i_p, p in enumerate(group['params']): 217 | if p.grad is None: 218 | continue 219 | grad = p.grad.data 220 | if grad.is_sparse: 221 | raise RuntimeError( 222 | 'Adam does not support sparse gradients, please consider SparseAdam instead') 223 | 224 | state = self.state[p] 225 | 226 | # State initialization 227 | if len(state) == 0: 228 | state['step'] = 0 229 | # Exponential moving average of gradient values 230 | state['next_m'] = torch.zeros_like(p.data) 231 | # Exponential moving average of squared gradient values 232 | state['next_v'] = torch.zeros_like(p.data) 233 | 234 | next_m, next_v = state['next_m'], state['next_v'] 235 | beta1, beta2 = group['b1'], group['b2'] 236 | 237 | # Add grad clipping 238 | if group['max_grad_norm'] > 0: 239 | clip_grad_norm_(p, group['max_grad_norm']) 240 | 241 | # Decay the first and second moment running average coefficient 242 | # In-place operations to update the averages at the same time 243 | next_m.mul_(beta1).add_(1 - beta1, grad) 244 | next_v.mul_(beta2).addcmul_(1 - beta2, grad, grad) 245 | update = next_m / (next_v.sqrt() + group['e']) 246 | 247 | # Just adding the square of the weights to the loss function is *not* 248 | # the correct way of using L2 regularization/weight decay with Adam, 249 | # since that will interact with the m and v parameters in strange ways. 250 | # 251 | # Instead we want to decay the weights in a manner that doesn't interact 252 | # with the m/v parameters. This is equivalent to adding the square 253 | # of the weights to the loss with plain (non-momentum) SGD. 254 | if group['weight_decay'] > 0.0: 255 | if self.init_param_group: 256 | update += group['weight_decay'] * \ 257 | (2.0 * p.data - 258 | self.init_param_group[i_group][i_p]) 259 | else: 260 | update += group['weight_decay'] * p.data 261 | 262 | if group['t_total'] != -1: 263 | schedule_fct = SCHEDULES[group['schedule']] 264 | lr_scheduled = group['lr'] * schedule_fct( 265 | state['step']/group['t_total'], group['warmup']) 266 | else: 267 | lr_scheduled = group['lr'] 268 | 269 | update_with_lr = lr_scheduled * update 270 | p.data.add_(-update_with_lr) 271 | 272 | state['step'] += 1 273 | 274 | # step_size = lr_scheduled * math.sqrt(bias_correction2) / bias_correction1 275 | # No bias correction 276 | # bias_correction1 = 1 - beta1 ** state['step'] 277 | # bias_correction2 = 1 - beta2 ** state['step'] 278 | 279 | return loss 280 | 281 | def load_state_dict_subset_finetune(self, state_dict, num_load_group): 282 | r"""Loads the optimizer state. 283 | 284 | Arguments: 285 | state_dict (dict): optimizer state. Should be an object returned 286 | from a call to :meth:`state_dict`. 287 | """ 288 | # deepcopy, to be consistent with module API 289 | state_dict = deepcopy(state_dict) 290 | # Validate the state_dict 291 | groups = self.param_groups 292 | saved_groups = state_dict['param_groups'] 293 | 294 | if len(groups) < num_load_group or len(saved_groups) < num_load_group: 295 | raise ValueError("loaded state dict has a different number of " 296 | "parameter groups") 297 | param_lens = (len(g['params']) for g in groups[:num_load_group]) 298 | saved_lens = (len(g['params']) for g in saved_groups[:num_load_group]) 299 | if any(p_len != s_len for p_len, s_len in zip(param_lens, saved_lens)): 300 | raise ValueError("loaded state dict contains a parameter group " 301 | "that doesn't match the size of optimizer's group") 302 | 303 | # Update the state 304 | id_map = {old_id: p for old_id, p in 305 | zip(chain(*(g['params'] for g in saved_groups[:num_load_group])), 306 | chain(*(g['params'] for g in groups[:num_load_group])))} 307 | 308 | def cast(param, value): 309 | r"""Make a deep copy of value, casting all tensors to device of param.""" 310 | if isinstance(value, torch.Tensor): 311 | # Floating-point types are a bit special here. They are the only ones 312 | # that are assumed to always match the type of params. 313 | if param.is_floating_point(): 314 | value = value.to(param.dtype) 315 | value = value.to(param.device) 316 | return value 317 | elif isinstance(value, dict): 318 | return {k: cast(param, v) for k, v in value.items()} 319 | elif isinstance(value, container_abcs.Iterable): 320 | return type(value)(cast(param, v) for v in value) 321 | else: 322 | return value 323 | 324 | # Copy state assigned to params (and cast tensors to appropriate types). 325 | # State that is not assigned to params is copied as is (needed for 326 | # backward compatibility). 327 | state = defaultdict(dict) 328 | for k, v in state_dict['state'].items(): 329 | if k in id_map: 330 | param = id_map[k] 331 | state[param] = cast(param, v) 332 | else: 333 | state[k] = v 334 | # handle additional params 335 | for k, v in self.state: 336 | if k not in state: 337 | state[k] = v 338 | 339 | # do not change groups: {'weight_decay': 0.01, 'lr': 9.995e-06, 'schedule': 'warmup_linear', 'warmup': 0.1, 't_total': 400000, 'b1': 0.9, 'b2': 0.999, 'e': 1e-06, 'max_grad_norm': 1.0, 'params': [...]} 340 | # # Update parameter groups, setting their 'params' value 341 | # def update_group(group, new_group): 342 | # new_group['params'] = group['params'] 343 | # return new_group 344 | # param_groups = [ 345 | # update_group(g, ng) for g, ng in zip(groups[:num_load_group], saved_groups[:num_load_group])] 346 | # # handle additional params 347 | # param_groups.extend(groups[num_load_group:]) 348 | 349 | self.__setstate__({'state': state, 'param_groups': groups}) 350 | 351 | 352 | def find_state_dict_subset_finetune(org_state_dict, org_name_list, no_decay, param_optimizer): 353 | # only use the bert encoder and embeddings 354 | want_name_set = set() 355 | for n in org_name_list: 356 | if ('bert.encoder' in n) or ('bert.embeddings' in n): 357 | want_name_set.add(n) 358 | # original: name to pid, pid to name 359 | org_grouped_names = [[n for n in org_name_list if not any(nd in n for nd in no_decay)], 360 | [n for n in org_name_list if any(nd in n for nd in no_decay)]] 361 | org_n2id, org_id2n = {}, {} 362 | for ng, pg in zip(org_grouped_names, org_state_dict['param_groups']): 363 | for n, pid in zip(ng, pg['params']): 364 | org_n2id[n] = pid 365 | org_id2n[pid] = n 366 | # group by: whether pretrained; whether weight decay 367 | g_np_list = [ 368 | [(n, p) for n, p in param_optimizer if n in want_name_set and not any( 369 | nd in n for nd in no_decay)], 370 | [(n, p) for n, p in param_optimizer if n in want_name_set and any( 371 | nd in n for nd in no_decay)], 372 | [(n, p) for n, p in param_optimizer if n not in want_name_set and not any( 373 | nd in n for nd in no_decay)], 374 | [(n, p) for n, p in param_optimizer if n not in want_name_set and any( 375 | nd in n for nd in no_decay)], 376 | ] 377 | optimizer_grouped_parameters = [ 378 | {'params': [p for n, p in g_np_list[0]], 'weight_decay': 0.01}, 379 | {'params': [p for n, p in g_np_list[1]], 'weight_decay': 0.0}, 380 | {'params': [p for n, p in g_np_list[2]], 'weight_decay': 0.01}, 381 | {'params': [p for n, p in g_np_list[3]], 'weight_decay': 0.0} 382 | ] 383 | new_state_dict = {} 384 | # regroup the original state_dict 385 | new_state_dict['state'] = {pid: v for pid, v in org_state_dict['state'].items( 386 | ) if pid not in org_id2n or org_id2n[pid] in want_name_set} 387 | # reset step count to 0 388 | for pid, st in new_state_dict['state'].items(): 389 | st['step'] = 0 390 | 391 | def _filter_group(group, g_np_list, i, org_n2id): 392 | packed = {k: v for k, v in group.items() if k != 'params'} 393 | packed['params'] = [pid for pid in group['params'] 394 | if pid in org_id2n and org_id2n[pid] in want_name_set] 395 | assert len(g_np_list[i]) == len(packed['params']) 396 | # keep them the same order 397 | packed['params'] = [org_n2id[n] for n, p in g_np_list[i]] 398 | return packed 399 | new_state_dict['param_groups'] = [_filter_group( 400 | g, g_np_list, i, org_n2id) for i, g in enumerate(org_state_dict['param_groups'])] 401 | return new_state_dict, optimizer_grouped_parameters 402 | -------------------------------------------------------------------------------- /biunilm/run_ppl.py: -------------------------------------------------------------------------------- 1 | """BERT finetuning runner.""" 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | 7 | import os 8 | import sys 9 | import copy 10 | import logging 11 | import glob 12 | import math 13 | import json 14 | import argparse 15 | import random 16 | import pickle 17 | from pathlib import Path 18 | from tqdm import tqdm, trange 19 | import numpy as np 20 | import pandas as pd 21 | import torch 22 | from torch.utils.data import RandomSampler 23 | from torch.utils.data.distributed import DistributedSampler 24 | 25 | from pytorch_pretrained_bert.tokenization import BertTokenizer, WhitespaceTokenizer 26 | from pytorch_pretrained_bert.modeling import BertForPreTrainingLossMask 27 | from pytorch_pretrained_bert.optimization import BertAdam, warmup_linear 28 | 29 | from nn.data_parallel import DataParallelImbalance 30 | import biunilm.seq2seq_loader as seq2seq_loader 31 | import torch.distributed as dist 32 | 33 | 34 | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', 35 | datefmt='%m/%d/%Y %H:%M:%S', 36 | level=logging.INFO) 37 | logger = logging.getLogger(__name__) 38 | 39 | 40 | def _get_max_epoch_model(output_dir): 41 | fn_model_list = glob.glob(os.path.join(output_dir, "model.*.bin")) 42 | # fn_optim_list = glob.glob(os.path.join(output_dir, "optim.*.bin")) 43 | # if (not fn_model_list) or (not fn_optim_list): 44 | # return None 45 | 46 | if not fn_model_list: 47 | return None 48 | 49 | both_set = set([int(Path(fn).stem.split('.')[-1]) for fn in fn_model_list] 50 | ) 51 | if both_set: 52 | # *ZY* 53 | global_step = str(max(both_set)) 54 | fn_model = [s for s in fn_model_list if global_step in s] 55 | assert len(fn_model) == 1 56 | # fn_optim = [s for s in fn_optim_list if global_step in s] 57 | # assert len(fn_optim) == 1 58 | 59 | tmp = Path(fn_model[0]).stem.split('.')[-2].strip().split('_') 60 | n_epoch = int(tmp[0].strip('e').strip()) 61 | n_step = int(tmp[1].strip('s').strip()) 62 | return [fn_model[0], None, int(global_step), n_epoch, n_step] 63 | else: 64 | return None 65 | 66 | 67 | def pre_preprocess(train_flag, args, data_tokenizer, bi_uni_pipeline): 68 | train_flag = 'test' 69 | 70 | # TODO: PPL 71 | dial_src = os.path.join(args.data_dir, "dial.{:}".format(train_flag)) 72 | dial_ppl_src = os.path.join(args.data_dir, "dial.{:}.ppl".format(train_flag)) 73 | if not os.path.exists(dial_ppl_src): 74 | n_write = 0 75 | with open(dial_ppl_src, 'wt') as wf: 76 | with open(dial_src, 'rt') as rf: 77 | for line in rf: 78 | src, usrid, tgt, data_type = line.strip().split('\t')[:4] 79 | elems = tgt.strip().split(' ') 80 | 81 | for idx in range(len(elems)): 82 | word = elems[idx].strip() 83 | if len(word): 84 | wf.write('\t'.join([src, usrid, ' '.join(elems[:idx+1]), data_type])+'\n') 85 | n_write += 1 86 | 87 | logger.info("Write {:} samples for perplexity calculation to {:}".format(n_write, dial_ppl_src)) 88 | else: 89 | logger.info("Read ppl test file: {:}".format(dial_ppl_src)) 90 | 91 | dataset = seq2seq_loader.MyDataset( 92 | [dial_ppl_src], args.eval_batch_size, data_tokenizer, 93 | args.max_seq_length, preprocess=bi_uni_pipeline, accept_dtypes=['dial']) 94 | 95 | return dataset 96 | 97 | 98 | def validate(model, valid_dataloader, device, n_gpu): 99 | valid_ppl = 0 100 | n_samples = 0 101 | n_tokens = 0 102 | 103 | batch_size_gpu = int(valid_dataloader.batch_size / n_gpu) 104 | 105 | iter_bar = tqdm(valid_dataloader, desc='Iter (loss=X.XXX)') 106 | 107 | with torch.no_grad(): 108 | for step, batch in enumerate(iter_bar): 109 | batch = [ 110 | t.to(device) if t is not None else None for t in batch] 111 | 112 | num_tokens_a, num_tokens_b, input_ids, usrid_ids, segment_ids, input_mask, mask_qkv, lm_label_ids, masked_pos, masked_weights, is_next, task_idx = batch 113 | oracle_pos, oracle_weights, oracle_labels = None, None, None 114 | input_mask = None 115 | assert segment_ids is not None 116 | 117 | loss_tuple = model(input_ids, usrid_ids, segment_ids, input_mask, lm_label_ids, is_next, 118 | masked_pos=masked_pos, masked_weights=masked_weights, task_idx=task_idx, 119 | num_tokens_a=num_tokens_a, num_tokens_b=num_tokens_b, 120 | masked_pos_2=oracle_pos, masked_weights_2=oracle_weights, 121 | masked_labels_2=oracle_labels, mask_qkv=mask_qkv, is_ppl_eval=True) 122 | 123 | masked_lm_loss, next_sentence_loss, ppl = loss_tuple 124 | 125 | if n_gpu > 1: # mean() to average on multi-gpu. 126 | # loss = loss.mean() 127 | masked_lm_loss = masked_lm_loss.mean() 128 | # next_sentence_loss = next_sentence_loss.mean() 129 | ppl = ppl.sum() 130 | 131 | # loss = masked_lm_loss + next_sentence_loss 132 | loss = masked_lm_loss 133 | iter_bar.set_description('Iter (loss=%5.3f)' % loss.item()) 134 | 135 | valid_ppl += ppl.item() 136 | n_tokens += masked_weights.sum().item() 137 | n_samples += len(task_idx) 138 | 139 | # ppl = np.exp(valid_ppl / n_samples) 140 | ppl = np.exp(valid_ppl / n_tokens) # n_tokens == n_samples, I masked one token per sample 141 | 142 | return ppl 143 | 144 | 145 | def save(model, optimizer, args, i_epoch, i_step, global_step): 146 | model_to_save = model.module if hasattr( 147 | model, 'module') else model # Only save the model it-self 148 | output_model_file = os.path.join( 149 | args.output_dir, "model.e{:}_s{:}.{:}.bin".format(i_epoch, i_step, global_step)) 150 | torch.save(model_to_save.state_dict(), output_model_file) 151 | output_optim_file = os.path.join( 152 | args.output_dir, "optim.e{:}_s{:}.{:}.bin".format(i_epoch, i_step, global_step)) 153 | torch.save(optimizer.state_dict(), output_optim_file) 154 | 155 | 156 | def main(): 157 | parser = argparse.ArgumentParser() 158 | 159 | parser.add_argument('--n_clayer', type=int, required=True, 160 | help="n conditional layer") 161 | 162 | parser.add_argument('--gate', type=str, default="attn", 163 | help="gate method: [attn|gate|gate_x2] ") 164 | 165 | # Required parameters 166 | parser.add_argument("--data_dir", 167 | default=None, 168 | type=str, 169 | required=True, 170 | help="The input data dir. Should contain the .tsv files (or other data files) for the task.") 171 | 172 | parser.add_argument("--c_tfidf_map", type=str, required=True, 173 | help="e.g. c_tfidf_map.pkl in args.data_dir") 174 | 175 | # parser.add_argument("--tgt_file", default=None, type=str, 176 | # help="The output data file name.") 177 | 178 | parser.add_argument("--bert_model", default=None, type=str, required=True, 179 | help="Bert pre-trained model selected in the list: bert-base-uncased, " 180 | "bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese.") 181 | parser.add_argument("--config_path", default=None, type=str, 182 | help="Bert config file path.") 183 | parser.add_argument("--output_dir", 184 | default=None, 185 | type=str, 186 | required=True, 187 | help="The output directory where the model predictions and checkpoints will be written.") 188 | parser.add_argument("--log_dir", 189 | default='', 190 | type=str, 191 | required=True, 192 | help="The output directory where the log will be written.") 193 | parser.add_argument("--model_recover_path", 194 | default=None, 195 | type=str, 196 | help="The file of fine-tuned pretraining model.") 197 | parser.add_argument("--optim_recover_path", 198 | default=None, 199 | type=str, 200 | help="The file of pretraining optimizer.") 201 | # Other parameters 202 | parser.add_argument("--max_seq_length", 203 | default=128, 204 | type=int, 205 | help="The maximum total input sequence length after WordPiece tokenization. \n" 206 | "Sequences longer than this will be truncated, and sequences shorter \n" 207 | "than this will be padded.") 208 | 209 | parser.add_argument("--do_lower_case", 210 | action='store_true', 211 | help="Set this flag if you are using an uncased model.") 212 | parser.add_argument("--train_batch_size", 213 | default=32, 214 | type=int, 215 | help="Total batch size for training.") 216 | parser.add_argument("--eval_batch_size", 217 | default=64, 218 | type=int, 219 | help="Total batch size for eval.") 220 | parser.add_argument("--valid_steps", 221 | default=8192, 222 | type=int) 223 | 224 | parser.add_argument("--learning_rate", default=3e-5, type=float, 225 | help="The initial learning rate for Adam.") 226 | parser.add_argument("--label_smoothing", default=0, type=float, 227 | help="The initial learning rate for Adam.") 228 | parser.add_argument("--weight_decay", 229 | default=0.01, 230 | type=float, 231 | help="The weight decay rate for Adam.") 232 | parser.add_argument("--finetune_decay", 233 | action='store_true', 234 | help="Weight decay to the original weights.") 235 | 236 | parser.add_argument("--warmup_proportion", 237 | default=0.1, 238 | type=float, 239 | help="Proportion of training to perform linear learning rate warmup for. " 240 | "E.g., 0.1 = 10%% of training.") 241 | parser.add_argument("--hidden_dropout_prob", default=0.1, type=float, 242 | help="Dropout rate for hidden states.") 243 | parser.add_argument("--attention_probs_dropout_prob", default=0.1, type=float, 244 | help="Dropout rate for attention probabilities.") 245 | parser.add_argument("--no_cuda", 246 | action='store_true', 247 | help="Whether not to use CUDA when available") 248 | parser.add_argument("--local_rank", 249 | type=int, 250 | default=-1, 251 | help="local_rank for distributed training on gpus") 252 | parser.add_argument('--seed', 253 | type=int, 254 | default=42, 255 | help="random seed for initialization") 256 | parser.add_argument('--gradient_accumulation_steps', 257 | type=int, 258 | default=1, 259 | help="Number of updates steps to accumulate before performing a backward/update pass.") 260 | parser.add_argument('--fp16', action='store_true', 261 | help="Whether to use 16-bit float precision instead of 32-bit") 262 | parser.add_argument('--fp32_embedding', action='store_true', 263 | help="Whether to use 32-bit float precision instead of 16-bit for embeddings") 264 | parser.add_argument('--loss_scale', type=float, default=0, 265 | help="Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n" 266 | "0 (default value): dynamic loss scaling.\n" 267 | "Positive power of 2: static loss scaling value.\n") 268 | parser.add_argument('--amp', action='store_true', 269 | help="Whether to use amp for fp16") 270 | parser.add_argument('--from_scratch', action='store_true', 271 | help="Initialize parameters with random values (i.e., training from scratch).") 272 | parser.add_argument('--new_segment_ids', action='store_true', 273 | help="Use new segment ids for bi-uni-directional LM.") 274 | parser.add_argument('--new_pos_ids', action='store_true', 275 | help="Use new position ids for LMs.") 276 | 277 | parser.add_argument('--tokenized_input', action='store_true', 278 | help="Whether the input is tokenized.") 279 | 280 | parser.add_argument('--max_len_a', type=int, default=0, 281 | help="Truncate_config: maximum length of segment A.") 282 | parser.add_argument('--max_len_b', type=int, default=0, 283 | help="Truncate_config: maximum length of segment B.") 284 | parser.add_argument('--trunc_seg', default='', 285 | help="Truncate_config: first truncate segment A/B (option: a, b).") 286 | parser.add_argument('--always_truncate_tail', action='store_true', 287 | help="Truncate_config: Whether we should always truncate tail.") 288 | parser.add_argument("--mask_prob", default=0.15, type=float, 289 | help="Number of prediction is sometimes less than max_pred when sequence is short.") 290 | parser.add_argument("--mask_prob_eos", default=0, type=float, 291 | help="Number of prediction is sometimes less than max_pred when sequence is short.") 292 | parser.add_argument('--max_pred', type=int, default=20, 293 | help="Max tokens of prediction.") 294 | parser.add_argument("--num_workers", default=0, type=int, 295 | help="Number of workers for the data loader.") 296 | 297 | parser.add_argument('--mask_source_words', action='store_true', 298 | help="Whether to mask source words for training") 299 | parser.add_argument('--skipgram_prb', type=float, default=0.0, 300 | help='prob of ngram mask') 301 | parser.add_argument('--skipgram_size', type=int, default=1, 302 | help='the max size of ngram mask') 303 | parser.add_argument('--mask_whole_word', action='store_true', 304 | help="Whether masking a whole word.") 305 | parser.add_argument('--do_l2r_training', action='store_true', 306 | help="Whether to do left to right training") 307 | parser.add_argument('--has_sentence_oracle', action='store_true', 308 | help="Whether to have sentence level oracle for training. " 309 | "Only useful for summary generation") 310 | parser.add_argument('--max_position_embeddings', type=int, default=None, 311 | help="max position embeddings") 312 | parser.add_argument('--relax_projection', action='store_true', 313 | help="Use different projection layers for tasks.") 314 | parser.add_argument('--ffn_type', default=0, type=int, 315 | help="0: default mlp; 1: W((Wx+b) elem_prod x);") 316 | parser.add_argument('--num_qkv', default=0, type=int, 317 | help="Number of different .") 318 | parser.add_argument('--seg_emb', action='store_true', 319 | help="Using segment embedding for self-attention.") 320 | parser.add_argument('--s2s_special_token', action='store_true', 321 | help="New special tokens ([S2S_SEP]/[S2S_CLS]) of S2S.") 322 | parser.add_argument('--s2s_add_segment', action='store_true', 323 | help="Additional segmental for the encoder of S2S.") 324 | parser.add_argument('--s2s_share_segment', action='store_true', 325 | help="Sharing segment embeddings for the encoder of S2S (used with --s2s_add_segment).") 326 | parser.add_argument('--pos_shift', action='store_true', 327 | help="Using position shift for fine-tuning.") 328 | 329 | args = parser.parse_args() 330 | 331 | # Fine-tune use 332 | # assert Path(args.model_recover_path).exists( 333 | # ), "--model_recover_path doesn't exist" 334 | 335 | args.output_dir = args.output_dir.replace( 336 | '[PT_OUTPUT_DIR]', os.getenv('PT_OUTPUT_DIR', '')) 337 | args.log_dir = args.log_dir.replace( 338 | '[PT_OUTPUT_DIR]', os.getenv('PT_OUTPUT_DIR', '')) 339 | 340 | os.makedirs(args.output_dir, exist_ok=True) 341 | os.makedirs(args.log_dir, exist_ok=True) 342 | json.dump(args.__dict__, open(os.path.join( 343 | args.output_dir, 'opt.json'), 'w'), sort_keys=True, indent=2) 344 | 345 | if args.local_rank == -1 or args.no_cuda: 346 | device = torch.device( 347 | "cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") 348 | n_gpu = torch.cuda.device_count() 349 | 350 | else: 351 | torch.cuda.set_device(args.local_rank) 352 | device = torch.device("cuda", args.local_rank) 353 | n_gpu = 1 354 | # Initializes the distributed backend which will take care of sychronizing nodes/GPUs 355 | dist.init_process_group(backend='nccl') 356 | 357 | logger.info("device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".format( 358 | device, n_gpu, bool(args.local_rank != -1), args.fp16)) 359 | 360 | if args.gradient_accumulation_steps < 1: 361 | raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format( 362 | args.gradient_accumulation_steps)) 363 | 364 | args.train_batch_size = int( 365 | args.train_batch_size / args.gradient_accumulation_steps) 366 | 367 | random.seed(args.seed) 368 | np.random.seed(args.seed) 369 | torch.manual_seed(args.seed) 370 | if n_gpu > 0: 371 | torch.cuda.manual_seed_all(args.seed) 372 | 373 | if args.local_rank not in (-1, 0): 374 | # Make sure only the first process in distributed training will download model & vocab 375 | dist.barrier() 376 | if args.local_rank == 0: 377 | dist.barrier() 378 | 379 | ################################### 380 | # *ZY* 381 | # Load User Mask 382 | with open(os.path.join(args.data_dir, args.c_tfidf_map), 'rb') as f: 383 | c_tfidf_map = pickle.load(f) 384 | 385 | # Get User Indexer 386 | c_indexer = {cid: index for index, cid in enumerate(sorted(list(c_tfidf_map.keys())))} 387 | logger.info("{:} conditions.".format(len(c_indexer))) 388 | 389 | tokenizer = BertTokenizer.from_pretrained( 390 | args.bert_model, do_lower_case=args.do_lower_case) 391 | if args.max_position_embeddings: 392 | tokenizer.max_len = args.max_position_embeddings 393 | data_tokenizer = WhitespaceTokenizer() if args.tokenized_input else tokenizer 394 | 395 | if not args.tokenized_input: 396 | logger.warning("Strongly recommend using BertTokenizer(# Slow) before.") 397 | 398 | bi_uni_pipeline = [seq2seq_loader.Preprocess4Seq2seq(args.max_pred, args.mask_prob, list(tokenizer.vocab.keys( 399 | )), tokenizer.convert_tokens_to_ids, args.max_seq_length, new_segment_ids=args.new_segment_ids, 400 | truncate_config={'max_len_a': args.max_len_a, 401 | 'max_len_b': args.max_len_b, 402 | 'trunc_seg': args.trunc_seg, 403 | 'always_truncate_tail': args.always_truncate_tail}, 404 | mask_source_words=args.mask_source_words, 405 | skipgram_prb=args.skipgram_prb, 406 | skipgram_size=args.skipgram_size, 407 | mask_whole_word=args.mask_whole_word, mode="s2s", 408 | has_oracle=args.has_sentence_oracle, num_qkv=args.num_qkv, 409 | s2s_special_token=args.s2s_special_token, 410 | s2s_add_segment=args.s2s_add_segment, 411 | s2s_share_segment=args.s2s_share_segment, 412 | pos_shift=args.pos_shift, c_indexer=c_indexer, 413 | c_tfidf_map=c_tfidf_map, only_mask_last=True)] 414 | 415 | logger.info("Preprocess Test Set...") 416 | valid_dataset = pre_preprocess('test', args, data_tokenizer, bi_uni_pipeline) 417 | 418 | valid_dataloader = torch.utils.data.DataLoader(valid_dataset, batch_size=args.eval_batch_size, 419 | num_workers=args.num_workers, shuffle=False, 420 | collate_fn=seq2seq_loader.batch_list_to_batch_tensors, pin_memory=False) 421 | 422 | special_num_here = 2048 423 | recover_step = _get_max_epoch_model(args.output_dir) 424 | # (fn_model[0], fn_optim[0], int(global_step), n_epoch, n_step) 425 | 426 | if recover_step: 427 | if recover_step[-1] % special_num_here == 0: 428 | n_finished_epoch = recover_step[-2] - 1 429 | else: 430 | n_finished_epoch = recover_step[-2] 431 | recover_step[-1] = 0 # step in an epoch 432 | else: 433 | n_finished_epoch = 0 434 | 435 | logger.info("### Finished {:} Epoch(s) ###".format(n_finished_epoch)) 436 | 437 | amp_handle = None 438 | if args.fp16 and args.amp: 439 | raise NotImplementedError 440 | # from apex import amp 441 | # amp_handle = amp.init(enable_caching=True) 442 | # logger.info("enable fp16 with amp") 443 | 444 | # Prepare model 445 | cls_num_labels = 2 446 | 447 | type_vocab_size = 2 # V2 448 | 449 | c_indexer = torch.load(os.path.join(args.data_dir, 'c_indexer.pt')) 450 | n_condition = len(c_indexer) 451 | if '' not in c_indexer.keys(): 452 | n_condition += 1 453 | 454 | num_sentlvl_labels = 2 if args.has_sentence_oracle else 0 455 | relax_projection = 4 if args.relax_projection else 0 456 | if args.local_rank not in (-1, 0): 457 | # Make sure only the first process in distributed training will download model & vocab 458 | dist.barrier() 459 | if (recover_step is None) and (args.model_recover_path is None): 460 | raise ValueError 461 | 462 | else: 463 | if recover_step: 464 | assert args.model_recover_path is None # TODO: automatically recover to most recent model 465 | logger.info("***** Recover model: {:} *****".format(recover_step[0])) 466 | model_recover = torch.load(recover_step[0], map_location='cpu') 467 | # recover_step == number of epochs 468 | assert isinstance(recover_step[2], int) 469 | global_step = recover_step[2] 470 | elif args.model_recover_path: 471 | logger.info("***** (ONLY)Recover model: %s *****", 472 | args.model_recover_path) 473 | model_recover = torch.load( 474 | args.model_recover_path, map_location='cpu') 475 | global_step = 0 476 | 477 | n_dial = 10 # FAKE 478 | model = BertForPreTrainingLossMask.from_pretrained( 479 | args.bert_model, state_dict=model_recover, num_labels=cls_num_labels, num_rel=0, 480 | type_vocab_size=type_vocab_size, config_path=args.config_path, task_idx=3, 481 | num_sentlvl_labels=num_sentlvl_labels, max_position_embeddings=args.max_position_embeddings, 482 | label_smoothing=args.label_smoothing, fp32_embedding=args.fp32_embedding, relax_projection=relax_projection, 483 | new_pos_ids=args.new_pos_ids, ffn_type=args.ffn_type, hidden_dropout_prob=args.hidden_dropout_prob, 484 | attention_probs_dropout_prob=args.attention_probs_dropout_prob, num_qkv=args.num_qkv, seg_emb=args.seg_emb, 485 | n_condition=n_condition, n_dial=n_dial, n_clayer=args.n_clayer, gate=args.gate) 486 | 487 | if args.local_rank == 0: 488 | dist.barrier() 489 | 490 | if args.fp16: 491 | model.half() 492 | if args.fp32_embedding: 493 | model.bert.embeddings.word_embeddings.float() 494 | model.bert.embeddings.position_embeddings.float() 495 | model.bert.embeddings.token_type_embeddings.float() 496 | 497 | model.to(device) 498 | if args.local_rank != -1: 499 | try: 500 | from torch.nn.parallel import DistributedDataParallel as DDP 501 | except ImportError: 502 | raise ImportError("DistributedDataParallel") 503 | model = DDP(model, device_ids=[ 504 | args.local_rank], output_device=args.local_rank, find_unused_parameters=True) 505 | elif n_gpu > 1: 506 | # model = torch.nn.DataParallel(model) 507 | model = DataParallelImbalance(model) 508 | 509 | logger.info("***** CUDA.empty_cache() *****") 510 | torch.cuda.empty_cache() 511 | 512 | model.eval() 513 | # logger.info("### First Valid") 514 | valid_loss = validate(model, valid_dataloader, device, n_gpu) 515 | logger.info("### PPL {:.3f}".format(valid_loss)) 516 | 517 | 518 | if __name__ == "__main__": 519 | main() 520 | -------------------------------------------------------------------------------- /biunilm/seq2seq_loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import math 4 | from random import randint, shuffle, uniform 5 | from random import random as rand 6 | from random import sample as sample_func 7 | 8 | from numpy import array 9 | from numpy.random import choice 10 | 11 | import torch 12 | 13 | from biunilm.loader_utils import get_random_word, batch_list_to_batch_tensors, Pipeline 14 | 15 | # Input file format : 16 | # 1. One sentence per line. These should ideally be actual sentences, 17 | # not entire paragraphs or arbitrary spans of text. (Because we use 18 | # the sentence boundaries for the "next sentence prediction" task). 19 | # 2. Blank lines between documents. Document boundaries are needed 20 | # so that the "next sentence prediction" task doesn't span between documents. 21 | 22 | 23 | def truncate_tokens_pair(tokens_a, tokens_b, max_len, max_len_a=0, max_len_b=0, trunc_seg=None, always_truncate_tail=False): 24 | num_truncated_a = [0, 0] 25 | num_truncated_b = [0, 0] 26 | while True: 27 | if len(tokens_a) + len(tokens_b) <= max_len: 28 | break 29 | if (max_len_a > 0) and len(tokens_a) > max_len_a: 30 | trunc_tokens = tokens_a 31 | num_truncated = num_truncated_a 32 | elif (max_len_b > 0) and len(tokens_b) > max_len_b: 33 | trunc_tokens = tokens_b 34 | num_truncated = num_truncated_b 35 | elif trunc_seg: 36 | # truncate the specified segment 37 | if trunc_seg == 'a': 38 | trunc_tokens = tokens_a 39 | num_truncated = num_truncated_a 40 | else: 41 | trunc_tokens = tokens_b 42 | num_truncated = num_truncated_b 43 | else: 44 | # truncate the longer segment 45 | if len(tokens_a) > len(tokens_b): 46 | trunc_tokens = tokens_a 47 | num_truncated = num_truncated_a 48 | else: 49 | trunc_tokens = tokens_b 50 | num_truncated = num_truncated_b 51 | # whether always truncate source sequences 52 | if (not always_truncate_tail) and (rand() < 0.5): 53 | del trunc_tokens[0] 54 | num_truncated[0] += 1 55 | else: 56 | trunc_tokens.pop() 57 | num_truncated[1] += 1 58 | return num_truncated_a, num_truncated_b 59 | 60 | 61 | class MyDataset(torch.utils.data.Dataset): 62 | def __init__(self, file_src_list, batch_size, tokenizer, max_len, file_oracle=None, short_sampling_prob=0.1, sent_reverse_order=False, preprocess=[], 63 | n_dial=-1, n_text=-1, accept_dtypes=[]): 64 | super(MyDataset).__init__() 65 | self.tokenizer = tokenizer # tokenize function 66 | 67 | print("### I set minimum source length to 4.") 68 | self.min_src_len = 4 # TODO: !!! 69 | 70 | self.max_len = max_len # maximum length of tokens 71 | self.short_sampling_prob = short_sampling_prob 72 | assert isinstance(preprocess, list) 73 | assert len(preprocess) == 1 74 | self.preprocess = preprocess 75 | self.n_condition = len(self.preprocess[0].c_indexer) 76 | 77 | self.batch_size = batch_size 78 | self.sent_reverse_order = sent_reverse_order 79 | 80 | assert file_oracle is None 81 | 82 | assert len(accept_dtypes) > 0 83 | 84 | # read the file into memory 85 | self.is_pretrain = True 86 | dial = [] 87 | non_text = [] 88 | c_text = [] 89 | 90 | assert isinstance(file_src_list, list) 91 | for file_src in file_src_list: 92 | with open(file_src, "r", encoding='utf-8') as f: 93 | for index, line in enumerate(f): 94 | if index % 500000 == 0: 95 | print('Preprocess the {:}th line...'.format(index)) 96 | sys.stdout.flush() 97 | 98 | src, cond, tgt, data_type = line.strip('\n').split('\t')[:4] 99 | src_tk = tokenizer.tokenize(src.strip()) 100 | tgt_tk = tokenizer.tokenize(tgt.strip()) 101 | cond = cond.strip() 102 | 103 | if len(src_tk) < self.min_src_len: 104 | src_tk = [] # TODO: !!! 105 | 106 | sample = (src_tk, tgt_tk, cond, data_type) 107 | 108 | if len(tgt_tk) > 0 and len(cond) > 0: 109 | if data_type in accept_dtypes: 110 | if data_type == 'dial': 111 | if len(src_tk): 112 | dial.append(sample) 113 | elif data_type == 'mono': 114 | if cond == ' ': 115 | non_text.append(sample) 116 | else: 117 | c_text.append(sample) 118 | else: 119 | raise ValueError 120 | 121 | if 0 <= n_dial < len(dial): 122 | dial = sample_func(dial, n_dial) 123 | 124 | if 0 <= n_text < len(c_text): 125 | c_text = sample_func(c_text, n_text) 126 | 127 | print('Load {:} labeled dial samples.'.format(len(dial))) 128 | print('Load {:} labeled text samples.'.format(len(c_text))) 129 | print('Load {:} text samples.'.format(len(non_text))) 130 | 131 | if len(non_text): 132 | raise NotImplementedError # I have not checked it. 133 | 134 | self.n_samples = len(dial) + len(c_text) + len(non_text) 135 | self.n_dial_samples = len(dial) # 0215 136 | self.ex_list = [dial, c_text, non_text] 137 | 138 | self.index_map = {} 139 | index = 0 140 | for idx, _ in enumerate(dial): 141 | assert index not in self.index_map.keys() 142 | self.index_map[index] = (0, idx) 143 | index += 1 144 | 145 | for idx, _ in enumerate(c_text): 146 | assert index not in self.index_map.keys() 147 | self.index_map[index] = (1, idx) 148 | index += 1 149 | 150 | for idx, _ in enumerate(non_text): 151 | assert index not in self.index_map.keys() 152 | self.index_map[index] = (2, idx) 153 | index += 1 154 | 155 | assert list(self.index_map.keys()) == list(range(self.n_samples)) 156 | 157 | def __len__(self): 158 | return self.n_samples 159 | 160 | def __getitem__(self, index): 161 | data_type, idx = self.index_map[index] 162 | instance = self.preprocess[0](self.ex_list[data_type][idx]) 163 | return instance 164 | 165 | 166 | class MySampler(torch.utils.data.Sampler): 167 | def __init__(self, my_dataset, batch_size, n_gpu, n_ctext=-1, equal_sample=False): 168 | assert isinstance(my_dataset, MyDataset) 169 | assert batch_size % n_gpu == 0 170 | 171 | self.batch_size = batch_size 172 | self.n_gpu = n_gpu 173 | self.batch_size_gpu = int(self.batch_size / self.n_gpu) 174 | 175 | self.n_samples = my_dataset.n_samples 176 | 177 | self.dial_index = [] 178 | self.ctext_index = [] 179 | self.non_index = [] 180 | for index, p in my_dataset.index_map.items(): 181 | if p[0] == 0: 182 | self.dial_index.append(index) 183 | elif p[0] == 1: 184 | self.ctext_index.append(index) 185 | elif p[0] == 2: 186 | self.non_index.append(index) 187 | else: 188 | raise ValueError 189 | 190 | print("### Train Set: dial {:}, ctext {:}, non-text {:}".format(len(self.dial_index), 191 | len(self.ctext_index), 192 | len(self.non_index))) 193 | 194 | if n_ctext > 0: 195 | self.n_ctext = min(n_ctext, self.batch_size_gpu) 196 | self.n_non = 0 197 | self.n_dial = 0 198 | 199 | else: 200 | self.n_non = 0 201 | if equal_sample: 202 | self.n_ctext = round(self.batch_size_gpu * 1 / 2) if len(self.ctext_index) else 0 203 | else: 204 | self.n_ctext = round(self.batch_size_gpu * 1 / 4) if len(self.ctext_index) else 0 205 | 206 | self.n_dial = self.batch_size_gpu - self.n_ctext - self.n_non 207 | 208 | print("### Sampler: dial {:}, ctext {:}, non-text {:}".format(self.n_dial, self.n_ctext, self.n_non)) 209 | assert self.n_dial >= 0 210 | 211 | self.dial_gen = self.get_batch_index_generator(self.dial_index, self.n_dial) 212 | self.ctext_gen = self.get_batch_index_generator(self.ctext_index, self.n_ctext) 213 | self.non_gen = self.get_batch_index_generator(self.non_index, self.n_non) 214 | 215 | def __len__(self): 216 | # return math.ceil(self.n_samples / float(self.batch_size)) 217 | return self.n_samples 218 | 219 | def __iter__(self): # iterator to load data 220 | for __ in range(math.ceil(self.n_samples / float(self.batch_size))): 221 | batch_index = [] 222 | 223 | for i in range(self.n_gpu): 224 | batch_index_gpu = self.get_batch() 225 | batch_index.extend(batch_index_gpu) 226 | 227 | for index in batch_index: 228 | yield index 229 | 230 | def get_batch(self): 231 | batch_index = [] 232 | if self.n_dial > 0: 233 | try: 234 | batch_index.extend(next(self.dial_gen)) 235 | except StopIteration: 236 | self.dial_gen = self.get_batch_index_generator(self.dial_index, self.n_dial) 237 | batch_index.extend(next(self.dial_gen)) 238 | 239 | if self.n_ctext > 0: 240 | try: 241 | batch_index.extend(next(self.ctext_gen)) 242 | except StopIteration: 243 | self.ctext_gen = self.get_batch_index_generator(self.ctext_index, self.n_ctext) 244 | batch_index.extend(next(self.ctext_gen)) 245 | 246 | if self.n_non > 0: 247 | try: 248 | batch_index.extend(next(self.non_gen)) 249 | except StopIteration: 250 | self.non_gen = self.get_batch_index_generator(self.non_index, self.n_non) 251 | batch_index.extend(next(self.non_gen)) 252 | 253 | return batch_index 254 | 255 | def get_batch_index_generator(self, a_list, batch_size): 256 | def get_batch_index(a_list, batch_size): 257 | assert isinstance(a_list, list) 258 | for start in range(0, len(a_list), batch_size): 259 | yield a_list[start:start + batch_size] 260 | 261 | assert isinstance(a_list, list) 262 | a_list = sample_func(a_list, len(a_list)) 263 | generator = get_batch_index(a_list, batch_size) 264 | return generator 265 | 266 | 267 | class Preprocess4Seq2seq(Pipeline): 268 | """ Pre-processing steps for pretraining transformer """ 269 | 270 | def __init__(self, max_pred, mask_prob, vocab_words, indexer, max_len=512, skipgram_prb=0, skipgram_size=0, 271 | block_mask=False, mask_whole_word=False, new_segment_ids=False, truncate_config={}, mask_source_words=False, 272 | mode="s2s", has_oracle=False, num_qkv=0, s2s_special_token=False, s2s_add_segment=False, 273 | s2s_share_segment=False, pos_shift=False, 274 | c_indexer=None, c_tfidf_map=None, tfidf_eps=1e-8, dial_mask_rate=0, only_mask_last=False, FGfree_indexer=None): 275 | super().__init__() 276 | self.max_len = max_len 277 | self.max_pred = max_pred # max tokens of prediction 278 | self.mask_prob = mask_prob # masking probability 279 | self.vocab_words = vocab_words # vocabulary (sub)words 280 | self.indexer = indexer # function from token to token index 281 | self.FGfree_indexer = FGfree_indexer 282 | 283 | # *ZY* 284 | self.dial_mask_rate = dial_mask_rate 285 | self.only_mask_last = only_mask_last # TODO: to calculate perplexity 286 | 287 | assert isinstance(c_tfidf_map, dict) 288 | self.c_tfidf_map = c_tfidf_map 289 | self.tfidf_eps = tfidf_eps 290 | 291 | self.nan_cond = ' ' 292 | assert isinstance(c_indexer, dict) 293 | if self.nan_cond not in c_indexer.keys(): 294 | print('#'*10+'To add condition, we re-arranged c_indexer (+1)'+'#'*10) 295 | sys.stdout.flush() 296 | self.c_indexer = {self.nan_cond: 0} 297 | for i, u in enumerate(c_indexer.keys()): 298 | self.c_indexer[u] = i + 1 299 | else: 300 | self.c_indexer = c_indexer 301 | 302 | # Check 303 | assert sorted(list(self.c_indexer.values())) == list(range(len(self.c_indexer))) 304 | 305 | self.max_len = max_len 306 | self._tril_matrix = torch.tril(torch.ones( 307 | (max_len, max_len), dtype=torch.long)) 308 | self.skipgram_prb = skipgram_prb 309 | self.skipgram_size = skipgram_size 310 | self.mask_whole_word = mask_whole_word 311 | self.new_segment_ids = new_segment_ids 312 | self.always_truncate_tail = truncate_config.get( 313 | 'always_truncate_tail', False) 314 | self.max_len_a = truncate_config.get('max_len_a', None) 315 | self.max_len_b = truncate_config.get('max_len_b', None) 316 | self.trunc_seg = truncate_config.get('trunc_seg', None) 317 | self.task_idx = 3 # relax projection layer for different tasks 318 | self.mask_source_words = mask_source_words 319 | assert mode in ("s2s", "l2r") 320 | self.mode = mode 321 | self.has_oracle = has_oracle 322 | self.num_qkv = num_qkv 323 | self.s2s_special_token = s2s_special_token 324 | self.s2s_add_segment = s2s_add_segment 325 | self.s2s_share_segment = s2s_share_segment 326 | self.pos_shift = pos_shift 327 | 328 | assert self.has_oracle is False 329 | assert self.pos_shift is False # I did not check this option 330 | assert self.num_qkv == 0 331 | 332 | def tfidf_mask(self, cid, cand_pos_tk, n_sample): 333 | tk_tfidf = [] 334 | for _, tk in cand_pos_tk: 335 | try: 336 | tk_tfidf.append(max(self.c_tfidf_map[cid][tk], self.tfidf_eps)) 337 | except KeyError: 338 | tk_tfidf.append(self.tfidf_eps) 339 | 340 | tk_tfidf = array(tk_tfidf) 341 | tk_tfidf = tk_tfidf / tk_tfidf.sum() 342 | 343 | tk_index = choice(range(len(tk_tfidf)), size=n_sample, replace=False, p=tk_tfidf).tolist() 344 | 345 | return [cand_pos_tk[idx][0] for idx in tk_index] 346 | 347 | def preprocess(self, tokens_a, tokens_b, cond, task_idx): 348 | 349 | # tokens_a = ['i', 'love', 'you'] 350 | # tokens_b = ['you', 'like', 'me'] 351 | # cond = ' ' 352 | # task_idx = 3 353 | 354 | try: 355 | cid = self.c_indexer[cond] 356 | except KeyError: 357 | print("Warning: {:} not in c_indexer".format(cond)) 358 | cid = self.c_indexer[self.nan_cond] 359 | 360 | # -3 for special tokens [CLS], [SEP], [SEP] 361 | num_truncated_a, _ = truncate_tokens_pair(tokens_a, tokens_b, self.max_len - 3, max_len_a=self.max_len_a, 362 | max_len_b=self.max_len_b, trunc_seg=self.trunc_seg, always_truncate_tail=self.always_truncate_tail) 363 | 364 | # Add Special Tokens 365 | if len(tokens_a) > 0: 366 | if (task_idx == 3) and self.s2s_special_token: # dial 367 | tokens = ['[S2S_CLS]'] + tokens_a + ['[S2S_SEP]'] + tokens_b + ['[SEP]'] 368 | else: 369 | tokens = ['[CLS]'] + tokens_a + ['[SEP]'] + tokens_b + ['[SEP]'] 370 | 371 | num_tokens_a = len(tokens_a) + 2 372 | num_tokens_b = len(tokens_b) + 1 373 | 374 | else: # text 375 | tokens = ['[CLS]'] + tokens_b + ['[SEP]'] 376 | num_tokens_a = 0 377 | num_tokens_b = len(tokens_b) + 2 378 | 379 | effective_length = len(tokens_b) 380 | # if (task_idx != 3) and self.mask_source_words: 381 | # effective_length += len(tokens_a) 382 | n_pred = min(self.max_pred, max( 383 | 1, int(round(effective_length*self.mask_prob)))) 384 | # candidate positions of masked tokens 385 | 386 | cand_pos_tk = [] 387 | special_pos = set() # will not be masked 388 | for i, tk in enumerate(tokens): 389 | if len(tokens_a) and (i >= len(tokens_a)+2) and (tk != '[CLS]'): # TODO: mask tokens_b (target sequence) 390 | # we will mask [SEP] as an ending symbol 391 | cand_pos_tk.append((i, tk)) 392 | 393 | elif (len(tokens_a) == 0) and (i >= 1) and (tk != '[CLS]') and (not tk.startswith('[SEP')): 394 | cand_pos_tk.append((i, tk)) 395 | 396 | else: 397 | special_pos.add(i) 398 | 399 | if self.only_mask_last: 400 | cand_pos_tk = [(len(tokens)-2, tokens[-2])] 401 | 402 | # *ZY* 403 | if cond != self.nan_cond: 404 | if task_idx == 1: 405 | cand_pos = self.tfidf_mask(cond, cand_pos_tk, n_pred) 406 | elif (task_idx == 3) and (self.dial_mask_rate > 0.01) and (rand() < self.dial_mask_rate): 407 | cand_pos = self.tfidf_mask(cond, cand_pos_tk, n_pred) 408 | else: 409 | cand_pos = [p[0] for p in cand_pos_tk] 410 | else: 411 | cand_pos = [p[0] for p in cand_pos_tk] 412 | 413 | if self.only_mask_last: 414 | masked_pos = [len(tokens) - 2] 415 | n_real_pred = 1 416 | else: 417 | shuffle(cand_pos) 418 | masked_pos = set() 419 | max_cand_pos = max(cand_pos) 420 | 421 | for pos in cand_pos: # Uniform Distribution Here 422 | if len(masked_pos) >= n_pred: 423 | break 424 | if pos in masked_pos: # Avoid Overlapping 425 | continue 426 | 427 | def _expand_whole_word(st, end): 428 | # because of using WordPiece 429 | new_st, new_end = st, end 430 | while (new_st >= 0) and tokens[new_st].startswith('##'): 431 | new_st -= 1 432 | while (new_end < len(tokens)) and tokens[new_end].startswith('##'): 433 | new_end += 1 434 | return new_st, new_end 435 | 436 | if (self.skipgram_prb > 0) and (self.skipgram_size >= 2) and (rand() < self.skipgram_prb): 437 | # ngram 438 | cur_skipgram_size = randint(2, self.skipgram_size) 439 | if self.mask_whole_word: 440 | st_pos, end_pos = _expand_whole_word( 441 | pos, pos + cur_skipgram_size) 442 | else: 443 | st_pos, end_pos = pos, pos + cur_skipgram_size 444 | else: 445 | # directly mask 446 | if self.mask_whole_word: 447 | st_pos, end_pos = _expand_whole_word(pos, pos + 1) 448 | else: 449 | st_pos, end_pos = pos, pos + 1 450 | 451 | for mp in range(st_pos, end_pos): 452 | if (0 < mp <= max_cand_pos) and (mp not in special_pos): 453 | masked_pos.add(mp) 454 | else: 455 | break 456 | 457 | masked_pos = list(masked_pos) 458 | n_real_pred = len(masked_pos) 459 | if n_real_pred > n_pred: 460 | shuffle(masked_pos) 461 | masked_pos = masked_pos[:n_pred] 462 | n_real_pred = n_pred 463 | 464 | masked_tokens = [tokens[pos] for pos in masked_pos] 465 | 466 | for pos in masked_pos: 467 | if self.only_mask_last or rand() < 0.8: # 80% 468 | tokens[pos] = '[MASK]' 469 | elif rand() < 0.5: # 10% 470 | tokens[pos] = get_random_word(self.vocab_words) 471 | 472 | # when n_pred < max_pred, we only calculate loss within n_pred 473 | masked_weights = [1]*len(masked_tokens) 474 | 475 | # Token Indexing 476 | masked_ids = self.indexer(masked_tokens) 477 | 478 | # Token Indexing 479 | input_ids = self.indexer(tokens) 480 | 481 | # Zero Padding 482 | n_pad = self.max_len - len(input_ids) 483 | input_ids.extend([0]*n_pad) 484 | 485 | mask_qkv = None 486 | 487 | is_next = 1 488 | 489 | if task_idx == 3: 490 | segment_ids = [0] * num_tokens_a + [1] * num_tokens_b 491 | input_mask = torch.zeros(self.max_len, self.max_len, dtype=torch.long) 492 | input_mask[:num_tokens_a, :num_tokens_a].fill_(1) 493 | tril = torch.tril(torch.ones((self.max_len, self.max_len), dtype=torch.long)) 494 | input_mask[num_tokens_a:, :] = tril[num_tokens_a:, :] 495 | 496 | elif task_idx == 1: # left-to-right 497 | segment_ids = [1] * len(tokens) 498 | input_mask = torch.tril(torch.ones((self.max_len, self.max_len), dtype=torch.long)) 499 | 500 | elif task_idx == 0: # bi-attn 501 | segment_ids = [0] * len(tokens) 502 | input_mask = torch.ones((self.max_len, self.max_len), dtype=torch.long) 503 | 504 | else: 505 | raise ValueError 506 | 507 | segment_ids.extend([0]*n_pad) 508 | 509 | # Zero Padding for masked target 510 | if self.max_pred > n_real_pred: 511 | n_pad = self.max_pred - n_real_pred 512 | if masked_ids is not None: 513 | masked_ids.extend([0]*n_pad) 514 | if masked_pos is not None: 515 | masked_pos.extend([0]*n_pad) 516 | if masked_weights is not None: 517 | masked_weights.extend([0]*n_pad) 518 | 519 | # print("tokens, ", tokens) 520 | # print("input_ids, ", input_ids) 521 | # print("segment_ids, ", segment_ids) 522 | # print("masked_ids, ", masked_ids) 523 | # print("masked_pos, ", masked_pos) 524 | # print("input_mask, ", input_mask) 525 | # exit() 526 | 527 | return (num_tokens_a, num_tokens_b, input_ids, cid, segment_ids, input_mask, mask_qkv, masked_ids, masked_pos, masked_weights, is_next, task_idx) 528 | 529 | def preprocess_FGfree(self, tokens_a, tokens_b, cond, task_idx): 530 | def _get_attn_mask(n_words, num_tokens_a, 531 | mask_pos_idx_map_sorted, 532 | task_idx): 533 | 534 | if task_idx == 3: 535 | input_mask = torch.zeros(self.max_len, self.max_len, dtype=torch.long) 536 | # Source 537 | input_mask[:num_tokens_a, :num_tokens_a].fill_(1) 538 | 539 | # Target 540 | tril = torch.tril(torch.ones((self.max_len, self.max_len), dtype=torch.long)) 541 | input_mask[num_tokens_a:, :] = tril[num_tokens_a:, :] 542 | 543 | elif task_idx == 1: 544 | input_mask = torch.tril(torch.ones((self.max_len, self.max_len), dtype=torch.long)) 545 | 546 | else: 547 | raise ValueError("do not support task_idx {:}".format(task_idx)) 548 | 549 | for i, (pos, idx) in enumerate(mask_pos_idx_map_sorted): 550 | input_mask[:, idx].fill_(0) 551 | input_mask[idx, idx].fill_(1) 552 | 553 | input_mask[n_words:, :].fill_(0) 554 | return input_mask 555 | 556 | # tokens_a = ['i', 'love', 'you'] 557 | # tokens_b = ['you', 'like', 'me'] 558 | # cond = ' ' 559 | # task_idx = 3 560 | 561 | try: 562 | cid = self.c_indexer[cond] 563 | except KeyError: 564 | print("Warning: {:} not in c_indexer".format(cond)) 565 | cid = self.c_indexer[self.nan_cond] 566 | 567 | effective_length = len(tokens_b) 568 | # if (task_idx != 3) and self.mask_source_words: 569 | # effective_length += len(tokens_a) 570 | n_pred = min(self.max_pred, max( 571 | 1, int(round(effective_length*self.mask_prob)))) 572 | # candidate positions of masked tokens 573 | 574 | # -3 for special tokens [CLS], [SEP], [SEP] 575 | num_truncated_a, _ = truncate_tokens_pair(tokens_a, tokens_b, self.max_len - 3 - n_pred, max_len_a=self.max_len_a, 576 | max_len_b=self.max_len_b, trunc_seg=self.trunc_seg, always_truncate_tail=self.always_truncate_tail) 577 | 578 | # Add Special Tokens 579 | if len(tokens_a) > 0: 580 | if (task_idx == 3) and self.s2s_special_token: # dial 581 | tokens = ['[S2S_CLS]'] + tokens_a + ['[S2S_SEP]'] + tokens_b + ['[SEP]'] 582 | else: # text 583 | tokens = ['[CLS]'] + tokens_a + ['[SEP]'] + tokens_b + ['[SEP]'] 584 | 585 | num_tokens_a = len(tokens_a) + 2 586 | num_tokens_b = len(tokens_b) + 1 587 | 588 | else: # text 589 | tokens = ['[CLS]'] + tokens_b + ['[SEP]'] 590 | num_tokens_a = 0 591 | num_tokens_b = len(tokens_b) + 2 592 | 593 | cand_pos_tk = [] 594 | special_pos = set() # will not be masked 595 | for i, tk in enumerate(tokens): 596 | if len(tokens_a) and (i >= len(tokens_a)+2) and (tk != '[CLS]'): # TODO: mask tokens_b (target sequence) 597 | # we will mask [SEP] as an ending symbol 598 | cand_pos_tk.append((i, tk)) 599 | 600 | elif (len(tokens_a) == 0) and (i >= 1) and (tk != '[CLS]') and (not tk.startswith('[SEP')): 601 | cand_pos_tk.append((i, tk)) 602 | 603 | else: 604 | special_pos.add(i) 605 | 606 | if self.only_mask_last: 607 | cand_pos_tk = [(len(tokens)-2, tokens[-2])] 608 | 609 | # *ZY* 610 | if cond != self.nan_cond: 611 | if task_idx == 1: 612 | cand_pos = self.tfidf_mask(cond, cand_pos_tk, n_pred) 613 | elif (task_idx == 3) and (self.dial_mask_rate > 0.01) and (rand() < self.dial_mask_rate): 614 | cand_pos = self.tfidf_mask(cond, cand_pos_tk, n_pred) 615 | else: 616 | cand_pos = [p[0] for p in cand_pos_tk] 617 | else: 618 | cand_pos = [p[0] for p in cand_pos_tk] 619 | 620 | shuffle(cand_pos) 621 | masked_pos = set() 622 | max_cand_pos = max(cand_pos) 623 | 624 | for pos in cand_pos: # Uniform Distribution Here 625 | if len(masked_pos) >= n_pred: 626 | break 627 | if pos in masked_pos: # Avoid Overlapping 628 | continue 629 | 630 | def _expand_whole_word(st, end): 631 | # because of using WordPiece 632 | new_st, new_end = st, end 633 | while (new_st >= 0) and tokens[new_st].startswith('##'): 634 | new_st -= 1 635 | while (new_end < len(tokens)) and tokens[new_end].startswith('##'): 636 | new_end += 1 637 | return new_st, new_end 638 | 639 | if (self.skipgram_prb > 0) and (self.skipgram_size >= 2) and (rand() < self.skipgram_prb): 640 | # ngram 641 | cur_skipgram_size = randint(2, self.skipgram_size) 642 | if self.mask_whole_word: 643 | st_pos, end_pos = _expand_whole_word( 644 | pos, pos + cur_skipgram_size) 645 | else: 646 | st_pos, end_pos = pos, pos + cur_skipgram_size 647 | else: 648 | # directly mask 649 | if self.mask_whole_word: 650 | st_pos, end_pos = _expand_whole_word(pos, pos + 1) 651 | else: 652 | st_pos, end_pos = pos, pos + 1 653 | 654 | for mp in range(st_pos, end_pos): 655 | if (0 < mp <= max_cand_pos) and (mp not in special_pos): 656 | masked_pos.add(mp) 657 | else: 658 | break 659 | 660 | masked_pos = list(masked_pos) 661 | n_real_pred = len(masked_pos) 662 | if n_real_pred > n_pred: 663 | shuffle(masked_pos) 664 | masked_pos = masked_pos[:n_pred] 665 | n_real_pred = n_pred 666 | 667 | masked_tokens = [tokens[pos] for pos in masked_pos] 668 | 669 | for pos in masked_pos: 670 | if rand() < 0.8: # 80% 671 | tokens[pos] = ('[MASK]', tokens[pos]) 672 | elif rand() < 0.5: # 10% 673 | tokens[pos] = (get_random_word(self.vocab_words), tokens[pos]) 674 | else: 675 | tokens[pos] = (tokens[pos], tokens[pos]) 676 | 677 | # when n_pred < max_pred, we only calculate loss within n_pred 678 | masked_weights = [1]*len(masked_tokens) 679 | 680 | # Token Indexing 681 | masked_ids = self.FGfree_indexer(masked_tokens) 682 | 683 | # Token Indexing 684 | # input_ids = self.indexer(tokens) 685 | input_ids, position_ids, mask_pos_idx_map = self.FGfree_indexer(tokens, ret_ids_only=False) 686 | mask_pos_idx_map_sorted = sorted(mask_pos_idx_map.items(), key=lambda p: p[1]) 687 | 688 | num_tokens_b += n_real_pred 689 | 690 | is_next = 1 691 | mask_qkv = None 692 | 693 | if task_idx == 3: 694 | segment_ids = [0] * num_tokens_a + [1] * num_tokens_b 695 | 696 | elif task_idx == 1: 697 | segment_ids = [1] * (num_tokens_a + num_tokens_b) 698 | 699 | elif task_idx == 0: 700 | segment_ids = [0] * (num_tokens_a + num_tokens_b) 701 | 702 | else: 703 | raise ValueError 704 | 705 | assert len(input_ids) == len(position_ids) 706 | assert len(input_ids) == len(segment_ids) 707 | 708 | n_words = len(input_ids) 709 | n_pad = self.max_len - n_words 710 | end_at = position_ids[-1] + 1 711 | 712 | # Zero Padding 713 | input_ids.extend([0]*n_pad) 714 | segment_ids.extend([0]*n_pad) 715 | position_ids.extend(list(range(end_at, end_at+n_pad))) 716 | 717 | assert len(input_ids) == len(position_ids) 718 | 719 | input_mask = _get_attn_mask(n_words, num_tokens_a, mask_pos_idx_map_sorted, task_idx) 720 | 721 | masked_pos = [mask_pos_idx_map[pos] for pos in masked_pos] 722 | 723 | # Zero Padding for masked target 724 | if self.max_pred > n_real_pred: 725 | n_pad = self.max_pred - n_real_pred 726 | if masked_ids is not None: 727 | masked_ids.extend([0]*n_pad) 728 | if masked_pos is not None: 729 | masked_pos.extend([0]*n_pad) 730 | if masked_weights is not None: 731 | masked_weights.extend([0]*n_pad) 732 | 733 | # print("tokens, ", tokens) 734 | # print("input_ids, ", input_ids) 735 | # print("segment_ids, ", segment_ids) 736 | # print("position_ids, ", position_ids) 737 | # print("masked_ids, ", masked_ids) 738 | # print("masked_pos, ", masked_pos) 739 | # print("input_mask, ", input_mask[:n_words+2, :n_words+2]) 740 | # exit() 741 | 742 | return (num_tokens_a, num_tokens_b, input_ids, cid, segment_ids, input_mask, mask_qkv, masked_ids, masked_pos, masked_weights, is_next, task_idx) 743 | 744 | def __call__(self, instance): 745 | tokens_a, tokens_b, cond, data_type = instance 746 | 747 | # print("instance: ", instance) 748 | 749 | if data_type == 'dial': 750 | task_idx = 3 # seq2seq 751 | elif data_type == 'mono': 752 | 753 | if len(tokens_a): # TODO: Notice Here! 754 | tokens_b = tokens_a + ['[SEP]'] + tokens_b 755 | tokens_a = [] 756 | 757 | if (rand() < 0.5) or (cond == ' '): 758 | task_idx = 1 # generation 759 | else: 760 | task_idx = 0 # bi-attn, encoding 761 | else: 762 | raise ValueError 763 | 764 | if (self.FGfree_indexer is None) or (task_idx == 0): 765 | return self.preprocess(tokens_a, tokens_b, cond, task_idx) 766 | else: 767 | return self.preprocess_FGfree(tokens_a, tokens_b, cond, task_idx) 768 | 769 | 770 | class Preprocess4Decoder(Pipeline): 771 | """ Pre-processing steps for pretraining transformer """ 772 | def __init__(self, vocab_words, indexer, max_len=512, max_tgt_length=128, new_segment_ids=False, mode="s2s", 773 | num_qkv=0, s2s_special_token=False, s2s_add_segment=False, s2s_share_segment=False, pos_shift=False, 774 | c_indexer=None): 775 | super().__init__() 776 | self.max_len = max_len 777 | self.vocab_words = vocab_words # vocabulary (sub)words 778 | self.indexer = indexer # function from token to token index 779 | self.max_len = max_len 780 | self._tril_matrix = torch.tril(torch.ones( 781 | (max_len, max_len), dtype=torch.long)) 782 | self.new_segment_ids = new_segment_ids 783 | self.task_idx = 3 # relax projection layer for different tasks 784 | assert mode in ("s2s", "l2r") 785 | self.mode = mode 786 | self.max_tgt_length = max_tgt_length 787 | self.num_qkv = num_qkv 788 | self.s2s_special_token = s2s_special_token 789 | self.s2s_add_segment = s2s_add_segment 790 | self.s2s_share_segment = s2s_share_segment 791 | self.pos_shift = pos_shift 792 | 793 | # *ZY* 794 | self.nan_cond = ' ' 795 | assert isinstance(c_indexer, dict) 796 | if self.nan_cond not in c_indexer.keys(): 797 | print('#'*10+'To add user, we re-arranged c_indexer (+1)'+'#'*10) 798 | sys.stdout.flush() 799 | self.c_indexer = {self.nan_cond: 0} 800 | for i, u in enumerate(c_indexer.keys()): 801 | self.c_indexer[u] = i + 1 802 | # Check 803 | assert sorted(list(self.c_indexer.values())) == list(range(len(self.c_indexer))) 804 | 805 | def __call__(self, instance): 806 | tokens_a, usrid, max_a_len = instance 807 | 808 | try: 809 | cid = self.c_indexer[usrid] 810 | except KeyError: 811 | print("Warning: {:} not in c_indexer".format(usrid)) 812 | cid = self.c_indexer[self.nan_cond] 813 | 814 | # Add Special Tokens 815 | if self.s2s_special_token: 816 | padded_tokens_a = ['[S2S_CLS]'] + tokens_a + ['[S2S_SEP]'] 817 | else: 818 | padded_tokens_a = ['[CLS]'] + tokens_a + ['[SEP]'] 819 | assert len(padded_tokens_a) <= max_a_len + 2 820 | if max_a_len + 2 > len(padded_tokens_a): 821 | padded_tokens_a += ['[PAD]'] * \ 822 | (max_a_len + 2 - len(padded_tokens_a)) 823 | assert len(padded_tokens_a) == max_a_len + 2 824 | max_len_in_batch = min(self.max_tgt_length + 825 | max_a_len + 2, self.max_len) 826 | tokens = padded_tokens_a 827 | 828 | segment_ids = [0]*(len(padded_tokens_a)) \ 829 | + [1]*(max_len_in_batch - len(padded_tokens_a)) 830 | 831 | if self.num_qkv > 1: 832 | mask_qkv = [0]*(len(padded_tokens_a)) + [1] * \ 833 | (max_len_in_batch - len(padded_tokens_a)) 834 | else: 835 | mask_qkv = None 836 | 837 | position_ids = [] 838 | for i in range(len(tokens_a) + 2): 839 | position_ids.append(i) 840 | for i in range(len(tokens_a) + 2, max_a_len + 2): 841 | position_ids.append(0) 842 | for i in range(max_a_len + 2, max_len_in_batch): 843 | position_ids.append(i - (max_a_len + 2) + len(tokens_a) + 2) 844 | 845 | # Token Indexing 846 | input_ids = self.indexer(tokens) 847 | 848 | # Zero Padding 849 | input_mask = torch.zeros( 850 | max_len_in_batch, max_len_in_batch, dtype=torch.long) 851 | if self.mode == "s2s": 852 | input_mask[:, :len(tokens_a)+2].fill_(1) 853 | else: 854 | st, end = 0, len(tokens_a) + 2 855 | input_mask[st:end, st:end].copy_( 856 | self._tril_matrix[:end, :end]) 857 | input_mask[end:, :len(tokens_a)+2].fill_(1) 858 | second_st, second_end = len(padded_tokens_a), max_len_in_batch 859 | 860 | input_mask[second_st:second_end, second_st:second_end].copy_( 861 | self._tril_matrix[:second_end-second_st, :second_end-second_st]) 862 | 863 | return (input_ids, cid, segment_ids, position_ids, input_mask, mask_qkv, self.task_idx) 864 | --------------------------------------------------------------------------------