├── models ├── pretrained_common │ ├── __init__.py │ ├── activations.py │ ├── optimization.py │ ├── configuration_utils.py │ ├── file_utils.py │ ├── tokenization_utils.py │ └── modeling_utils.py ├── __init__.py ├── bert_post_training.py ├── bert_base_cls.py ├── utils │ ├── scorer.py │ └── checkpointing.py └── bert │ ├── configuration_bert.py │ └── tokenization_bert.py ├── model_overview.jpg ├── resources ├── bert-base-uncased │ └── bert-base-uncased-config.json └── bert-post-uncased │ └── bert-post-uncased-config.json ├── scripts ├── download_post_checkpoints.sh └── download_datasets.sh ├── config └── hparams.py ├── README.md ├── data ├── data_utils.py ├── dataset.py └── create_bert_post_training_data.py ├── main.py ├── evaluation.py ├── train.py └── post_train.py /models/pretrained_common/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "2.8.0" 2 | -------------------------------------------------------------------------------- /model_overview.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taesunwhang/BERT-ResSel/HEAD/model_overview.jpg -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from models.bert_base_cls import BERTbase 2 | from models.bert_post_training import BertDomainPostTraining 3 | 4 | 5 | def Model(hparams, *args): 6 | name_model_map = { 7 | "bert_base_ft" : BERTbase, 8 | "bert_dpt_ft" : BERTbase, 9 | 10 | "bert_ubuntu_pt" : BertDomainPostTraining, 11 | } 12 | 13 | return name_model_map[hparams.model_type](hparams, *args) -------------------------------------------------------------------------------- /resources/bert-base-uncased/bert-base-uncased-config.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "BertForMaskedLM" 4 | ], 5 | "attention_probs_dropout_prob": 0.1, 6 | "hidden_act": "gelu", 7 | "hidden_dropout_prob": 0.1, 8 | "hidden_size": 768, 9 | "initializer_range": 0.02, 10 | "intermediate_size": 3072, 11 | "max_position_embeddings": 512, 12 | "num_attention_heads": 12, 13 | "num_hidden_layers": 12, 14 | "type_vocab_size": 2, 15 | "vocab_size": 30522 16 | } 17 | -------------------------------------------------------------------------------- /resources/bert-post-uncased/bert-post-uncased-config.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "BertForMaskedLM" 4 | ], 5 | "attention_probs_dropout_prob": 0.1, 6 | "hidden_act": "gelu", 7 | "hidden_dropout_prob": 0.1, 8 | "hidden_size": 768, 9 | "initializer_range": 0.02, 10 | "intermediate_size": 3072, 11 | "max_position_embeddings": 512, 12 | "num_attention_heads": 12, 13 | "num_hidden_layers": 12, 14 | "type_vocab_size": 2, 15 | "vocab_size": 30523 16 | } 17 | -------------------------------------------------------------------------------- /scripts/download_post_checkpoints.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | export file_name=bert-base-uncased-pytorch_model.bin 4 | if [ -f $PWD/resources/bert-base-uncased/$file_name ]; then 5 | echo "$file_name exists" 6 | else 7 | wget https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-pytorch_model.bin 8 | mv $file_name resources/bert-base-uncased/ 9 | fi 10 | 11 | export file_name=bert-post-uncased-pytorch_model.pth 12 | if [ -f $PWD/resources/bert-post-uncased/$file_name ]; then 13 | echo "$file_name exists" 14 | else 15 | echo "$file_name does not exist" 16 | export file_id=1jt0RhVT9y2d4AITn84kSOk06hjIv1y49 17 | 18 | wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id='$file_id -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id=$file_id" -O $file_name && rm -rf /tmp/cookies.txt 19 | mv $file_name resources/bert-post-uncased/ 20 | fi 21 | -------------------------------------------------------------------------------- /models/bert_post_training.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch.nn as nn 3 | 4 | from models.bert import modeling_bert, configuration_bert 5 | 6 | class BertDomainPostTraining(nn.Module): 7 | def __init__(self, hparams): 8 | super(BertDomainPostTraining, self).__init__() 9 | self.hparams = hparams 10 | bert_config = configuration_bert.BertConfig.from_pretrained( 11 | os.path.join(self.hparams.bert_pretrained_dir, self.hparams.bert_pretrained, 12 | "%s-config.json" % self.hparams.bert_pretrained), 13 | ) 14 | self._bert_model = modeling_bert.BertForPreTraining.from_pretrained( 15 | os.path.join(self.hparams.bert_pretrained_dir, self.hparams.bert_pretrained, self.hparams.bert_checkpoint_path), 16 | config=bert_config 17 | ) 18 | 19 | if self.hparams.do_eot: 20 | self._bert_model.resize_token_embeddings(self._bert_model.config.vocab_size + 1) # [EOT] 21 | 22 | def forward(self, batch): 23 | 24 | bert_outputs = self._bert_model( 25 | input_ids=batch["input_ids"], 26 | token_type_ids=batch["token_type_ids"], 27 | attention_mask=batch["attention_mask"], 28 | masked_lm_labels=batch["masked_lm_labels"], 29 | next_sentence_label=batch["next_sentence_labels"] 30 | ) 31 | mlm_loss, nsp_loss, prediction_scores, seq_relationship_score = bert_outputs[:4] 32 | 33 | return mlm_loss, nsp_loss -------------------------------------------------------------------------------- /scripts/download_datasets.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | export file_name=ubuntu.zip 4 | if [ -f $PWD/data/ubuntu_corpus_v1/ubuntu_train.pkl ]; then 5 | echo "ubuntu_train.pkl exists" 6 | else 7 | echo "ubuntu_train.pkl does not exist" 8 | export file_id=1VKQaNNC5NR-6TwVPpxYAVZQp_nwK3d5u 9 | 10 | wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id='$file_id -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id=$file_id" -O $file_name && rm -rf /tmp/cookies.txt 11 | unzip $file_name -d data/ubuntu_corpus_v1 12 | rm -r $file_name 13 | fi 14 | 15 | export file_name=ubuntu_post_training.txt 16 | if [ -f $PWD/data/ubuntu_corpus_v1/$file_name ]; then 17 | echo "$file_name exists" 18 | else 19 | echo "$file_name does not exist" 20 | export file_id=1mYS_PrnrKx4zDWOPTFhx_SeEwdumYXCK 21 | 22 | wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id='$file_id -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id=$file_id" -O $file_name && rm -rf /tmp/cookies.txt 23 | mv $file_name data/ubuntu_corpus_v1/ 24 | fi 25 | 26 | -------------------------------------------------------------------------------- /models/bert_base_cls.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | 5 | from models.bert import modeling_bert, configuration_bert 6 | 7 | class BERTbase(nn.Module): 8 | def __init__(self, hparams): 9 | super(BERTbase, self).__init__() 10 | self.hparams = hparams 11 | 12 | bert_config = configuration_bert.BertConfig.from_pretrained( 13 | os.path.join(self.hparams.bert_pretrained_dir, self.hparams.bert_pretrained, 14 | "%s-config.json" % self.hparams.bert_pretrained), 15 | ) 16 | self._bert_model = modeling_bert.BertModel.from_pretrained( 17 | os.path.join(self.hparams.bert_pretrained_dir, self.hparams.bert_pretrained, self.hparams.bert_checkpoint_path), 18 | config=bert_config 19 | ) 20 | 21 | if self.hparams.do_eot and self.hparams.model_type == "bert_base_ft": 22 | self._bert_model.resize_token_embeddings(self._bert_model.config.vocab_size + 1) # [EOT] 23 | 24 | self._classification = nn.Sequential( 25 | nn.Dropout(p=1 - self.hparams.dropout_keep_prob), 26 | nn.Linear(self.hparams.bert_hidden_dim, 1) 27 | ) 28 | 29 | def forward(self, batch): 30 | bert_outputs, _ = self._bert_model( 31 | batch["anno_sent"], 32 | token_type_ids=batch["segment_ids"], 33 | attention_mask=batch["attention_mask"] 34 | ) 35 | cls_logits = bert_outputs[:,0,:] # bs, bert_output_size 36 | logits = self._classification(cls_logits) # bs, 1 37 | logits = logits.squeeze(-1) 38 | 39 | return logits -------------------------------------------------------------------------------- /models/pretrained_common/activations.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import math 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | 7 | logger = logging.getLogger(__name__) 8 | 9 | def swish(x): 10 | return x * torch.sigmoid(x) 11 | 12 | 13 | def _gelu_python(x): 14 | """ Original Implementation of the gelu activation function in Google Bert repo when initially created. 15 | For information: OpenAI GPT's gelu is slightly different (and gives slightly different results): 16 | 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) 17 | This is now written in C in torch.nn.functional 18 | Also see https://arxiv.org/abs/1606.08415 19 | """ 20 | return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) 21 | 22 | 23 | def gelu_new(x): 24 | """ Implementation of the gelu activation function currently in Google Bert repo (identical to OpenAI GPT). 25 | Also see https://arxiv.org/abs/1606.08415 26 | """ 27 | return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0)))) 28 | 29 | 30 | if torch.__version__ < "1.4.0": 31 | gelu = _gelu_python 32 | else: 33 | gelu = F.gelu 34 | 35 | ACT2FN = { 36 | "relu": F.relu, 37 | "swish": swish, 38 | "gelu": gelu, 39 | "tanh": torch.tanh, 40 | "gelu_new": gelu_new, 41 | } 42 | 43 | 44 | def get_activation(activation_string): 45 | if activation_string in ACT2FN: 46 | return ACT2FN[activation_string] 47 | else: 48 | raise KeyError("function {} not found in ACT2FN mapping {}".format(activation_string, list(ACT2FN.keys()))) -------------------------------------------------------------------------------- /config/hparams.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | 3 | BASE_PARAMS = defaultdict( 4 | # lambda: None, # Set default value to None. 5 | # GPU params 6 | gpu_ids=[0], 7 | 8 | # Input params 9 | train_batch_size=8, 10 | eval_batch_size=100, 11 | virtual_batch_size=32, 12 | 13 | evaluate_candidates_num=10, 14 | recall_k_list=[1,2,5,10], 15 | 16 | # Training BERT params 17 | learning_rate=3e-5, 18 | dropout_keep_prob=0.8, 19 | num_epochs=10, 20 | max_gradient_norm=5, 21 | 22 | pad_idx=0, 23 | max_position_embeddings=100, 24 | num_hidden_layers=12, 25 | num_attention_heads=12, 26 | intermediate_size=3072, 27 | bert_hidden_dim=768, 28 | attention_probs_dropout_prob=0.1, 29 | layer_norm_eps=1e-12, 30 | 31 | # Train Model Config 32 | task_name="ubuntu", 33 | do_bert=True, 34 | do_eot=True, 35 | max_dialog_len=448, 36 | max_response_len=64, 37 | # summation -> 512 38 | 39 | # Need to change to train...(e.g.data dir, config dir, vocab dir, etc.) 40 | save_dirpath='checkpoints/', # /path/to/checkpoints 41 | 42 | bert_pretrained="bert-base-uncased", # should be defined here 43 | bert_checkpoint_path="bert-base-uncased-pytorch_model.bin", 44 | model_type="bert_base_ft", 45 | 46 | load_pthpath="", 47 | cpu_workers=8, 48 | tensorboard_step=1000, 49 | evaluate_print_step=100, 50 | ) 51 | 52 | DPT_FINETUNING_PARAMS = BASE_PARAMS.copy() 53 | DPT_FINETUNING_PARAMS.update( 54 | bert_checkpoint_path="bert-post-uncased-pytorch_model.pth", # should be defined here 55 | model_type="bert_dpt_ft" 56 | ) 57 | 58 | POST_TRAINING_PARAMS = BASE_PARAMS.copy() 59 | POST_TRAINING_PARAMS.update( 60 | num_epochs=3, 61 | # lambda: None, # Set default value to None. 62 | # GPU params 63 | gpu_ids=[0], 64 | 65 | # Input params 66 | train_batch_size=8, 67 | virtual_batch_size=512, 68 | tensorboard_step=100, 69 | 70 | checkpoint_save_step=2500, # virtual_batch -> 10000 step 71 | model_type="bert_ubuntu_pt", 72 | data_dir="./data/ubuntu_corpus_v1/ubuntu_post_training.hdf5", 73 | ) -------------------------------------------------------------------------------- /models/utils/scorer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | def calculate_candidates_ranking(prediction, ground_truth, eval_candidates_num=10): 4 | total_num_split = len(ground_truth) / eval_candidates_num 5 | 6 | pred_split = np.split(prediction, total_num_split) 7 | gt_split = np.split(np.array(ground_truth), total_num_split) 8 | orig_rank_split = np.split(np.tile(np.arange(0, eval_candidates_num), int(total_num_split)), total_num_split) 9 | stack_scores = np.stack((gt_split, pred_split, orig_rank_split), axis=-1) 10 | 11 | rank_by_pred_l = [] 12 | for i, stack_score in enumerate(stack_scores): 13 | rank_by_pred = sorted(stack_score, key=lambda x: x[1], reverse=True) 14 | rank_by_pred = np.stack(rank_by_pred, axis=-1) 15 | rank_by_pred_l.append(rank_by_pred[0]) 16 | 17 | return np.array(rank_by_pred_l) 18 | 19 | def logits_recall_at_k(rank_by_pred, k_list=[1, 2, 5, 10]): 20 | # 1 dialog, 10 response candidates ground truth 1 or 0 21 | # prediction_score : [batch_size] 22 | # target : [batch_size] e.g. 1 0 0 0 0 0 0 0 0 0 23 | # e.g. batch : 100 -> 100/10 = 10 24 | 25 | num_correct = np.zeros([rank_by_pred.shape[0], len(k_list)]) 26 | 27 | pos_index = [] 28 | for sorted_score in rank_by_pred: 29 | for p_i, score in enumerate(sorted_score): 30 | if int(score) == 1: 31 | pos_index.append(p_i) 32 | index_dict = dict() 33 | for i, p_i in enumerate(pos_index): 34 | index_dict[i] = p_i 35 | 36 | for i, p_i in enumerate(pos_index): 37 | for j, k in enumerate(k_list): 38 | if p_i + 1 <= k: 39 | num_correct[i][j] += 1 40 | 41 | return np.sum(num_correct, axis=0), pos_index 42 | 43 | def logits_mrr(rank_by_pred): 44 | pos_index = [] 45 | for sorted_score in rank_by_pred: 46 | for p_i, score in enumerate(sorted_score): 47 | if int(score) == 1: 48 | pos_index.append(p_i) 49 | 50 | # print("pos_index", pos_index) 51 | mrr = [] 52 | for i, p_i in enumerate(pos_index): 53 | mrr.append(1 / (p_i + 1)) 54 | 55 | # print(mrr) 56 | 57 | return np.sum(mrr) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | An Effective Domain Adaptive Post-Training Method for BERT in Response Selection 2 | ==================================== 3 | Implements the model described in the following paper [An Effective Domain Adaptive Post-Training Method for BERT in Response Selection](https://arxiv.org/abs/1908.04812v2). 4 | ``` 5 | @inproceedings{whang2020domain, 6 | author={Whang, Taesun and Lee, Dongyub and Lee, Chanhee and Yang, Kisu and Oh, Dongsuk and Lim, HeuiSeok}, 7 | title="An Effective Domain Adaptive Post-Training Method for BERT in Response Selection", 8 | year=2020, 9 | booktitle={Proc. Interspeech 2020} 10 | } 11 | ``` 12 | 13 | This code is reimplemented as a fork of [huggingface/transformers][7]. 14 | 15 |

16 | 17 |

18 | 19 | 20 | Data Creation 21 | -------- 22 | 1. Download `ubuntu_train.pkl, ubuntu_valid.pkl, ubuntu_test.pkl` [here][1] or you can create `pkl` files to train response selection model based on BERT model. 23 | If you wish to create pkl, download ubuntu_corpus_v1 dataset [here][2] provided by [Xu et al. (2016)](https://arxiv.org/pdf/1605.05110.pdf) and keep the files under `data/ubuntu_corpus_v1` directory. 24 | 2. Ubuntu corpus for domain post trianing will be created by running: 25 | ```shell 26 | python data/data_utils.py 27 | ``` 28 | 29 | Post Training Data Creation 30 | -------- 31 | Download `ubuntu_post_training.txt` corpus [here][3] and simply run 32 | ```shell 33 | python data/create_bert_post_training_data.py 34 | ``` 35 | After creating post_training data, keep `ubuntu_post_training.hdf5` file under `data/ubuntu_corpus_v1`directory. 36 | 37 | Domain Post Training BERT 38 | -------- 39 | To domain post-train BERT, simply run 40 | ```shell 41 | python main.py --model bert_ubuntu_pt --train_type post_training --bert_pretrained bert-base-uncased --data_dir ./data/ubuntu_corpus_v1/ubuntu_post_training.hdf5 42 | ``` 43 | 44 | BERT Fine-tuning (Response Selection) 45 | -------- 46 | ### Training 47 | Train a response selection model based on `BERT_base`: 48 | ```shell 49 | python main.py --model bert_base_ft --train_type fine_tuning --bert_pretrained bert-base-uncased 50 | ``` 51 | 52 | Train a response selection model based on `Domain post-trained BERT`. If you wish to get the domain post trained BERT, download model checkpoint (`bert-post-uncased-pytorch_model.pth`) [here][4], 53 | and keep checkpoint under `resources/bert-post-uncased` directory: 54 | ```shell 55 | python main.py --model bert_dpt_ft --train_type fine_tuning --bert_pretrained bert-post-uncased 56 | ``` 57 | 58 | ### Evaluation 59 | To evaluate `bert_base`,`bert_dpt` models, set a model checkpoint path and simply run 60 | ```shell 61 | python main.py --model bert_dpt_ft --train_type fine_tuning --bert_pretrained bert-post-uncased --evaluate /path/to/checkpoint.pth 62 | ``` 63 | If you wish to get the pre-trained response selection model, we provide the model checkpoints below. 64 | 65 | | Model | R@1 | R@2 | R@5 | MRR | 66 | |:---------:|:------:|:------:|:------:|:------:| 67 | | [BERT_base][5] | 0.8115 | 0.9003 | 0.9768 | 0.8809 | 68 | | [BERT_DPT][6] | 0.8515 | 0.9272 | 0.9851 | 0.9081 | 69 | 70 | 71 | Acknowledgements 72 | -------- 73 | - This work was supported by Institute for Information & communications Technology Promotion (IITP) grant funded by the Korea government (MSIT) (no. 2016-0-00010-003, Digital Centent InHouse R&D) 74 | - Work in collaboration with [Kakao Corp][8]. 75 | 76 | [1]: https://drive.google.com/drive/folders/1mLzXifYYwmlFEWDzSbbecLlzKstB8gQK?usp=sharing 77 | [2]: https://www.dropbox.com/s/2fdn26rj6h9bpvl/ubuntu_data.zip 78 | [3]: https://drive.google.com/file/d/1mYS_PrnrKx4zDWOPTFhx_SeEwdumYXCK/view?usp=sharing 79 | [4]: https://drive.google.com/file/d/1jt0RhVT9y2d4AITn84kSOk06hjIv1y49/view?usp=sharing 80 | [5]: https://drive.google.com/file/d/1amuPQ_CtfvNuQMdRR8eo0YGAQLP4XBP7/view?usp=sharing 81 | [6]: https://drive.google.com/file/d/1Ip_VqzpByWZRAgiN7OxPeyYxK6onPia0/view?usp=sharing 82 | [7]: https://github.com/huggingface/transformers 83 | [8]: https://www.kakaocorp.com 84 | -------------------------------------------------------------------------------- /data/data_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | sys.path.append(os.getcwd()) 4 | import pickle 5 | from tqdm import tqdm 6 | 7 | from models.bert import tokenization_bert 8 | 9 | class InputExamples(object): 10 | def __init__(self, utterances, response, label, seq_lengths): 11 | 12 | self.utterances = utterances 13 | self.response = response 14 | self.label = label 15 | 16 | self.dialog_len = seq_lengths[0] 17 | self.response_len = seq_lengths[1] 18 | 19 | class UbuntuDataUtils(object): 20 | def __init__(self, txt_path, bert_pretrained_dir): 21 | # bert_tokenizer init 22 | self.txt_path = txt_path 23 | self._bert_tokenizer_init(bert_pretrained_dir) 24 | 25 | def _bert_tokenizer_init(self, bert_pretrained_dir, bert_pretrained='bert-base-uncased'): 26 | 27 | self._bert_tokenizer = tokenization_bert.BertTokenizer( 28 | vocab_file=os.path.join(os.path.join(bert_pretrained_dir, bert_pretrained), 29 | "%s-vocab.txt" % bert_pretrained)) 30 | print("BERT tokenizer init completes") 31 | 32 | def read_raw_file(self, data_type): 33 | print("Loading raw txt file...") 34 | 35 | ubuntu_path = self.txt_path % data_type # train, dev, test 36 | with open(ubuntu_path, "r", encoding="utf8") as fr_handle: 37 | data = [line.strip() for line in fr_handle if len(line.strip()) > 0] 38 | print("(%s) total number of sentence : %d" % (data_type, len(data))) 39 | 40 | return data 41 | 42 | def make_post_training_corpus(self, data, post_training_path): 43 | with open(post_training_path, "w", encoding="utf-8") as fw_handle: 44 | cnt = 0 45 | for document in data: 46 | dialog_data = document.split("\t") 47 | if dialog_data[0] == '0': 48 | continue 49 | for utt in dialog_data[1:-1]: 50 | if len(utt) == 0: 51 | continue 52 | fw_handle.write(utt.strip() + "\n") 53 | fw_handle.write("\n") 54 | cnt+=1 55 | 56 | def make_examples_pkl(self, data, ubuntu_pkl_path): 57 | with open(ubuntu_pkl_path, "ab") as pkl_handle: 58 | for dialog in tqdm(data): 59 | dialog_data = dialog.split("\t") 60 | label = dialog_data[0] 61 | utterances = [] 62 | dialog_len = [] 63 | 64 | for utt in dialog_data[1:-1]: 65 | utt_tok = self._bert_tokenizer.tokenize(utt) 66 | utterances.append(utt_tok) 67 | dialog_len.append(len(utt_tok)) 68 | 69 | response = self._bert_tokenizer.tokenize(dialog_data[-1]) 70 | 71 | pickle.dump(InputExamples( 72 | utterances=utterances, response=response, label=int(label), 73 | seq_lengths=(dialog_len, len(response))), pkl_handle) 74 | 75 | print(ubuntu_pkl_path, " save completes!") 76 | 77 | def ubuntu_manual(self): 78 | knowledge_path = "ubuntu_manual_knowledge.txt" 79 | ubuntu_knowledge_dict = dict() 80 | ubuntu_man_l = [] 81 | 82 | with open(knowledge_path, "r", encoding="utf-8") as f_handle: 83 | for line in f_handle: 84 | ubuntu_man = line.strip().split("\t") 85 | if len(ubuntu_man) == 2: 86 | ubuntu_knowledge_dict[ubuntu_man[0]] = ubuntu_man[1] 87 | ubuntu_man_l.append(ubuntu_man[1]) 88 | 89 | print(ubuntu_knowledge_dict.keys()) 90 | print(len(ubuntu_knowledge_dict)) 91 | 92 | if __name__ == '__main__': 93 | ubuntu_raw_path = "./data/ubuntu_corpus_v1/%s.txt" 94 | ubuntu_pkl_path = "./data/ubuntu_corpus_v1/ubuntu_%s.pkl" 95 | bert_pretrained_dir = "./resources" 96 | 97 | ubuntu_utils = UbuntuDataUtils(ubuntu_raw_path, bert_pretrained_dir) 98 | 99 | # response seleciton fine-tuning pkl creation 100 | for data_type in ["train", "valid", "test"]: 101 | data = ubuntu_utils.read_raw_file(data_type) 102 | ubuntu_utils.make_examples_pkl(data, ubuntu_pkl_path % data_type) 103 | 104 | # domain post training corpus creation 105 | for data_type in ["train"]: 106 | data = ubuntu_utils.read_raw_file(data_type) 107 | ubuntu_utils.make_post_training_corpus(data, "./data/ubuntu_corpus_v1/ubuntu_post_training.txt") 108 | 109 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import collections 4 | import logging 5 | from datetime import datetime 6 | 7 | from config.hparams import * 8 | from data.data_utils import InputExamples 9 | from train import ResponseSelection 10 | from evaluation import Evaluation 11 | from post_train import BERTDomainPostTraining 12 | 13 | PARAMS_MAP = { 14 | # fine-tuning (ft) 15 | "bert_base_ft" : BASE_PARAMS, 16 | "bert_dpt_ft" : DPT_FINETUNING_PARAMS, 17 | 18 | # post-training (pt) 19 | "bert_ubuntu_pt" : POST_TRAINING_PARAMS, 20 | } 21 | 22 | MODEL = { 23 | "fine_tuning" : ResponseSelection, 24 | "post_training" : BERTDomainPostTraining 25 | } 26 | 27 | def init_logger(path:str): 28 | if not os.path.exists(path): 29 | os.makedirs(path) 30 | logger = logging.getLogger() 31 | logger.handlers = [] 32 | logger.setLevel(logging.DEBUG) 33 | debug_fh = logging.FileHandler(os.path.join(path, "debug.log")) 34 | debug_fh.setLevel(logging.DEBUG) 35 | 36 | info_fh = logging.FileHandler(os.path.join(path, "info.log")) 37 | info_fh.setLevel(logging.INFO) 38 | 39 | ch = logging.StreamHandler() 40 | ch.setLevel(logging.INFO) 41 | 42 | info_formatter = logging.Formatter('%(asctime)s | %(levelname)-8s | %(message)s') 43 | debug_formatter = logging.Formatter('%(asctime)s | %(levelname)-8s | %(message)s | %(lineno)d:%(funcName)s') 44 | 45 | ch.setFormatter(info_formatter) 46 | info_fh.setFormatter(info_formatter) 47 | debug_fh.setFormatter(debug_formatter) 48 | 49 | logger.addHandler(ch) 50 | logger.addHandler(debug_fh) 51 | logger.addHandler(info_fh) 52 | 53 | return logger 54 | 55 | def train_model(args): 56 | hparams = PARAMS_MAP[args.model] 57 | hparams["root_dir"] = args.root_dir 58 | hparams["bert_pretrained_dir"] = args.bert_pretrained_dir 59 | hparams["bert_pretrained"] = args.bert_pretrained 60 | hparams["data_dir"] = args.data_dir 61 | hparams["model_type"] = args.model 62 | 63 | timestamp = datetime.now().strftime('%Y%m%d-%H%M%S') 64 | root_dir = os.path.join(hparams["root_dir"], args.model, args.train_type, "%s/" % timestamp) 65 | logger = init_logger(root_dir) 66 | logger.info("Hyper-parameters: %s" % str(hparams)) 67 | hparams["root_dir"] = root_dir 68 | 69 | hparams = collections.namedtuple("HParams", sorted(hparams.keys()))(**hparams) 70 | model = MODEL[args.train_type](hparams) 71 | model.train() 72 | 73 | def evaluate_model(args): 74 | hparams = PARAMS_MAP[args.model] 75 | 76 | hparams = collections.namedtuple("HParams", sorted(hparams.keys()))(**hparams) 77 | 78 | model = Evaluation(hparams) 79 | model.run_evaluate(args.evaluate) 80 | 81 | if __name__ == '__main__': 82 | arg_parser = argparse.ArgumentParser(description="Bert / Response Selection (PyTorch)") 83 | arg_parser.add_argument("--model", dest="model", type=str, 84 | default="bert_base", 85 | help="Model Name") 86 | arg_parser.add_argument("--root_dir", dest="root_dir", type=str, 87 | default="./results", 88 | help="model train logs, checkpoints") 89 | arg_parser.add_argument("--data_dir", dest="data_dir", type=str, 90 | default="./data/ubuntu_corpus_v1/%s_%s.pkl", 91 | help="ubuntu corpus v1 pkl path") # ubuntu_train.pkl, ubuntu_valid_pkl, ubuntu_test.pkl 92 | arg_parser.add_argument("--bert_pretrained_dir", dest="bert_pretrained_dir", type=str, 93 | default="./resources", 94 | help="bert pretrained directory") 95 | arg_parser.add_argument("--bert_pretrained", dest="bert_pretrained", type=str, 96 | default="bert-post-uncased", 97 | help="bert pretrained directory") # bert-base-uncased, bert-post-uncased -> under bert_pretrained_dir 98 | arg_parser.add_argument("--train_type", dest="train_type", type=str, 99 | default="fine_tuning", 100 | help="Train type") # fine_tuning, post_training 101 | arg_parser.add_argument("--evaluate", dest="evaluate", type=str, 102 | help="Evaluation Checkpoint", default="") 103 | 104 | args = arg_parser.parse_args() 105 | if args.evaluate: 106 | evaluate_model(args) 107 | else: 108 | train_model(args) -------------------------------------------------------------------------------- /evaluation.py: -------------------------------------------------------------------------------- 1 | import os 2 | # os.environ["CUDA_VISIBLE_DEVICES"] = "1" 3 | import logging 4 | from tqdm import tqdm 5 | import numpy as np 6 | 7 | import torch 8 | import torch.nn as nn 9 | from torch.utils.data import DataLoader 10 | 11 | from models import Model 12 | from data.dataset import ResponseSelectionDataset 13 | from models.utils.checkpointing import load_checkpoint 14 | from models.utils.scorer import calculate_candidates_ranking, logits_mrr, logits_recall_at_k 15 | 16 | class Evaluation(object): 17 | def __init__(self, hparams, model=None, split = "test"): 18 | 19 | self.hparams = hparams 20 | self.model = model 21 | self._logger = logging.getLogger(__name__) 22 | self.device = (torch.device("cuda", self.hparams.gpu_ids[0]) 23 | if self.hparams.gpu_ids[0] >= 0 else torch.device("cpu")) 24 | self.split = split 25 | print("Evaluation Split :", self.split) 26 | do_valid, do_test = False, False 27 | if split == "valid": 28 | do_valid = True 29 | else: 30 | do_test = True 31 | self._build_dataloader(do_valid=do_valid, do_test=do_test) 32 | self._dataloader = self.valid_dataloader if split == 'valid' else self.test_dataloader 33 | 34 | if model is None: 35 | print("No pre-defined model!") 36 | self._build_model() 37 | 38 | def _build_dataloader(self, do_valid=False, do_test=False): 39 | 40 | if do_valid: 41 | self.valid_dataset = ResponseSelectionDataset( 42 | self.hparams, 43 | split="valid", 44 | ) 45 | self.valid_dataloader = DataLoader( 46 | self.valid_dataset, 47 | batch_size=self.hparams.eval_batch_size, 48 | num_workers=self.hparams.cpu_workers, 49 | drop_last=False, 50 | ) 51 | 52 | if do_test: 53 | self.test_dataset = ResponseSelectionDataset( 54 | self.hparams, 55 | split="test", 56 | ) 57 | 58 | self.test_dataloader = DataLoader( 59 | self.test_dataset, 60 | batch_size=self.hparams.eval_batch_size, 61 | num_workers=self.hparams.cpu_workers, 62 | drop_last=False, 63 | ) 64 | 65 | def _build_model(self): 66 | self.model = Model(self.hparams) 67 | self.model = self.model.to(self.device) 68 | # Use Multi-GPUs 69 | if -1 not in self.hparams.gpu_ids and len(self.hparams.gpu_ids) > 1: 70 | self.model = nn.DataParallel(self.model, self.hparams.gpu_ids) 71 | 72 | def run_evaluate(self, evaluation_path): 73 | self._logger.info("Evaluation") 74 | model_state_dict, optimizer_state_dict = load_checkpoint(evaluation_path) 75 | 76 | if isinstance(self.model, nn.DataParallel): 77 | self.model.module.load_state_dict(model_state_dict) 78 | else: 79 | self.model.load_state_dict(model_state_dict) 80 | 81 | k_list = self.hparams.recall_k_list 82 | total_mrr = 0 83 | total_examples, total_correct = 0, 0 84 | self.model.eval() 85 | with torch.no_grad(): 86 | for batch_idx, batch in enumerate(tqdm(self._dataloader)): 87 | buffer_batch = batch.copy() 88 | for key in batch: 89 | buffer_batch[key] = batch[key].to(self.device) 90 | 91 | logits = self.model(buffer_batch) 92 | pred = torch.sigmoid(logits).to("cpu").tolist() # bs 93 | 94 | rank_by_pred = calculate_candidates_ranking(np.array(pred), np.array(buffer_batch["label"].to("cpu").tolist()), 95 | self.hparams.evaluate_candidates_num) 96 | num_correct, pos_index = logits_recall_at_k(rank_by_pred, k_list) 97 | 98 | total_mrr += logits_mrr(rank_by_pred) 99 | 100 | total_correct = np.add(total_correct, num_correct) 101 | total_examples = (batch_idx + 1) * rank_by_pred.shape[0] 102 | 103 | recall_result = "" 104 | if (batch_idx + 1) % self.hparams.evaluate_print_step == 0: 105 | for i in range(len(k_list)): 106 | recall_result += "Recall@%s : " % k_list[i] + "%.2f%% | " % ((total_correct[i] / total_examples) * 100) 107 | else: 108 | print("%d[th] | %s | MRR : %.3f" % (batch_idx + 1, recall_result, float(total_mrr / total_examples))) 109 | self._logger.info("%d[th] | %s | MRR : %.3f" % (batch_idx + 1, recall_result, float(total_mrr / total_examples))) 110 | 111 | avg_mrr = float(total_mrr / total_examples) 112 | recall_result = "" 113 | 114 | for i in range(len(k_list)): 115 | recall_result += "Recall@%s : " % k_list[i] + "%.2f%% | " % ((total_correct[i] / total_examples) * 100) 116 | self._logger.info(recall_result) 117 | self._logger.info("MRR: %.4f" % avg_mrr) -------------------------------------------------------------------------------- /models/utils/checkpointing.py: -------------------------------------------------------------------------------- 1 | """ 2 | A checkpoint manager periodically saves model and optimizer as .pth 3 | files during training. 4 | Checkpoint managers help with experiment reproducibility, they record 5 | the commit SHA of your current codebase in the checkpoint saving 6 | directory. While loading any checkpoint from other commit, they raise a 7 | friendly warning, a signal to inspect commit diffs for potential bugs. 8 | Moreover, they copy experiment hyper-parameters as a YAML config in 9 | this directory. 10 | That said, always run your experiments after committing your changes, 11 | this doesn't account for untracked or staged, but uncommitted changes. 12 | """ 13 | from pathlib import Path 14 | from subprocess import PIPE, Popen 15 | import warnings 16 | 17 | import torch 18 | from torch import nn, optim 19 | import json 20 | 21 | class CheckpointManager(object): 22 | """A checkpoint manager saves state dicts of model and optimizer 23 | as .pth files in a specified directory. This class closely follows 24 | the API of PyTorch optimizers and learning rate schedulers. 25 | Note:: 26 | For ``DataParallel`` modules, ``model.module.state_dict()`` is 27 | saved, instead of ``model.state_dict()``. 28 | Parameters 29 | ---------- 30 | model: nn.Module 31 | Wrapped model, which needs to be checkpointed. 32 | optimizer: optim.Optimizer 33 | Wrapped optimizer which needs to be checkpointed. 34 | checkpoint_dirpath: str 35 | Path to an empty or non-existent directory to save checkpoints. 36 | step_size: int, optional (default=1) 37 | Period of saving checkpoints. 38 | last_epoch: int, optional (default=-1) 39 | The index of last epoch. 40 | Example 41 | -------- 42 | >>> model = torch.nn.Linear(10, 2) 43 | >>> optimizer = torch.optim.Adam(model.parameters()) 44 | >>> ckpt_manager = CheckpointManager(model, optimizer, "/tmp/ckpt") 45 | >>> for epoch in range(20): 46 | ... for batch in dataloader: 47 | ... do_iteration(batch) 48 | ... ckpt_manager.step() 49 | """ 50 | 51 | def __init__( 52 | self, 53 | model, 54 | optimizer, 55 | checkpoint_dirpath, 56 | step_size=1, 57 | last_epoch=-1, 58 | **kwargs, 59 | ): 60 | 61 | if not isinstance(model, nn.Module): 62 | raise TypeError("{} is not a Module".format(type(model).__name__)) 63 | 64 | if not isinstance(optimizer, optim.Optimizer): 65 | raise TypeError( 66 | "{} is not an Optimizer".format(type(optimizer).__name__) 67 | ) 68 | 69 | self.model = model 70 | self.optimizer = optimizer 71 | self.ckpt_dirpath = Path(checkpoint_dirpath) 72 | self.step_size = step_size 73 | self.last_epoch = last_epoch 74 | self.init_directory(**kwargs) 75 | 76 | def init_directory(self, hparams): 77 | """Initialize empty checkpoint directory and record commit SHA 78 | in it. Also save hyper-parameters config in this directory to 79 | associate checkpoints with their hyper-parameters. 80 | """ 81 | 82 | self.ckpt_dirpath.mkdir(parents=True, exist_ok=True) 83 | # save current git commit hash in this checkpoint directory 84 | commit_sha_subprocess = Popen( 85 | ["git", "rev-parse", "--short", "HEAD"], stdout=PIPE, stderr=PIPE 86 | ) 87 | commit_sha, _ = commit_sha_subprocess.communicate() 88 | commit_sha = commit_sha.decode("utf-8").strip().replace("\n", "") 89 | commit_sha_filepath = self.ckpt_dirpath / f".commit-{commit_sha}" 90 | commit_sha_filepath.touch() 91 | with open(str(self.ckpt_dirpath / "hparams.json"), 'w') as hparams_handle: 92 | json.dump(hparams, hparams_handle) 93 | 94 | def step(self, epoch=None): 95 | """Save checkpoint if step size conditions meet. """ 96 | 97 | if not epoch: 98 | epoch = self.last_epoch + 1 99 | self.last_epoch = epoch 100 | 101 | if not self.last_epoch % self.step_size: 102 | torch.save( 103 | { 104 | "model": self._model_state_dict(), 105 | "optimizer": self.optimizer.state_dict(), 106 | }, 107 | self.ckpt_dirpath / f"checkpoint_{self.last_epoch}.pth", 108 | ) 109 | 110 | def _model_state_dict(self): 111 | """Returns state dict of model, taking care of DataParallel case.""" 112 | if isinstance(self.model, nn.DataParallel): 113 | return self.model.module.state_dict() 114 | else: 115 | return self.model.state_dict() 116 | 117 | 118 | def load_checkpoint(checkpoint_pthpath): 119 | """Given a path to saved checkpoint, load corresponding state dicts 120 | of model and optimizer from it. This method checks if the current 121 | commit SHA of codebase matches the commit SHA recorded when this 122 | checkpoint was saved by checkpoint manager. 123 | Parameters 124 | ---------- 125 | checkpoint_pthpath: str or pathlib.Path 126 | Path to saved checkpoint (as created by ``CheckpointManager``). 127 | Returns 128 | ------- 129 | nn.Module, optim.Optimizer 130 | Model and optimizer state dicts loaded from checkpoint. 131 | Raises 132 | ------ 133 | UserWarning 134 | If commit SHA do not match, or if the directory doesn't have 135 | the recorded commit SHA. 136 | """ 137 | 138 | if isinstance(checkpoint_pthpath, str): 139 | checkpoint_pthpath = Path(checkpoint_pthpath) 140 | checkpoint_dirpath = checkpoint_pthpath.resolve().parent 141 | checkpoint_commit_sha = list(checkpoint_dirpath.glob(".commit-*")) 142 | 143 | if len(checkpoint_commit_sha) == 0: 144 | warnings.warn( 145 | "Commit SHA was not recorded while saving checkpoints." 146 | ) 147 | else: 148 | # verify commit sha, raise warning if it doesn't match 149 | commit_sha_subprocess = Popen( 150 | ["git", "rev-parse", "--short", "HEAD"], stdout=PIPE, stderr=PIPE 151 | ) 152 | commit_sha, _ = commit_sha_subprocess.communicate() 153 | commit_sha = commit_sha.decode("utf-8").strip().replace("\n", "") 154 | 155 | # remove ".commit-" 156 | checkpoint_commit_sha = checkpoint_commit_sha[0].name[8:] 157 | 158 | if commit_sha != checkpoint_commit_sha: 159 | warnings.warn( 160 | f"Current commit ({commit_sha}) and the commit " 161 | f"({checkpoint_commit_sha}) at which checkpoint was saved," 162 | " are different. This might affect reproducibility." 163 | ) 164 | 165 | # load encoder, decoder, optimizer state_dicts 166 | components = torch.load(checkpoint_pthpath) 167 | 168 | return components["model"], components["optimizer"] -------------------------------------------------------------------------------- /data/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import pickle 4 | import h5py 5 | import numpy as np 6 | 7 | from models.bert import tokenization_bert 8 | from torch.utils.data import Dataset 9 | 10 | class ResponseSelectionDataset(Dataset): 11 | """ 12 | A full representation of VisDial v1.0 (train/val/test) dataset. According 13 | to the appropriate split, it returns dictionary of question, image, 14 | history, ground truth answer, answer options, dense annotations etc. 15 | """ 16 | def __init__( 17 | self, 18 | hparams, 19 | split: str = "", 20 | ): 21 | super().__init__() 22 | 23 | self.hparams = hparams 24 | self.split = split 25 | 26 | # read pkls -> Input Examples 27 | self.input_examples = [] 28 | with open(hparams.data_dir % (hparams.task_name, split), "rb") as pkl_handle: 29 | while True: 30 | try: 31 | self.input_examples.append(pickle.load(pkl_handle)) 32 | if len(self.input_examples) % 100000 == 0: 33 | print("%d examples has been loaded!" % len(self.input_examples)) 34 | except EOFError: 35 | break 36 | 37 | print("total %s examples" % split, len(self.input_examples)) 38 | 39 | bert_pretrained_dir = os.path.join(self.hparams.bert_pretrained_dir, self.hparams.bert_pretrained) 40 | print(bert_pretrained_dir) 41 | self._bert_tokenizer = tokenization_bert.BertTokenizer( 42 | vocab_file=os.path.join(bert_pretrained_dir, "%s-vocab.txt" % self.hparams.bert_pretrained)) 43 | 44 | # End of Turn Token 45 | if self.hparams.do_eot: 46 | self._bert_tokenizer.add_tokens(["[EOT]"]) 47 | 48 | def __len__(self): 49 | return len(self.input_examples) 50 | 51 | def __getitem__(self, index): 52 | # Get Input Examples 53 | """ 54 | InputExamples 55 | self.utterances = utterances 56 | self.response = response 57 | self.label 58 | """ 59 | 60 | anno_sent, segment_ids, attention_mask = self._annotate_sentence(self.input_examples[index]) 61 | 62 | current_feature = dict() 63 | current_feature["anno_sent"] = torch.tensor(anno_sent).long() 64 | current_feature["segment_ids"] = torch.tensor(segment_ids).long() 65 | current_feature["attention_mask"] = torch.tensor(attention_mask).long() 66 | current_feature["label"] = torch.tensor(self.input_examples[index].label).float() 67 | 68 | return current_feature 69 | 70 | def _annotate_sentence(self, example): 71 | 72 | dialog_context = [] 73 | if self.hparams.do_eot: 74 | for utt in example.utterances: 75 | dialog_context.extend(utt + ["[EOT]"]) 76 | else: 77 | for utt in example.utterances: 78 | dialog_context.extend(utt) 79 | 80 | # Set Dialog Context length to 280, Response length to 40 81 | dialog_context, response = self._max_len_trim_seq(dialog_context, example.response) 82 | 83 | # dialog context 84 | dialog_context = ["[CLS]"] + dialog_context + ["[SEP]"] 85 | segment_ids = [0] * self.hparams.max_dialog_len 86 | attention_mask = [1] * len(dialog_context) 87 | 88 | while len(dialog_context) < self.hparams.max_dialog_len: 89 | dialog_context.append("[PAD]") 90 | attention_mask.append(0) 91 | 92 | assert len(dialog_context) == len(segment_ids) == len(attention_mask) 93 | 94 | response = response + ["[SEP]"] 95 | segment_ids.extend([1] * len(response)) 96 | attention_mask.extend([1] * len(response)) 97 | 98 | while len(response) < self.hparams.max_response_len: 99 | response.append("[PAD]") 100 | segment_ids.append(0) 101 | attention_mask.append(0) 102 | 103 | dialog_response = dialog_context + response 104 | 105 | # print(segment_ids) 106 | # print(attention_mask) 107 | # print(len(dialog_response), len(segment_ids), len(attention_mask)) 108 | 109 | assert len(dialog_response) == len(segment_ids) == len(attention_mask) 110 | anno_sent = self._bert_tokenizer.convert_tokens_to_ids(dialog_response) 111 | 112 | return anno_sent, segment_ids, attention_mask 113 | 114 | def _max_len_trim_seq(self, dialog_context, response): 115 | while len(dialog_context) > self.hparams.max_dialog_len - 2: 116 | dialog_context.pop(0) # from the front 117 | 118 | while len(response) > self.hparams.max_response_len - 1: 119 | response.pop() # from the back 120 | 121 | return dialog_context, response 122 | 123 | class BertPostTrainingDataset(Dataset): 124 | """ 125 | A full representation of VisDial v1.0 (train/val/test) dataset. According 126 | to the appropriate split, it returns dictionary of question, image, 127 | history, ground truth answer, answer options, dense annotations etc. 128 | """ 129 | 130 | def __init__( 131 | self, 132 | hparams, 133 | split: str = "", 134 | ): 135 | super().__init__() 136 | 137 | self.hparams = hparams 138 | self.split = split 139 | 140 | with h5py.File(self.hparams.data_dir, "r") as features_hdf: 141 | self.feature_keys = list(features_hdf.keys()) 142 | self.num_instances = np.array(features_hdf.get("next_sentence_labels")).shape[0] 143 | print("total %s examples : %d" % (split, self.num_instances)) 144 | 145 | def __len__(self): 146 | return self.num_instances 147 | 148 | def __getitem__(self, index): 149 | # Get Input Examples 150 | """ 151 | InputExamples 152 | self.utterances = utterances 153 | self.response = response 154 | self.label 155 | """ 156 | features = self._read_hdf_features(index) 157 | anno_masked_lm_labels = self._anno_mask_inputs(features["masked_lm_ids"], features["masked_lm_positions"]) 158 | curr_features = dict() 159 | for feat_key in features.keys(): 160 | curr_features[feat_key] = torch.tensor(features[feat_key]).long() 161 | curr_features["masked_lm_labels"] = torch.tensor(anno_masked_lm_labels).long() 162 | 163 | return curr_features 164 | 165 | def _read_hdf_features(self, index): 166 | features = {} 167 | with h5py.File(self.hparams.data_dir, "r") as features_hdf: 168 | for f_key in self.feature_keys: 169 | features[f_key] = features_hdf[f_key][index] 170 | 171 | return features 172 | 173 | def _anno_mask_inputs(self, masked_lm_ids, masked_lm_positions, max_seq_len=512): 174 | # masked_lm_ids -> labels 175 | anno_masked_lm_labels = [-1] * max_seq_len 176 | 177 | for pos, label in zip(masked_lm_positions, masked_lm_ids): 178 | if pos == 0: continue 179 | anno_masked_lm_labels[pos] = label 180 | 181 | return anno_masked_lm_labels -------------------------------------------------------------------------------- /models/bert/configuration_bert.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ BERT model configuration """ 17 | 18 | from __future__ import absolute_import, division, print_function, unicode_literals 19 | 20 | import json 21 | import logging 22 | import sys 23 | from io import open 24 | 25 | from models.pretrained_common.configuration_utils import PretrainedConfig 26 | 27 | logger = logging.getLogger(__name__) 28 | 29 | BERT_PRETRAINED_CONFIG_ARCHIVE_MAP = { 30 | 'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-config.json", 31 | 'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-config.json", 32 | 'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-config.json", 33 | 'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-config.json", 34 | 'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-config.json", 35 | 'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-config.json", 36 | 'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-config.json", 37 | 'bert-base-german-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-cased-config.json", 38 | 'bert-large-uncased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-config.json", 39 | 'bert-large-cased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-config.json", 40 | 'bert-large-uncased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-finetuned-squad-config.json", 41 | 'bert-large-cased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad-config.json", 42 | 'bert-base-cased-finetuned-mrpc': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-config.json", 43 | } 44 | 45 | 46 | class BertConfig(PretrainedConfig): 47 | r""" 48 | :class:`~pytorch_transformers.BertConfig` is the configuration class to store the configuration of a 49 | `BertModel`. 50 | Arguments: 51 | vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `BertModel`. 52 | hidden_size: Size of the encoder layers and the pooler layer. 53 | num_hidden_layers: Number of hidden layers in the Transformer encoder. 54 | num_attention_heads: Number of attention heads for each attention layer in 55 | the Transformer encoder. 56 | intermediate_size: The size of the "intermediate" (i.e., feed-forward) 57 | layer in the Transformer encoder. 58 | hidden_act: The non-linear activation function (function or string) in the 59 | encoder and pooler. If string, "gelu", "relu" and "swish" are supported. 60 | hidden_dropout_prob: The dropout probabilitiy for all fully connected 61 | layers in the embeddings, encoder, and pooler. 62 | attention_probs_dropout_prob: The dropout ratio for the attention 63 | probabilities. 64 | max_position_embeddings: The maximum sequence length that this model might 65 | ever be used with. Typically set this to something large just in case 66 | (e.g., 512 or 1024 or 2048). 67 | type_vocab_size: The vocabulary size of the `token_type_ids` passed into 68 | `BertModel`. 69 | initializer_range: The sttdev of the truncated_normal_initializer for 70 | initializing all weight matrices. 71 | layer_norm_eps: The epsilon used by LayerNorm. 72 | """ 73 | pretrained_config_archive_map = BERT_PRETRAINED_CONFIG_ARCHIVE_MAP 74 | 75 | def __init__(self, 76 | vocab_size_or_config_json_file=30522, 77 | hidden_size=768, 78 | num_hidden_layers=12, 79 | num_attention_heads=12, 80 | intermediate_size=3072, 81 | hidden_act="gelu", 82 | hidden_dropout_prob=0.1, 83 | attention_probs_dropout_prob=0.1, 84 | max_position_embeddings=512, 85 | type_vocab_size=2, 86 | initializer_range=0.02, 87 | layer_norm_eps=1e-12, 88 | **kwargs): 89 | super(BertConfig, self).__init__(**kwargs) 90 | if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2 91 | and isinstance(vocab_size_or_config_json_file, unicode)): 92 | with open(vocab_size_or_config_json_file, "r", encoding='utf-8') as reader: 93 | json_config = json.loads(reader.read()) 94 | for key, value in json_config.items(): 95 | self.__dict__[key] = value 96 | elif isinstance(vocab_size_or_config_json_file, int): 97 | self.vocab_size = vocab_size_or_config_json_file 98 | self.hidden_size = hidden_size 99 | self.num_hidden_layers = num_hidden_layers 100 | self.num_attention_heads = num_attention_heads 101 | self.hidden_act = hidden_act 102 | self.intermediate_size = intermediate_size 103 | self.hidden_dropout_prob = hidden_dropout_prob 104 | self.attention_probs_dropout_prob = attention_probs_dropout_prob 105 | self.max_position_embeddings = max_position_embeddings 106 | self.type_vocab_size = type_vocab_size 107 | self.initializer_range = initializer_range 108 | self.layer_norm_eps = layer_norm_eps 109 | else: 110 | raise ValueError("First argument must be either a vocabulary size (int)" 111 | " or the path to a pretrained model config file (str)") -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 3 | import logging 4 | 5 | from datetime import datetime 6 | from tqdm import tqdm 7 | 8 | import torch 9 | from torch import nn, optim 10 | from torch.utils.data import DataLoader 11 | from torch.utils.tensorboard import SummaryWriter 12 | 13 | from data.dataset import ResponseSelectionDataset 14 | from models.utils.checkpointing import CheckpointManager, load_checkpoint 15 | from models import Model 16 | from evaluation import Evaluation 17 | 18 | class ResponseSelection(object): 19 | def __init__(self, hparams): 20 | self.hparams = hparams 21 | self._logger = logging.getLogger(__name__) 22 | 23 | def _build_dataloader(self): 24 | # ============================================================================= 25 | # SETUP DATASET, DATALOADER 26 | # ============================================================================= 27 | self.train_dataset = ResponseSelectionDataset(self.hparams, split="train") 28 | self.train_dataloader = DataLoader( 29 | self.train_dataset, 30 | batch_size=self.hparams.train_batch_size, 31 | num_workers=self.hparams.cpu_workers, 32 | shuffle=True, 33 | drop_last=True 34 | ) 35 | 36 | print(""" 37 | # ------------------------------------------------------------------------- 38 | # DATALOADER FINISHED 39 | # ------------------------------------------------------------------------- 40 | """) 41 | 42 | def _build_model(self): 43 | # ============================================================================= 44 | # MODEL : Standard, Mention Pooling, Entity Marker 45 | # ============================================================================= 46 | print('\t* Building model...') 47 | 48 | self.model = Model(self.hparams) 49 | self.model = self.model.to(self.device) 50 | 51 | # Use Multi-GPUs 52 | if -1 not in self.hparams.gpu_ids and len(self.hparams.gpu_ids) > 1: 53 | self.model = nn.DataParallel(self.model, self.hparams.gpu_ids) 54 | 55 | # ============================================================================= 56 | # CRITERION 57 | # ============================================================================= 58 | self.criterion = nn.BCEWithLogitsLoss() 59 | 60 | self.optimizer = optim.Adam(self.model.parameters(), lr=self.hparams.learning_rate) 61 | self.iterations = len(self.train_dataset) // self.hparams.virtual_batch_size 62 | 63 | def _setup_training(self): 64 | if self.hparams.save_dirpath == 'checkpoints/': 65 | self.save_dirpath = os.path.join(self.hparams.root_dir, self.hparams.save_dirpath) 66 | self.summary_writer = SummaryWriter(self.save_dirpath) 67 | self.checkpoint_manager = CheckpointManager(self.model, self.optimizer, self.save_dirpath, hparams=self.hparams) 68 | 69 | # If loading from checkpoint, adjust start epoch and load parameters. 70 | if self.hparams.load_pthpath == "": 71 | self.start_epoch = 1 72 | else: 73 | # "path/to/checkpoint_xx.pth" -> xx 74 | self.start_epoch = int(self.hparams.load_pthpath.split("_")[-1][:-4]) 75 | self.start_epoch += 1 76 | model_state_dict, optimizer_state_dict = load_checkpoint(self.hparams.load_pthpath) 77 | if isinstance(self.model, nn.DataParallel): 78 | self.model.module.load_state_dict(model_state_dict) 79 | else: 80 | self.model.load_state_dict(model_state_dict) 81 | self.optimizer.load_state_dict(optimizer_state_dict) 82 | self.previous_model_path = self.hparams.load_pthpath 83 | print("Loaded model from {}".format(self.hparams.load_pthpath)) 84 | 85 | print( 86 | """ 87 | # ------------------------------------------------------------------------- 88 | # Setup Training Finished 89 | # ------------------------------------------------------------------------- 90 | """ 91 | ) 92 | 93 | def train(self): 94 | self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 95 | 96 | self._build_dataloader() 97 | self._build_model() 98 | self._setup_training() 99 | 100 | # Evaluation Setup 101 | evaluation = Evaluation(self.hparams, model=self.model, split="test") 102 | 103 | start_time = datetime.now().strftime('%H:%M:%S') 104 | self._logger.info("Start train model at %s" % start_time) 105 | 106 | train_begin = datetime.utcnow() # New 107 | global_iteration_step = 0 108 | accumulate_loss = 0 109 | accu_count = 0 110 | for epoch in range(self.start_epoch, self.hparams.num_epochs): 111 | self.model.train() 112 | 113 | tqdm_batch_iterator = tqdm(self.train_dataloader) 114 | accumulate_batch = 0 115 | 116 | for batch_idx, batch in enumerate(tqdm_batch_iterator): 117 | buffer_batch = batch.copy() 118 | for key in batch: 119 | buffer_batch[key] = buffer_batch[key].to(self.device) 120 | 121 | logits = self.model(buffer_batch) 122 | loss = self.criterion(logits, buffer_batch["label"]) 123 | 124 | loss.backward() 125 | accumulate_loss += loss.item() 126 | accu_count += 1 127 | 128 | # TODO: virtual batch implementation 129 | accumulate_batch += buffer_batch["label"].shape[0] 130 | if self.hparams.virtual_batch_size == accumulate_batch \ 131 | or batch_idx == (len(self.train_dataset) // self.hparams.train_batch_size): # last batch 132 | nn.utils.clip_grad_norm_(self.model.parameters(), self.hparams.max_gradient_norm) 133 | 134 | self.optimizer.step() 135 | self.optimizer.zero_grad() 136 | accumulate_batch = 0 137 | 138 | global_iteration_step += 1 139 | description = "[{}][Epoch: {:3d}][Iter: {:6d}][Loss: {:6f}][lr: {:7f}]".format( 140 | datetime.utcnow() - train_begin, 141 | epoch, 142 | global_iteration_step, (accumulate_loss / accu_count), 143 | self.optimizer.param_groups[0]['lr']) 144 | tqdm_batch_iterator.set_description(description) 145 | 146 | # tensorboard 147 | if global_iteration_step % self.hparams.tensorboard_step == 0: 148 | description = "[{}][Epoch: {:3d}][Iter: {:6d}][Loss: {:6f}][lr: {:7f}]".format( 149 | datetime.utcnow() - train_begin, 150 | epoch, 151 | global_iteration_step, (accumulate_loss / accu_count), 152 | self.optimizer.param_groups[0]['lr'], 153 | ) 154 | self._logger.info(description) 155 | accumulate_loss, accu_count = 0, 0 156 | 157 | # ------------------------------------------------------------------------- 158 | # ON EPOCH END (checkpointing and validation) 159 | # ------------------------------------------------------------------------- 160 | self.checkpoint_manager.step(epoch) 161 | self.previous_model_path = os.path.join(self.checkpoint_manager.ckpt_dirpath, "checkpoint_%d.pth" % (epoch)) 162 | self._logger.info(self.previous_model_path) 163 | 164 | torch.cuda.empty_cache() 165 | self._logger.info("Evaluation after %d epoch" % epoch) 166 | evaluation.run_evaluate(self.previous_model_path) 167 | torch.cuda.empty_cache() 168 | -------------------------------------------------------------------------------- /post_train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | 4 | from datetime import datetime 5 | from tqdm import tqdm 6 | 7 | import torch 8 | from torch import nn, optim 9 | from torch.utils.data import DataLoader 10 | from torch.utils.tensorboard import SummaryWriter 11 | 12 | from data.dataset import BertPostTrainingDataset 13 | from models.utils.checkpointing import CheckpointManager, load_checkpoint 14 | from models import Model 15 | 16 | class BERTDomainPostTraining(object): 17 | def __init__(self, hparams): 18 | self.hparams = hparams 19 | self._logger = logging.getLogger(__name__) 20 | 21 | def _build_dataloader(self): 22 | # ============================================================================= 23 | # SETUP DATASET, DATALOADER 24 | # ============================================================================= 25 | self.train_dataset = BertPostTrainingDataset(self.hparams, split="train") 26 | self.train_dataloader = DataLoader( 27 | self.train_dataset, 28 | batch_size=self.hparams.train_batch_size, 29 | num_workers=self.hparams.cpu_workers, 30 | shuffle=False, 31 | drop_last=True 32 | ) 33 | 34 | print(""" 35 | # ------------------------------------------------------------------------- 36 | # DATALOADER FINISHED 37 | # ------------------------------------------------------------------------- 38 | """) 39 | 40 | def _build_model(self): 41 | # ============================================================================= 42 | # MODEL : Standard, Mention Pooling, Entity Marker 43 | # ============================================================================= 44 | print('\t* Building model...') 45 | 46 | self.model = Model(self.hparams) 47 | self.model = self.model.to(self.device) 48 | 49 | # Use Multi-GPUs 50 | if -1 not in self.hparams.gpu_ids and len(self.hparams.gpu_ids) > 1: 51 | self.model = nn.DataParallel(self.model, self.hparams.gpu_ids) 52 | 53 | self.optimizer = optim.Adam(self.model.parameters(), lr=self.hparams.learning_rate) 54 | self.iterations = len(self.train_dataset) // self.hparams.virtual_batch_size 55 | 56 | print( 57 | """ 58 | # ------------------------------------------------------------------------- 59 | # Building Model Finished 60 | # ------------------------------------------------------------------------- 61 | """ 62 | ) 63 | 64 | def _setup_training(self): 65 | if self.hparams.save_dirpath == 'checkpoints/': 66 | self.save_dirpath = os.path.join(self.hparams.root_dir, self.hparams.save_dirpath) 67 | self.summary_writer = SummaryWriter(self.save_dirpath) 68 | self.checkpoint_manager = CheckpointManager(self.model, self.optimizer, self.save_dirpath, hparams=self.hparams) 69 | 70 | # If loading from checkpoint, adjust start epoch and load parameters. 71 | if self.hparams.load_pthpath == "": 72 | self.start_epoch = 1 73 | else: 74 | # "path/to/checkpoint_xx.pth" -> xx 75 | self.start_epoch = int(self.hparams.load_pthpath.split("_")[-1][:-4]) 76 | self.start_epoch += 1 77 | model_state_dict, optimizer_state_dict = load_checkpoint(self.hparams.load_pthpath) 78 | if isinstance(self.model, nn.DataParallel): 79 | self.model.module.load_state_dict(model_state_dict) 80 | else: 81 | self.model.load_state_dict(model_state_dict) 82 | self.optimizer.load_state_dict(optimizer_state_dict) 83 | self.previous_model_path = self.hparams.load_pthpath 84 | print("Loaded model from {}".format(self.hparams.load_pthpath)) 85 | 86 | print( 87 | """ 88 | # ------------------------------------------------------------------------- 89 | # Setup Training Finished 90 | # ------------------------------------------------------------------------- 91 | """ 92 | ) 93 | 94 | def train(self): 95 | self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 96 | 97 | self._build_dataloader() 98 | self._build_model() 99 | self._setup_training() 100 | 101 | start_time = datetime.now().strftime('%H:%M:%S') 102 | self._logger.info("Start train model at %s" % start_time) 103 | 104 | train_begin = datetime.utcnow() # New 105 | global_iteration_step = 0 106 | accu_mlm_loss, accu_nsp_loss = 0, 0 107 | accumulate_batch, accu_count = 0, 0 108 | 109 | for epoch in range(self.start_epoch, self.hparams.num_epochs): 110 | self.model.train() 111 | 112 | tqdm_batch_iterator = tqdm(self.train_dataloader) 113 | for batch_idx, batch in enumerate(tqdm_batch_iterator): 114 | buffer_batch = batch.copy() 115 | for key in batch: 116 | buffer_batch[key] = buffer_batch[key].to(self.device) 117 | 118 | mlm_loss, nsp_loss = self.model(buffer_batch) 119 | total_loss = mlm_loss.mean() + nsp_loss.mean() 120 | total_loss.backward() 121 | accu_mlm_loss += mlm_loss.mean().item() 122 | accu_nsp_loss += nsp_loss.mean().item() 123 | accu_count += 1 124 | 125 | # TODO: virtual batch implementation 126 | accumulate_batch += buffer_batch["next_sentence_labels"].shape[0] 127 | if self.hparams.virtual_batch_size == accumulate_batch \ 128 | or batch_idx == (len(self.train_dataset) // self.hparams.train_batch_size): # last batch 129 | 130 | nn.utils.clip_grad_norm_(self.model.parameters(), self.hparams.max_gradient_norm) 131 | 132 | self.optimizer.step() 133 | self.optimizer.zero_grad() 134 | 135 | global_iteration_step += 1 136 | description = "[{}][Epoch: {:3d}][Iter: {:6d}][MLM_Loss: {:6f}][NSP_Loss: {:6f}][lr: {:7f}]".format( 137 | datetime.utcnow() - train_begin, 138 | epoch, 139 | global_iteration_step, (accu_mlm_loss / accu_count), (accu_nsp_loss / accu_count), 140 | self.optimizer.param_groups[0]['lr']) 141 | tqdm_batch_iterator.set_description(description) 142 | 143 | # tensorboard 144 | if global_iteration_step % self.hparams.tensorboard_step == 0: 145 | description = "[{}][Epoch: {:3d}][Iter: {:6d}]MLM_Loss: {:6f}][NSP_Loss: {:6f}][lr: {:7f}]".format( 146 | datetime.utcnow() - train_begin, 147 | epoch, 148 | global_iteration_step, (accu_mlm_loss / accu_count), (accu_nsp_loss / accu_count), 149 | self.optimizer.param_groups[0]['lr'], 150 | ) 151 | self._logger.info(description) 152 | 153 | accumulate_batch, accu_count = 0, 0 154 | accu_mlm_loss, accu_nsp_loss = 0, 0 155 | 156 | if global_iteration_step % self.hparams.checkpoint_save_step == 0: 157 | # ------------------------------------------------------------------------- 158 | # ON EPOCH END (checkpointing and validation) 159 | # ------------------------------------------------------------------------- 160 | self.checkpoint_manager.step(global_iteration_step) 161 | self.previous_model_path = os.path.join(self.checkpoint_manager.ckpt_dirpath, 162 | "checkpoint_%d.pth" % (global_iteration_step)) 163 | self._logger.info(self.previous_model_path) -------------------------------------------------------------------------------- /models/pretrained_common/optimization.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """PyTorch optimization for BERT model.""" 16 | 17 | import logging 18 | import math 19 | 20 | import torch 21 | from torch.optim import Optimizer 22 | from torch.optim.lr_scheduler import LambdaLR 23 | 24 | logger = logging.getLogger(__name__) 25 | 26 | class ConstantLRSchedule(LambdaLR): 27 | """ Constant learning rate schedule. 28 | """ 29 | def __init__(self, optimizer, last_epoch=-1): 30 | super(ConstantLRSchedule, self).__init__(optimizer, lambda _: 1.0, last_epoch=last_epoch) 31 | 32 | 33 | class WarmupConstantSchedule(LambdaLR): 34 | """ Linear warmup and then constant. 35 | Linearly increases learning rate schedule from 0 to 1 over `warmup_steps` training steps. 36 | Keeps learning rate schedule equal to 1. after warmup_steps. 37 | """ 38 | def __init__(self, optimizer, warmup_steps, last_epoch=-1): 39 | self.warmup_steps = warmup_steps 40 | super(WarmupConstantSchedule, self).__init__(optimizer, self.lr_lambda, last_epoch=last_epoch) 41 | 42 | def lr_lambda(self, step): 43 | if step < self.warmup_steps: 44 | return float(step) / float(max(1.0, self.warmup_steps)) 45 | return 1. 46 | 47 | 48 | class WarmupLinearSchedule(LambdaLR): 49 | """ Linear warmup and then linear decay. 50 | Linearly increases learning rate from 0 to 1 over `warmup_steps` training steps. 51 | Linearly decreases learning rate from 1. to 0. over remaining `t_total - warmup_steps` steps. 52 | """ 53 | def __init__(self, optimizer, warmup_steps, t_total, last_epoch=-1): 54 | self.warmup_steps = warmup_steps 55 | self.t_total = t_total 56 | super(WarmupLinearSchedule, self).__init__(optimizer, self.lr_lambda, last_epoch=last_epoch) 57 | 58 | def lr_lambda(self, step): 59 | if step < self.warmup_steps: 60 | return float(step) / float(max(1, self.warmup_steps)) 61 | return max(0.0, float(self.t_total - step) / float(max(1.0, self.t_total - self.warmup_steps))) 62 | 63 | 64 | class WarmupCosineSchedule(LambdaLR): 65 | """ Linear warmup and then cosine decay. 66 | Linearly increases learning rate from 0 to 1 over `warmup_steps` training steps. 67 | Decreases learning rate from 1. to 0. over remaining `t_total - warmup_steps` steps following a cosine curve. 68 | If `cycles` (default=0.5) is different from default, learning rate follows cosine function after warmup. 69 | """ 70 | def __init__(self, optimizer, warmup_steps, t_total, cycles=.5, last_epoch=-1): 71 | self.warmup_steps = warmup_steps 72 | self.t_total = t_total 73 | self.cycles = cycles 74 | super(WarmupCosineSchedule, self).__init__(optimizer, self.lr_lambda, last_epoch=last_epoch) 75 | 76 | def lr_lambda(self, step): 77 | if step < self.warmup_steps: 78 | return float(step) / float(max(1.0, self.warmup_steps)) 79 | # progress after warmup 80 | progress = float(step - self.warmup_steps) / float(max(1, self.t_total - self.warmup_steps)) 81 | return max(0.0, 0.5 * (1. + math.cos(math.pi * float(self.cycles) * 2.0 * progress))) 82 | 83 | 84 | class WarmupCosineWithHardRestartsSchedule(LambdaLR): 85 | """ Linear warmup and then cosine cycles with hard restarts. 86 | Linearly increases learning rate from 0 to 1 over `warmup_steps` training steps. 87 | If `cycles` (default=1.) is different from default, learning rate follows `cycles` times a cosine decaying 88 | learning rate (with hard restarts). 89 | """ 90 | def __init__(self, optimizer, warmup_steps, t_total, cycles=1., last_epoch=-1): 91 | self.warmup_steps = warmup_steps 92 | self.t_total = t_total 93 | self.cycles = cycles 94 | super(WarmupCosineWithHardRestartsSchedule, self).__init__(optimizer, self.lr_lambda, last_epoch=last_epoch) 95 | 96 | def lr_lambda(self, step): 97 | if step < self.warmup_steps: 98 | return float(step) / float(max(1, self.warmup_steps)) 99 | # progress after warmup 100 | progress = float(step - self.warmup_steps) / float(max(1, self.t_total - self.warmup_steps)) 101 | if progress >= 1.0: 102 | return 0.0 103 | return max(0.0, 0.5 * (1. + math.cos(math.pi * ((float(self.cycles) * progress) % 1.0)))) 104 | 105 | 106 | class AdamW(Optimizer): 107 | """ Implements Adam algorithm with weight decay fix. 108 | 109 | Parameters: 110 | lr (float): learning rate. Default 1e-3. 111 | betas (tuple of 2 floats): Adams beta parameters (b1, b2). Default: (0.9, 0.999) 112 | eps (float): Adams epsilon. Default: 1e-6 113 | weight_decay (float): Weight decay. Default: 0.0 114 | correct_bias (bool): can be set to False to avoid correcting bias in Adam (e.g. like in Bert TF repository). Default True. 115 | """ 116 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-6, weight_decay=0.0, correct_bias=True): 117 | if lr < 0.0: 118 | raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr)) 119 | if not 0.0 <= betas[0] < 1.0: 120 | raise ValueError("Invalid beta parameter: {} - should be in [0.0, 1.0[".format(betas[0])) 121 | if not 0.0 <= betas[1] < 1.0: 122 | raise ValueError("Invalid beta parameter: {} - should be in [0.0, 1.0[".format(betas[1])) 123 | if not 0.0 <= eps: 124 | raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(eps)) 125 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, 126 | correct_bias=correct_bias) 127 | super(AdamW, self).__init__(params, defaults) 128 | 129 | def step(self, closure=None): 130 | """Performs a single optimization step. 131 | 132 | Arguments: 133 | closure (callable, optional): A closure that reevaluates the model 134 | and returns the loss. 135 | """ 136 | loss = None 137 | if closure is not None: 138 | loss = closure() 139 | 140 | for group in self.param_groups: 141 | for p in group['params']: 142 | if p.grad is None: 143 | continue 144 | grad = p.grad.data 145 | if grad.is_sparse: 146 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') 147 | 148 | state = self.state[p] 149 | 150 | # State initialization 151 | if len(state) == 0: 152 | state['step'] = 0 153 | # Exponential moving average of gradient values 154 | state['exp_avg'] = torch.zeros_like(p.data) 155 | # Exponential moving average of squared gradient values 156 | state['exp_avg_sq'] = torch.zeros_like(p.data) 157 | 158 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 159 | beta1, beta2 = group['betas'] 160 | 161 | state['step'] += 1 162 | 163 | # Decay the first and second moment running average coefficient 164 | # In-place operations to update the averages at the same time 165 | exp_avg.mul_(beta1).add_(1.0 - beta1, grad) 166 | exp_avg_sq.mul_(beta2).addcmul_(1.0 - beta2, grad, grad) 167 | denom = exp_avg_sq.sqrt().add_(group['eps']) 168 | 169 | step_size = group['lr'] 170 | if group['correct_bias']: # No bias correction for Bert 171 | bias_correction1 = 1.0 - beta1 ** state['step'] 172 | bias_correction2 = 1.0 - beta2 ** state['step'] 173 | step_size = step_size * math.sqrt(bias_correction2) / bias_correction1 174 | 175 | p.data.addcdiv_(-step_size, exp_avg, denom) 176 | 177 | # Just adding the square of the weights to the loss function is *not* 178 | # the correct way of using L2 regularization/weight decay with Adam, 179 | # since that will interact with the m and v parameters in strange ways. 180 | # 181 | # Instead we want to decay the weights in a manner that doesn't interact 182 | # with the m/v parameters. This is equivalent to adding the square 183 | # of the weights to the loss with plain (non-momentum) SGD. 184 | # Add weight decay at the end (fixed version) 185 | if group['weight_decay'] > 0.0: 186 | p.data.add_(-group['lr'] * group['weight_decay'], p.data) 187 | 188 | return loss 189 | -------------------------------------------------------------------------------- /models/pretrained_common/configuration_utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ Configuration base class and utilities.""" 17 | 18 | from __future__ import (absolute_import, division, print_function, 19 | unicode_literals) 20 | 21 | import copy 22 | import json 23 | import logging 24 | import os 25 | from io import open 26 | 27 | from .file_utils import cached_path, CONFIG_NAME 28 | 29 | logger = logging.getLogger(__name__) 30 | 31 | class PretrainedConfig(object): 32 | r""" Base class for all configuration classes. 33 | Handles a few parameters common to all models' configurations as well as methods for loading/downloading/saving configurations. 34 | 35 | Note: 36 | A configuration file can be loaded and saved to disk. Loading the configuration file and using this file to initialize a model does **not** load the model weights. 37 | It only affects the model's configuration. 38 | 39 | Class attributes (overridden by derived classes): 40 | - ``pretrained_config_archive_map``: a python ``dict`` of with `short-cut-names` (string) as keys and `url` (string) of associated pretrained model configurations as values. 41 | 42 | Parameters: 43 | ``finetuning_task``: string, default `None`. Name of the task used to fine-tune the model. This can be used when converting from an original (TensorFlow or PyTorch) checkpoint. 44 | ``num_labels``: integer, default `2`. Number of classes to use when the model is a classification model (sequences/tokens) 45 | ``output_attentions``: boolean, default `False`. Should the model returns attentions weights. 46 | ``output_hidden_states``: string, default `False`. Should the model returns all hidden-states. 47 | ``torchscript``: string, default `False`. Is the model used with Torchscript. 48 | """ 49 | pretrained_config_archive_map = {} 50 | 51 | def __init__(self, **kwargs): 52 | self.finetuning_task = kwargs.pop('finetuning_task', None) 53 | self.num_labels = kwargs.pop('num_labels', 2) 54 | self.output_attentions = kwargs.pop('output_attentions', False) 55 | self.output_hidden_states = kwargs.pop('output_hidden_states', False) 56 | self.torchscript = kwargs.pop('torchscript', False) 57 | self.pruned_heads = kwargs.pop('pruned_heads', {}) 58 | 59 | def save_pretrained(self, save_directory): 60 | """ Save a configuration object to the directory `save_directory`, so that it 61 | can be re-loaded using the :func:`~pytorch_transformers.PretrainedConfig.from_pretrained` class method. 62 | """ 63 | assert os.path.isdir(save_directory), "Saving path should be a directory where the model and configuration can be saved" 64 | 65 | # If we save using the predefined names, we can load using `from_pretrained` 66 | output_config_file = os.path.join(save_directory, CONFIG_NAME) 67 | 68 | self.to_json_file(output_config_file) 69 | 70 | @classmethod 71 | def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): 72 | r""" Instantiate a :class:`~pytorch_transformers.PretrainedConfig` (or a derived class) from a pre-trained model configuration. 73 | 74 | Parameters: 75 | pretrained_model_name_or_path: either: 76 | 77 | - a string with the `shortcut name` of a pre-trained model configuration to load from cache or download, e.g.: ``bert-base-uncased``. 78 | - a path to a `directory` containing a configuration file saved using the :func:`~pytorch_transformers.PretrainedConfig.save_pretrained` method, e.g.: ``./my_model_directory/``. 79 | - a path or url to a saved configuration JSON `file`, e.g.: ``./my_model_directory/configuration.json``. 80 | 81 | cache_dir: (`optional`) string: 82 | Path to a directory in which a downloaded pre-trained model 83 | configuration should be cached if the standard cache should not be used. 84 | 85 | kwargs: (`optional`) dict: key/value pairs with which to update the configuration object after loading. 86 | 87 | - The values in kwargs of any keys which are configuration attributes will be used to override the loaded values. 88 | - Behavior concerning key/value pairs whose keys are *not* configuration attributes is controlled by the `return_unused_kwargs` keyword parameter. 89 | 90 | force_download: (`optional`) boolean, default False: 91 | Force to (re-)download the model weights and configuration files and override the cached versions if they exists. 92 | 93 | proxies: (`optional`) dict, default None: 94 | A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}. 95 | The proxies are used on each request. 96 | 97 | return_unused_kwargs: (`optional`) bool: 98 | 99 | - If False, then this function returns just the final configuration object. 100 | - If True, then this functions returns a tuple `(config, unused_kwargs)` where `unused_kwargs` is a dictionary consisting of the key/value pairs whose keys are not configuration attributes: ie the part of kwargs which has not been used to update `config` and is otherwise ignored. 101 | 102 | Examples:: 103 | 104 | # We can't instantiate directly the base class `PretrainedConfig` so let's show the examples on a 105 | # derived class: BertConfig 106 | config = BertConfig.from_pretrained('bert-base-uncased') # Download configuration from S3 and cache. 107 | config = BertConfig.from_pretrained('./test/saved_model/') # E.g. config (or model) was saved using `save_pretrained('./test/saved_model/')` 108 | config = BertConfig.from_pretrained('./test/saved_model/my_configuration.json') 109 | config = BertConfig.from_pretrained('bert-base-uncased', output_attention=True, foo=False) 110 | assert config.output_attention == True 111 | config, unused_kwargs = BertConfig.from_pretrained('bert-base-uncased', output_attention=True, 112 | foo=False, return_unused_kwargs=True) 113 | assert config.output_attention == True 114 | assert unused_kwargs == {'foo': False} 115 | 116 | """ 117 | cache_dir = kwargs.pop('cache_dir', None) 118 | force_download = kwargs.pop('force_download', False) 119 | proxies = kwargs.pop('proxies', None) 120 | return_unused_kwargs = kwargs.pop('return_unused_kwargs', False) 121 | 122 | if pretrained_model_name_or_path in cls.pretrained_config_archive_map: 123 | config_file = cls.pretrained_config_archive_map[pretrained_model_name_or_path] 124 | elif os.path.isdir(pretrained_model_name_or_path): 125 | config_file = os.path.join(pretrained_model_name_or_path, CONFIG_NAME) 126 | else: 127 | config_file = pretrained_model_name_or_path 128 | # redirect to the cache, if necessary 129 | try: 130 | resolved_config_file = cached_path(config_file, cache_dir=cache_dir, force_download=force_download, proxies=proxies) 131 | except EnvironmentError as e: 132 | if pretrained_model_name_or_path in cls.pretrained_config_archive_map: 133 | logger.error( 134 | "Couldn't reach server at '{}' to download pretrained model configuration file.".format( 135 | config_file)) 136 | else: 137 | logger.error( 138 | "Model name '{}' was not found in model name list ({}). " 139 | "We assumed '{}' was a path or url but couldn't find any file " 140 | "associated to this path or url.".format( 141 | pretrained_model_name_or_path, 142 | ', '.join(cls.pretrained_config_archive_map.keys()), 143 | config_file)) 144 | raise e 145 | if resolved_config_file == config_file: 146 | logger.info("loading configuration file {}".format(config_file)) 147 | else: 148 | logger.info("loading configuration file {} from cache at {}".format( 149 | config_file, resolved_config_file)) 150 | 151 | # Load config 152 | config = cls.from_json_file(resolved_config_file) 153 | 154 | if hasattr(config, 'pruned_heads'): 155 | config.pruned_heads = dict((int(key), set(value)) for key, value in config.pruned_heads.items()) 156 | 157 | # Update config with kwargs if needed 158 | to_remove = [] 159 | for key, value in kwargs.items(): 160 | if hasattr(config, key): 161 | setattr(config, key, value) 162 | to_remove.append(key) 163 | for key in to_remove: 164 | kwargs.pop(key, None) 165 | 166 | logger.info("Model config %s", config) 167 | if return_unused_kwargs: 168 | return config, kwargs 169 | else: 170 | return config 171 | 172 | @classmethod 173 | def from_dict(cls, json_object): 174 | """Constructs a `Config` from a Python dictionary of parameters.""" 175 | config = cls(vocab_size_or_config_json_file=-1) 176 | for key, value in json_object.items(): 177 | config.__dict__[key] = value 178 | return config 179 | 180 | @classmethod 181 | def from_json_file(cls, json_file): 182 | """Constructs a `BertConfig` from a json file of parameters.""" 183 | with open(json_file, "r", encoding='utf-8') as reader: 184 | text = reader.read() 185 | return cls.from_dict(json.loads(text)) 186 | 187 | def __eq__(self, other): 188 | return self.__dict__ == other.__dict__ 189 | 190 | def __repr__(self): 191 | return str(self.to_json_string()) 192 | 193 | def to_dict(self): 194 | """Serializes this instance to a Python dictionary.""" 195 | output = copy.deepcopy(self.__dict__) 196 | return output 197 | 198 | def to_json_string(self): 199 | """Serializes this instance to a JSON string.""" 200 | return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n" 201 | 202 | def to_json_file(self, json_file_path): 203 | """ Save this instance to a json file.""" 204 | with open(json_file_path, "w", encoding='utf-8') as writer: 205 | writer.write(self.to_json_string()) 206 | -------------------------------------------------------------------------------- /data/create_bert_post_training_data.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Create masked LM/next sentence masked_lm TF examples for BERT.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import sys 22 | import os 23 | sys.path.append(os.getcwd()) 24 | import collections 25 | import random 26 | import argparse 27 | import h5py 28 | import numpy as np 29 | from tqdm import tqdm 30 | 31 | from models.bert import tokenization_bert 32 | 33 | class TrainingInstance(object): 34 | """A single training instance (sentence pair).""" 35 | 36 | def __init__(self, tokens, segment_ids, masked_lm_positions, masked_lm_labels, 37 | is_random_next): 38 | self.tokens = tokens 39 | self.segment_ids = segment_ids 40 | self.is_random_next = is_random_next 41 | self.masked_lm_positions = masked_lm_positions 42 | self.masked_lm_labels = masked_lm_labels 43 | 44 | MaskedLmInstance = collections.namedtuple("MaskedLmInstance", 45 | ["index", "label"]) 46 | 47 | class CreateBertPretrainingData(object): 48 | def __init__(self, args): 49 | self.args = args 50 | self._bert_tokenizer_init(args.special_tok) 51 | 52 | def _bert_tokenizer_init(self, special_tok, bert_pretrained='bert-base-uncased'): 53 | bert_pretrained_dir = os.path.join("./resources", bert_pretrained) 54 | vocab_file_path = "%s-vocab.txt" % bert_pretrained 55 | 56 | self._bert_tokenizer = tokenization_bert.BertTokenizer(vocab_file=os.path.join(bert_pretrained_dir, vocab_file_path)) 57 | self._bert_tokenizer.add_tokens([special_tok]) 58 | 59 | print("BERT tokenizer init completes") 60 | 61 | def _add_special_tokens(self, tokens, special_tok="[EOT]"): 62 | tokens = tokens + [special_tok] 63 | return tokens 64 | 65 | def create_training_instances(self, input_file, max_seq_length, 66 | dupe_factor, short_seq_prob, masked_lm_prob, 67 | max_predictions_per_seq, rng, special_tok=None): 68 | """Create `TrainingInstance`s from raw text.""" 69 | all_documents = [[]] 70 | 71 | # Input file format: 72 | # (1) One sentence per line. These should ideally be actual sentences, not 73 | # entire paragraphs or arbitrary spans of text. (Because we use the 74 | # sentence boundaries for the "next sentence prediction" task). 75 | # (2) Blank lines between documents. Document boundaries are needed so 76 | # that the "next sentence prediction" task doesn't span between documents. 77 | document_cnt = 0 78 | with open(input_file, "r", encoding="utf=8") as fr_handle: 79 | for line in tqdm(fr_handle): 80 | line = line.strip() 81 | 82 | # Empty lines are used as document delimiters 83 | if len(line) == 0: 84 | all_documents.append([]) 85 | document_cnt +=1 86 | if document_cnt % 50000 == 0: 87 | print("%d documents have been tokenized!" % document_cnt) 88 | tokens = self._bert_tokenizer.tokenize(line) 89 | 90 | if special_tok: 91 | tokens = self._add_special_tokens(tokens, special_tok) # special tok per sentence 92 | 93 | if tokens: 94 | all_documents[-1].append(tokens) 95 | 96 | # Remove empty documents 97 | all_documents = [x for x in all_documents if x] 98 | rng.shuffle(all_documents) 99 | 100 | vocab_words = list(self._bert_tokenizer.vocab.keys()) 101 | 102 | self.feature_keys = ["input_ids", "attention_mask", "token_type_ids", 103 | "masked_lm_positions", "masked_lm_ids", "next_sentence_labels"] 104 | 105 | print("Total number of documents : %d" % len(all_documents)) 106 | hf = h5py.File(self.args.output_file, "w") 107 | for d in range(dupe_factor): 108 | rng.shuffle(all_documents) 109 | 110 | self.all_doc_feat_dict = dict() 111 | for feat_key in self.feature_keys: 112 | self.all_doc_feat_dict[feat_key] = [] 113 | 114 | for document_index in tqdm(range(len(all_documents))): 115 | instances = self.create_instances_from_document( 116 | all_documents, document_index, max_seq_length, short_seq_prob, 117 | masked_lm_prob, max_predictions_per_seq, vocab_words, rng) 118 | self.instance_to_example_feature(instances, self.args.max_seq_length, 119 | self.args.max_predictions_per_seq) 120 | self.build_h5_data(hf, d_idx=d), 121 | print("Current Dupe Factor : %d" % (d + 1)) 122 | print("Pretraining Data Creation Completes!") 123 | hf.close() 124 | 125 | def build_h5_data(self, hf, d_idx=0): 126 | """ 127 | features["input_ids"] = torch.tensor(input_ids).long() 128 | features["attention_mask"] = torch.tensor(input_mask).long() 129 | features["token_type_ids"] = torch.tensor(segment_ids).long() 130 | features["masked_lm_positions"] = torch.tensor(masked_lm_positions).long() 131 | features["masked_lm_ids"] = torch.tensor(masked_lm_ids).long() # masked_lm_ids 132 | features["next_sentence_labels"] = torch.tensor([next_sentence_label]).long() 133 | """ 134 | 135 | h5_key_dict = {} 136 | print("Number of documents features", len(self.all_doc_feat_dict["next_sentence_labels"])) 137 | if d_idx == 0: 138 | for feat_key in self.feature_keys: 139 | key_size = [len(self.all_doc_feat_dict["next_sentence_labels"])] + [len(self.all_doc_feat_dict[feat_key][0])] 140 | h5_key_dict[feat_key] = hf.create_dataset(feat_key, tuple(key_size), dtype='i8', chunks=True, 141 | maxshape=(None, tuple(key_size)[1]), 142 | data=np.array(self.all_doc_feat_dict[feat_key])) 143 | else: 144 | for feat_key in self.feature_keys: 145 | hf[feat_key].resize((hf[feat_key].shape[0] + len(self.all_doc_feat_dict["next_sentence_labels"])), axis=0) 146 | hf[feat_key][-len(self.all_doc_feat_dict["next_sentence_labels"]):] = np.array(self.all_doc_feat_dict[feat_key]) 147 | 148 | def create_instances_from_document(self, all_documents, document_index, max_seq_length, short_seq_prob, 149 | masked_lm_prob, max_predictions_per_seq, vocab_words, rng): 150 | """Creates `TrainingInstance`s for a single document.""" 151 | document = all_documents[document_index] 152 | 153 | # Account for [CLS], [SEP], [SEP] 154 | max_num_tokens = max_seq_length - 3 155 | 156 | # We *usually* want to fill up the entire sequence since we are padding 157 | # to `max_seq_length` anyways, so short sequences are generally wasted 158 | # computation. However, we *sometimes* 159 | # (i.e., short_seq_prob == 0.1 == 10% of the time) want to use shorter 160 | # sequences to minimize the mismatch between pre-training and fine-tuning. 161 | # The `target_seq_length` is just a rough target however, whereas 162 | # `max_seq_length` is a hard limit. 163 | target_seq_length = max_num_tokens 164 | if rng.random() < short_seq_prob: 165 | target_seq_length = rng.randint(2, max_num_tokens) 166 | 167 | # We DON'T just concatenate all of the tokens from a document into a long 168 | # sequence and choose an arbitrary split point because this would make the 169 | # next sentence prediction task too easy. Instead, we split the input into 170 | # segments "A" and "B" based on the actual "sentences" provided by the user 171 | # input. 172 | instances = [] 173 | current_chunk = [] 174 | current_length = 0 175 | i = 0 176 | 177 | while i < len(document): 178 | segment = document[i] 179 | current_chunk.append(segment) 180 | current_length += len(segment) 181 | if i == len(document) - 1 or current_length >= target_seq_length: 182 | if current_chunk: 183 | # `a_end` is how many segments from `current_chunk` go into the `A` 184 | # (first) sentence. 185 | a_end = 1 186 | if len(current_chunk) >= 2: 187 | a_end = rng.randint(1, len(current_chunk) - 1) 188 | 189 | tokens_a = [] 190 | for j in range(a_end): 191 | tokens_a.extend(current_chunk[j]) 192 | 193 | tokens_b = [] 194 | # Random next 195 | is_random_next = False 196 | if len(current_chunk) == 1 or rng.random() < 0.5: 197 | is_random_next = True 198 | target_b_length = target_seq_length - len(tokens_a) 199 | 200 | # This should rarely go for more than one iteration for large 201 | # corpora. However, just to be careful, we try to make sure that 202 | # the random document is not the same as the document 203 | # we're processing. 204 | for _ in range(10): 205 | random_document_index = rng.randint(0, len(all_documents) - 1) 206 | if random_document_index != document_index: 207 | break 208 | 209 | random_document = all_documents[random_document_index] 210 | random_start = rng.randint(0, len(random_document) - 1) 211 | for j in range(random_start, len(random_document)): 212 | tokens_b.extend(random_document[j]) 213 | if len(tokens_b) >= target_b_length: 214 | break 215 | # We didn't actually use these segments so we "put them back" so 216 | # they don't go to waste. 217 | num_unused_segments = len(current_chunk) - a_end 218 | i -= num_unused_segments 219 | # Actual next 220 | else: 221 | is_random_next = False 222 | for j in range(a_end, len(current_chunk)): 223 | tokens_b.extend(current_chunk[j]) 224 | self.truncate_seq_pair(tokens_a, tokens_b, max_num_tokens, rng) 225 | 226 | assert len(tokens_a) >= 1 227 | assert len(tokens_b) >= 1 228 | 229 | tokens = [] 230 | segment_ids = [] 231 | tokens.append("[CLS]") 232 | segment_ids.append(0) 233 | for token in tokens_a: 234 | tokens.append(token) 235 | segment_ids.append(0) 236 | 237 | tokens.append("[SEP]") 238 | segment_ids.append(0) 239 | 240 | for token in tokens_b: 241 | tokens.append(token) 242 | segment_ids.append(1) 243 | tokens.append("[SEP]") 244 | segment_ids.append(1) 245 | 246 | (tokens, masked_lm_positions, masked_lm_labels) = self.create_masked_lm_predictions( 247 | tokens, masked_lm_prob, max_predictions_per_seq, vocab_words, rng) 248 | instance = TrainingInstance( 249 | tokens=tokens, 250 | segment_ids=segment_ids, 251 | is_random_next=is_random_next, 252 | masked_lm_positions=masked_lm_positions, 253 | masked_lm_labels=masked_lm_labels) 254 | instances.append(instance) 255 | 256 | current_chunk = [] 257 | current_length = 0 258 | i += 1 259 | 260 | return instances 261 | 262 | def create_masked_lm_predictions(self, tokens, masked_lm_prob, max_predictions_per_seq, vocab_words, rng): 263 | """Creates the predictions for the masked LM objective.""" 264 | 265 | cand_indexes = [] 266 | for (i, token) in enumerate(tokens): 267 | if token == "[CLS]" or token == "[SEP]": 268 | continue 269 | # Whole Word Masking means that if we mask all of the wordpieces 270 | # corresponding to an original word. When a word has been split into 271 | # WordPieces, the first token does not have any marker and any subsequence 272 | # tokens are prefixed with ##. So whenever we see the ## token, we 273 | # append it to the previous set of word indexes. 274 | # 275 | # Note that Whole Word Masking does *not* change the training code 276 | # at all -- we still predict each WordPiece independently, softmaxed 277 | # over the entire vocabulary. 278 | if (self.args.do_whole_word_mask and len(cand_indexes) >= 1 and 279 | token.startswith("##")): 280 | cand_indexes[-1].append(i) 281 | else: 282 | cand_indexes.append([i]) 283 | 284 | rng.shuffle(cand_indexes) 285 | 286 | output_tokens = list(tokens) 287 | 288 | num_to_predict = min(max_predictions_per_seq, 289 | max(1, int(round(len(tokens) * masked_lm_prob)))) 290 | 291 | masked_lms = [] 292 | covered_indexes = set() 293 | for index_set in cand_indexes: 294 | if len(masked_lms) >= num_to_predict: 295 | break 296 | # If adding a whole-word mask would exceed the maximum number of 297 | # predictions, then just skip this candidate. 298 | if len(masked_lms) + len(index_set) > num_to_predict: 299 | continue 300 | is_any_index_covered = False 301 | for index in index_set: 302 | if index in covered_indexes: 303 | is_any_index_covered = True 304 | break 305 | if is_any_index_covered: 306 | continue 307 | for index in index_set: 308 | covered_indexes.add(index) 309 | 310 | masked_token = None 311 | # 80% of the time, replace with [MASK] 312 | if rng.random() < 0.8: 313 | masked_token = "[MASK]" 314 | else: 315 | # 10% of the time, keep original 316 | if rng.random() < 0.5: 317 | masked_token = tokens[index] 318 | # 10% of the time, replace with random word 319 | else: 320 | masked_token = vocab_words[rng.randint(0, len(vocab_words) - 1)] 321 | 322 | output_tokens[index] = masked_token 323 | 324 | masked_lms.append(MaskedLmInstance(index=index, label=tokens[index])) 325 | 326 | assert len(masked_lms) <= num_to_predict 327 | masked_lms = sorted(masked_lms, key=lambda x: x.index) 328 | 329 | masked_lm_positions = [] 330 | masked_lm_labels = [] 331 | for p in masked_lms: 332 | masked_lm_positions.append(p.index) 333 | masked_lm_labels.append(p.label) 334 | 335 | return (output_tokens, masked_lm_positions, masked_lm_labels) 336 | 337 | def truncate_seq_pair(self, tokens_a, tokens_b, max_num_tokens, rng): 338 | """Truncates a pair of sequences to a maximum sequence length.""" 339 | while True: 340 | total_length = len(tokens_a) + len(tokens_b) 341 | if total_length <= max_num_tokens: 342 | break 343 | 344 | trunc_tokens = tokens_a if len(tokens_a) > len(tokens_b) else tokens_b 345 | assert len(trunc_tokens) >= 1 346 | 347 | # We want to sometimes truncate from the front and sometimes from the 348 | # back to add more randomness and avoid biases. 349 | if rng.random() < 0.5: 350 | del trunc_tokens[0] 351 | else: 352 | trunc_tokens.pop() 353 | 354 | def instance_to_example_feature(self, instances, max_seq_length, max_predictions_per_seq): 355 | for instance in instances: 356 | input_ids = self._bert_tokenizer.convert_tokens_to_ids(instance.tokens) 357 | input_mask = [1] * len(input_ids) 358 | segment_ids = list(instance.segment_ids) 359 | 360 | assert len(input_ids) <= max_seq_length 361 | 362 | while len(input_ids) < max_seq_length: 363 | input_ids.append(0) 364 | input_mask.append(0) 365 | segment_ids.append(0) 366 | 367 | assert len(input_ids) == max_seq_length 368 | assert len(input_mask) == max_seq_length 369 | assert len(segment_ids) == max_seq_length 370 | 371 | masked_lm_positions = list(instance.masked_lm_positions) 372 | masked_lm_ids = self._bert_tokenizer.convert_tokens_to_ids(instance.masked_lm_labels) 373 | masked_lm_weights = [1.0] * len(masked_lm_ids) 374 | 375 | while len(masked_lm_positions) < max_predictions_per_seq: 376 | masked_lm_positions.append(0) 377 | masked_lm_ids.append(0) 378 | masked_lm_weights.append(0.0) 379 | 380 | next_sentence_label = 1 if instance.is_random_next else 0 381 | 382 | self.all_doc_feat_dict["input_ids"].append(input_ids) 383 | self.all_doc_feat_dict["attention_mask"].append(input_mask) 384 | self.all_doc_feat_dict["token_type_ids"].append(segment_ids) 385 | self.all_doc_feat_dict["masked_lm_positions"].append(masked_lm_positions) 386 | self.all_doc_feat_dict["masked_lm_ids"].append(masked_lm_ids) 387 | self.all_doc_feat_dict["next_sentence_labels"].append([next_sentence_label]) 388 | 389 | if __name__ == "__main__": 390 | arg_parser = argparse.ArgumentParser(description="Bert / Create Pretraining Data") 391 | arg_parser.add_argument("--input_file", dest="input_file", type=str, 392 | default="./data/ubuntu_corpus_v1/ubuntu_post_training.txt", 393 | help="Input raw text file (or comma-separated list of files).") 394 | arg_parser.add_argument("--output_file", dest="output_file", type=str, 395 | default="./data/ubuntu_corpus_v1/ubuntu_post_training.hdf5", 396 | help="Output example pkl.") 397 | arg_parser.add_argument("--do_lower_case", dest="do_lower_case", type=bool, default=True, 398 | help="Whether to lower case the input text. Should be True for uncased.") 399 | arg_parser.add_argument("--do_whole_word_mask", dest="do_whole_word_mask", type=bool, default=True, 400 | help="Whether to use whole word masking rather than per-WordPiece masking.") 401 | arg_parser.add_argument("--max_seq_length", dest="max_seq_length", type=int, default=512, 402 | help="Maximum sequence length.") 403 | arg_parser.add_argument("--max_predictions_per_seq", dest="max_predictions_per_seq", type=int, default=70, 404 | help="Maximum number of masked LM predictions per sequence.") 405 | arg_parser.add_argument("--random_seed", dest="random_seed", type=int, default=12345, 406 | help="Random seed for data generation.") 407 | arg_parser.add_argument("--dupe_factor", dest="dupe_factor", type=int, default=1, 408 | help="Number of times to duplicate the input data (with different masks).") 409 | arg_parser.add_argument("--masked_lm_prob", dest="masked_lm_prob", type=float, default=0.15, 410 | help="Masked LM probability.") 411 | arg_parser.add_argument("--short_seq_prob", dest="short_seq_prob", type=float, default=0.1, 412 | help="Probability of creating sequences which are shorter than the maximum length.") 413 | arg_parser.add_argument("--special_tok", dest="special_tok", type=str, default="[EOT]", 414 | help="Special Token.") 415 | args = arg_parser.parse_args() 416 | 417 | create_data = CreateBertPretrainingData(args) 418 | 419 | rng = random.Random(args.random_seed) 420 | create_data.create_training_instances( 421 | args.input_file, args.max_seq_length, args.dupe_factor, 422 | args.short_seq_prob, args.masked_lm_prob, args.max_predictions_per_seq, rng, args.special_tok) 423 | -------------------------------------------------------------------------------- /models/pretrained_common/file_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utilities for working with the local dataset cache. 3 | This file is adapted from the AllenNLP library at https://github.com/allenai/allennlp 4 | Copyright by the AllenNLP authors. 5 | """ 6 | 7 | import fnmatch 8 | import json 9 | import logging 10 | import os 11 | import shutil 12 | import sys 13 | import tarfile 14 | import tempfile 15 | from contextlib import contextmanager 16 | from functools import partial, wraps 17 | from hashlib import sha256 18 | from typing import Optional 19 | from urllib.parse import urlparse 20 | from zipfile import ZipFile, is_zipfile 21 | 22 | import requests 23 | from filelock import FileLock 24 | from tqdm.auto import tqdm 25 | 26 | from . import __version__ 27 | 28 | 29 | logger = logging.getLogger(__name__) # pylint: disable=invalid-name 30 | 31 | try: 32 | USE_TF = os.environ.get("USE_TF", "AUTO").upper() 33 | USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper() 34 | if USE_TORCH in ("1", "ON", "YES", "AUTO") and USE_TF not in ("1", "ON", "YES"): 35 | import torch 36 | 37 | _torch_available = True # pylint: disable=invalid-name 38 | logger.info("PyTorch version {} available.".format(torch.__version__)) 39 | else: 40 | logger.info("Disabling PyTorch because USE_TF is set") 41 | _torch_available = False 42 | except ImportError: 43 | _torch_available = False # pylint: disable=invalid-name 44 | 45 | try: 46 | USE_TF = os.environ.get("USE_TF", "AUTO").upper() 47 | USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper() 48 | 49 | if USE_TF in ("1", "ON", "YES", "AUTO") and USE_TORCH not in ("1", "ON", "YES"): 50 | import tensorflow as tf 51 | 52 | assert hasattr(tf, "__version__") and int(tf.__version__[0]) >= 2 53 | _tf_available = True # pylint: disable=invalid-name 54 | logger.info("TensorFlow version {} available.".format(tf.__version__)) 55 | else: 56 | logger.info("Disabling Tensorflow because USE_TORCH is set") 57 | _tf_available = False 58 | except (ImportError, AssertionError): 59 | _tf_available = False # pylint: disable=invalid-name 60 | 61 | try: 62 | from torch.hub import _get_torch_home 63 | 64 | torch_cache_home = _get_torch_home() 65 | except ImportError: 66 | torch_cache_home = os.path.expanduser( 67 | os.getenv("TORCH_HOME", os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "torch")) 68 | ) 69 | default_cache_path = os.path.join(torch_cache_home, "transformers") 70 | 71 | try: 72 | from pathlib import Path 73 | 74 | PYTORCH_PRETRAINED_BERT_CACHE = Path( 75 | os.getenv("PYTORCH_TRANSFORMERS_CACHE", os.getenv("PYTORCH_PRETRAINED_BERT_CACHE", default_cache_path)) 76 | ) 77 | except (AttributeError, ImportError): 78 | PYTORCH_PRETRAINED_BERT_CACHE = os.getenv( 79 | "PYTORCH_TRANSFORMERS_CACHE", os.getenv("PYTORCH_PRETRAINED_BERT_CACHE", default_cache_path) 80 | ) 81 | 82 | PYTORCH_TRANSFORMERS_CACHE = PYTORCH_PRETRAINED_BERT_CACHE # Kept for backward compatibility 83 | TRANSFORMERS_CACHE = PYTORCH_PRETRAINED_BERT_CACHE # Kept for backward compatibility 84 | 85 | WEIGHTS_NAME = "pytorch_model.bin" 86 | TF2_WEIGHTS_NAME = "tf_model.h5" 87 | TF_WEIGHTS_NAME = "model.ckpt" 88 | CONFIG_NAME = "config.json" 89 | MODEL_CARD_NAME = "modelcard.json" 90 | 91 | 92 | MULTIPLE_CHOICE_DUMMY_INPUTS = [[[0], [1]], [[0], [1]]] 93 | DUMMY_INPUTS = [[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]] 94 | DUMMY_MASK = [[1, 1, 1, 1, 1], [1, 1, 1, 0, 0], [0, 0, 0, 1, 1]] 95 | 96 | S3_BUCKET_PREFIX = "https://s3.amazonaws.com/models.huggingface.co/bert" 97 | CLOUDFRONT_DISTRIB_PREFIX = "https://d2ws9o8vfrpkyk.cloudfront.net" 98 | 99 | 100 | def is_torch_available(): 101 | return _torch_available 102 | 103 | 104 | def is_tf_available(): 105 | return _tf_available 106 | 107 | 108 | def add_start_docstrings(*docstr): 109 | def docstring_decorator(fn): 110 | fn.__doc__ = "".join(docstr) + (fn.__doc__ if fn.__doc__ is not None else "") 111 | return fn 112 | 113 | return docstring_decorator 114 | 115 | 116 | def add_start_docstrings_to_callable(*docstr): 117 | def docstring_decorator(fn): 118 | class_name = ":class:`~transformers.{}`".format(fn.__qualname__.split(".")[0]) 119 | intro = " The {} forward method, overrides the :func:`__call__` special method.".format(class_name) 120 | note = r""" 121 | 122 | .. note:: 123 | Although the recipe for forward pass needs to be defined within 124 | this function, one should call the :class:`Module` instance afterwards 125 | instead of this since the former takes care of running the 126 | pre and post processing steps while the latter silently ignores them. 127 | """ 128 | fn.__doc__ = intro + note + "".join(docstr) + (fn.__doc__ if fn.__doc__ is not None else "") 129 | return fn 130 | 131 | return docstring_decorator 132 | 133 | 134 | def add_end_docstrings(*docstr): 135 | def docstring_decorator(fn): 136 | fn.__doc__ = fn.__doc__ + "".join(docstr) 137 | return fn 138 | 139 | return docstring_decorator 140 | 141 | 142 | def is_remote_url(url_or_filename): 143 | parsed = urlparse(url_or_filename) 144 | return parsed.scheme in ("http", "https") 145 | 146 | 147 | def hf_bucket_url(identifier, postfix=None, cdn=False) -> str: 148 | endpoint = CLOUDFRONT_DISTRIB_PREFIX if cdn else S3_BUCKET_PREFIX 149 | if postfix is None: 150 | return "/".join((endpoint, identifier)) 151 | else: 152 | return "/".join((endpoint, identifier, postfix)) 153 | 154 | 155 | def url_to_filename(url, etag=None): 156 | """ 157 | Convert `url` into a hashed filename in a repeatable way. 158 | If `etag` is specified, append its hash to the url's, delimited 159 | by a period. 160 | If the url ends with .h5 (Keras HDF5 weights) adds '.h5' to the name 161 | so that TF 2.0 can identify it as a HDF5 file 162 | (see https://github.com/tensorflow/tensorflow/blob/00fad90125b18b80fe054de1055770cfb8fe4ba3/tensorflow/python/keras/engine/network.py#L1380) 163 | """ 164 | url_bytes = url.encode("utf-8") 165 | url_hash = sha256(url_bytes) 166 | filename = url_hash.hexdigest() 167 | 168 | if etag: 169 | etag_bytes = etag.encode("utf-8") 170 | etag_hash = sha256(etag_bytes) 171 | filename += "." + etag_hash.hexdigest() 172 | 173 | if url.endswith(".h5"): 174 | filename += ".h5" 175 | 176 | return filename 177 | 178 | 179 | def filename_to_url(filename, cache_dir=None): 180 | """ 181 | Return the url and etag (which may be ``None``) stored for `filename`. 182 | Raise ``EnvironmentError`` if `filename` or its stored metadata do not exist. 183 | """ 184 | if cache_dir is None: 185 | cache_dir = TRANSFORMERS_CACHE 186 | if isinstance(cache_dir, Path): 187 | cache_dir = str(cache_dir) 188 | 189 | cache_path = os.path.join(cache_dir, filename) 190 | if not os.path.exists(cache_path): 191 | raise EnvironmentError("file {} not found".format(cache_path)) 192 | 193 | meta_path = cache_path + ".json" 194 | if not os.path.exists(meta_path): 195 | raise EnvironmentError("file {} not found".format(meta_path)) 196 | 197 | with open(meta_path, encoding="utf-8") as meta_file: 198 | metadata = json.load(meta_file) 199 | url = metadata["url"] 200 | etag = metadata["etag"] 201 | 202 | return url, etag 203 | 204 | 205 | def cached_path( 206 | url_or_filename, 207 | cache_dir=None, 208 | force_download=False, 209 | proxies=None, 210 | resume_download=False, 211 | user_agent=None, 212 | extract_compressed_file=False, 213 | force_extract=False, 214 | local_files_only=False, 215 | ) -> Optional[str]: 216 | """ 217 | Given something that might be a URL (or might be a local path), 218 | determine which. If it's a URL, download the file and cache it, and 219 | return the path to the cached file. If it's already a local path, 220 | make sure the file exists and then return the path. 221 | Args: 222 | cache_dir: specify a cache directory to save the file to (overwrite the default cache dir). 223 | force_download: if True, re-dowload the file even if it's already cached in the cache dir. 224 | resume_download: if True, resume the download if incompletly recieved file is found. 225 | user_agent: Optional string or dict that will be appended to the user-agent on remote requests. 226 | extract_compressed_file: if True and the path point to a zip or tar file, extract the compressed 227 | file in a folder along the archive. 228 | force_extract: if True when extract_compressed_file is True and the archive was already extracted, 229 | re-extract the archive and overide the folder where it was extracted. 230 | 231 | Return: 232 | None in case of non-recoverable file (non-existent or inaccessible url + no cache on disk). 233 | Local path (string) otherwise 234 | """ 235 | if cache_dir is None: 236 | cache_dir = TRANSFORMERS_CACHE 237 | if isinstance(url_or_filename, Path): 238 | url_or_filename = str(url_or_filename) 239 | if isinstance(cache_dir, Path): 240 | cache_dir = str(cache_dir) 241 | 242 | if is_remote_url(url_or_filename): 243 | # URL, so get it from the cache (downloading if necessary) 244 | output_path = get_from_cache( 245 | url_or_filename, 246 | cache_dir=cache_dir, 247 | force_download=force_download, 248 | proxies=proxies, 249 | resume_download=resume_download, 250 | user_agent=user_agent, 251 | local_files_only=local_files_only, 252 | ) 253 | elif os.path.exists(url_or_filename): 254 | # File, and it exists. 255 | output_path = url_or_filename 256 | elif urlparse(url_or_filename).scheme == "": 257 | # File, but it doesn't exist. 258 | raise EnvironmentError("file {} not found".format(url_or_filename)) 259 | else: 260 | # Something unknown 261 | raise ValueError("unable to parse {} as a URL or as a local path".format(url_or_filename)) 262 | 263 | if extract_compressed_file: 264 | if not is_zipfile(output_path) and not tarfile.is_tarfile(output_path): 265 | return output_path 266 | 267 | # Path where we extract compressed archives 268 | # We avoid '.' in dir name and add "-extracted" at the end: "./model.zip" => "./model-zip-extracted/" 269 | output_dir, output_file = os.path.split(output_path) 270 | output_extract_dir_name = output_file.replace(".", "-") + "-extracted" 271 | output_path_extracted = os.path.join(output_dir, output_extract_dir_name) 272 | 273 | if os.path.isdir(output_path_extracted) and os.listdir(output_path_extracted) and not force_extract: 274 | return output_path_extracted 275 | 276 | # Prevent parallel extractions 277 | lock_path = output_path + ".lock" 278 | with FileLock(lock_path): 279 | shutil.rmtree(output_path_extracted, ignore_errors=True) 280 | os.makedirs(output_path_extracted) 281 | if is_zipfile(output_path): 282 | with ZipFile(output_path, "r") as zip_file: 283 | zip_file.extractall(output_path_extracted) 284 | zip_file.close() 285 | elif tarfile.is_tarfile(output_path): 286 | tar_file = tarfile.open(output_path) 287 | tar_file.extractall(output_path_extracted) 288 | tar_file.close() 289 | else: 290 | raise EnvironmentError("Archive format of {} could not be identified".format(output_path)) 291 | 292 | return output_path_extracted 293 | 294 | return output_path 295 | 296 | 297 | def http_get(url, temp_file, proxies=None, resume_size=0, user_agent=None): 298 | ua = "transformers/{}; python/{}".format(__version__, sys.version.split()[0]) 299 | if is_torch_available(): 300 | ua += "; torch/{}".format(torch.__version__) 301 | if is_tf_available(): 302 | ua += "; tensorflow/{}".format(tf.__version__) 303 | if isinstance(user_agent, dict): 304 | ua += "; " + "; ".join("{}/{}".format(k, v) for k, v in user_agent.items()) 305 | elif isinstance(user_agent, str): 306 | ua += "; " + user_agent 307 | headers = {"user-agent": ua} 308 | if resume_size > 0: 309 | headers["Range"] = "bytes=%d-" % (resume_size,) 310 | response = requests.get(url, stream=True, proxies=proxies, headers=headers) 311 | if response.status_code == 416: # Range not satisfiable 312 | return 313 | content_length = response.headers.get("Content-Length") 314 | total = resume_size + int(content_length) if content_length is not None else None 315 | progress = tqdm( 316 | unit="B", 317 | unit_scale=True, 318 | total=total, 319 | initial=resume_size, 320 | desc="Downloading", 321 | disable=bool(logger.getEffectiveLevel() == logging.NOTSET), 322 | ) 323 | for chunk in response.iter_content(chunk_size=1024): 324 | if chunk: # filter out keep-alive new chunks 325 | progress.update(len(chunk)) 326 | temp_file.write(chunk) 327 | progress.close() 328 | 329 | 330 | def get_from_cache( 331 | url, 332 | cache_dir=None, 333 | force_download=False, 334 | proxies=None, 335 | etag_timeout=10, 336 | resume_download=False, 337 | user_agent=None, 338 | local_files_only=False, 339 | ) -> Optional[str]: 340 | """ 341 | Given a URL, look for the corresponding file in the local cache. 342 | If it's not there, download it. Then return the path to the cached file. 343 | 344 | Return: 345 | None in case of non-recoverable file (non-existent or inaccessible url + no cache on disk). 346 | Local path (string) otherwise 347 | """ 348 | if cache_dir is None: 349 | cache_dir = TRANSFORMERS_CACHE 350 | if isinstance(cache_dir, Path): 351 | cache_dir = str(cache_dir) 352 | 353 | os.makedirs(cache_dir, exist_ok=True) 354 | 355 | etag = None 356 | if not local_files_only: 357 | try: 358 | response = requests.head(url, allow_redirects=True, proxies=proxies, timeout=etag_timeout) 359 | if response.status_code == 200: 360 | etag = response.headers.get("ETag") 361 | except (EnvironmentError, requests.exceptions.Timeout): 362 | # etag is already None 363 | pass 364 | 365 | filename = url_to_filename(url, etag) 366 | 367 | # get cache path to put the file 368 | cache_path = os.path.join(cache_dir, filename) 369 | 370 | # etag is None = we don't have a connection, or url doesn't exist, or is otherwise inaccessible. 371 | # try to get the last downloaded one 372 | if etag is None: 373 | if os.path.exists(cache_path): 374 | return cache_path 375 | else: 376 | matching_files = [ 377 | file 378 | for file in fnmatch.filter(os.listdir(cache_dir), filename + ".*") 379 | if not file.endswith(".json") and not file.endswith(".lock") 380 | ] 381 | if len(matching_files) > 0: 382 | return os.path.join(cache_dir, matching_files[-1]) 383 | else: 384 | # If files cannot be found and local_files_only=True, 385 | # the models might've been found if local_files_only=False 386 | # Notify the user about that 387 | if local_files_only: 388 | raise ValueError( 389 | "Cannot find the requested files in the cached path and outgoing traffic has been" 390 | " disabled. To enable model look-ups and downloads online, set 'local_files_only'" 391 | " to False." 392 | ) 393 | return None 394 | 395 | # From now on, etag is not None. 396 | if os.path.exists(cache_path) and not force_download: 397 | return cache_path 398 | 399 | # Prevent parallel downloads of the same file with a lock. 400 | lock_path = cache_path + ".lock" 401 | with FileLock(lock_path): 402 | 403 | # If the download just completed while the lock was activated. 404 | if os.path.exists(cache_path) and not force_download: 405 | # Even if returning early like here, the lock will be released. 406 | return cache_path 407 | 408 | if resume_download: 409 | incomplete_path = cache_path + ".incomplete" 410 | 411 | @contextmanager 412 | def _resumable_file_manager(): 413 | with open(incomplete_path, "a+b") as f: 414 | yield f 415 | 416 | temp_file_manager = _resumable_file_manager 417 | if os.path.exists(incomplete_path): 418 | resume_size = os.stat(incomplete_path).st_size 419 | else: 420 | resume_size = 0 421 | else: 422 | temp_file_manager = partial(tempfile.NamedTemporaryFile, dir=cache_dir, delete=False) 423 | resume_size = 0 424 | 425 | # Download to temporary file, then copy to cache dir once finished. 426 | # Otherwise you get corrupt cache entries if the download gets interrupted. 427 | with temp_file_manager() as temp_file: 428 | logger.info("%s not found in cache or force_download set to True, downloading to %s", url, temp_file.name) 429 | 430 | http_get(url, temp_file, proxies=proxies, resume_size=resume_size, user_agent=user_agent) 431 | 432 | logger.info("storing %s in cache at %s", url, cache_path) 433 | os.replace(temp_file.name, cache_path) 434 | 435 | logger.info("creating metadata file for %s", cache_path) 436 | meta = {"url": url, "etag": etag} 437 | meta_path = cache_path + ".json" 438 | with open(meta_path, "w") as meta_file: 439 | json.dump(meta, meta_file) 440 | 441 | return cache_path 442 | 443 | 444 | class cached_property(property): 445 | """ 446 | Descriptor that mimics @property but caches output in member variable. 447 | 448 | From tensorflow_datasets 449 | 450 | Built-in in functools from Python 3.8. 451 | """ 452 | 453 | def __get__(self, obj, objtype=None): 454 | # See docs.python.org/3/howto/descriptor.html#properties 455 | if obj is None: 456 | return self 457 | if self.fget is None: 458 | raise AttributeError("unreadable attribute") 459 | attr = "__cached_" + self.fget.__name__ 460 | cached = getattr(obj, attr, None) 461 | if cached is None: 462 | cached = self.fget(obj) 463 | setattr(obj, attr, cached) 464 | return cached 465 | 466 | 467 | def torch_required(func): 468 | # Chose a different decorator name than in tests so it's clear they are not the same. 469 | @wraps(func) 470 | def wrapper(*args, **kwargs): 471 | if is_torch_available(): 472 | return func(*args, **kwargs) 473 | else: 474 | raise ImportError(f"Method `{func.__name__}` requires PyTorch.") 475 | 476 | return wrapper 477 | 478 | 479 | def tf_required(func): 480 | # Chose a different decorator name than in tests so it's clear they are not the same. 481 | @wraps(func) 482 | def wrapper(*args, **kwargs): 483 | if is_tf_available(): 484 | return func(*args, **kwargs) 485 | else: 486 | raise ImportError(f"Method `{func.__name__}` requires TF.") 487 | 488 | return wrapper 489 | -------------------------------------------------------------------------------- /models/bert/tokenization_bert.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Tokenization classes.""" 16 | 17 | from __future__ import absolute_import, division, print_function, unicode_literals 18 | 19 | import collections 20 | import logging 21 | import os 22 | import unicodedata 23 | from io import open 24 | 25 | from models.pretrained_common.tokenization_utils import PreTrainedTokenizer 26 | 27 | logger = logging.getLogger(__name__) 28 | 29 | VOCAB_FILES_NAMES = {'vocab_file': 'vocab.txt'} 30 | 31 | PRETRAINED_VOCAB_FILES_MAP = { 32 | 'vocab_file': 33 | { 34 | 'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt", 35 | 'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-vocab.txt", 36 | 'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-vocab.txt", 37 | 'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-vocab.txt", 38 | 'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-vocab.txt", 39 | 'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-vocab.txt", 40 | 'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-vocab.txt", 41 | 'bert-base-german-cased': "https://int-deepset-models-bert.s3.eu-central-1.amazonaws.com/pytorch/bert-base-german-cased-vocab.txt", 42 | 'bert-large-uncased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-vocab.txt", 43 | 'bert-large-cased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-vocab.txt", 44 | 'bert-large-uncased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-finetuned-squad-vocab.txt", 45 | 'bert-large-cased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad-vocab.txt", 46 | 'bert-base-cased-finetuned-mrpc': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-vocab.txt", 47 | } 48 | } 49 | 50 | PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { 51 | 'bert-base-uncased': 512, 52 | 'bert-large-uncased': 512, 53 | 'bert-base-cased': 512, 54 | 'bert-large-cased': 512, 55 | 'bert-base-multilingual-uncased': 512, 56 | 'bert-base-multilingual-cased': 512, 57 | 'bert-base-chinese': 512, 58 | 'bert-base-german-cased': 512, 59 | 'bert-large-uncased-whole-word-masking': 512, 60 | 'bert-large-cased-whole-word-masking': 512, 61 | 'bert-large-uncased-whole-word-masking-finetuned-squad': 512, 62 | 'bert-large-cased-whole-word-masking-finetuned-squad': 512, 63 | 'bert-base-cased-finetuned-mrpc': 512, 64 | } 65 | 66 | def load_vocab(vocab_file): 67 | """Loads a vocabulary file into a dictionary.""" 68 | vocab = collections.OrderedDict() 69 | with open(vocab_file, "r", encoding="utf-8") as reader: 70 | tokens = reader.readlines() 71 | for index, token in enumerate(tokens): 72 | token = token.rstrip('\n') 73 | vocab[token] = index 74 | return vocab 75 | 76 | 77 | def whitespace_tokenize(text): 78 | """Runs basic whitespace cleaning and splitting on a piece of text.""" 79 | text = text.strip() 80 | if not text: 81 | return [] 82 | tokens = text.split() 83 | return tokens 84 | 85 | 86 | class BertTokenizer(PreTrainedTokenizer): 87 | r""" 88 | Constructs a BertTokenizer. 89 | :class:`~pytorch_pretrained_bert.BertTokenizer` runs end-to-end tokenization: punctuation splitting + wordpiece 90 | Args: 91 | vocab_file: Path to a one-wordpiece-per-line vocabulary file 92 | do_lower_case: Whether to lower case the input. Only has an effect when do_wordpiece_only=False 93 | do_basic_tokenize: Whether to do basic tokenization before wordpiece. 94 | max_len: An artificial maximum length to truncate tokenized sequences to; Effective maximum length is always the 95 | minimum of this value (if specified) and the underlying BERT model's sequence length. 96 | never_split: List of tokens which will never be split during tokenization. Only has an effect when 97 | do_wordpiece_only=False 98 | """ 99 | 100 | vocab_files_names = VOCAB_FILES_NAMES 101 | pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP 102 | max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES 103 | 104 | def __init__(self, vocab_file, do_lower_case=True, do_basic_tokenize=True, never_split=None, 105 | unk_token="[UNK]", sep_token="[SEP]", pad_token="[PAD]", cls_token="[CLS]", 106 | mask_token="[MASK]", tokenize_chinese_chars=True, **kwargs): 107 | """Constructs a BertTokenizer. 108 | Args: 109 | **vocab_file**: Path to a one-wordpiece-per-line vocabulary file 110 | **do_lower_case**: (`optional`) boolean (default True) 111 | Whether to lower case the input 112 | Only has an effect when do_basic_tokenize=True 113 | **do_basic_tokenize**: (`optional`) boolean (default True) 114 | Whether to do basic tokenization before wordpiece. 115 | **never_split**: (`optional`) list of string 116 | List of tokens which will never be split during tokenization. 117 | Only has an effect when do_basic_tokenize=True 118 | **tokenize_chinese_chars**: (`optional`) boolean (default True) 119 | Whether to tokenize Chinese characters. 120 | This should likely be desactivated for Japanese: 121 | see: https://github.com/huggingface/pytorch-pretrained-BERT/issues/328 122 | """ 123 | super(BertTokenizer, self).__init__(unk_token=unk_token, sep_token=sep_token, 124 | pad_token=pad_token, cls_token=cls_token, 125 | mask_token=mask_token, **kwargs) 126 | 127 | if not os.path.isfile(vocab_file): 128 | raise ValueError( 129 | "Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained " 130 | "model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`".format(vocab_file)) 131 | print(vocab_file) 132 | self.vocab = load_vocab(vocab_file) 133 | self.ids_to_tokens = collections.OrderedDict( 134 | [(ids, tok) for tok, ids in self.vocab.items()]) 135 | self.do_basic_tokenize = do_basic_tokenize 136 | if do_basic_tokenize: 137 | self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case, 138 | never_split=never_split, 139 | tokenize_chinese_chars=tokenize_chinese_chars) 140 | self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab, unk_token=self.unk_token) 141 | 142 | @property 143 | def vocab_size(self): 144 | return len(self.vocab) 145 | 146 | def _tokenize(self, text): 147 | split_tokens = [] 148 | if self.do_basic_tokenize: 149 | for token in self.basic_tokenizer.tokenize(text, never_split=self.all_special_tokens): 150 | for sub_token in self.wordpiece_tokenizer.tokenize(token): 151 | split_tokens.append(sub_token) 152 | else: 153 | split_tokens = self.wordpiece_tokenizer.tokenize(text) 154 | return split_tokens 155 | 156 | def _convert_token_to_id(self, token): 157 | """ Converts a token (str/unicode) in an id using the vocab. """ 158 | return self.vocab.get(token, self.vocab.get(self.unk_token)) 159 | 160 | def _convert_id_to_token(self, index): 161 | """Converts an index (integer) in a token (string/unicode) using the vocab.""" 162 | return self.ids_to_tokens.get(index, self.unk_token) 163 | 164 | def convert_tokens_to_string(self, tokens): 165 | """ Converts a sequence of tokens (string) in a single string. """ 166 | out_string = ' '.join(tokens).replace(' ##', '').strip() 167 | return out_string 168 | 169 | def save_vocabulary(self, vocab_path): 170 | """Save the tokenizer vocabulary to a directory or file.""" 171 | index = 0 172 | if os.path.isdir(vocab_path): 173 | vocab_file = os.path.join(vocab_path, VOCAB_FILES_NAMES['vocab_file']) 174 | with open(vocab_file, "w", encoding="utf-8") as writer: 175 | for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]): 176 | if index != token_index: 177 | logger.warning("Saving vocabulary to {}: vocabulary indices are not consecutive." 178 | " Please check that the vocabulary is not corrupted!".format(vocab_file)) 179 | index = token_index 180 | writer.write(token + u'\n') 181 | index += 1 182 | return (vocab_file,) 183 | 184 | @classmethod 185 | def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs): 186 | """ Instantiate a BertTokenizer from pre-trained vocabulary files. 187 | """ 188 | if pretrained_model_name_or_path in PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES: 189 | if '-cased' in pretrained_model_name_or_path and kwargs.get('do_lower_case', True): 190 | logger.warning("The pre-trained model you are loading is a cased model but you have not set " 191 | "`do_lower_case` to False. We are setting `do_lower_case=False` for you but " 192 | "you may want to check this behavior.") 193 | kwargs['do_lower_case'] = False 194 | elif '-cased' not in pretrained_model_name_or_path and not kwargs.get('do_lower_case', True): 195 | logger.warning("The pre-trained model you are loading is an uncased model but you have set " 196 | "`do_lower_case` to False. We are setting `do_lower_case=True` for you " 197 | "but you may want to check this behavior.") 198 | kwargs['do_lower_case'] = True 199 | 200 | return super(BertTokenizer, cls)._from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) 201 | 202 | 203 | class BasicTokenizer(object): 204 | """Runs basic tokenization (punctuation splitting, lower casing, etc.).""" 205 | 206 | def __init__(self, do_lower_case=True, never_split=None, tokenize_chinese_chars=True): 207 | """ Constructs a BasicTokenizer. 208 | Args: 209 | **do_lower_case**: Whether to lower case the input. 210 | **never_split**: (`optional`) list of str 211 | Kept for backward compatibility purposes. 212 | Now implemented directly at the base class level (see :func:`PreTrainedTokenizer.tokenize`) 213 | List of token not to split. 214 | **tokenize_chinese_chars**: (`optional`) boolean (default True) 215 | Whether to tokenize Chinese characters. 216 | This should likely be desactivated for Japanese: 217 | see: https://github.com/huggingface/pytorch-pretrained-BERT/issues/328 218 | """ 219 | if never_split is None: 220 | never_split = [] 221 | self.do_lower_case = do_lower_case 222 | self.never_split = never_split 223 | self.tokenize_chinese_chars = tokenize_chinese_chars 224 | 225 | def tokenize(self, text, never_split=None): 226 | """ Basic Tokenization of a piece of text. 227 | Split on "white spaces" only, for sub-word tokenization, see WordPieceTokenizer. 228 | Args: 229 | **never_split**: (`optional`) list of str 230 | Kept for backward compatibility purposes. 231 | Now implemented directly at the base class level (see :func:`PreTrainedTokenizer.tokenize`) 232 | List of token not to split. 233 | """ 234 | never_split = self.never_split + (never_split if never_split is not None else []) 235 | text = self._clean_text(text) 236 | # This was added on November 1st, 2018 for the multilingual and Chinese 237 | # models. This is also applied to the English models now, but it doesn't 238 | # matter since the English models were not trained on any Chinese data 239 | # and generally don't have any Chinese data in them (there are Chinese 240 | # characters in the vocabulary because Wikipedia does have some Chinese 241 | # words in the English Wikipedia.). 242 | if self.tokenize_chinese_chars: 243 | text = self._tokenize_chinese_chars(text) 244 | orig_tokens = whitespace_tokenize(text) 245 | split_tokens = [] 246 | for token in orig_tokens: 247 | if self.do_lower_case and token not in never_split: 248 | token = token.lower() 249 | token = self._run_strip_accents(token) 250 | split_tokens.extend(self._run_split_on_punc(token)) 251 | 252 | output_tokens = whitespace_tokenize(" ".join(split_tokens)) 253 | return output_tokens 254 | 255 | def _run_strip_accents(self, text): 256 | """Strips accents from a piece of text.""" 257 | text = unicodedata.normalize("NFD", text) 258 | output = [] 259 | for char in text: 260 | cat = unicodedata.category(char) 261 | if cat == "Mn": 262 | continue 263 | output.append(char) 264 | return "".join(output) 265 | 266 | def _run_split_on_punc(self, text, never_split=None): 267 | """Splits punctuation on a piece of text.""" 268 | if never_split is not None and text in never_split: 269 | return [text] 270 | chars = list(text) 271 | i = 0 272 | start_new_word = True 273 | output = [] 274 | while i < len(chars): 275 | char = chars[i] 276 | if _is_punctuation(char): 277 | output.append([char]) 278 | start_new_word = True 279 | else: 280 | if start_new_word: 281 | output.append([]) 282 | start_new_word = False 283 | output[-1].append(char) 284 | i += 1 285 | 286 | return ["".join(x) for x in output] 287 | 288 | def _tokenize_chinese_chars(self, text): 289 | """Adds whitespace around any CJK character.""" 290 | output = [] 291 | for char in text: 292 | cp = ord(char) 293 | if self._is_chinese_char(cp): 294 | output.append(" ") 295 | output.append(char) 296 | output.append(" ") 297 | else: 298 | output.append(char) 299 | return "".join(output) 300 | 301 | def _is_chinese_char(self, cp): 302 | """Checks whether CP is the codepoint of a CJK character.""" 303 | # This defines a "chinese character" as anything in the CJK Unicode block: 304 | # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) 305 | # 306 | # Note that the CJK Unicode block is NOT all Japanese and Korean characters, 307 | # despite its name. The modern Korean Hangul alphabet is a different block, 308 | # as is Japanese Hiragana and Katakana. Those alphabets are used to write 309 | # space-separated words, so they are not treated specially and handled 310 | # like the all of the other languages. 311 | if ((cp >= 0x4E00 and cp <= 0x9FFF) or # 312 | (cp >= 0x3400 and cp <= 0x4DBF) or # 313 | (cp >= 0x20000 and cp <= 0x2A6DF) or # 314 | (cp >= 0x2A700 and cp <= 0x2B73F) or # 315 | (cp >= 0x2B740 and cp <= 0x2B81F) or # 316 | (cp >= 0x2B820 and cp <= 0x2CEAF) or 317 | (cp >= 0xF900 and cp <= 0xFAFF) or # 318 | (cp >= 0x2F800 and cp <= 0x2FA1F)): # 319 | return True 320 | 321 | return False 322 | 323 | def _clean_text(self, text): 324 | """Performs invalid character removal and whitespace cleanup on text.""" 325 | output = [] 326 | for char in text: 327 | cp = ord(char) 328 | if cp == 0 or cp == 0xfffd or _is_control(char): 329 | continue 330 | if _is_whitespace(char): 331 | output.append(" ") 332 | else: 333 | output.append(char) 334 | return "".join(output) 335 | 336 | 337 | class WordpieceTokenizer(object): 338 | """Runs WordPiece tokenization.""" 339 | 340 | def __init__(self, vocab, unk_token, max_input_chars_per_word=100): 341 | self.vocab = vocab 342 | self.unk_token = unk_token 343 | self.max_input_chars_per_word = max_input_chars_per_word 344 | 345 | def tokenize(self, text): 346 | """Tokenizes a piece of text into its word pieces. 347 | This uses a greedy longest-match-first algorithm to perform tokenization 348 | using the given vocabulary. 349 | For example: 350 | input = "unaffable" 351 | output = ["un", "##aff", "##able"] 352 | Args: 353 | text: A single token or whitespace separated tokens. This should have 354 | already been passed through `BasicTokenizer`. 355 | Returns: 356 | A list of wordpiece tokens. 357 | """ 358 | 359 | output_tokens = [] 360 | for token in whitespace_tokenize(text): 361 | chars = list(token) 362 | if len(chars) > self.max_input_chars_per_word: 363 | output_tokens.append(self.unk_token) 364 | continue 365 | 366 | is_bad = False 367 | start = 0 368 | sub_tokens = [] 369 | while start < len(chars): 370 | end = len(chars) 371 | cur_substr = None 372 | while start < end: 373 | substr = "".join(chars[start:end]) 374 | if start > 0: 375 | substr = "##" + substr 376 | if substr in self.vocab: 377 | cur_substr = substr 378 | break 379 | end -= 1 380 | if cur_substr is None: 381 | is_bad = True 382 | break 383 | sub_tokens.append(cur_substr) 384 | start = end 385 | 386 | if is_bad: 387 | output_tokens.append(self.unk_token) 388 | else: 389 | output_tokens.extend(sub_tokens) 390 | return output_tokens 391 | 392 | 393 | def _is_whitespace(char): 394 | """Checks whether `chars` is a whitespace character.""" 395 | # \t, \n, and \r are technically contorl characters but we treat them 396 | # as whitespace since they are generally considered as such. 397 | if char == " " or char == "\t" or char == "\n" or char == "\r": 398 | return True 399 | cat = unicodedata.category(char) 400 | if cat == "Zs": 401 | return True 402 | return False 403 | 404 | 405 | def _is_control(char): 406 | """Checks whether `chars` is a control character.""" 407 | # These are technically control characters but we count them as whitespace 408 | # characters. 409 | if char == "\t" or char == "\n" or char == "\r": 410 | return False 411 | cat = unicodedata.category(char) 412 | if cat.startswith("C"): 413 | return True 414 | return False 415 | 416 | 417 | def _is_punctuation(char): 418 | """Checks whether `chars` is a punctuation character.""" 419 | cp = ord(char) 420 | # We treat all non-letter/number ASCII as punctuation. 421 | # Characters such as "^", "$", and "`" are not in the Unicode 422 | # Punctuation class but we treat them as punctuation anyways, for 423 | # consistency. 424 | if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or 425 | (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)): 426 | return True 427 | cat = unicodedata.category(char) 428 | if cat.startswith("P"): 429 | return True 430 | return False -------------------------------------------------------------------------------- /models/pretrained_common/tokenization_utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Open AI Team Authors and The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Tokenization classes for OpenAI GPT.""" 16 | from __future__ import (absolute_import, division, print_function, 17 | unicode_literals) 18 | 19 | import logging 20 | import os 21 | import json 22 | import six 23 | from io import open 24 | 25 | from models.pretrained_common.file_utils import cached_path 26 | 27 | logger = logging.getLogger(__name__) 28 | 29 | SPECIAL_TOKENS_MAP_FILE = 'special_tokens_map.json' 30 | ADDED_TOKENS_FILE = 'added_tokens.json' 31 | 32 | class PreTrainedTokenizer(object): 33 | """ An abstract class to handle dowloading and loading pretrained tokenizers and adding tokens to the vocabulary. 34 | 35 | Derived class can set up a few special tokens to be used in common scripts and internals: 36 | bos_token, eos_token, EOP_TOKEN, EOD_TOKEN, unk_token, sep_token, pad_token, cls_token, mask_token 37 | additional_special_tokens = [] 38 | 39 | We defined an added_tokens_encoder to add new tokens to the vocabulary without having to handle the 40 | specific vocabulary augmentation methods of the various underlying dictionnary structures (BPE, sentencepiece...). 41 | """ 42 | vocab_files_names = {} 43 | pretrained_vocab_files_map = {} 44 | max_model_input_sizes = {} 45 | 46 | SPECIAL_TOKENS_ATTRIBUTES = ["bos_token", "eos_token", "unk_token", "sep_token", 47 | "pad_token", "cls_token", "mask_token", 48 | "additional_special_tokens"] 49 | 50 | @property 51 | def bos_token(self): 52 | if self._bos_token is None: 53 | logger.error("Using bos_token, but it is not set yet.") 54 | return self._bos_token 55 | 56 | @property 57 | def eos_token(self): 58 | if self._eos_token is None: 59 | logger.error("Using eos_token, but it is not set yet.") 60 | return self._eos_token 61 | 62 | @property 63 | def unk_token(self): 64 | if self._unk_token is None: 65 | logger.error("Using unk_token, but it is not set yet.") 66 | return self._unk_token 67 | 68 | @property 69 | def sep_token(self): 70 | if self._sep_token is None: 71 | logger.error("Using sep_token, but it is not set yet.") 72 | return self._sep_token 73 | 74 | @property 75 | def pad_token(self): 76 | if self._pad_token is None: 77 | logger.error("Using pad_token, but it is not set yet.") 78 | return self._pad_token 79 | 80 | @property 81 | def cls_token(self): 82 | if self._cls_token is None: 83 | logger.error("Using cls_token, but it is not set yet.") 84 | return self._cls_token 85 | 86 | @property 87 | def mask_token(self): 88 | if self._mask_token is None: 89 | logger.error("Using mask_token, but it is not set yet.") 90 | return self._mask_token 91 | 92 | @property 93 | def additional_special_tokens(self): 94 | if self._additional_special_tokens is None: 95 | logger.error("Using additional_special_tokens, but it is not set yet.") 96 | return self._additional_special_tokens 97 | 98 | @bos_token.setter 99 | def bos_token(self, value): 100 | self._bos_token = value 101 | 102 | @eos_token.setter 103 | def eos_token(self, value): 104 | self._eos_token = value 105 | 106 | @unk_token.setter 107 | def unk_token(self, value): 108 | self._unk_token = value 109 | 110 | @sep_token.setter 111 | def sep_token(self, value): 112 | self._sep_token = value 113 | 114 | @pad_token.setter 115 | def pad_token(self, value): 116 | self._pad_token = value 117 | 118 | @cls_token.setter 119 | def cls_token(self, value): 120 | self._cls_token = value 121 | 122 | @mask_token.setter 123 | def mask_token(self, value): 124 | self._mask_token = value 125 | 126 | @additional_special_tokens.setter 127 | def additional_special_tokens(self, value): 128 | self._additional_special_tokens = value 129 | 130 | def __init__(self, max_len=None, **kwargs): 131 | self._bos_token = None 132 | self._eos_token = None 133 | self._unk_token = None 134 | self._sep_token = None 135 | self._pad_token = None 136 | self._cls_token = None 137 | self._mask_token = None 138 | self._additional_special_tokens = [] 139 | 140 | self.max_len = max_len if max_len is not None else int(1e12) 141 | self.added_tokens_encoder = {} 142 | self.added_tokens_decoder = {} 143 | 144 | for key, value in kwargs.items(): 145 | if key in self.SPECIAL_TOKENS_ATTRIBUTES: 146 | setattr(self, key, value) 147 | 148 | 149 | @classmethod 150 | def from_pretrained(cls, *inputs, **kwargs): 151 | return cls._from_pretrained(*inputs, **kwargs) 152 | 153 | 154 | @classmethod 155 | def _from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs): 156 | """ 157 | Instantiate a PreTrainedTokenizer from pre-trained vocabulary files. 158 | Download and cache the vocabulary files if needed. 159 | """ 160 | s3_models = list(cls.max_model_input_sizes.keys()) 161 | vocab_files = {} 162 | if pretrained_model_name_or_path in s3_models: 163 | # Get the vocabulary from AWS S3 bucket 164 | for file_id, map_list in cls.pretrained_vocab_files_map.items(): 165 | vocab_files[file_id] = map_list[pretrained_model_name_or_path] 166 | else: 167 | # Get the vocabulary from local files 168 | logger.info( 169 | "Model name '{}' not found in model shortcut name list ({}). " 170 | "Assuming '{}' is a path or url to a directory containing tokenizer files.".format( 171 | pretrained_model_name_or_path, ', '.join(s3_models), 172 | pretrained_model_name_or_path)) 173 | 174 | # Look for the tokenizer main vocabulary files 175 | for file_id, file_name in cls.vocab_files_names.items(): 176 | if os.path.isdir(pretrained_model_name_or_path): 177 | # If a directory is provided we look for the standard filenames 178 | full_file_name = os.path.join(pretrained_model_name_or_path, file_name) 179 | else: 180 | # If a path to a file is provided we use it (will only work for non-BPE tokenizer using a single vocabulary file) 181 | full_file_name = pretrained_model_name_or_path 182 | if not os.path.exists(full_file_name): 183 | logger.info("Didn't find file {}. We won't load it.".format(full_file_name)) 184 | full_file_name = None 185 | vocab_files[file_id] = full_file_name 186 | 187 | # Look for the additional tokens files 188 | all_vocab_files_names = {'added_tokens_file': ADDED_TOKENS_FILE, 189 | 'special_tokens_map_file': SPECIAL_TOKENS_MAP_FILE} 190 | 191 | # If a path to a file was provided, get the parent directory 192 | saved_directory = pretrained_model_name_or_path 193 | if os.path.exists(saved_directory) and not os.path.isdir(saved_directory): 194 | saved_directory = os.path.dirname(saved_directory) 195 | 196 | for file_id, file_name in all_vocab_files_names.items(): 197 | full_file_name = os.path.join(saved_directory, file_name) 198 | if not os.path.exists(full_file_name): 199 | logger.info("Didn't find file {}. We won't load it.".format(full_file_name)) 200 | full_file_name = None 201 | vocab_files[file_id] = full_file_name 202 | 203 | if all(full_file_name is None for full_file_name in vocab_files.values()): 204 | logger.error( 205 | "Model name '{}' was not found in model name list ({}). " 206 | "We assumed '{}' was a path or url but couldn't find tokenizer files" 207 | "at this path or url.".format( 208 | pretrained_model_name_or_path, ', '.join(s3_models), 209 | pretrained_model_name_or_path, )) 210 | return None 211 | 212 | # Get files from url, cache, or disk depending on the case 213 | try: 214 | resolved_vocab_files = {} 215 | for file_id, file_path in vocab_files.items(): 216 | if file_path is None: 217 | resolved_vocab_files[file_id] = None 218 | else: 219 | resolved_vocab_files[file_id] = cached_path(file_path, cache_dir=cache_dir) 220 | except EnvironmentError: 221 | if pretrained_model_name_or_path in s3_models: 222 | logger.error("Couldn't reach server to download vocabulary.") 223 | else: 224 | logger.error( 225 | "Model name '{}' was not found in model name list ({}). " 226 | "We assumed '{}' was a path or url but couldn't find files {} " 227 | "at this path or url.".format( 228 | pretrained_model_name_or_path, ', '.join(s3_models), 229 | pretrained_model_name_or_path, str(vocab_files.keys()))) 230 | return None 231 | 232 | for file_id, file_path in vocab_files.items(): 233 | if file_path == resolved_vocab_files[file_id]: 234 | logger.info("loading file {}".format(file_path)) 235 | else: 236 | logger.info("loading file {} from cache at {}".format( 237 | file_path, resolved_vocab_files[file_id])) 238 | 239 | # Set max length if needed 240 | if pretrained_model_name_or_path in cls.max_model_input_sizes: 241 | # if we're using a pretrained model, ensure the tokenizer 242 | # wont index sequences longer than the number of positional embeddings 243 | max_len = cls.max_model_input_sizes[pretrained_model_name_or_path] 244 | if max_len is not None and isinstance(max_len, (int, float)): 245 | kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len) 246 | 247 | # Merge resolved_vocab_files arguments in kwargs. 248 | added_tokens_file = resolved_vocab_files.pop('added_tokens_file', None) 249 | special_tokens_map_file = resolved_vocab_files.pop('special_tokens_map_file', None) 250 | for args_name, file_path in resolved_vocab_files.items(): 251 | if args_name not in kwargs: 252 | kwargs[args_name] = file_path 253 | if special_tokens_map_file is not None: 254 | special_tokens_map = json.load(open(special_tokens_map_file, encoding="utf-8")) 255 | for key, value in special_tokens_map.items(): 256 | if key not in kwargs: 257 | kwargs[key] = value 258 | 259 | # Instantiate tokenizer. 260 | tokenizer = cls(*inputs, **kwargs) 261 | 262 | # Add supplementary tokens. 263 | if added_tokens_file is not None: 264 | added_tok_encoder = json.load(open(added_tokens_file, encoding="utf-8")) 265 | added_tok_decoder = {v:k for k, v in added_tok_encoder.items()} 266 | tokenizer.added_tokens_encoder.update(added_tok_encoder) 267 | tokenizer.added_tokens_decoder.update(added_tok_decoder) 268 | 269 | return tokenizer 270 | 271 | 272 | def save_pretrained(self, save_directory): 273 | """ Save the tokenizer vocabulary files (with added tokens) and the 274 | special-tokens-to-class-attributes-mapping to a directory, so that it 275 | can be re-loaded using the `from_pretrained(save_directory)` class method. 276 | """ 277 | if not os.path.isdir(save_directory): 278 | logger.error("Saving directory ({}) should be a directory".format(save_directory)) 279 | return 280 | 281 | special_tokens_map_file = os.path.join(save_directory, SPECIAL_TOKENS_MAP_FILE) 282 | added_tokens_file = os.path.join(save_directory, ADDED_TOKENS_FILE) 283 | 284 | with open(special_tokens_map_file, 'w', encoding='utf-8') as f: 285 | f.write(json.dumps(self.special_tokens_map, ensure_ascii=False)) 286 | 287 | with open(added_tokens_file, 'w', encoding='utf-8') as f: 288 | if self.added_tokens_encoder: 289 | out_str = json.dumps(self.added_tokens_encoder, ensure_ascii=False) 290 | else: 291 | out_str = u"{}" 292 | f.write(out_str) 293 | 294 | vocab_files = self.save_vocabulary(save_directory) 295 | 296 | return vocab_files + (special_tokens_map_file, added_tokens_file) 297 | 298 | 299 | def save_vocabulary(self, save_directory): 300 | """ Save the tokenizer vocabulary to a directory. This method doesn't save added tokens 301 | and special token mappings. 302 | 303 | Please use `save_pretrained()` to save the full Tokenizer state so that it can be 304 | reloaded using the `from_pretrained(save_directory)` class method. 305 | """ 306 | raise NotImplementedError 307 | 308 | 309 | def vocab_size(self): 310 | raise NotImplementedError 311 | 312 | 313 | def __len__(self): 314 | return self.vocab_size + len(self.added_tokens_encoder) 315 | 316 | 317 | def add_tokens(self, new_tokens): 318 | """ Add a list of new tokens to the tokenizer class. If the new tokens are not in the 319 | vocabulary, they are added to the added_tokens_encoder with indices starting from 320 | the last index of the current vocabulary. 321 | 322 | Returns: 323 | Number of tokens added to the vocabulary which can be used to correspondingly 324 | increase the size of the associated model embedding matrices. 325 | """ 326 | if not new_tokens: 327 | return 0 328 | 329 | to_add_tokens = [] 330 | for token in new_tokens: 331 | if self.convert_tokens_to_ids(token) == self.convert_tokens_to_ids(self.unk_token): 332 | to_add_tokens.append(token) 333 | logger.info("Adding %s to the vocabulary", token) 334 | 335 | added_tok_encoder = dict((tok, len(self) + i) for i, tok in enumerate(to_add_tokens)) 336 | added_tok_decoder = {v:k for k, v in added_tok_encoder.items()} 337 | self.added_tokens_encoder.update(added_tok_encoder) 338 | self.added_tokens_decoder.update(added_tok_decoder) 339 | 340 | return len(to_add_tokens) 341 | 342 | 343 | def add_special_tokens(self, special_tokens_dict): 344 | """ Add a dictionnary of special tokens (eos, pad, cls...) to the encoder and link them 345 | to class attributes. If the special tokens are not in the vocabulary, they are added 346 | to it and indexed starting from the last index of the current vocabulary. 347 | 348 | Returns: 349 | Number of tokens added to the vocabulary which can be used to correspondingly 350 | increase the size of the associated model embedding matrices. 351 | """ 352 | if not special_tokens_dict: 353 | return 0 354 | 355 | added_special_tokens = self.add_tokens(special_tokens_dict.values()) 356 | for key, value in special_tokens_dict.items(): 357 | logger.info("Assigning %s to the %s key of the tokenizer", value, key) 358 | setattr(self, key, value) 359 | 360 | return added_special_tokens 361 | 362 | 363 | def tokenize(self, text, **kwargs): 364 | """ Converts a string in a sequence of tokens (string), using the tokenizer. 365 | Split in words for word-based vocabulary or sub-words for sub-word-based 366 | vocabularies (BPE/SentencePieces/WordPieces). 367 | 368 | Take care of added tokens. 369 | """ 370 | def split_on_tokens(tok_list, text): 371 | if not text: 372 | return [] 373 | if not tok_list: 374 | return self._tokenize(text, **kwargs) 375 | tok = tok_list[0] 376 | split_text = text.split(tok) 377 | return sum((split_on_tokens(tok_list[1:], sub_text.strip()) + [tok] \ 378 | for sub_text in split_text), [])[:-1] 379 | 380 | added_tokens = list(self.added_tokens_encoder.keys()) + self.all_special_tokens 381 | tokenized_text = split_on_tokens(added_tokens, text) 382 | return tokenized_text 383 | 384 | def _tokenize(self, text, **kwargs): 385 | """ Converts a string in a sequence of tokens (string), using the tokenizer. 386 | Split in words for word-based vocabulary or sub-words for sub-word-based 387 | vocabularies (BPE/SentencePieces/WordPieces). 388 | 389 | Don't take care of added tokens. 390 | """ 391 | raise NotImplementedError 392 | 393 | def convert_tokens_to_ids(self, tokens): 394 | """ Converts a single token or a sequence of tokens (str/unicode) in a integer id 395 | (resp.) a sequence of ids, using the vocabulary. 396 | """ 397 | if isinstance(tokens, str) or (six.PY2 and isinstance(tokens, unicode)): 398 | return self._convert_token_to_id_with_added_voc(tokens) 399 | 400 | ids = [] 401 | for token in tokens: 402 | ids.append(self._convert_token_to_id_with_added_voc(token)) 403 | if len(ids) > self.max_len: 404 | logger.warning("Token indices sequence length is longer than the specified maximum sequence length " 405 | "for this model ({} > {}). Running this sequence through the model will result in " 406 | "indexing errors".format(len(ids), self.max_len)) 407 | return ids 408 | 409 | def _convert_token_to_id_with_added_voc(self, token): 410 | if token in self.added_tokens_encoder: 411 | return self.added_tokens_encoder[token] 412 | return self._convert_token_to_id(token) 413 | 414 | def _convert_token_to_id(self, token): 415 | raise NotImplementedError 416 | 417 | 418 | def encode(self, text): 419 | """ Converts a string in a sequence of ids (integer), using the tokenizer and vocabulary. 420 | same as self.convert_tokens_to_ids(self.tokenize(text)). 421 | """ 422 | return self.convert_tokens_to_ids(self.tokenize(text)) 423 | 424 | 425 | def convert_ids_to_tokens(self, ids, skip_special_tokens=False): 426 | """ Converts a single index or a sequence of indices (integers) in a token " 427 | (resp.) a sequence of tokens (str/unicode), using the vocabulary and added tokens. 428 | 429 | Args: 430 | skip_special_tokens: Don't decode special tokens (self.all_special_tokens). Default: False 431 | """ 432 | if isinstance(ids, int): 433 | if ids in self.added_tokens_decoder: 434 | return self.added_tokens_decoder[ids] 435 | else: 436 | return self._convert_id_to_token(ids) 437 | tokens = [] 438 | for index in ids: 439 | if index in self.all_special_ids and skip_special_tokens: 440 | continue 441 | if index in self.added_tokens_decoder: 442 | tokens.append(self.added_tokens_decoder[index]) 443 | else: 444 | tokens.append(self._convert_id_to_token(index)) 445 | return tokens 446 | 447 | def _convert_id_to_token(self, index): 448 | raise NotImplementedError 449 | 450 | def convert_tokens_to_string(self, tokens): 451 | """ Converts a sequence of tokens (string) in a single string. 452 | The most simple way to do it is ' '.join(self.convert_ids_to_tokens(token_ids)) 453 | but we often want to remove sub-word tokenization artifacts at the same time. 454 | """ 455 | return ' '.join(self.convert_ids_to_tokens(tokens)) 456 | 457 | def decode(self, token_ids, skip_special_tokens=False, clean_up_tokenization_spaces=True): 458 | """ Converts a sequence of ids (integer) in a string, using the tokenizer and vocabulary 459 | with options to remove special tokens and clean up tokenization spaces. 460 | """ 461 | filtered_tokens = self.convert_ids_to_tokens(token_ids, skip_special_tokens=skip_special_tokens) 462 | text = self.convert_tokens_to_string(filtered_tokens) 463 | if clean_up_tokenization_spaces: 464 | text = clean_up_tokenization(text) 465 | return text 466 | 467 | @property 468 | def special_tokens_map(self): 469 | """ A dictionary mapping special token class attribute (cls_token, unk_token...) to their 470 | values ('', ''...) 471 | """ 472 | set_attr = {} 473 | for attr in self.SPECIAL_TOKENS_ATTRIBUTES: 474 | attr_value = getattr(self, "_" + attr) 475 | if attr_value: 476 | set_attr[attr] = attr_value 477 | return set_attr 478 | 479 | @property 480 | def all_special_tokens(self): 481 | """ List all the special tokens ('', ''...) mapped to class attributes 482 | (cls_token, unk_token...). 483 | """ 484 | all_toks = [] 485 | set_attr = self.special_tokens_map 486 | for attr_value in set_attr.values(): 487 | all_toks = all_toks + (attr_value if isinstance(attr_value, (list, tuple)) else [attr_value]) 488 | all_toks = list(set(all_toks)) 489 | return all_toks 490 | 491 | @property 492 | def all_special_ids(self): 493 | """ List the vocabulary indices of the special tokens ('', ''...) mapped to 494 | class attributes (cls_token, unk_token...). 495 | """ 496 | all_toks = self.all_special_tokens 497 | all_ids = list(self.convert_tokens_to_ids(t) for t in all_toks) 498 | return all_ids 499 | 500 | 501 | 502 | def clean_up_tokenization(out_string): 503 | out_string = out_string.replace(' .', '.').replace(' ?', '?').replace(' !', '!').replace(' ,', ',' 504 | ).replace(" ' ", "'").replace(" n't", "n't").replace(" 'm", "'m").replace(" do not", " don't" 505 | ).replace(" 's", "'s").replace(" 've", "'ve").replace(" 're", "'re") 506 | return out_string 507 | -------------------------------------------------------------------------------- /models/pretrained_common/modeling_utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """PyTorch BERT model.""" 17 | 18 | from __future__ import (absolute_import, division, print_function, 19 | unicode_literals) 20 | 21 | import copy 22 | import json 23 | from collections import OrderedDict 24 | import logging 25 | import os 26 | from io import open 27 | 28 | import six 29 | import torch 30 | from torch import nn 31 | from torch.nn import CrossEntropyLoss 32 | from torch.nn import functional as F 33 | 34 | from .configuration_utils import PretrainedConfig 35 | from .file_utils import cached_path, WEIGHTS_NAME, TF_WEIGHTS_NAME 36 | 37 | logger = logging.getLogger(__name__) 38 | 39 | 40 | try: 41 | from torch.nn import Identity 42 | except ImportError: 43 | # Older PyTorch compatibility 44 | class Identity(nn.Module): 45 | r"""A placeholder identity operator that is argument-insensitive. 46 | """ 47 | def __init__(self, *args, **kwargs): 48 | super(Identity, self).__init__() 49 | 50 | def forward(self, input): 51 | return input 52 | 53 | class PreTrainedModel(nn.Module): 54 | r""" Base class for all models. 55 | :class:`~pytorch_transformers.PreTrainedModel` takes care of storing the configuration of the models and handles methods for loading/downloading/saving models 56 | as well as a few methods commons to all models to (i) resize the input embeddings and (ii) prune heads in the self-attention heads. 57 | Class attributes (overridden by derived classes): 58 | - ``config_class``: a class derived from :class:`~pytorch_transformers.PretrainedConfig` to use as configuration class for this model architecture. 59 | - ``pretrained_model_archive_map``: a python ``dict`` of with `short-cut-names` (string) as keys and `url` (string) of associated pretrained weights as values. 60 | - ``load_tf_weights``: a python ``method`` for loading a TensorFlow checkpoint in a PyTorch model, taking as arguments: 61 | - ``model``: an instance of the relevant subclass of :class:`~pytorch_transformers.PreTrainedModel`, 62 | - ``config``: an instance of the relevant subclass of :class:`~pytorch_transformers.PretrainedConfig`, 63 | - ``path``: a path (string) to the TensorFlow checkpoint. 64 | - ``base_model_prefix``: a string indicating the attribute associated to the base model in derived classes of the same architecture adding modules on top of the base model. 65 | """ 66 | config_class = None 67 | pretrained_model_archive_map = {} 68 | load_tf_weights = lambda model, config, path: None 69 | base_model_prefix = "" 70 | 71 | def __init__(self, config, *inputs, **kwargs): 72 | super(PreTrainedModel, self).__init__() 73 | if not isinstance(config, PretrainedConfig): 74 | raise ValueError( 75 | "Parameter config in `{}(config)` should be an instance of class `PretrainedConfig`. " 76 | "To create a model from a pretrained model use " 77 | "`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format( 78 | self.__class__.__name__, self.__class__.__name__ 79 | )) 80 | # Save config in model 81 | self.config = config 82 | 83 | def _get_resized_embeddings(self, old_embeddings, new_num_tokens=None): 84 | """ Build a resized Embedding Module from a provided token Embedding Module. 85 | Increasing the size will add newly initialized vectors at the end 86 | Reducing the size will remove vectors from the end 87 | Args: 88 | new_num_tokens: (`optional`) int 89 | New number of tokens in the embedding matrix. 90 | Increasing the size will add newly initialized vectors at the end 91 | Reducing the size will remove vectors from the end 92 | If not provided or None: return the provided token Embedding Module. 93 | Return: ``torch.nn.Embeddings`` 94 | Pointer to the resized Embedding Module or the old Embedding Module if new_num_tokens is None 95 | """ 96 | if new_num_tokens is None: 97 | return old_embeddings 98 | 99 | old_num_tokens, old_embedding_dim = old_embeddings.weight.size() 100 | if old_num_tokens == new_num_tokens: 101 | return old_embeddings 102 | 103 | # Build new embeddings 104 | new_embeddings = nn.Embedding(new_num_tokens, old_embedding_dim) 105 | new_embeddings.to(old_embeddings.weight.device) 106 | 107 | # initialize all new embeddings (in particular added tokens) 108 | self._init_weights(new_embeddings) 109 | 110 | # Copy word embeddings from the previous weights 111 | num_tokens_to_copy = min(old_num_tokens, new_num_tokens) 112 | new_embeddings.weight.data[:num_tokens_to_copy, :] = old_embeddings.weight.data[:num_tokens_to_copy, :] 113 | 114 | return new_embeddings 115 | 116 | def _tie_or_clone_weights(self, first_module, second_module): 117 | """ Tie or clone module weights depending of weither we are using TorchScript or not 118 | """ 119 | if self.config.torchscript: 120 | first_module.weight = nn.Parameter(second_module.weight.clone()) 121 | else: 122 | first_module.weight = second_module.weight 123 | 124 | if hasattr(first_module, 'bias') and first_module.bias is not None: 125 | first_module.bias.data = torch.nn.functional.pad( 126 | first_module.bias.data, 127 | (0, first_module.weight.shape[0] - first_module.bias.shape[0]), 128 | 'constant', 129 | 0 130 | ) 131 | 132 | def resize_token_embeddings(self, new_num_tokens=None): 133 | """ Resize input token embeddings matrix of the model if new_num_tokens != config.vocab_size. 134 | Take care of tying weights embeddings afterwards if the model class has a `tie_weights()` method. 135 | Arguments: 136 | new_num_tokens: (`optional`) int: 137 | New number of tokens in the embedding matrix. Increasing the size will add newly initialized vectors at the end. Reducing the size will remove vectors from the end. 138 | If not provided or None: does nothing and just returns a pointer to the input tokens ``torch.nn.Embeddings`` Module of the model. 139 | Return: ``torch.nn.Embeddings`` 140 | Pointer to the input tokens Embeddings Module of the model 141 | """ 142 | base_model = getattr(self, self.base_model_prefix, self) # get the base model if needed 143 | model_embeds = base_model._resize_token_embeddings(new_num_tokens) # word_embeds 144 | if new_num_tokens is None: 145 | return model_embeds 146 | 147 | # Update base model and current model config 148 | self.config.vocab_size = new_num_tokens 149 | base_model.vocab_size = new_num_tokens 150 | 151 | # Tie weights again if needed 152 | if hasattr(self, 'tie_weights'): 153 | self.tie_weights() 154 | 155 | return model_embeds 156 | 157 | def init_weights(self): 158 | """ Initialize and prunes weights if needed. """ 159 | # Initialize weights 160 | self.apply(self._init_weights) 161 | 162 | # Prune heads if needed 163 | if self.config.pruned_heads: 164 | self.prune_heads(self.config.pruned_heads) 165 | 166 | def prune_heads(self, heads_to_prune): 167 | """ Prunes heads of the base model. 168 | Arguments: 169 | heads_to_prune: dict with keys being selected layer indices (`int`) and associated values being the list of heads to prune in said layer (list of `int`). 170 | E.g. {1: [0, 2], 2: [2, 3]} will prune heads 0 and 2 on layer 1 and heads 2 and 3 on layer 2. 171 | """ 172 | base_model = getattr(self, self.base_model_prefix, self) # get the base model if needed 173 | 174 | # save new sets of pruned heads as union of previously stored pruned heads and newly pruned heads 175 | for layer, heads in heads_to_prune.items(): 176 | union_heads = set(self.config.pruned_heads.get(layer, [])) | set(heads) 177 | self.config.pruned_heads[layer] = list(union_heads) # Unfortunately we have to store it as list for JSON 178 | 179 | base_model._prune_heads(heads_to_prune) 180 | 181 | def save_pretrained(self, save_directory): 182 | """ Save a model and its configuration file to a directory, so that it 183 | can be re-loaded using the `:func:`~pytorch_transformers.PreTrainedModel.from_pretrained`` class method. 184 | """ 185 | assert os.path.isdir(save_directory), "Saving path should be a directory where the model and configuration can be saved" 186 | 187 | # Only save the model it-self if we are using distributed training 188 | model_to_save = self.module if hasattr(self, 'module') else self 189 | 190 | # Save configuration file 191 | model_to_save.config.save_pretrained(save_directory) 192 | 193 | # If we save using the predefined names, we can load using `from_pretrained` 194 | output_model_file = os.path.join(save_directory, WEIGHTS_NAME) 195 | 196 | torch.save(model_to_save.state_dict(), output_model_file) 197 | 198 | @classmethod 199 | def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): 200 | r"""Instantiate a pretrained pytorch model from a pre-trained model configuration. 201 | The model is set in evaluation mode by default using ``model.eval()`` (Dropout modules are deactivated) 202 | To train the model, you should first set it back in training mode with ``model.train()`` 203 | The warning ``Weights from XXX not initialized from pretrained model`` means that the weights of XXX do not come pre-trained with the rest of the model. 204 | It is up to you to train those weights with a downstream fine-tuning task. 205 | The warning ``Weights from XXX not used in YYY`` means that the layer XXX is not used by YYY, therefore those weights are discarded. 206 | Parameters: 207 | pretrained_model_name_or_path: either: 208 | - a string with the `shortcut name` of a pre-trained model to load from cache or download, e.g.: ``bert-base-uncased``. 209 | - a path to a `directory` containing model weights saved using :func:`~pytorch_transformers.PreTrainedModel.save_pretrained`, e.g.: ``./my_model_directory/``. 210 | - a path or url to a `tensorflow index checkpoint file` (e.g. `./tf_model/model.ckpt.index`). In this case, ``from_tf`` should be set to True and a configuration object should be provided as ``config`` argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards. 211 | model_args: (`optional`) Sequence of positional arguments: 212 | All remaning positional arguments will be passed to the underlying model's ``__init__`` method 213 | config: (`optional`) instance of a class derived from :class:`~pytorch_transformers.PretrainedConfig`: 214 | Configuration for the model to use instead of an automatically loaded configuation. Configuration can be automatically loaded when: 215 | - the model is a model provided by the library (loaded with the ``shortcut-name`` string of a pretrained model), or 216 | - the model was saved using :func:`~pytorch_transformers.PreTrainedModel.save_pretrained` and is reloaded by suppling the save directory. 217 | - the model is loaded by suppling a local directory as ``pretrained_model_name_or_path`` and a configuration JSON file named `config.json` is found in the directory. 218 | state_dict: (`optional`) dict: 219 | an optional state dictionnary for the model to use instead of a state dictionary loaded from saved weights file. 220 | This option can be used if you want to create a model from a pretrained configuration but load your own weights. 221 | In this case though, you should check if using :func:`~pytorch_transformers.PreTrainedModel.save_pretrained` and :func:`~pytorch_transformers.PreTrainedModel.from_pretrained` is not a simpler option. 222 | cache_dir: (`optional`) string: 223 | Path to a directory in which a downloaded pre-trained model 224 | configuration should be cached if the standard cache should not be used. 225 | force_download: (`optional`) boolean, default False: 226 | Force to (re-)download the model weights and configuration files and override the cached versions if they exists. 227 | proxies: (`optional`) dict, default None: 228 | A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}. 229 | The proxies are used on each request. 230 | output_loading_info: (`optional`) boolean: 231 | Set to ``True`` to also return a dictionnary containing missing keys, unexpected keys and error messages. 232 | kwargs: (`optional`) Remaining dictionary of keyword arguments: 233 | Can be used to update the configuration object (after it being loaded) and initiate the model. (e.g. ``output_attention=True``). Behave differently depending on whether a `config` is provided or automatically loaded: 234 | - If a configuration is provided with ``config``, ``**kwargs`` will be directly passed to the underlying model's ``__init__`` method (we assume all relevant updates to the configuration have already been done) 235 | - If a configuration is not provided, ``kwargs`` will be first passed to the configuration class initialization function (:func:`~pytorch_transformers.PretrainedConfig.from_pretrained`). Each key of ``kwargs`` that corresponds to a configuration attribute will be used to override said attribute with the supplied ``kwargs`` value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model's ``__init__`` function. 236 | Examples:: 237 | model = BertModel.from_pretrained('bert-base-uncased') # Download model and configuration from S3 and cache. 238 | model = BertModel.from_pretrained('./test/saved_model/') # E.g. model was saved using `save_pretrained('./test/saved_model/')` 239 | model = BertModel.from_pretrained('bert-base-uncased', output_attention=True) # Update configuration during loading 240 | assert model.config.output_attention == True 241 | # Loading from a TF checkpoint file instead of a PyTorch model (slower) 242 | config = BertConfig.from_json_file('./tf_model/my_tf_model_config.json') 243 | model = BertModel.from_pretrained('./tf_model/my_tf_checkpoint.ckpt.index', from_tf=True, config=config) 244 | """ 245 | config = kwargs.pop('config', None) 246 | state_dict = kwargs.pop('state_dict', None) 247 | cache_dir = kwargs.pop('cache_dir', None) 248 | from_tf = kwargs.pop('from_tf', False) 249 | force_download = kwargs.pop('force_download', False) 250 | proxies = kwargs.pop('proxies', None) 251 | output_loading_info = kwargs.pop('output_loading_info', False) 252 | 253 | # Load config 254 | if config is None: 255 | config, model_kwargs = cls.config_class.from_pretrained( 256 | pretrained_model_name_or_path, *model_args, 257 | cache_dir=cache_dir, return_unused_kwargs=True, 258 | force_download=force_download, 259 | **kwargs 260 | ) 261 | else: 262 | model_kwargs = kwargs 263 | 264 | # Load model 265 | if pretrained_model_name_or_path in cls.pretrained_model_archive_map: 266 | archive_file = cls.pretrained_model_archive_map[pretrained_model_name_or_path] 267 | elif os.path.isdir(pretrained_model_name_or_path): 268 | if from_tf: 269 | # Directly load from a TensorFlow checkpoint 270 | archive_file = os.path.join(pretrained_model_name_or_path, TF_WEIGHTS_NAME + ".index") 271 | else: 272 | archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME) 273 | else: 274 | if from_tf: 275 | # Directly load from a TensorFlow checkpoint 276 | archive_file = pretrained_model_name_or_path + ".index" 277 | else: 278 | archive_file = pretrained_model_name_or_path 279 | # redirect to the cache, if necessary 280 | try: 281 | resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir, force_download=force_download, proxies=proxies) 282 | except EnvironmentError as e: 283 | if pretrained_model_name_or_path in cls.pretrained_model_archive_map: 284 | logger.error( 285 | "Couldn't reach server at '{}' to download pretrained weights.".format( 286 | archive_file)) 287 | else: 288 | logger.error( 289 | "Model name '{}' was not found in model name list ({}). " 290 | "We assumed '{}' was a path or url but couldn't find any file " 291 | "associated to this path or url.".format( 292 | pretrained_model_name_or_path, 293 | ', '.join(cls.pretrained_model_archive_map.keys()), 294 | archive_file)) 295 | raise e 296 | if resolved_archive_file == archive_file: 297 | logger.info("loading weights file {}".format(archive_file)) 298 | else: 299 | logger.info("loading weights file {} from cache at {}".format( 300 | archive_file, resolved_archive_file)) 301 | 302 | # Instantiate model. 303 | model = cls(config, *model_args, **model_kwargs) 304 | 305 | if state_dict is None and not from_tf: 306 | state_dict = torch.load(resolved_archive_file, map_location='cpu') 307 | # taesun -> domaing post training model 308 | # state_dict.keys() : "model", "optimizer" 309 | if resolved_archive_file.endswith('pth'): 310 | for state_key in state_dict["model"].keys(): 311 | if state_key.startswith("_bert_model.bert."): 312 | state_dict[state_key[len("_bert_model.bert."):]] = state_dict["model"][state_key] 313 | 314 | if from_tf: 315 | # Directly load from a TensorFlow checkpoint 316 | return cls.load_tf_weights(model, config, resolved_archive_file[:-6]) # Remove the '.index' 317 | 318 | # Convert old format to new format if needed from a PyTorch state_dict 319 | old_keys = [] 320 | new_keys = [] 321 | for key in state_dict.keys(): 322 | new_key = None 323 | if 'gamma' in key: 324 | new_key = key.replace('gamma', 'weight') 325 | if 'beta' in key: 326 | new_key = key.replace('beta', 'bias') 327 | if new_key: 328 | old_keys.append(key) 329 | new_keys.append(new_key) 330 | for old_key, new_key in zip(old_keys, new_keys): 331 | state_dict[new_key] = state_dict.pop(old_key) 332 | 333 | # Load from a PyTorch state_dict 334 | missing_keys = [] 335 | unexpected_keys = [] 336 | error_msgs = [] 337 | # copy state_dict so _load_from_state_dict can modify it 338 | metadata = getattr(state_dict, '_metadata', None) 339 | state_dict = state_dict.copy() 340 | if metadata is not None: 341 | state_dict._metadata = metadata 342 | 343 | def load(module, prefix=''): 344 | local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) 345 | module._load_from_state_dict( 346 | state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs) 347 | for name, child in module._modules.items(): 348 | if child is not None: 349 | load(child, prefix + name + '.') 350 | 351 | # Make sure we are able to load base models as well as derived models (with heads) 352 | start_prefix = '' 353 | model_to_load = model 354 | if not hasattr(model, cls.base_model_prefix) and any(s.startswith(cls.base_model_prefix) for s in state_dict.keys()): 355 | start_prefix = cls.base_model_prefix + '.' 356 | if hasattr(model, cls.base_model_prefix) and not any(s.startswith(cls.base_model_prefix) for s in state_dict.keys()): 357 | model_to_load = getattr(model, cls.base_model_prefix) 358 | 359 | load(model_to_load, prefix=start_prefix) 360 | if len(missing_keys) > 0: 361 | logger.info("Weights of {} not initialized from pretrained model: {}".format( 362 | model.__class__.__name__, missing_keys)) 363 | if len(unexpected_keys) > 0: 364 | logger.info("Weights from pretrained model not used in {}: {}".format( 365 | model.__class__.__name__, unexpected_keys)) 366 | if len(error_msgs) > 0: 367 | raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( 368 | model.__class__.__name__, "\n\t".join(error_msgs))) 369 | 370 | if hasattr(model, 'tie_weights'): 371 | model.tie_weights() # make sure word embedding weights are still tied 372 | 373 | # Set model in evaluation mode to desactivate DropOut modules by default 374 | model.eval() 375 | 376 | if output_loading_info: 377 | loading_info = {"missing_keys": missing_keys, "unexpected_keys": unexpected_keys, "error_msgs": error_msgs} 378 | return model, loading_info 379 | 380 | return model 381 | 382 | 383 | class Conv1D(nn.Module): 384 | def __init__(self, nf, nx): 385 | """ Conv1D layer as defined by Radford et al. for OpenAI GPT (and also used in GPT-2) 386 | Basically works like a Linear layer but the weights are transposed 387 | """ 388 | super(Conv1D, self).__init__() 389 | self.nf = nf 390 | w = torch.empty(nx, nf) 391 | nn.init.normal_(w, std=0.02) 392 | self.weight = nn.Parameter(w) 393 | self.bias = nn.Parameter(torch.zeros(nf)) 394 | 395 | def forward(self, x): 396 | size_out = x.size()[:-1] + (self.nf,) 397 | x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight) 398 | x = x.view(*size_out) 399 | return x 400 | 401 | 402 | class PoolerStartLogits(nn.Module): 403 | """ Compute SQuAD start_logits from sequence hidden states. """ 404 | def __init__(self, config): 405 | super(PoolerStartLogits, self).__init__() 406 | self.dense = nn.Linear(config.hidden_size, 1) 407 | 408 | def forward(self, hidden_states, p_mask=None): 409 | """ Args: 410 | **p_mask**: (`optional`) ``torch.FloatTensor`` of shape `(batch_size, seq_len)` 411 | invalid position mask such as query and special symbols (PAD, SEP, CLS) 412 | 1.0 means token should be masked. 413 | """ 414 | x = self.dense(hidden_states).squeeze(-1) 415 | 416 | if p_mask is not None: 417 | x = x * (1 - p_mask) - 1e30 * p_mask 418 | 419 | return x 420 | 421 | 422 | class PoolerEndLogits(nn.Module): 423 | """ Compute SQuAD end_logits from sequence hidden states and start token hidden state. 424 | """ 425 | def __init__(self, config): 426 | super(PoolerEndLogits, self).__init__() 427 | self.dense_0 = nn.Linear(config.hidden_size * 2, config.hidden_size) 428 | self.activation = nn.Tanh() 429 | self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) 430 | self.dense_1 = nn.Linear(config.hidden_size, 1) 431 | 432 | def forward(self, hidden_states, start_states=None, start_positions=None, p_mask=None): 433 | """ Args: 434 | One of ``start_states``, ``start_positions`` should be not None. 435 | If both are set, ``start_positions`` overrides ``start_states``. 436 | **start_states**: ``torch.LongTensor`` of shape identical to hidden_states 437 | hidden states of the first tokens for the labeled span. 438 | **start_positions**: ``torch.LongTensor`` of shape ``(batch_size,)`` 439 | position of the first token for the labeled span: 440 | **p_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, seq_len)`` 441 | Mask of invalid position such as query and special symbols (PAD, SEP, CLS) 442 | 1.0 means token should be masked. 443 | """ 444 | assert start_states is not None or start_positions is not None, "One of start_states, start_positions should be not None" 445 | if start_positions is not None: 446 | slen, hsz = hidden_states.shape[-2:] 447 | start_positions = start_positions[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz) 448 | start_states = hidden_states.gather(-2, start_positions) # shape (bsz, 1, hsz) 449 | start_states = start_states.expand(-1, slen, -1) # shape (bsz, slen, hsz) 450 | 451 | x = self.dense_0(torch.cat([hidden_states, start_states], dim=-1)) 452 | x = self.activation(x) 453 | x = self.LayerNorm(x) 454 | x = self.dense_1(x).squeeze(-1) 455 | 456 | if p_mask is not None: 457 | x = x * (1 - p_mask) - 1e30 * p_mask 458 | 459 | return x 460 | 461 | 462 | class PoolerAnswerClass(nn.Module): 463 | """ Compute SQuAD 2.0 answer class from classification and start tokens hidden states. """ 464 | def __init__(self, config): 465 | super(PoolerAnswerClass, self).__init__() 466 | self.dense_0 = nn.Linear(config.hidden_size * 2, config.hidden_size) 467 | self.activation = nn.Tanh() 468 | self.dense_1 = nn.Linear(config.hidden_size, 1, bias=False) 469 | 470 | def forward(self, hidden_states, start_states=None, start_positions=None, cls_index=None): 471 | """ 472 | Args: 473 | One of ``start_states``, ``start_positions`` should be not None. 474 | If both are set, ``start_positions`` overrides ``start_states``. 475 | **start_states**: ``torch.LongTensor`` of shape identical to ``hidden_states``. 476 | hidden states of the first tokens for the labeled span. 477 | **start_positions**: ``torch.LongTensor`` of shape ``(batch_size,)`` 478 | position of the first token for the labeled span. 479 | **cls_index**: torch.LongTensor of shape ``(batch_size,)`` 480 | position of the CLS token. If None, take the last token. 481 | note(Original repo): 482 | no dependency on end_feature so that we can obtain one single `cls_logits` 483 | for each sample 484 | """ 485 | hsz = hidden_states.shape[-1] 486 | assert start_states is not None or start_positions is not None, "One of start_states, start_positions should be not None" 487 | if start_positions is not None: 488 | start_positions = start_positions[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz) 489 | start_states = hidden_states.gather(-2, start_positions).squeeze(-2) # shape (bsz, hsz) 490 | 491 | if cls_index is not None: 492 | cls_index = cls_index[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz) 493 | cls_token_state = hidden_states.gather(-2, cls_index).squeeze(-2) # shape (bsz, hsz) 494 | else: 495 | cls_token_state = hidden_states[:, -1, :] # shape (bsz, hsz) 496 | 497 | x = self.dense_0(torch.cat([start_states, cls_token_state], dim=-1)) 498 | x = self.activation(x) 499 | x = self.dense_1(x).squeeze(-1) 500 | 501 | return x 502 | 503 | 504 | class SQuADHead(nn.Module): 505 | r""" A SQuAD head inspired by XLNet. 506 | Parameters: 507 | config (:class:`~pytorch_transformers.XLNetConfig`): Model configuration class with all the parameters of the model. 508 | Inputs: 509 | **hidden_states**: ``torch.FloatTensor`` of shape ``(batch_size, seq_len, hidden_size)`` 510 | hidden states of sequence tokens 511 | **start_positions**: ``torch.LongTensor`` of shape ``(batch_size,)`` 512 | position of the first token for the labeled span. 513 | **end_positions**: ``torch.LongTensor`` of shape ``(batch_size,)`` 514 | position of the last token for the labeled span. 515 | **cls_index**: torch.LongTensor of shape ``(batch_size,)`` 516 | position of the CLS token. If None, take the last token. 517 | **is_impossible**: ``torch.LongTensor`` of shape ``(batch_size,)`` 518 | Whether the question has a possible answer in the paragraph or not. 519 | **p_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, seq_len)`` 520 | Mask of invalid position such as query and special symbols (PAD, SEP, CLS) 521 | 1.0 means token should be masked. 522 | Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs: 523 | **loss**: (`optional`, returned if both ``start_positions`` and ``end_positions`` are provided) ``torch.FloatTensor`` of shape ``(1,)``: 524 | Classification loss as the sum of start token, end token (and is_impossible if provided) classification losses. 525 | **start_top_log_probs**: (`optional`, returned if ``start_positions`` or ``end_positions`` is not provided) 526 | ``torch.FloatTensor`` of shape ``(batch_size, config.start_n_top)`` 527 | Log probabilities for the top config.start_n_top start token possibilities (beam-search). 528 | **start_top_index**: (`optional`, returned if ``start_positions`` or ``end_positions`` is not provided) 529 | ``torch.LongTensor`` of shape ``(batch_size, config.start_n_top)`` 530 | Indices for the top config.start_n_top start token possibilities (beam-search). 531 | **end_top_log_probs**: (`optional`, returned if ``start_positions`` or ``end_positions`` is not provided) 532 | ``torch.FloatTensor`` of shape ``(batch_size, config.start_n_top * config.end_n_top)`` 533 | Log probabilities for the top ``config.start_n_top * config.end_n_top`` end token possibilities (beam-search). 534 | **end_top_index**: (`optional`, returned if ``start_positions`` or ``end_positions`` is not provided) 535 | ``torch.LongTensor`` of shape ``(batch_size, config.start_n_top * config.end_n_top)`` 536 | Indices for the top ``config.start_n_top * config.end_n_top`` end token possibilities (beam-search). 537 | **cls_logits**: (`optional`, returned if ``start_positions`` or ``end_positions`` is not provided) 538 | ``torch.FloatTensor`` of shape ``(batch_size,)`` 539 | Log probabilities for the ``is_impossible`` label of the answers. 540 | """ 541 | def __init__(self, config): 542 | super(SQuADHead, self).__init__() 543 | self.start_n_top = config.start_n_top 544 | self.end_n_top = config.end_n_top 545 | 546 | self.start_logits = PoolerStartLogits(config) 547 | self.end_logits = PoolerEndLogits(config) 548 | self.answer_class = PoolerAnswerClass(config) 549 | 550 | def forward(self, hidden_states, start_positions=None, end_positions=None, 551 | cls_index=None, is_impossible=None, p_mask=None): 552 | outputs = () 553 | 554 | start_logits = self.start_logits(hidden_states, p_mask=p_mask) 555 | 556 | if start_positions is not None and end_positions is not None: 557 | # If we are on multi-GPU, let's remove the dimension added by batch splitting 558 | for x in (start_positions, end_positions, cls_index, is_impossible): 559 | if x is not None and x.dim() > 1: 560 | x.squeeze_(-1) 561 | 562 | # during training, compute the end logits based on the ground truth of the start position 563 | end_logits = self.end_logits(hidden_states, start_positions=start_positions, p_mask=p_mask) 564 | 565 | loss_fct = CrossEntropyLoss() 566 | start_loss = loss_fct(start_logits, start_positions) 567 | end_loss = loss_fct(end_logits, end_positions) 568 | total_loss = (start_loss + end_loss) / 2 569 | 570 | if cls_index is not None and is_impossible is not None: 571 | # Predict answerability from the representation of CLS and START 572 | cls_logits = self.answer_class(hidden_states, start_positions=start_positions, cls_index=cls_index) 573 | loss_fct_cls = nn.BCEWithLogitsLoss() 574 | cls_loss = loss_fct_cls(cls_logits, is_impossible) 575 | 576 | # note(zhiliny): by default multiply the loss by 0.5 so that the scale is comparable to start_loss and end_loss 577 | total_loss += cls_loss * 0.5 578 | 579 | outputs = (total_loss,) + outputs 580 | 581 | else: 582 | # during inference, compute the end logits based on beam search 583 | bsz, slen, hsz = hidden_states.size() 584 | start_log_probs = F.softmax(start_logits, dim=-1) # shape (bsz, slen) 585 | 586 | start_top_log_probs, start_top_index = torch.topk(start_log_probs, self.start_n_top, dim=-1) # shape (bsz, start_n_top) 587 | start_top_index_exp = start_top_index.unsqueeze(-1).expand(-1, -1, hsz) # shape (bsz, start_n_top, hsz) 588 | start_states = torch.gather(hidden_states, -2, start_top_index_exp) # shape (bsz, start_n_top, hsz) 589 | start_states = start_states.unsqueeze(1).expand(-1, slen, -1, -1) # shape (bsz, slen, start_n_top, hsz) 590 | 591 | hidden_states_expanded = hidden_states.unsqueeze(2).expand_as(start_states) # shape (bsz, slen, start_n_top, hsz) 592 | p_mask = p_mask.unsqueeze(-1) if p_mask is not None else None 593 | end_logits = self.end_logits(hidden_states_expanded, start_states=start_states, p_mask=p_mask) 594 | end_log_probs = F.softmax(end_logits, dim=1) # shape (bsz, slen, start_n_top) 595 | 596 | end_top_log_probs, end_top_index = torch.topk(end_log_probs, self.end_n_top, dim=1) # shape (bsz, end_n_top, start_n_top) 597 | end_top_log_probs = end_top_log_probs.view(-1, self.start_n_top * self.end_n_top) 598 | end_top_index = end_top_index.view(-1, self.start_n_top * self.end_n_top) 599 | 600 | start_states = torch.einsum("blh,bl->bh", hidden_states, start_log_probs) 601 | cls_logits = self.answer_class(hidden_states, start_states=start_states, cls_index=cls_index) 602 | 603 | outputs = (start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits) + outputs 604 | 605 | # return start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits 606 | # or (if labels are provided) (total_loss,) 607 | return outputs 608 | 609 | 610 | class SequenceSummary(nn.Module): 611 | r""" Compute a single vector summary of a sequence hidden states according to various possibilities: 612 | Args of the config class: 613 | summary_type: 614 | - 'last' => [default] take the last token hidden state (like XLNet) 615 | - 'first' => take the first token hidden state (like Bert) 616 | - 'mean' => take the mean of all tokens hidden states 617 | - 'cls_index' => supply a Tensor of classification token position (GPT/GPT-2) 618 | - 'attn' => Not implemented now, use multi-head attention 619 | summary_use_proj: Add a projection after the vector extraction 620 | summary_proj_to_labels: If True, the projection outputs to config.num_labels classes (otherwise to hidden_size). Default: False. 621 | summary_activation: 'tanh' => add a tanh activation to the output, Other => no activation. Default 622 | summary_first_dropout: Add a dropout before the projection and activation 623 | summary_last_dropout: Add a dropout after the projection and activation 624 | """ 625 | def __init__(self, config): 626 | super(SequenceSummary, self).__init__() 627 | 628 | self.summary_type = config.summary_type if hasattr(config, 'summary_use_proj') else 'last' 629 | if self.summary_type == 'attn': 630 | # We should use a standard multi-head attention module with absolute positional embedding for that. 631 | # Cf. https://github.com/zihangdai/xlnet/blob/master/modeling.py#L253-L276 632 | # We can probably just use the multi-head attention module of PyTorch >=1.1.0 633 | raise NotImplementedError 634 | 635 | self.summary = Identity() 636 | if hasattr(config, 'summary_use_proj') and config.summary_use_proj: 637 | if hasattr(config, 'summary_proj_to_labels') and config.summary_proj_to_labels and config.num_labels > 0: 638 | num_classes = config.num_labels 639 | else: 640 | num_classes = config.hidden_size 641 | self.summary = nn.Linear(config.hidden_size, num_classes) 642 | 643 | self.activation = Identity() 644 | if hasattr(config, 'summary_activation') and config.summary_activation == 'tanh': 645 | self.activation = nn.Tanh() 646 | 647 | self.first_dropout = Identity() 648 | if hasattr(config, 'summary_first_dropout') and config.summary_first_dropout > 0: 649 | self.first_dropout = nn.Dropout(config.summary_first_dropout) 650 | 651 | self.last_dropout = Identity() 652 | if hasattr(config, 'summary_last_dropout') and config.summary_last_dropout > 0: 653 | self.last_dropout = nn.Dropout(config.summary_last_dropout) 654 | 655 | def forward(self, hidden_states, cls_index=None): 656 | """ hidden_states: float Tensor in shape [bsz, seq_len, hidden_size], the hidden-states of the last layer. 657 | cls_index: [optional] position of the classification token if summary_type == 'cls_index', 658 | shape (bsz,) or more generally (bsz, ...) where ... are optional leading dimensions of hidden_states. 659 | if summary_type == 'cls_index' and cls_index is None: 660 | we take the last token of the sequence as classification token 661 | """ 662 | if self.summary_type == 'last': 663 | output = hidden_states[:, -1] 664 | elif self.summary_type == 'first': 665 | output = hidden_states[:, 0] 666 | elif self.summary_type == 'mean': 667 | output = hidden_states.mean(dim=1) 668 | elif self.summary_type == 'cls_index': 669 | if cls_index is None: 670 | cls_index = torch.full_like(hidden_states[..., :1, :], hidden_states.shape[-2]-1, dtype=torch.long) 671 | else: 672 | cls_index = cls_index.unsqueeze(-1).unsqueeze(-1) 673 | cls_index = cls_index.expand((-1,) * (cls_index.dim()-1) + (hidden_states.size(-1),)) 674 | # shape of cls_index: (bsz, XX, 1, hidden_size) where XX are optional leading dim of hidden_states 675 | output = hidden_states.gather(-2, cls_index).squeeze(-2) # shape (bsz, XX, hidden_size) 676 | elif self.summary_type == 'attn': 677 | raise NotImplementedError 678 | 679 | output = self.first_dropout(output) 680 | output = self.summary(output) 681 | output = self.activation(output) 682 | output = self.last_dropout(output) 683 | 684 | return output 685 | 686 | 687 | def prune_linear_layer(layer, index, dim=0): 688 | """ Prune a linear layer (a model parameters) to keep only entries in index. 689 | Return the pruned layer as a new layer with requires_grad=True. 690 | Used to remove heads. 691 | """ 692 | index = index.to(layer.weight.device) 693 | W = layer.weight.index_select(dim, index).clone().detach() 694 | if layer.bias is not None: 695 | if dim == 1: 696 | b = layer.bias.clone().detach() 697 | else: 698 | b = layer.bias[index].clone().detach() 699 | new_size = list(layer.weight.size()) 700 | new_size[dim] = len(index) 701 | new_layer = nn.Linear(new_size[1], new_size[0], bias=layer.bias is not None).to(layer.weight.device) 702 | new_layer.weight.requires_grad = False 703 | new_layer.weight.copy_(W.contiguous()) 704 | new_layer.weight.requires_grad = True 705 | if layer.bias is not None: 706 | new_layer.bias.requires_grad = False 707 | new_layer.bias.copy_(b.contiguous()) 708 | new_layer.bias.requires_grad = True 709 | return new_layer 710 | 711 | 712 | def prune_conv1d_layer(layer, index, dim=1): 713 | """ Prune a Conv1D layer (a model parameters) to keep only entries in index. 714 | A Conv1D work as a Linear layer (see e.g. BERT) but the weights are transposed. 715 | Return the pruned layer as a new layer with requires_grad=True. 716 | Used to remove heads. 717 | """ 718 | index = index.to(layer.weight.device) 719 | W = layer.weight.index_select(dim, index).clone().detach() 720 | if dim == 0: 721 | b = layer.bias.clone().detach() 722 | else: 723 | b = layer.bias[index].clone().detach() 724 | new_size = list(layer.weight.size()) 725 | new_size[dim] = len(index) 726 | new_layer = Conv1D(new_size[1], new_size[0]).to(layer.weight.device) 727 | new_layer.weight.requires_grad = False 728 | new_layer.weight.copy_(W.contiguous()) 729 | new_layer.weight.requires_grad = True 730 | new_layer.bias.requires_grad = False 731 | new_layer.bias.copy_(b.contiguous()) 732 | new_layer.bias.requires_grad = True 733 | return new_layer 734 | 735 | def prune_layer(layer, index, dim=None): 736 | """ Prune a Conv1D or nn.Linear layer (a model parameters) to keep only entries in index. 737 | Return the pruned layer as a new layer with requires_grad=True. 738 | Used to remove heads. 739 | """ 740 | if isinstance(layer, nn.Linear): 741 | return prune_linear_layer(layer, index, dim=0 if dim is None else dim) 742 | elif isinstance(layer, Conv1D): 743 | return prune_conv1d_layer(layer, index, dim=1 if dim is None else dim) 744 | else: 745 | raise ValueError("Can't prune layer of class {}".format(layer.__class__)) --------------------------------------------------------------------------------