├── models
├── pretrained_common
│ ├── __init__.py
│ ├── activations.py
│ ├── optimization.py
│ ├── configuration_utils.py
│ ├── file_utils.py
│ ├── tokenization_utils.py
│ └── modeling_utils.py
├── __init__.py
├── bert_post_training.py
├── bert_base_cls.py
├── utils
│ ├── scorer.py
│ └── checkpointing.py
└── bert
│ ├── configuration_bert.py
│ └── tokenization_bert.py
├── model_overview.jpg
├── resources
├── bert-base-uncased
│ └── bert-base-uncased-config.json
└── bert-post-uncased
│ └── bert-post-uncased-config.json
├── scripts
├── download_post_checkpoints.sh
└── download_datasets.sh
├── config
└── hparams.py
├── README.md
├── data
├── data_utils.py
├── dataset.py
└── create_bert_post_training_data.py
├── main.py
├── evaluation.py
├── train.py
└── post_train.py
/models/pretrained_common/__init__.py:
--------------------------------------------------------------------------------
1 | __version__ = "2.8.0"
2 |
--------------------------------------------------------------------------------
/model_overview.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taesunwhang/BERT-ResSel/HEAD/model_overview.jpg
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | from models.bert_base_cls import BERTbase
2 | from models.bert_post_training import BertDomainPostTraining
3 |
4 |
5 | def Model(hparams, *args):
6 | name_model_map = {
7 | "bert_base_ft" : BERTbase,
8 | "bert_dpt_ft" : BERTbase,
9 |
10 | "bert_ubuntu_pt" : BertDomainPostTraining,
11 | }
12 |
13 | return name_model_map[hparams.model_type](hparams, *args)
--------------------------------------------------------------------------------
/resources/bert-base-uncased/bert-base-uncased-config.json:
--------------------------------------------------------------------------------
1 | {
2 | "architectures": [
3 | "BertForMaskedLM"
4 | ],
5 | "attention_probs_dropout_prob": 0.1,
6 | "hidden_act": "gelu",
7 | "hidden_dropout_prob": 0.1,
8 | "hidden_size": 768,
9 | "initializer_range": 0.02,
10 | "intermediate_size": 3072,
11 | "max_position_embeddings": 512,
12 | "num_attention_heads": 12,
13 | "num_hidden_layers": 12,
14 | "type_vocab_size": 2,
15 | "vocab_size": 30522
16 | }
17 |
--------------------------------------------------------------------------------
/resources/bert-post-uncased/bert-post-uncased-config.json:
--------------------------------------------------------------------------------
1 | {
2 | "architectures": [
3 | "BertForMaskedLM"
4 | ],
5 | "attention_probs_dropout_prob": 0.1,
6 | "hidden_act": "gelu",
7 | "hidden_dropout_prob": 0.1,
8 | "hidden_size": 768,
9 | "initializer_range": 0.02,
10 | "intermediate_size": 3072,
11 | "max_position_embeddings": 512,
12 | "num_attention_heads": 12,
13 | "num_hidden_layers": 12,
14 | "type_vocab_size": 2,
15 | "vocab_size": 30523
16 | }
17 |
--------------------------------------------------------------------------------
/scripts/download_post_checkpoints.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | export file_name=bert-base-uncased-pytorch_model.bin
4 | if [ -f $PWD/resources/bert-base-uncased/$file_name ]; then
5 | echo "$file_name exists"
6 | else
7 | wget https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-pytorch_model.bin
8 | mv $file_name resources/bert-base-uncased/
9 | fi
10 |
11 | export file_name=bert-post-uncased-pytorch_model.pth
12 | if [ -f $PWD/resources/bert-post-uncased/$file_name ]; then
13 | echo "$file_name exists"
14 | else
15 | echo "$file_name does not exist"
16 | export file_id=1jt0RhVT9y2d4AITn84kSOk06hjIv1y49
17 |
18 | wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id='$file_id -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id=$file_id" -O $file_name && rm -rf /tmp/cookies.txt
19 | mv $file_name resources/bert-post-uncased/
20 | fi
21 |
--------------------------------------------------------------------------------
/models/bert_post_training.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch.nn as nn
3 |
4 | from models.bert import modeling_bert, configuration_bert
5 |
6 | class BertDomainPostTraining(nn.Module):
7 | def __init__(self, hparams):
8 | super(BertDomainPostTraining, self).__init__()
9 | self.hparams = hparams
10 | bert_config = configuration_bert.BertConfig.from_pretrained(
11 | os.path.join(self.hparams.bert_pretrained_dir, self.hparams.bert_pretrained,
12 | "%s-config.json" % self.hparams.bert_pretrained),
13 | )
14 | self._bert_model = modeling_bert.BertForPreTraining.from_pretrained(
15 | os.path.join(self.hparams.bert_pretrained_dir, self.hparams.bert_pretrained, self.hparams.bert_checkpoint_path),
16 | config=bert_config
17 | )
18 |
19 | if self.hparams.do_eot:
20 | self._bert_model.resize_token_embeddings(self._bert_model.config.vocab_size + 1) # [EOT]
21 |
22 | def forward(self, batch):
23 |
24 | bert_outputs = self._bert_model(
25 | input_ids=batch["input_ids"],
26 | token_type_ids=batch["token_type_ids"],
27 | attention_mask=batch["attention_mask"],
28 | masked_lm_labels=batch["masked_lm_labels"],
29 | next_sentence_label=batch["next_sentence_labels"]
30 | )
31 | mlm_loss, nsp_loss, prediction_scores, seq_relationship_score = bert_outputs[:4]
32 |
33 | return mlm_loss, nsp_loss
--------------------------------------------------------------------------------
/scripts/download_datasets.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | export file_name=ubuntu.zip
4 | if [ -f $PWD/data/ubuntu_corpus_v1/ubuntu_train.pkl ]; then
5 | echo "ubuntu_train.pkl exists"
6 | else
7 | echo "ubuntu_train.pkl does not exist"
8 | export file_id=1VKQaNNC5NR-6TwVPpxYAVZQp_nwK3d5u
9 |
10 | wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id='$file_id -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id=$file_id" -O $file_name && rm -rf /tmp/cookies.txt
11 | unzip $file_name -d data/ubuntu_corpus_v1
12 | rm -r $file_name
13 | fi
14 |
15 | export file_name=ubuntu_post_training.txt
16 | if [ -f $PWD/data/ubuntu_corpus_v1/$file_name ]; then
17 | echo "$file_name exists"
18 | else
19 | echo "$file_name does not exist"
20 | export file_id=1mYS_PrnrKx4zDWOPTFhx_SeEwdumYXCK
21 |
22 | wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id='$file_id -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id=$file_id" -O $file_name && rm -rf /tmp/cookies.txt
23 | mv $file_name data/ubuntu_corpus_v1/
24 | fi
25 |
26 |
--------------------------------------------------------------------------------
/models/bert_base_cls.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torch.nn as nn
4 |
5 | from models.bert import modeling_bert, configuration_bert
6 |
7 | class BERTbase(nn.Module):
8 | def __init__(self, hparams):
9 | super(BERTbase, self).__init__()
10 | self.hparams = hparams
11 |
12 | bert_config = configuration_bert.BertConfig.from_pretrained(
13 | os.path.join(self.hparams.bert_pretrained_dir, self.hparams.bert_pretrained,
14 | "%s-config.json" % self.hparams.bert_pretrained),
15 | )
16 | self._bert_model = modeling_bert.BertModel.from_pretrained(
17 | os.path.join(self.hparams.bert_pretrained_dir, self.hparams.bert_pretrained, self.hparams.bert_checkpoint_path),
18 | config=bert_config
19 | )
20 |
21 | if self.hparams.do_eot and self.hparams.model_type == "bert_base_ft":
22 | self._bert_model.resize_token_embeddings(self._bert_model.config.vocab_size + 1) # [EOT]
23 |
24 | self._classification = nn.Sequential(
25 | nn.Dropout(p=1 - self.hparams.dropout_keep_prob),
26 | nn.Linear(self.hparams.bert_hidden_dim, 1)
27 | )
28 |
29 | def forward(self, batch):
30 | bert_outputs, _ = self._bert_model(
31 | batch["anno_sent"],
32 | token_type_ids=batch["segment_ids"],
33 | attention_mask=batch["attention_mask"]
34 | )
35 | cls_logits = bert_outputs[:,0,:] # bs, bert_output_size
36 | logits = self._classification(cls_logits) # bs, 1
37 | logits = logits.squeeze(-1)
38 |
39 | return logits
--------------------------------------------------------------------------------
/models/pretrained_common/activations.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import math
3 |
4 | import torch
5 | import torch.nn.functional as F
6 |
7 | logger = logging.getLogger(__name__)
8 |
9 | def swish(x):
10 | return x * torch.sigmoid(x)
11 |
12 |
13 | def _gelu_python(x):
14 | """ Original Implementation of the gelu activation function in Google Bert repo when initially created.
15 | For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
16 | 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
17 | This is now written in C in torch.nn.functional
18 | Also see https://arxiv.org/abs/1606.08415
19 | """
20 | return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
21 |
22 |
23 | def gelu_new(x):
24 | """ Implementation of the gelu activation function currently in Google Bert repo (identical to OpenAI GPT).
25 | Also see https://arxiv.org/abs/1606.08415
26 | """
27 | return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))
28 |
29 |
30 | if torch.__version__ < "1.4.0":
31 | gelu = _gelu_python
32 | else:
33 | gelu = F.gelu
34 |
35 | ACT2FN = {
36 | "relu": F.relu,
37 | "swish": swish,
38 | "gelu": gelu,
39 | "tanh": torch.tanh,
40 | "gelu_new": gelu_new,
41 | }
42 |
43 |
44 | def get_activation(activation_string):
45 | if activation_string in ACT2FN:
46 | return ACT2FN[activation_string]
47 | else:
48 | raise KeyError("function {} not found in ACT2FN mapping {}".format(activation_string, list(ACT2FN.keys())))
--------------------------------------------------------------------------------
/config/hparams.py:
--------------------------------------------------------------------------------
1 | from collections import defaultdict
2 |
3 | BASE_PARAMS = defaultdict(
4 | # lambda: None, # Set default value to None.
5 | # GPU params
6 | gpu_ids=[0],
7 |
8 | # Input params
9 | train_batch_size=8,
10 | eval_batch_size=100,
11 | virtual_batch_size=32,
12 |
13 | evaluate_candidates_num=10,
14 | recall_k_list=[1,2,5,10],
15 |
16 | # Training BERT params
17 | learning_rate=3e-5,
18 | dropout_keep_prob=0.8,
19 | num_epochs=10,
20 | max_gradient_norm=5,
21 |
22 | pad_idx=0,
23 | max_position_embeddings=100,
24 | num_hidden_layers=12,
25 | num_attention_heads=12,
26 | intermediate_size=3072,
27 | bert_hidden_dim=768,
28 | attention_probs_dropout_prob=0.1,
29 | layer_norm_eps=1e-12,
30 |
31 | # Train Model Config
32 | task_name="ubuntu",
33 | do_bert=True,
34 | do_eot=True,
35 | max_dialog_len=448,
36 | max_response_len=64,
37 | # summation -> 512
38 |
39 | # Need to change to train...(e.g.data dir, config dir, vocab dir, etc.)
40 | save_dirpath='checkpoints/', # /path/to/checkpoints
41 |
42 | bert_pretrained="bert-base-uncased", # should be defined here
43 | bert_checkpoint_path="bert-base-uncased-pytorch_model.bin",
44 | model_type="bert_base_ft",
45 |
46 | load_pthpath="",
47 | cpu_workers=8,
48 | tensorboard_step=1000,
49 | evaluate_print_step=100,
50 | )
51 |
52 | DPT_FINETUNING_PARAMS = BASE_PARAMS.copy()
53 | DPT_FINETUNING_PARAMS.update(
54 | bert_checkpoint_path="bert-post-uncased-pytorch_model.pth", # should be defined here
55 | model_type="bert_dpt_ft"
56 | )
57 |
58 | POST_TRAINING_PARAMS = BASE_PARAMS.copy()
59 | POST_TRAINING_PARAMS.update(
60 | num_epochs=3,
61 | # lambda: None, # Set default value to None.
62 | # GPU params
63 | gpu_ids=[0],
64 |
65 | # Input params
66 | train_batch_size=8,
67 | virtual_batch_size=512,
68 | tensorboard_step=100,
69 |
70 | checkpoint_save_step=2500, # virtual_batch -> 10000 step
71 | model_type="bert_ubuntu_pt",
72 | data_dir="./data/ubuntu_corpus_v1/ubuntu_post_training.hdf5",
73 | )
--------------------------------------------------------------------------------
/models/utils/scorer.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | def calculate_candidates_ranking(prediction, ground_truth, eval_candidates_num=10):
4 | total_num_split = len(ground_truth) / eval_candidates_num
5 |
6 | pred_split = np.split(prediction, total_num_split)
7 | gt_split = np.split(np.array(ground_truth), total_num_split)
8 | orig_rank_split = np.split(np.tile(np.arange(0, eval_candidates_num), int(total_num_split)), total_num_split)
9 | stack_scores = np.stack((gt_split, pred_split, orig_rank_split), axis=-1)
10 |
11 | rank_by_pred_l = []
12 | for i, stack_score in enumerate(stack_scores):
13 | rank_by_pred = sorted(stack_score, key=lambda x: x[1], reverse=True)
14 | rank_by_pred = np.stack(rank_by_pred, axis=-1)
15 | rank_by_pred_l.append(rank_by_pred[0])
16 |
17 | return np.array(rank_by_pred_l)
18 |
19 | def logits_recall_at_k(rank_by_pred, k_list=[1, 2, 5, 10]):
20 | # 1 dialog, 10 response candidates ground truth 1 or 0
21 | # prediction_score : [batch_size]
22 | # target : [batch_size] e.g. 1 0 0 0 0 0 0 0 0 0
23 | # e.g. batch : 100 -> 100/10 = 10
24 |
25 | num_correct = np.zeros([rank_by_pred.shape[0], len(k_list)])
26 |
27 | pos_index = []
28 | for sorted_score in rank_by_pred:
29 | for p_i, score in enumerate(sorted_score):
30 | if int(score) == 1:
31 | pos_index.append(p_i)
32 | index_dict = dict()
33 | for i, p_i in enumerate(pos_index):
34 | index_dict[i] = p_i
35 |
36 | for i, p_i in enumerate(pos_index):
37 | for j, k in enumerate(k_list):
38 | if p_i + 1 <= k:
39 | num_correct[i][j] += 1
40 |
41 | return np.sum(num_correct, axis=0), pos_index
42 |
43 | def logits_mrr(rank_by_pred):
44 | pos_index = []
45 | for sorted_score in rank_by_pred:
46 | for p_i, score in enumerate(sorted_score):
47 | if int(score) == 1:
48 | pos_index.append(p_i)
49 |
50 | # print("pos_index", pos_index)
51 | mrr = []
52 | for i, p_i in enumerate(pos_index):
53 | mrr.append(1 / (p_i + 1))
54 |
55 | # print(mrr)
56 |
57 | return np.sum(mrr)
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | An Effective Domain Adaptive Post-Training Method for BERT in Response Selection
2 | ====================================
3 | Implements the model described in the following paper [An Effective Domain Adaptive Post-Training Method for BERT in Response Selection](https://arxiv.org/abs/1908.04812v2).
4 | ```
5 | @inproceedings{whang2020domain,
6 | author={Whang, Taesun and Lee, Dongyub and Lee, Chanhee and Yang, Kisu and Oh, Dongsuk and Lim, HeuiSeok},
7 | title="An Effective Domain Adaptive Post-Training Method for BERT in Response Selection",
8 | year=2020,
9 | booktitle={Proc. Interspeech 2020}
10 | }
11 | ```
12 |
13 | This code is reimplemented as a fork of [huggingface/transformers][7].
14 |
15 |
16 |
17 |
18 |
19 |
20 | Data Creation
21 | --------
22 | 1. Download `ubuntu_train.pkl, ubuntu_valid.pkl, ubuntu_test.pkl` [here][1] or you can create `pkl` files to train response selection model based on BERT model.
23 | If you wish to create pkl, download ubuntu_corpus_v1 dataset [here][2] provided by [Xu et al. (2016)](https://arxiv.org/pdf/1605.05110.pdf) and keep the files under `data/ubuntu_corpus_v1` directory.
24 | 2. Ubuntu corpus for domain post trianing will be created by running:
25 | ```shell
26 | python data/data_utils.py
27 | ```
28 |
29 | Post Training Data Creation
30 | --------
31 | Download `ubuntu_post_training.txt` corpus [here][3] and simply run
32 | ```shell
33 | python data/create_bert_post_training_data.py
34 | ```
35 | After creating post_training data, keep `ubuntu_post_training.hdf5` file under `data/ubuntu_corpus_v1`directory.
36 |
37 | Domain Post Training BERT
38 | --------
39 | To domain post-train BERT, simply run
40 | ```shell
41 | python main.py --model bert_ubuntu_pt --train_type post_training --bert_pretrained bert-base-uncased --data_dir ./data/ubuntu_corpus_v1/ubuntu_post_training.hdf5
42 | ```
43 |
44 | BERT Fine-tuning (Response Selection)
45 | --------
46 | ### Training
47 | Train a response selection model based on `BERT_base`:
48 | ```shell
49 | python main.py --model bert_base_ft --train_type fine_tuning --bert_pretrained bert-base-uncased
50 | ```
51 |
52 | Train a response selection model based on `Domain post-trained BERT`. If you wish to get the domain post trained BERT, download model checkpoint (`bert-post-uncased-pytorch_model.pth`) [here][4],
53 | and keep checkpoint under `resources/bert-post-uncased` directory:
54 | ```shell
55 | python main.py --model bert_dpt_ft --train_type fine_tuning --bert_pretrained bert-post-uncased
56 | ```
57 |
58 | ### Evaluation
59 | To evaluate `bert_base`,`bert_dpt` models, set a model checkpoint path and simply run
60 | ```shell
61 | python main.py --model bert_dpt_ft --train_type fine_tuning --bert_pretrained bert-post-uncased --evaluate /path/to/checkpoint.pth
62 | ```
63 | If you wish to get the pre-trained response selection model, we provide the model checkpoints below.
64 |
65 | | Model | R@1 | R@2 | R@5 | MRR |
66 | |:---------:|:------:|:------:|:------:|:------:|
67 | | [BERT_base][5] | 0.8115 | 0.9003 | 0.9768 | 0.8809 |
68 | | [BERT_DPT][6] | 0.8515 | 0.9272 | 0.9851 | 0.9081 |
69 |
70 |
71 | Acknowledgements
72 | --------
73 | - This work was supported by Institute for Information & communications Technology Promotion (IITP) grant funded by the Korea government (MSIT) (no. 2016-0-00010-003, Digital Centent InHouse R&D)
74 | - Work in collaboration with [Kakao Corp][8].
75 |
76 | [1]: https://drive.google.com/drive/folders/1mLzXifYYwmlFEWDzSbbecLlzKstB8gQK?usp=sharing
77 | [2]: https://www.dropbox.com/s/2fdn26rj6h9bpvl/ubuntu_data.zip
78 | [3]: https://drive.google.com/file/d/1mYS_PrnrKx4zDWOPTFhx_SeEwdumYXCK/view?usp=sharing
79 | [4]: https://drive.google.com/file/d/1jt0RhVT9y2d4AITn84kSOk06hjIv1y49/view?usp=sharing
80 | [5]: https://drive.google.com/file/d/1amuPQ_CtfvNuQMdRR8eo0YGAQLP4XBP7/view?usp=sharing
81 | [6]: https://drive.google.com/file/d/1Ip_VqzpByWZRAgiN7OxPeyYxK6onPia0/view?usp=sharing
82 | [7]: https://github.com/huggingface/transformers
83 | [8]: https://www.kakaocorp.com
84 |
--------------------------------------------------------------------------------
/data/data_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | sys.path.append(os.getcwd())
4 | import pickle
5 | from tqdm import tqdm
6 |
7 | from models.bert import tokenization_bert
8 |
9 | class InputExamples(object):
10 | def __init__(self, utterances, response, label, seq_lengths):
11 |
12 | self.utterances = utterances
13 | self.response = response
14 | self.label = label
15 |
16 | self.dialog_len = seq_lengths[0]
17 | self.response_len = seq_lengths[1]
18 |
19 | class UbuntuDataUtils(object):
20 | def __init__(self, txt_path, bert_pretrained_dir):
21 | # bert_tokenizer init
22 | self.txt_path = txt_path
23 | self._bert_tokenizer_init(bert_pretrained_dir)
24 |
25 | def _bert_tokenizer_init(self, bert_pretrained_dir, bert_pretrained='bert-base-uncased'):
26 |
27 | self._bert_tokenizer = tokenization_bert.BertTokenizer(
28 | vocab_file=os.path.join(os.path.join(bert_pretrained_dir, bert_pretrained),
29 | "%s-vocab.txt" % bert_pretrained))
30 | print("BERT tokenizer init completes")
31 |
32 | def read_raw_file(self, data_type):
33 | print("Loading raw txt file...")
34 |
35 | ubuntu_path = self.txt_path % data_type # train, dev, test
36 | with open(ubuntu_path, "r", encoding="utf8") as fr_handle:
37 | data = [line.strip() for line in fr_handle if len(line.strip()) > 0]
38 | print("(%s) total number of sentence : %d" % (data_type, len(data)))
39 |
40 | return data
41 |
42 | def make_post_training_corpus(self, data, post_training_path):
43 | with open(post_training_path, "w", encoding="utf-8") as fw_handle:
44 | cnt = 0
45 | for document in data:
46 | dialog_data = document.split("\t")
47 | if dialog_data[0] == '0':
48 | continue
49 | for utt in dialog_data[1:-1]:
50 | if len(utt) == 0:
51 | continue
52 | fw_handle.write(utt.strip() + "\n")
53 | fw_handle.write("\n")
54 | cnt+=1
55 |
56 | def make_examples_pkl(self, data, ubuntu_pkl_path):
57 | with open(ubuntu_pkl_path, "ab") as pkl_handle:
58 | for dialog in tqdm(data):
59 | dialog_data = dialog.split("\t")
60 | label = dialog_data[0]
61 | utterances = []
62 | dialog_len = []
63 |
64 | for utt in dialog_data[1:-1]:
65 | utt_tok = self._bert_tokenizer.tokenize(utt)
66 | utterances.append(utt_tok)
67 | dialog_len.append(len(utt_tok))
68 |
69 | response = self._bert_tokenizer.tokenize(dialog_data[-1])
70 |
71 | pickle.dump(InputExamples(
72 | utterances=utterances, response=response, label=int(label),
73 | seq_lengths=(dialog_len, len(response))), pkl_handle)
74 |
75 | print(ubuntu_pkl_path, " save completes!")
76 |
77 | def ubuntu_manual(self):
78 | knowledge_path = "ubuntu_manual_knowledge.txt"
79 | ubuntu_knowledge_dict = dict()
80 | ubuntu_man_l = []
81 |
82 | with open(knowledge_path, "r", encoding="utf-8") as f_handle:
83 | for line in f_handle:
84 | ubuntu_man = line.strip().split("\t")
85 | if len(ubuntu_man) == 2:
86 | ubuntu_knowledge_dict[ubuntu_man[0]] = ubuntu_man[1]
87 | ubuntu_man_l.append(ubuntu_man[1])
88 |
89 | print(ubuntu_knowledge_dict.keys())
90 | print(len(ubuntu_knowledge_dict))
91 |
92 | if __name__ == '__main__':
93 | ubuntu_raw_path = "./data/ubuntu_corpus_v1/%s.txt"
94 | ubuntu_pkl_path = "./data/ubuntu_corpus_v1/ubuntu_%s.pkl"
95 | bert_pretrained_dir = "./resources"
96 |
97 | ubuntu_utils = UbuntuDataUtils(ubuntu_raw_path, bert_pretrained_dir)
98 |
99 | # response seleciton fine-tuning pkl creation
100 | for data_type in ["train", "valid", "test"]:
101 | data = ubuntu_utils.read_raw_file(data_type)
102 | ubuntu_utils.make_examples_pkl(data, ubuntu_pkl_path % data_type)
103 |
104 | # domain post training corpus creation
105 | for data_type in ["train"]:
106 | data = ubuntu_utils.read_raw_file(data_type)
107 | ubuntu_utils.make_post_training_corpus(data, "./data/ubuntu_corpus_v1/ubuntu_post_training.txt")
108 |
109 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | import collections
4 | import logging
5 | from datetime import datetime
6 |
7 | from config.hparams import *
8 | from data.data_utils import InputExamples
9 | from train import ResponseSelection
10 | from evaluation import Evaluation
11 | from post_train import BERTDomainPostTraining
12 |
13 | PARAMS_MAP = {
14 | # fine-tuning (ft)
15 | "bert_base_ft" : BASE_PARAMS,
16 | "bert_dpt_ft" : DPT_FINETUNING_PARAMS,
17 |
18 | # post-training (pt)
19 | "bert_ubuntu_pt" : POST_TRAINING_PARAMS,
20 | }
21 |
22 | MODEL = {
23 | "fine_tuning" : ResponseSelection,
24 | "post_training" : BERTDomainPostTraining
25 | }
26 |
27 | def init_logger(path:str):
28 | if not os.path.exists(path):
29 | os.makedirs(path)
30 | logger = logging.getLogger()
31 | logger.handlers = []
32 | logger.setLevel(logging.DEBUG)
33 | debug_fh = logging.FileHandler(os.path.join(path, "debug.log"))
34 | debug_fh.setLevel(logging.DEBUG)
35 |
36 | info_fh = logging.FileHandler(os.path.join(path, "info.log"))
37 | info_fh.setLevel(logging.INFO)
38 |
39 | ch = logging.StreamHandler()
40 | ch.setLevel(logging.INFO)
41 |
42 | info_formatter = logging.Formatter('%(asctime)s | %(levelname)-8s | %(message)s')
43 | debug_formatter = logging.Formatter('%(asctime)s | %(levelname)-8s | %(message)s | %(lineno)d:%(funcName)s')
44 |
45 | ch.setFormatter(info_formatter)
46 | info_fh.setFormatter(info_formatter)
47 | debug_fh.setFormatter(debug_formatter)
48 |
49 | logger.addHandler(ch)
50 | logger.addHandler(debug_fh)
51 | logger.addHandler(info_fh)
52 |
53 | return logger
54 |
55 | def train_model(args):
56 | hparams = PARAMS_MAP[args.model]
57 | hparams["root_dir"] = args.root_dir
58 | hparams["bert_pretrained_dir"] = args.bert_pretrained_dir
59 | hparams["bert_pretrained"] = args.bert_pretrained
60 | hparams["data_dir"] = args.data_dir
61 | hparams["model_type"] = args.model
62 |
63 | timestamp = datetime.now().strftime('%Y%m%d-%H%M%S')
64 | root_dir = os.path.join(hparams["root_dir"], args.model, args.train_type, "%s/" % timestamp)
65 | logger = init_logger(root_dir)
66 | logger.info("Hyper-parameters: %s" % str(hparams))
67 | hparams["root_dir"] = root_dir
68 |
69 | hparams = collections.namedtuple("HParams", sorted(hparams.keys()))(**hparams)
70 | model = MODEL[args.train_type](hparams)
71 | model.train()
72 |
73 | def evaluate_model(args):
74 | hparams = PARAMS_MAP[args.model]
75 |
76 | hparams = collections.namedtuple("HParams", sorted(hparams.keys()))(**hparams)
77 |
78 | model = Evaluation(hparams)
79 | model.run_evaluate(args.evaluate)
80 |
81 | if __name__ == '__main__':
82 | arg_parser = argparse.ArgumentParser(description="Bert / Response Selection (PyTorch)")
83 | arg_parser.add_argument("--model", dest="model", type=str,
84 | default="bert_base",
85 | help="Model Name")
86 | arg_parser.add_argument("--root_dir", dest="root_dir", type=str,
87 | default="./results",
88 | help="model train logs, checkpoints")
89 | arg_parser.add_argument("--data_dir", dest="data_dir", type=str,
90 | default="./data/ubuntu_corpus_v1/%s_%s.pkl",
91 | help="ubuntu corpus v1 pkl path") # ubuntu_train.pkl, ubuntu_valid_pkl, ubuntu_test.pkl
92 | arg_parser.add_argument("--bert_pretrained_dir", dest="bert_pretrained_dir", type=str,
93 | default="./resources",
94 | help="bert pretrained directory")
95 | arg_parser.add_argument("--bert_pretrained", dest="bert_pretrained", type=str,
96 | default="bert-post-uncased",
97 | help="bert pretrained directory") # bert-base-uncased, bert-post-uncased -> under bert_pretrained_dir
98 | arg_parser.add_argument("--train_type", dest="train_type", type=str,
99 | default="fine_tuning",
100 | help="Train type") # fine_tuning, post_training
101 | arg_parser.add_argument("--evaluate", dest="evaluate", type=str,
102 | help="Evaluation Checkpoint", default="")
103 |
104 | args = arg_parser.parse_args()
105 | if args.evaluate:
106 | evaluate_model(args)
107 | else:
108 | train_model(args)
--------------------------------------------------------------------------------
/evaluation.py:
--------------------------------------------------------------------------------
1 | import os
2 | # os.environ["CUDA_VISIBLE_DEVICES"] = "1"
3 | import logging
4 | from tqdm import tqdm
5 | import numpy as np
6 |
7 | import torch
8 | import torch.nn as nn
9 | from torch.utils.data import DataLoader
10 |
11 | from models import Model
12 | from data.dataset import ResponseSelectionDataset
13 | from models.utils.checkpointing import load_checkpoint
14 | from models.utils.scorer import calculate_candidates_ranking, logits_mrr, logits_recall_at_k
15 |
16 | class Evaluation(object):
17 | def __init__(self, hparams, model=None, split = "test"):
18 |
19 | self.hparams = hparams
20 | self.model = model
21 | self._logger = logging.getLogger(__name__)
22 | self.device = (torch.device("cuda", self.hparams.gpu_ids[0])
23 | if self.hparams.gpu_ids[0] >= 0 else torch.device("cpu"))
24 | self.split = split
25 | print("Evaluation Split :", self.split)
26 | do_valid, do_test = False, False
27 | if split == "valid":
28 | do_valid = True
29 | else:
30 | do_test = True
31 | self._build_dataloader(do_valid=do_valid, do_test=do_test)
32 | self._dataloader = self.valid_dataloader if split == 'valid' else self.test_dataloader
33 |
34 | if model is None:
35 | print("No pre-defined model!")
36 | self._build_model()
37 |
38 | def _build_dataloader(self, do_valid=False, do_test=False):
39 |
40 | if do_valid:
41 | self.valid_dataset = ResponseSelectionDataset(
42 | self.hparams,
43 | split="valid",
44 | )
45 | self.valid_dataloader = DataLoader(
46 | self.valid_dataset,
47 | batch_size=self.hparams.eval_batch_size,
48 | num_workers=self.hparams.cpu_workers,
49 | drop_last=False,
50 | )
51 |
52 | if do_test:
53 | self.test_dataset = ResponseSelectionDataset(
54 | self.hparams,
55 | split="test",
56 | )
57 |
58 | self.test_dataloader = DataLoader(
59 | self.test_dataset,
60 | batch_size=self.hparams.eval_batch_size,
61 | num_workers=self.hparams.cpu_workers,
62 | drop_last=False,
63 | )
64 |
65 | def _build_model(self):
66 | self.model = Model(self.hparams)
67 | self.model = self.model.to(self.device)
68 | # Use Multi-GPUs
69 | if -1 not in self.hparams.gpu_ids and len(self.hparams.gpu_ids) > 1:
70 | self.model = nn.DataParallel(self.model, self.hparams.gpu_ids)
71 |
72 | def run_evaluate(self, evaluation_path):
73 | self._logger.info("Evaluation")
74 | model_state_dict, optimizer_state_dict = load_checkpoint(evaluation_path)
75 |
76 | if isinstance(self.model, nn.DataParallel):
77 | self.model.module.load_state_dict(model_state_dict)
78 | else:
79 | self.model.load_state_dict(model_state_dict)
80 |
81 | k_list = self.hparams.recall_k_list
82 | total_mrr = 0
83 | total_examples, total_correct = 0, 0
84 | self.model.eval()
85 | with torch.no_grad():
86 | for batch_idx, batch in enumerate(tqdm(self._dataloader)):
87 | buffer_batch = batch.copy()
88 | for key in batch:
89 | buffer_batch[key] = batch[key].to(self.device)
90 |
91 | logits = self.model(buffer_batch)
92 | pred = torch.sigmoid(logits).to("cpu").tolist() # bs
93 |
94 | rank_by_pred = calculate_candidates_ranking(np.array(pred), np.array(buffer_batch["label"].to("cpu").tolist()),
95 | self.hparams.evaluate_candidates_num)
96 | num_correct, pos_index = logits_recall_at_k(rank_by_pred, k_list)
97 |
98 | total_mrr += logits_mrr(rank_by_pred)
99 |
100 | total_correct = np.add(total_correct, num_correct)
101 | total_examples = (batch_idx + 1) * rank_by_pred.shape[0]
102 |
103 | recall_result = ""
104 | if (batch_idx + 1) % self.hparams.evaluate_print_step == 0:
105 | for i in range(len(k_list)):
106 | recall_result += "Recall@%s : " % k_list[i] + "%.2f%% | " % ((total_correct[i] / total_examples) * 100)
107 | else:
108 | print("%d[th] | %s | MRR : %.3f" % (batch_idx + 1, recall_result, float(total_mrr / total_examples)))
109 | self._logger.info("%d[th] | %s | MRR : %.3f" % (batch_idx + 1, recall_result, float(total_mrr / total_examples)))
110 |
111 | avg_mrr = float(total_mrr / total_examples)
112 | recall_result = ""
113 |
114 | for i in range(len(k_list)):
115 | recall_result += "Recall@%s : " % k_list[i] + "%.2f%% | " % ((total_correct[i] / total_examples) * 100)
116 | self._logger.info(recall_result)
117 | self._logger.info("MRR: %.4f" % avg_mrr)
--------------------------------------------------------------------------------
/models/utils/checkpointing.py:
--------------------------------------------------------------------------------
1 | """
2 | A checkpoint manager periodically saves model and optimizer as .pth
3 | files during training.
4 | Checkpoint managers help with experiment reproducibility, they record
5 | the commit SHA of your current codebase in the checkpoint saving
6 | directory. While loading any checkpoint from other commit, they raise a
7 | friendly warning, a signal to inspect commit diffs for potential bugs.
8 | Moreover, they copy experiment hyper-parameters as a YAML config in
9 | this directory.
10 | That said, always run your experiments after committing your changes,
11 | this doesn't account for untracked or staged, but uncommitted changes.
12 | """
13 | from pathlib import Path
14 | from subprocess import PIPE, Popen
15 | import warnings
16 |
17 | import torch
18 | from torch import nn, optim
19 | import json
20 |
21 | class CheckpointManager(object):
22 | """A checkpoint manager saves state dicts of model and optimizer
23 | as .pth files in a specified directory. This class closely follows
24 | the API of PyTorch optimizers and learning rate schedulers.
25 | Note::
26 | For ``DataParallel`` modules, ``model.module.state_dict()`` is
27 | saved, instead of ``model.state_dict()``.
28 | Parameters
29 | ----------
30 | model: nn.Module
31 | Wrapped model, which needs to be checkpointed.
32 | optimizer: optim.Optimizer
33 | Wrapped optimizer which needs to be checkpointed.
34 | checkpoint_dirpath: str
35 | Path to an empty or non-existent directory to save checkpoints.
36 | step_size: int, optional (default=1)
37 | Period of saving checkpoints.
38 | last_epoch: int, optional (default=-1)
39 | The index of last epoch.
40 | Example
41 | --------
42 | >>> model = torch.nn.Linear(10, 2)
43 | >>> optimizer = torch.optim.Adam(model.parameters())
44 | >>> ckpt_manager = CheckpointManager(model, optimizer, "/tmp/ckpt")
45 | >>> for epoch in range(20):
46 | ... for batch in dataloader:
47 | ... do_iteration(batch)
48 | ... ckpt_manager.step()
49 | """
50 |
51 | def __init__(
52 | self,
53 | model,
54 | optimizer,
55 | checkpoint_dirpath,
56 | step_size=1,
57 | last_epoch=-1,
58 | **kwargs,
59 | ):
60 |
61 | if not isinstance(model, nn.Module):
62 | raise TypeError("{} is not a Module".format(type(model).__name__))
63 |
64 | if not isinstance(optimizer, optim.Optimizer):
65 | raise TypeError(
66 | "{} is not an Optimizer".format(type(optimizer).__name__)
67 | )
68 |
69 | self.model = model
70 | self.optimizer = optimizer
71 | self.ckpt_dirpath = Path(checkpoint_dirpath)
72 | self.step_size = step_size
73 | self.last_epoch = last_epoch
74 | self.init_directory(**kwargs)
75 |
76 | def init_directory(self, hparams):
77 | """Initialize empty checkpoint directory and record commit SHA
78 | in it. Also save hyper-parameters config in this directory to
79 | associate checkpoints with their hyper-parameters.
80 | """
81 |
82 | self.ckpt_dirpath.mkdir(parents=True, exist_ok=True)
83 | # save current git commit hash in this checkpoint directory
84 | commit_sha_subprocess = Popen(
85 | ["git", "rev-parse", "--short", "HEAD"], stdout=PIPE, stderr=PIPE
86 | )
87 | commit_sha, _ = commit_sha_subprocess.communicate()
88 | commit_sha = commit_sha.decode("utf-8").strip().replace("\n", "")
89 | commit_sha_filepath = self.ckpt_dirpath / f".commit-{commit_sha}"
90 | commit_sha_filepath.touch()
91 | with open(str(self.ckpt_dirpath / "hparams.json"), 'w') as hparams_handle:
92 | json.dump(hparams, hparams_handle)
93 |
94 | def step(self, epoch=None):
95 | """Save checkpoint if step size conditions meet. """
96 |
97 | if not epoch:
98 | epoch = self.last_epoch + 1
99 | self.last_epoch = epoch
100 |
101 | if not self.last_epoch % self.step_size:
102 | torch.save(
103 | {
104 | "model": self._model_state_dict(),
105 | "optimizer": self.optimizer.state_dict(),
106 | },
107 | self.ckpt_dirpath / f"checkpoint_{self.last_epoch}.pth",
108 | )
109 |
110 | def _model_state_dict(self):
111 | """Returns state dict of model, taking care of DataParallel case."""
112 | if isinstance(self.model, nn.DataParallel):
113 | return self.model.module.state_dict()
114 | else:
115 | return self.model.state_dict()
116 |
117 |
118 | def load_checkpoint(checkpoint_pthpath):
119 | """Given a path to saved checkpoint, load corresponding state dicts
120 | of model and optimizer from it. This method checks if the current
121 | commit SHA of codebase matches the commit SHA recorded when this
122 | checkpoint was saved by checkpoint manager.
123 | Parameters
124 | ----------
125 | checkpoint_pthpath: str or pathlib.Path
126 | Path to saved checkpoint (as created by ``CheckpointManager``).
127 | Returns
128 | -------
129 | nn.Module, optim.Optimizer
130 | Model and optimizer state dicts loaded from checkpoint.
131 | Raises
132 | ------
133 | UserWarning
134 | If commit SHA do not match, or if the directory doesn't have
135 | the recorded commit SHA.
136 | """
137 |
138 | if isinstance(checkpoint_pthpath, str):
139 | checkpoint_pthpath = Path(checkpoint_pthpath)
140 | checkpoint_dirpath = checkpoint_pthpath.resolve().parent
141 | checkpoint_commit_sha = list(checkpoint_dirpath.glob(".commit-*"))
142 |
143 | if len(checkpoint_commit_sha) == 0:
144 | warnings.warn(
145 | "Commit SHA was not recorded while saving checkpoints."
146 | )
147 | else:
148 | # verify commit sha, raise warning if it doesn't match
149 | commit_sha_subprocess = Popen(
150 | ["git", "rev-parse", "--short", "HEAD"], stdout=PIPE, stderr=PIPE
151 | )
152 | commit_sha, _ = commit_sha_subprocess.communicate()
153 | commit_sha = commit_sha.decode("utf-8").strip().replace("\n", "")
154 |
155 | # remove ".commit-"
156 | checkpoint_commit_sha = checkpoint_commit_sha[0].name[8:]
157 |
158 | if commit_sha != checkpoint_commit_sha:
159 | warnings.warn(
160 | f"Current commit ({commit_sha}) and the commit "
161 | f"({checkpoint_commit_sha}) at which checkpoint was saved,"
162 | " are different. This might affect reproducibility."
163 | )
164 |
165 | # load encoder, decoder, optimizer state_dicts
166 | components = torch.load(checkpoint_pthpath)
167 |
168 | return components["model"], components["optimizer"]
--------------------------------------------------------------------------------
/data/dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import pickle
4 | import h5py
5 | import numpy as np
6 |
7 | from models.bert import tokenization_bert
8 | from torch.utils.data import Dataset
9 |
10 | class ResponseSelectionDataset(Dataset):
11 | """
12 | A full representation of VisDial v1.0 (train/val/test) dataset. According
13 | to the appropriate split, it returns dictionary of question, image,
14 | history, ground truth answer, answer options, dense annotations etc.
15 | """
16 | def __init__(
17 | self,
18 | hparams,
19 | split: str = "",
20 | ):
21 | super().__init__()
22 |
23 | self.hparams = hparams
24 | self.split = split
25 |
26 | # read pkls -> Input Examples
27 | self.input_examples = []
28 | with open(hparams.data_dir % (hparams.task_name, split), "rb") as pkl_handle:
29 | while True:
30 | try:
31 | self.input_examples.append(pickle.load(pkl_handle))
32 | if len(self.input_examples) % 100000 == 0:
33 | print("%d examples has been loaded!" % len(self.input_examples))
34 | except EOFError:
35 | break
36 |
37 | print("total %s examples" % split, len(self.input_examples))
38 |
39 | bert_pretrained_dir = os.path.join(self.hparams.bert_pretrained_dir, self.hparams.bert_pretrained)
40 | print(bert_pretrained_dir)
41 | self._bert_tokenizer = tokenization_bert.BertTokenizer(
42 | vocab_file=os.path.join(bert_pretrained_dir, "%s-vocab.txt" % self.hparams.bert_pretrained))
43 |
44 | # End of Turn Token
45 | if self.hparams.do_eot:
46 | self._bert_tokenizer.add_tokens(["[EOT]"])
47 |
48 | def __len__(self):
49 | return len(self.input_examples)
50 |
51 | def __getitem__(self, index):
52 | # Get Input Examples
53 | """
54 | InputExamples
55 | self.utterances = utterances
56 | self.response = response
57 | self.label
58 | """
59 |
60 | anno_sent, segment_ids, attention_mask = self._annotate_sentence(self.input_examples[index])
61 |
62 | current_feature = dict()
63 | current_feature["anno_sent"] = torch.tensor(anno_sent).long()
64 | current_feature["segment_ids"] = torch.tensor(segment_ids).long()
65 | current_feature["attention_mask"] = torch.tensor(attention_mask).long()
66 | current_feature["label"] = torch.tensor(self.input_examples[index].label).float()
67 |
68 | return current_feature
69 |
70 | def _annotate_sentence(self, example):
71 |
72 | dialog_context = []
73 | if self.hparams.do_eot:
74 | for utt in example.utterances:
75 | dialog_context.extend(utt + ["[EOT]"])
76 | else:
77 | for utt in example.utterances:
78 | dialog_context.extend(utt)
79 |
80 | # Set Dialog Context length to 280, Response length to 40
81 | dialog_context, response = self._max_len_trim_seq(dialog_context, example.response)
82 |
83 | # dialog context
84 | dialog_context = ["[CLS]"] + dialog_context + ["[SEP]"]
85 | segment_ids = [0] * self.hparams.max_dialog_len
86 | attention_mask = [1] * len(dialog_context)
87 |
88 | while len(dialog_context) < self.hparams.max_dialog_len:
89 | dialog_context.append("[PAD]")
90 | attention_mask.append(0)
91 |
92 | assert len(dialog_context) == len(segment_ids) == len(attention_mask)
93 |
94 | response = response + ["[SEP]"]
95 | segment_ids.extend([1] * len(response))
96 | attention_mask.extend([1] * len(response))
97 |
98 | while len(response) < self.hparams.max_response_len:
99 | response.append("[PAD]")
100 | segment_ids.append(0)
101 | attention_mask.append(0)
102 |
103 | dialog_response = dialog_context + response
104 |
105 | # print(segment_ids)
106 | # print(attention_mask)
107 | # print(len(dialog_response), len(segment_ids), len(attention_mask))
108 |
109 | assert len(dialog_response) == len(segment_ids) == len(attention_mask)
110 | anno_sent = self._bert_tokenizer.convert_tokens_to_ids(dialog_response)
111 |
112 | return anno_sent, segment_ids, attention_mask
113 |
114 | def _max_len_trim_seq(self, dialog_context, response):
115 | while len(dialog_context) > self.hparams.max_dialog_len - 2:
116 | dialog_context.pop(0) # from the front
117 |
118 | while len(response) > self.hparams.max_response_len - 1:
119 | response.pop() # from the back
120 |
121 | return dialog_context, response
122 |
123 | class BertPostTrainingDataset(Dataset):
124 | """
125 | A full representation of VisDial v1.0 (train/val/test) dataset. According
126 | to the appropriate split, it returns dictionary of question, image,
127 | history, ground truth answer, answer options, dense annotations etc.
128 | """
129 |
130 | def __init__(
131 | self,
132 | hparams,
133 | split: str = "",
134 | ):
135 | super().__init__()
136 |
137 | self.hparams = hparams
138 | self.split = split
139 |
140 | with h5py.File(self.hparams.data_dir, "r") as features_hdf:
141 | self.feature_keys = list(features_hdf.keys())
142 | self.num_instances = np.array(features_hdf.get("next_sentence_labels")).shape[0]
143 | print("total %s examples : %d" % (split, self.num_instances))
144 |
145 | def __len__(self):
146 | return self.num_instances
147 |
148 | def __getitem__(self, index):
149 | # Get Input Examples
150 | """
151 | InputExamples
152 | self.utterances = utterances
153 | self.response = response
154 | self.label
155 | """
156 | features = self._read_hdf_features(index)
157 | anno_masked_lm_labels = self._anno_mask_inputs(features["masked_lm_ids"], features["masked_lm_positions"])
158 | curr_features = dict()
159 | for feat_key in features.keys():
160 | curr_features[feat_key] = torch.tensor(features[feat_key]).long()
161 | curr_features["masked_lm_labels"] = torch.tensor(anno_masked_lm_labels).long()
162 |
163 | return curr_features
164 |
165 | def _read_hdf_features(self, index):
166 | features = {}
167 | with h5py.File(self.hparams.data_dir, "r") as features_hdf:
168 | for f_key in self.feature_keys:
169 | features[f_key] = features_hdf[f_key][index]
170 |
171 | return features
172 |
173 | def _anno_mask_inputs(self, masked_lm_ids, masked_lm_positions, max_seq_len=512):
174 | # masked_lm_ids -> labels
175 | anno_masked_lm_labels = [-1] * max_seq_len
176 |
177 | for pos, label in zip(masked_lm_positions, masked_lm_ids):
178 | if pos == 0: continue
179 | anno_masked_lm_labels[pos] = label
180 |
181 | return anno_masked_lm_labels
--------------------------------------------------------------------------------
/models/bert/configuration_bert.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | """ BERT model configuration """
17 |
18 | from __future__ import absolute_import, division, print_function, unicode_literals
19 |
20 | import json
21 | import logging
22 | import sys
23 | from io import open
24 |
25 | from models.pretrained_common.configuration_utils import PretrainedConfig
26 |
27 | logger = logging.getLogger(__name__)
28 |
29 | BERT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
30 | 'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-config.json",
31 | 'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-config.json",
32 | 'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-config.json",
33 | 'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-config.json",
34 | 'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-config.json",
35 | 'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-config.json",
36 | 'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-config.json",
37 | 'bert-base-german-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-cased-config.json",
38 | 'bert-large-uncased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-config.json",
39 | 'bert-large-cased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-config.json",
40 | 'bert-large-uncased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-finetuned-squad-config.json",
41 | 'bert-large-cased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad-config.json",
42 | 'bert-base-cased-finetuned-mrpc': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-config.json",
43 | }
44 |
45 |
46 | class BertConfig(PretrainedConfig):
47 | r"""
48 | :class:`~pytorch_transformers.BertConfig` is the configuration class to store the configuration of a
49 | `BertModel`.
50 | Arguments:
51 | vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `BertModel`.
52 | hidden_size: Size of the encoder layers and the pooler layer.
53 | num_hidden_layers: Number of hidden layers in the Transformer encoder.
54 | num_attention_heads: Number of attention heads for each attention layer in
55 | the Transformer encoder.
56 | intermediate_size: The size of the "intermediate" (i.e., feed-forward)
57 | layer in the Transformer encoder.
58 | hidden_act: The non-linear activation function (function or string) in the
59 | encoder and pooler. If string, "gelu", "relu" and "swish" are supported.
60 | hidden_dropout_prob: The dropout probabilitiy for all fully connected
61 | layers in the embeddings, encoder, and pooler.
62 | attention_probs_dropout_prob: The dropout ratio for the attention
63 | probabilities.
64 | max_position_embeddings: The maximum sequence length that this model might
65 | ever be used with. Typically set this to something large just in case
66 | (e.g., 512 or 1024 or 2048).
67 | type_vocab_size: The vocabulary size of the `token_type_ids` passed into
68 | `BertModel`.
69 | initializer_range: The sttdev of the truncated_normal_initializer for
70 | initializing all weight matrices.
71 | layer_norm_eps: The epsilon used by LayerNorm.
72 | """
73 | pretrained_config_archive_map = BERT_PRETRAINED_CONFIG_ARCHIVE_MAP
74 |
75 | def __init__(self,
76 | vocab_size_or_config_json_file=30522,
77 | hidden_size=768,
78 | num_hidden_layers=12,
79 | num_attention_heads=12,
80 | intermediate_size=3072,
81 | hidden_act="gelu",
82 | hidden_dropout_prob=0.1,
83 | attention_probs_dropout_prob=0.1,
84 | max_position_embeddings=512,
85 | type_vocab_size=2,
86 | initializer_range=0.02,
87 | layer_norm_eps=1e-12,
88 | **kwargs):
89 | super(BertConfig, self).__init__(**kwargs)
90 | if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2
91 | and isinstance(vocab_size_or_config_json_file, unicode)):
92 | with open(vocab_size_or_config_json_file, "r", encoding='utf-8') as reader:
93 | json_config = json.loads(reader.read())
94 | for key, value in json_config.items():
95 | self.__dict__[key] = value
96 | elif isinstance(vocab_size_or_config_json_file, int):
97 | self.vocab_size = vocab_size_or_config_json_file
98 | self.hidden_size = hidden_size
99 | self.num_hidden_layers = num_hidden_layers
100 | self.num_attention_heads = num_attention_heads
101 | self.hidden_act = hidden_act
102 | self.intermediate_size = intermediate_size
103 | self.hidden_dropout_prob = hidden_dropout_prob
104 | self.attention_probs_dropout_prob = attention_probs_dropout_prob
105 | self.max_position_embeddings = max_position_embeddings
106 | self.type_vocab_size = type_vocab_size
107 | self.initializer_range = initializer_range
108 | self.layer_norm_eps = layer_norm_eps
109 | else:
110 | raise ValueError("First argument must be either a vocabulary size (int)"
111 | " or the path to a pretrained model config file (str)")
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | os.environ["CUDA_VISIBLE_DEVICES"] = "0"
3 | import logging
4 |
5 | from datetime import datetime
6 | from tqdm import tqdm
7 |
8 | import torch
9 | from torch import nn, optim
10 | from torch.utils.data import DataLoader
11 | from torch.utils.tensorboard import SummaryWriter
12 |
13 | from data.dataset import ResponseSelectionDataset
14 | from models.utils.checkpointing import CheckpointManager, load_checkpoint
15 | from models import Model
16 | from evaluation import Evaluation
17 |
18 | class ResponseSelection(object):
19 | def __init__(self, hparams):
20 | self.hparams = hparams
21 | self._logger = logging.getLogger(__name__)
22 |
23 | def _build_dataloader(self):
24 | # =============================================================================
25 | # SETUP DATASET, DATALOADER
26 | # =============================================================================
27 | self.train_dataset = ResponseSelectionDataset(self.hparams, split="train")
28 | self.train_dataloader = DataLoader(
29 | self.train_dataset,
30 | batch_size=self.hparams.train_batch_size,
31 | num_workers=self.hparams.cpu_workers,
32 | shuffle=True,
33 | drop_last=True
34 | )
35 |
36 | print("""
37 | # -------------------------------------------------------------------------
38 | # DATALOADER FINISHED
39 | # -------------------------------------------------------------------------
40 | """)
41 |
42 | def _build_model(self):
43 | # =============================================================================
44 | # MODEL : Standard, Mention Pooling, Entity Marker
45 | # =============================================================================
46 | print('\t* Building model...')
47 |
48 | self.model = Model(self.hparams)
49 | self.model = self.model.to(self.device)
50 |
51 | # Use Multi-GPUs
52 | if -1 not in self.hparams.gpu_ids and len(self.hparams.gpu_ids) > 1:
53 | self.model = nn.DataParallel(self.model, self.hparams.gpu_ids)
54 |
55 | # =============================================================================
56 | # CRITERION
57 | # =============================================================================
58 | self.criterion = nn.BCEWithLogitsLoss()
59 |
60 | self.optimizer = optim.Adam(self.model.parameters(), lr=self.hparams.learning_rate)
61 | self.iterations = len(self.train_dataset) // self.hparams.virtual_batch_size
62 |
63 | def _setup_training(self):
64 | if self.hparams.save_dirpath == 'checkpoints/':
65 | self.save_dirpath = os.path.join(self.hparams.root_dir, self.hparams.save_dirpath)
66 | self.summary_writer = SummaryWriter(self.save_dirpath)
67 | self.checkpoint_manager = CheckpointManager(self.model, self.optimizer, self.save_dirpath, hparams=self.hparams)
68 |
69 | # If loading from checkpoint, adjust start epoch and load parameters.
70 | if self.hparams.load_pthpath == "":
71 | self.start_epoch = 1
72 | else:
73 | # "path/to/checkpoint_xx.pth" -> xx
74 | self.start_epoch = int(self.hparams.load_pthpath.split("_")[-1][:-4])
75 | self.start_epoch += 1
76 | model_state_dict, optimizer_state_dict = load_checkpoint(self.hparams.load_pthpath)
77 | if isinstance(self.model, nn.DataParallel):
78 | self.model.module.load_state_dict(model_state_dict)
79 | else:
80 | self.model.load_state_dict(model_state_dict)
81 | self.optimizer.load_state_dict(optimizer_state_dict)
82 | self.previous_model_path = self.hparams.load_pthpath
83 | print("Loaded model from {}".format(self.hparams.load_pthpath))
84 |
85 | print(
86 | """
87 | # -------------------------------------------------------------------------
88 | # Setup Training Finished
89 | # -------------------------------------------------------------------------
90 | """
91 | )
92 |
93 | def train(self):
94 | self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
95 |
96 | self._build_dataloader()
97 | self._build_model()
98 | self._setup_training()
99 |
100 | # Evaluation Setup
101 | evaluation = Evaluation(self.hparams, model=self.model, split="test")
102 |
103 | start_time = datetime.now().strftime('%H:%M:%S')
104 | self._logger.info("Start train model at %s" % start_time)
105 |
106 | train_begin = datetime.utcnow() # New
107 | global_iteration_step = 0
108 | accumulate_loss = 0
109 | accu_count = 0
110 | for epoch in range(self.start_epoch, self.hparams.num_epochs):
111 | self.model.train()
112 |
113 | tqdm_batch_iterator = tqdm(self.train_dataloader)
114 | accumulate_batch = 0
115 |
116 | for batch_idx, batch in enumerate(tqdm_batch_iterator):
117 | buffer_batch = batch.copy()
118 | for key in batch:
119 | buffer_batch[key] = buffer_batch[key].to(self.device)
120 |
121 | logits = self.model(buffer_batch)
122 | loss = self.criterion(logits, buffer_batch["label"])
123 |
124 | loss.backward()
125 | accumulate_loss += loss.item()
126 | accu_count += 1
127 |
128 | # TODO: virtual batch implementation
129 | accumulate_batch += buffer_batch["label"].shape[0]
130 | if self.hparams.virtual_batch_size == accumulate_batch \
131 | or batch_idx == (len(self.train_dataset) // self.hparams.train_batch_size): # last batch
132 | nn.utils.clip_grad_norm_(self.model.parameters(), self.hparams.max_gradient_norm)
133 |
134 | self.optimizer.step()
135 | self.optimizer.zero_grad()
136 | accumulate_batch = 0
137 |
138 | global_iteration_step += 1
139 | description = "[{}][Epoch: {:3d}][Iter: {:6d}][Loss: {:6f}][lr: {:7f}]".format(
140 | datetime.utcnow() - train_begin,
141 | epoch,
142 | global_iteration_step, (accumulate_loss / accu_count),
143 | self.optimizer.param_groups[0]['lr'])
144 | tqdm_batch_iterator.set_description(description)
145 |
146 | # tensorboard
147 | if global_iteration_step % self.hparams.tensorboard_step == 0:
148 | description = "[{}][Epoch: {:3d}][Iter: {:6d}][Loss: {:6f}][lr: {:7f}]".format(
149 | datetime.utcnow() - train_begin,
150 | epoch,
151 | global_iteration_step, (accumulate_loss / accu_count),
152 | self.optimizer.param_groups[0]['lr'],
153 | )
154 | self._logger.info(description)
155 | accumulate_loss, accu_count = 0, 0
156 |
157 | # -------------------------------------------------------------------------
158 | # ON EPOCH END (checkpointing and validation)
159 | # -------------------------------------------------------------------------
160 | self.checkpoint_manager.step(epoch)
161 | self.previous_model_path = os.path.join(self.checkpoint_manager.ckpt_dirpath, "checkpoint_%d.pth" % (epoch))
162 | self._logger.info(self.previous_model_path)
163 |
164 | torch.cuda.empty_cache()
165 | self._logger.info("Evaluation after %d epoch" % epoch)
166 | evaluation.run_evaluate(self.previous_model_path)
167 | torch.cuda.empty_cache()
168 |
--------------------------------------------------------------------------------
/post_train.py:
--------------------------------------------------------------------------------
1 | import os
2 | import logging
3 |
4 | from datetime import datetime
5 | from tqdm import tqdm
6 |
7 | import torch
8 | from torch import nn, optim
9 | from torch.utils.data import DataLoader
10 | from torch.utils.tensorboard import SummaryWriter
11 |
12 | from data.dataset import BertPostTrainingDataset
13 | from models.utils.checkpointing import CheckpointManager, load_checkpoint
14 | from models import Model
15 |
16 | class BERTDomainPostTraining(object):
17 | def __init__(self, hparams):
18 | self.hparams = hparams
19 | self._logger = logging.getLogger(__name__)
20 |
21 | def _build_dataloader(self):
22 | # =============================================================================
23 | # SETUP DATASET, DATALOADER
24 | # =============================================================================
25 | self.train_dataset = BertPostTrainingDataset(self.hparams, split="train")
26 | self.train_dataloader = DataLoader(
27 | self.train_dataset,
28 | batch_size=self.hparams.train_batch_size,
29 | num_workers=self.hparams.cpu_workers,
30 | shuffle=False,
31 | drop_last=True
32 | )
33 |
34 | print("""
35 | # -------------------------------------------------------------------------
36 | # DATALOADER FINISHED
37 | # -------------------------------------------------------------------------
38 | """)
39 |
40 | def _build_model(self):
41 | # =============================================================================
42 | # MODEL : Standard, Mention Pooling, Entity Marker
43 | # =============================================================================
44 | print('\t* Building model...')
45 |
46 | self.model = Model(self.hparams)
47 | self.model = self.model.to(self.device)
48 |
49 | # Use Multi-GPUs
50 | if -1 not in self.hparams.gpu_ids and len(self.hparams.gpu_ids) > 1:
51 | self.model = nn.DataParallel(self.model, self.hparams.gpu_ids)
52 |
53 | self.optimizer = optim.Adam(self.model.parameters(), lr=self.hparams.learning_rate)
54 | self.iterations = len(self.train_dataset) // self.hparams.virtual_batch_size
55 |
56 | print(
57 | """
58 | # -------------------------------------------------------------------------
59 | # Building Model Finished
60 | # -------------------------------------------------------------------------
61 | """
62 | )
63 |
64 | def _setup_training(self):
65 | if self.hparams.save_dirpath == 'checkpoints/':
66 | self.save_dirpath = os.path.join(self.hparams.root_dir, self.hparams.save_dirpath)
67 | self.summary_writer = SummaryWriter(self.save_dirpath)
68 | self.checkpoint_manager = CheckpointManager(self.model, self.optimizer, self.save_dirpath, hparams=self.hparams)
69 |
70 | # If loading from checkpoint, adjust start epoch and load parameters.
71 | if self.hparams.load_pthpath == "":
72 | self.start_epoch = 1
73 | else:
74 | # "path/to/checkpoint_xx.pth" -> xx
75 | self.start_epoch = int(self.hparams.load_pthpath.split("_")[-1][:-4])
76 | self.start_epoch += 1
77 | model_state_dict, optimizer_state_dict = load_checkpoint(self.hparams.load_pthpath)
78 | if isinstance(self.model, nn.DataParallel):
79 | self.model.module.load_state_dict(model_state_dict)
80 | else:
81 | self.model.load_state_dict(model_state_dict)
82 | self.optimizer.load_state_dict(optimizer_state_dict)
83 | self.previous_model_path = self.hparams.load_pthpath
84 | print("Loaded model from {}".format(self.hparams.load_pthpath))
85 |
86 | print(
87 | """
88 | # -------------------------------------------------------------------------
89 | # Setup Training Finished
90 | # -------------------------------------------------------------------------
91 | """
92 | )
93 |
94 | def train(self):
95 | self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
96 |
97 | self._build_dataloader()
98 | self._build_model()
99 | self._setup_training()
100 |
101 | start_time = datetime.now().strftime('%H:%M:%S')
102 | self._logger.info("Start train model at %s" % start_time)
103 |
104 | train_begin = datetime.utcnow() # New
105 | global_iteration_step = 0
106 | accu_mlm_loss, accu_nsp_loss = 0, 0
107 | accumulate_batch, accu_count = 0, 0
108 |
109 | for epoch in range(self.start_epoch, self.hparams.num_epochs):
110 | self.model.train()
111 |
112 | tqdm_batch_iterator = tqdm(self.train_dataloader)
113 | for batch_idx, batch in enumerate(tqdm_batch_iterator):
114 | buffer_batch = batch.copy()
115 | for key in batch:
116 | buffer_batch[key] = buffer_batch[key].to(self.device)
117 |
118 | mlm_loss, nsp_loss = self.model(buffer_batch)
119 | total_loss = mlm_loss.mean() + nsp_loss.mean()
120 | total_loss.backward()
121 | accu_mlm_loss += mlm_loss.mean().item()
122 | accu_nsp_loss += nsp_loss.mean().item()
123 | accu_count += 1
124 |
125 | # TODO: virtual batch implementation
126 | accumulate_batch += buffer_batch["next_sentence_labels"].shape[0]
127 | if self.hparams.virtual_batch_size == accumulate_batch \
128 | or batch_idx == (len(self.train_dataset) // self.hparams.train_batch_size): # last batch
129 |
130 | nn.utils.clip_grad_norm_(self.model.parameters(), self.hparams.max_gradient_norm)
131 |
132 | self.optimizer.step()
133 | self.optimizer.zero_grad()
134 |
135 | global_iteration_step += 1
136 | description = "[{}][Epoch: {:3d}][Iter: {:6d}][MLM_Loss: {:6f}][NSP_Loss: {:6f}][lr: {:7f}]".format(
137 | datetime.utcnow() - train_begin,
138 | epoch,
139 | global_iteration_step, (accu_mlm_loss / accu_count), (accu_nsp_loss / accu_count),
140 | self.optimizer.param_groups[0]['lr'])
141 | tqdm_batch_iterator.set_description(description)
142 |
143 | # tensorboard
144 | if global_iteration_step % self.hparams.tensorboard_step == 0:
145 | description = "[{}][Epoch: {:3d}][Iter: {:6d}]MLM_Loss: {:6f}][NSP_Loss: {:6f}][lr: {:7f}]".format(
146 | datetime.utcnow() - train_begin,
147 | epoch,
148 | global_iteration_step, (accu_mlm_loss / accu_count), (accu_nsp_loss / accu_count),
149 | self.optimizer.param_groups[0]['lr'],
150 | )
151 | self._logger.info(description)
152 |
153 | accumulate_batch, accu_count = 0, 0
154 | accu_mlm_loss, accu_nsp_loss = 0, 0
155 |
156 | if global_iteration_step % self.hparams.checkpoint_save_step == 0:
157 | # -------------------------------------------------------------------------
158 | # ON EPOCH END (checkpointing and validation)
159 | # -------------------------------------------------------------------------
160 | self.checkpoint_manager.step(global_iteration_step)
161 | self.previous_model_path = os.path.join(self.checkpoint_manager.ckpt_dirpath,
162 | "checkpoint_%d.pth" % (global_iteration_step))
163 | self._logger.info(self.previous_model_path)
--------------------------------------------------------------------------------
/models/pretrained_common/optimization.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """PyTorch optimization for BERT model."""
16 |
17 | import logging
18 | import math
19 |
20 | import torch
21 | from torch.optim import Optimizer
22 | from torch.optim.lr_scheduler import LambdaLR
23 |
24 | logger = logging.getLogger(__name__)
25 |
26 | class ConstantLRSchedule(LambdaLR):
27 | """ Constant learning rate schedule.
28 | """
29 | def __init__(self, optimizer, last_epoch=-1):
30 | super(ConstantLRSchedule, self).__init__(optimizer, lambda _: 1.0, last_epoch=last_epoch)
31 |
32 |
33 | class WarmupConstantSchedule(LambdaLR):
34 | """ Linear warmup and then constant.
35 | Linearly increases learning rate schedule from 0 to 1 over `warmup_steps` training steps.
36 | Keeps learning rate schedule equal to 1. after warmup_steps.
37 | """
38 | def __init__(self, optimizer, warmup_steps, last_epoch=-1):
39 | self.warmup_steps = warmup_steps
40 | super(WarmupConstantSchedule, self).__init__(optimizer, self.lr_lambda, last_epoch=last_epoch)
41 |
42 | def lr_lambda(self, step):
43 | if step < self.warmup_steps:
44 | return float(step) / float(max(1.0, self.warmup_steps))
45 | return 1.
46 |
47 |
48 | class WarmupLinearSchedule(LambdaLR):
49 | """ Linear warmup and then linear decay.
50 | Linearly increases learning rate from 0 to 1 over `warmup_steps` training steps.
51 | Linearly decreases learning rate from 1. to 0. over remaining `t_total - warmup_steps` steps.
52 | """
53 | def __init__(self, optimizer, warmup_steps, t_total, last_epoch=-1):
54 | self.warmup_steps = warmup_steps
55 | self.t_total = t_total
56 | super(WarmupLinearSchedule, self).__init__(optimizer, self.lr_lambda, last_epoch=last_epoch)
57 |
58 | def lr_lambda(self, step):
59 | if step < self.warmup_steps:
60 | return float(step) / float(max(1, self.warmup_steps))
61 | return max(0.0, float(self.t_total - step) / float(max(1.0, self.t_total - self.warmup_steps)))
62 |
63 |
64 | class WarmupCosineSchedule(LambdaLR):
65 | """ Linear warmup and then cosine decay.
66 | Linearly increases learning rate from 0 to 1 over `warmup_steps` training steps.
67 | Decreases learning rate from 1. to 0. over remaining `t_total - warmup_steps` steps following a cosine curve.
68 | If `cycles` (default=0.5) is different from default, learning rate follows cosine function after warmup.
69 | """
70 | def __init__(self, optimizer, warmup_steps, t_total, cycles=.5, last_epoch=-1):
71 | self.warmup_steps = warmup_steps
72 | self.t_total = t_total
73 | self.cycles = cycles
74 | super(WarmupCosineSchedule, self).__init__(optimizer, self.lr_lambda, last_epoch=last_epoch)
75 |
76 | def lr_lambda(self, step):
77 | if step < self.warmup_steps:
78 | return float(step) / float(max(1.0, self.warmup_steps))
79 | # progress after warmup
80 | progress = float(step - self.warmup_steps) / float(max(1, self.t_total - self.warmup_steps))
81 | return max(0.0, 0.5 * (1. + math.cos(math.pi * float(self.cycles) * 2.0 * progress)))
82 |
83 |
84 | class WarmupCosineWithHardRestartsSchedule(LambdaLR):
85 | """ Linear warmup and then cosine cycles with hard restarts.
86 | Linearly increases learning rate from 0 to 1 over `warmup_steps` training steps.
87 | If `cycles` (default=1.) is different from default, learning rate follows `cycles` times a cosine decaying
88 | learning rate (with hard restarts).
89 | """
90 | def __init__(self, optimizer, warmup_steps, t_total, cycles=1., last_epoch=-1):
91 | self.warmup_steps = warmup_steps
92 | self.t_total = t_total
93 | self.cycles = cycles
94 | super(WarmupCosineWithHardRestartsSchedule, self).__init__(optimizer, self.lr_lambda, last_epoch=last_epoch)
95 |
96 | def lr_lambda(self, step):
97 | if step < self.warmup_steps:
98 | return float(step) / float(max(1, self.warmup_steps))
99 | # progress after warmup
100 | progress = float(step - self.warmup_steps) / float(max(1, self.t_total - self.warmup_steps))
101 | if progress >= 1.0:
102 | return 0.0
103 | return max(0.0, 0.5 * (1. + math.cos(math.pi * ((float(self.cycles) * progress) % 1.0))))
104 |
105 |
106 | class AdamW(Optimizer):
107 | """ Implements Adam algorithm with weight decay fix.
108 |
109 | Parameters:
110 | lr (float): learning rate. Default 1e-3.
111 | betas (tuple of 2 floats): Adams beta parameters (b1, b2). Default: (0.9, 0.999)
112 | eps (float): Adams epsilon. Default: 1e-6
113 | weight_decay (float): Weight decay. Default: 0.0
114 | correct_bias (bool): can be set to False to avoid correcting bias in Adam (e.g. like in Bert TF repository). Default True.
115 | """
116 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-6, weight_decay=0.0, correct_bias=True):
117 | if lr < 0.0:
118 | raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr))
119 | if not 0.0 <= betas[0] < 1.0:
120 | raise ValueError("Invalid beta parameter: {} - should be in [0.0, 1.0[".format(betas[0]))
121 | if not 0.0 <= betas[1] < 1.0:
122 | raise ValueError("Invalid beta parameter: {} - should be in [0.0, 1.0[".format(betas[1]))
123 | if not 0.0 <= eps:
124 | raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(eps))
125 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay,
126 | correct_bias=correct_bias)
127 | super(AdamW, self).__init__(params, defaults)
128 |
129 | def step(self, closure=None):
130 | """Performs a single optimization step.
131 |
132 | Arguments:
133 | closure (callable, optional): A closure that reevaluates the model
134 | and returns the loss.
135 | """
136 | loss = None
137 | if closure is not None:
138 | loss = closure()
139 |
140 | for group in self.param_groups:
141 | for p in group['params']:
142 | if p.grad is None:
143 | continue
144 | grad = p.grad.data
145 | if grad.is_sparse:
146 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')
147 |
148 | state = self.state[p]
149 |
150 | # State initialization
151 | if len(state) == 0:
152 | state['step'] = 0
153 | # Exponential moving average of gradient values
154 | state['exp_avg'] = torch.zeros_like(p.data)
155 | # Exponential moving average of squared gradient values
156 | state['exp_avg_sq'] = torch.zeros_like(p.data)
157 |
158 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
159 | beta1, beta2 = group['betas']
160 |
161 | state['step'] += 1
162 |
163 | # Decay the first and second moment running average coefficient
164 | # In-place operations to update the averages at the same time
165 | exp_avg.mul_(beta1).add_(1.0 - beta1, grad)
166 | exp_avg_sq.mul_(beta2).addcmul_(1.0 - beta2, grad, grad)
167 | denom = exp_avg_sq.sqrt().add_(group['eps'])
168 |
169 | step_size = group['lr']
170 | if group['correct_bias']: # No bias correction for Bert
171 | bias_correction1 = 1.0 - beta1 ** state['step']
172 | bias_correction2 = 1.0 - beta2 ** state['step']
173 | step_size = step_size * math.sqrt(bias_correction2) / bias_correction1
174 |
175 | p.data.addcdiv_(-step_size, exp_avg, denom)
176 |
177 | # Just adding the square of the weights to the loss function is *not*
178 | # the correct way of using L2 regularization/weight decay with Adam,
179 | # since that will interact with the m and v parameters in strange ways.
180 | #
181 | # Instead we want to decay the weights in a manner that doesn't interact
182 | # with the m/v parameters. This is equivalent to adding the square
183 | # of the weights to the loss with plain (non-momentum) SGD.
184 | # Add weight decay at the end (fixed version)
185 | if group['weight_decay'] > 0.0:
186 | p.data.add_(-group['lr'] * group['weight_decay'], p.data)
187 |
188 | return loss
189 |
--------------------------------------------------------------------------------
/models/pretrained_common/configuration_utils.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | """ Configuration base class and utilities."""
17 |
18 | from __future__ import (absolute_import, division, print_function,
19 | unicode_literals)
20 |
21 | import copy
22 | import json
23 | import logging
24 | import os
25 | from io import open
26 |
27 | from .file_utils import cached_path, CONFIG_NAME
28 |
29 | logger = logging.getLogger(__name__)
30 |
31 | class PretrainedConfig(object):
32 | r""" Base class for all configuration classes.
33 | Handles a few parameters common to all models' configurations as well as methods for loading/downloading/saving configurations.
34 |
35 | Note:
36 | A configuration file can be loaded and saved to disk. Loading the configuration file and using this file to initialize a model does **not** load the model weights.
37 | It only affects the model's configuration.
38 |
39 | Class attributes (overridden by derived classes):
40 | - ``pretrained_config_archive_map``: a python ``dict`` of with `short-cut-names` (string) as keys and `url` (string) of associated pretrained model configurations as values.
41 |
42 | Parameters:
43 | ``finetuning_task``: string, default `None`. Name of the task used to fine-tune the model. This can be used when converting from an original (TensorFlow or PyTorch) checkpoint.
44 | ``num_labels``: integer, default `2`. Number of classes to use when the model is a classification model (sequences/tokens)
45 | ``output_attentions``: boolean, default `False`. Should the model returns attentions weights.
46 | ``output_hidden_states``: string, default `False`. Should the model returns all hidden-states.
47 | ``torchscript``: string, default `False`. Is the model used with Torchscript.
48 | """
49 | pretrained_config_archive_map = {}
50 |
51 | def __init__(self, **kwargs):
52 | self.finetuning_task = kwargs.pop('finetuning_task', None)
53 | self.num_labels = kwargs.pop('num_labels', 2)
54 | self.output_attentions = kwargs.pop('output_attentions', False)
55 | self.output_hidden_states = kwargs.pop('output_hidden_states', False)
56 | self.torchscript = kwargs.pop('torchscript', False)
57 | self.pruned_heads = kwargs.pop('pruned_heads', {})
58 |
59 | def save_pretrained(self, save_directory):
60 | """ Save a configuration object to the directory `save_directory`, so that it
61 | can be re-loaded using the :func:`~pytorch_transformers.PretrainedConfig.from_pretrained` class method.
62 | """
63 | assert os.path.isdir(save_directory), "Saving path should be a directory where the model and configuration can be saved"
64 |
65 | # If we save using the predefined names, we can load using `from_pretrained`
66 | output_config_file = os.path.join(save_directory, CONFIG_NAME)
67 |
68 | self.to_json_file(output_config_file)
69 |
70 | @classmethod
71 | def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
72 | r""" Instantiate a :class:`~pytorch_transformers.PretrainedConfig` (or a derived class) from a pre-trained model configuration.
73 |
74 | Parameters:
75 | pretrained_model_name_or_path: either:
76 |
77 | - a string with the `shortcut name` of a pre-trained model configuration to load from cache or download, e.g.: ``bert-base-uncased``.
78 | - a path to a `directory` containing a configuration file saved using the :func:`~pytorch_transformers.PretrainedConfig.save_pretrained` method, e.g.: ``./my_model_directory/``.
79 | - a path or url to a saved configuration JSON `file`, e.g.: ``./my_model_directory/configuration.json``.
80 |
81 | cache_dir: (`optional`) string:
82 | Path to a directory in which a downloaded pre-trained model
83 | configuration should be cached if the standard cache should not be used.
84 |
85 | kwargs: (`optional`) dict: key/value pairs with which to update the configuration object after loading.
86 |
87 | - The values in kwargs of any keys which are configuration attributes will be used to override the loaded values.
88 | - Behavior concerning key/value pairs whose keys are *not* configuration attributes is controlled by the `return_unused_kwargs` keyword parameter.
89 |
90 | force_download: (`optional`) boolean, default False:
91 | Force to (re-)download the model weights and configuration files and override the cached versions if they exists.
92 |
93 | proxies: (`optional`) dict, default None:
94 | A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
95 | The proxies are used on each request.
96 |
97 | return_unused_kwargs: (`optional`) bool:
98 |
99 | - If False, then this function returns just the final configuration object.
100 | - If True, then this functions returns a tuple `(config, unused_kwargs)` where `unused_kwargs` is a dictionary consisting of the key/value pairs whose keys are not configuration attributes: ie the part of kwargs which has not been used to update `config` and is otherwise ignored.
101 |
102 | Examples::
103 |
104 | # We can't instantiate directly the base class `PretrainedConfig` so let's show the examples on a
105 | # derived class: BertConfig
106 | config = BertConfig.from_pretrained('bert-base-uncased') # Download configuration from S3 and cache.
107 | config = BertConfig.from_pretrained('./test/saved_model/') # E.g. config (or model) was saved using `save_pretrained('./test/saved_model/')`
108 | config = BertConfig.from_pretrained('./test/saved_model/my_configuration.json')
109 | config = BertConfig.from_pretrained('bert-base-uncased', output_attention=True, foo=False)
110 | assert config.output_attention == True
111 | config, unused_kwargs = BertConfig.from_pretrained('bert-base-uncased', output_attention=True,
112 | foo=False, return_unused_kwargs=True)
113 | assert config.output_attention == True
114 | assert unused_kwargs == {'foo': False}
115 |
116 | """
117 | cache_dir = kwargs.pop('cache_dir', None)
118 | force_download = kwargs.pop('force_download', False)
119 | proxies = kwargs.pop('proxies', None)
120 | return_unused_kwargs = kwargs.pop('return_unused_kwargs', False)
121 |
122 | if pretrained_model_name_or_path in cls.pretrained_config_archive_map:
123 | config_file = cls.pretrained_config_archive_map[pretrained_model_name_or_path]
124 | elif os.path.isdir(pretrained_model_name_or_path):
125 | config_file = os.path.join(pretrained_model_name_or_path, CONFIG_NAME)
126 | else:
127 | config_file = pretrained_model_name_or_path
128 | # redirect to the cache, if necessary
129 | try:
130 | resolved_config_file = cached_path(config_file, cache_dir=cache_dir, force_download=force_download, proxies=proxies)
131 | except EnvironmentError as e:
132 | if pretrained_model_name_or_path in cls.pretrained_config_archive_map:
133 | logger.error(
134 | "Couldn't reach server at '{}' to download pretrained model configuration file.".format(
135 | config_file))
136 | else:
137 | logger.error(
138 | "Model name '{}' was not found in model name list ({}). "
139 | "We assumed '{}' was a path or url but couldn't find any file "
140 | "associated to this path or url.".format(
141 | pretrained_model_name_or_path,
142 | ', '.join(cls.pretrained_config_archive_map.keys()),
143 | config_file))
144 | raise e
145 | if resolved_config_file == config_file:
146 | logger.info("loading configuration file {}".format(config_file))
147 | else:
148 | logger.info("loading configuration file {} from cache at {}".format(
149 | config_file, resolved_config_file))
150 |
151 | # Load config
152 | config = cls.from_json_file(resolved_config_file)
153 |
154 | if hasattr(config, 'pruned_heads'):
155 | config.pruned_heads = dict((int(key), set(value)) for key, value in config.pruned_heads.items())
156 |
157 | # Update config with kwargs if needed
158 | to_remove = []
159 | for key, value in kwargs.items():
160 | if hasattr(config, key):
161 | setattr(config, key, value)
162 | to_remove.append(key)
163 | for key in to_remove:
164 | kwargs.pop(key, None)
165 |
166 | logger.info("Model config %s", config)
167 | if return_unused_kwargs:
168 | return config, kwargs
169 | else:
170 | return config
171 |
172 | @classmethod
173 | def from_dict(cls, json_object):
174 | """Constructs a `Config` from a Python dictionary of parameters."""
175 | config = cls(vocab_size_or_config_json_file=-1)
176 | for key, value in json_object.items():
177 | config.__dict__[key] = value
178 | return config
179 |
180 | @classmethod
181 | def from_json_file(cls, json_file):
182 | """Constructs a `BertConfig` from a json file of parameters."""
183 | with open(json_file, "r", encoding='utf-8') as reader:
184 | text = reader.read()
185 | return cls.from_dict(json.loads(text))
186 |
187 | def __eq__(self, other):
188 | return self.__dict__ == other.__dict__
189 |
190 | def __repr__(self):
191 | return str(self.to_json_string())
192 |
193 | def to_dict(self):
194 | """Serializes this instance to a Python dictionary."""
195 | output = copy.deepcopy(self.__dict__)
196 | return output
197 |
198 | def to_json_string(self):
199 | """Serializes this instance to a JSON string."""
200 | return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
201 |
202 | def to_json_file(self, json_file_path):
203 | """ Save this instance to a json file."""
204 | with open(json_file_path, "w", encoding='utf-8') as writer:
205 | writer.write(self.to_json_string())
206 |
--------------------------------------------------------------------------------
/data/create_bert_post_training_data.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """Create masked LM/next sentence masked_lm TF examples for BERT."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import sys
22 | import os
23 | sys.path.append(os.getcwd())
24 | import collections
25 | import random
26 | import argparse
27 | import h5py
28 | import numpy as np
29 | from tqdm import tqdm
30 |
31 | from models.bert import tokenization_bert
32 |
33 | class TrainingInstance(object):
34 | """A single training instance (sentence pair)."""
35 |
36 | def __init__(self, tokens, segment_ids, masked_lm_positions, masked_lm_labels,
37 | is_random_next):
38 | self.tokens = tokens
39 | self.segment_ids = segment_ids
40 | self.is_random_next = is_random_next
41 | self.masked_lm_positions = masked_lm_positions
42 | self.masked_lm_labels = masked_lm_labels
43 |
44 | MaskedLmInstance = collections.namedtuple("MaskedLmInstance",
45 | ["index", "label"])
46 |
47 | class CreateBertPretrainingData(object):
48 | def __init__(self, args):
49 | self.args = args
50 | self._bert_tokenizer_init(args.special_tok)
51 |
52 | def _bert_tokenizer_init(self, special_tok, bert_pretrained='bert-base-uncased'):
53 | bert_pretrained_dir = os.path.join("./resources", bert_pretrained)
54 | vocab_file_path = "%s-vocab.txt" % bert_pretrained
55 |
56 | self._bert_tokenizer = tokenization_bert.BertTokenizer(vocab_file=os.path.join(bert_pretrained_dir, vocab_file_path))
57 | self._bert_tokenizer.add_tokens([special_tok])
58 |
59 | print("BERT tokenizer init completes")
60 |
61 | def _add_special_tokens(self, tokens, special_tok="[EOT]"):
62 | tokens = tokens + [special_tok]
63 | return tokens
64 |
65 | def create_training_instances(self, input_file, max_seq_length,
66 | dupe_factor, short_seq_prob, masked_lm_prob,
67 | max_predictions_per_seq, rng, special_tok=None):
68 | """Create `TrainingInstance`s from raw text."""
69 | all_documents = [[]]
70 |
71 | # Input file format:
72 | # (1) One sentence per line. These should ideally be actual sentences, not
73 | # entire paragraphs or arbitrary spans of text. (Because we use the
74 | # sentence boundaries for the "next sentence prediction" task).
75 | # (2) Blank lines between documents. Document boundaries are needed so
76 | # that the "next sentence prediction" task doesn't span between documents.
77 | document_cnt = 0
78 | with open(input_file, "r", encoding="utf=8") as fr_handle:
79 | for line in tqdm(fr_handle):
80 | line = line.strip()
81 |
82 | # Empty lines are used as document delimiters
83 | if len(line) == 0:
84 | all_documents.append([])
85 | document_cnt +=1
86 | if document_cnt % 50000 == 0:
87 | print("%d documents have been tokenized!" % document_cnt)
88 | tokens = self._bert_tokenizer.tokenize(line)
89 |
90 | if special_tok:
91 | tokens = self._add_special_tokens(tokens, special_tok) # special tok per sentence
92 |
93 | if tokens:
94 | all_documents[-1].append(tokens)
95 |
96 | # Remove empty documents
97 | all_documents = [x for x in all_documents if x]
98 | rng.shuffle(all_documents)
99 |
100 | vocab_words = list(self._bert_tokenizer.vocab.keys())
101 |
102 | self.feature_keys = ["input_ids", "attention_mask", "token_type_ids",
103 | "masked_lm_positions", "masked_lm_ids", "next_sentence_labels"]
104 |
105 | print("Total number of documents : %d" % len(all_documents))
106 | hf = h5py.File(self.args.output_file, "w")
107 | for d in range(dupe_factor):
108 | rng.shuffle(all_documents)
109 |
110 | self.all_doc_feat_dict = dict()
111 | for feat_key in self.feature_keys:
112 | self.all_doc_feat_dict[feat_key] = []
113 |
114 | for document_index in tqdm(range(len(all_documents))):
115 | instances = self.create_instances_from_document(
116 | all_documents, document_index, max_seq_length, short_seq_prob,
117 | masked_lm_prob, max_predictions_per_seq, vocab_words, rng)
118 | self.instance_to_example_feature(instances, self.args.max_seq_length,
119 | self.args.max_predictions_per_seq)
120 | self.build_h5_data(hf, d_idx=d),
121 | print("Current Dupe Factor : %d" % (d + 1))
122 | print("Pretraining Data Creation Completes!")
123 | hf.close()
124 |
125 | def build_h5_data(self, hf, d_idx=0):
126 | """
127 | features["input_ids"] = torch.tensor(input_ids).long()
128 | features["attention_mask"] = torch.tensor(input_mask).long()
129 | features["token_type_ids"] = torch.tensor(segment_ids).long()
130 | features["masked_lm_positions"] = torch.tensor(masked_lm_positions).long()
131 | features["masked_lm_ids"] = torch.tensor(masked_lm_ids).long() # masked_lm_ids
132 | features["next_sentence_labels"] = torch.tensor([next_sentence_label]).long()
133 | """
134 |
135 | h5_key_dict = {}
136 | print("Number of documents features", len(self.all_doc_feat_dict["next_sentence_labels"]))
137 | if d_idx == 0:
138 | for feat_key in self.feature_keys:
139 | key_size = [len(self.all_doc_feat_dict["next_sentence_labels"])] + [len(self.all_doc_feat_dict[feat_key][0])]
140 | h5_key_dict[feat_key] = hf.create_dataset(feat_key, tuple(key_size), dtype='i8', chunks=True,
141 | maxshape=(None, tuple(key_size)[1]),
142 | data=np.array(self.all_doc_feat_dict[feat_key]))
143 | else:
144 | for feat_key in self.feature_keys:
145 | hf[feat_key].resize((hf[feat_key].shape[0] + len(self.all_doc_feat_dict["next_sentence_labels"])), axis=0)
146 | hf[feat_key][-len(self.all_doc_feat_dict["next_sentence_labels"]):] = np.array(self.all_doc_feat_dict[feat_key])
147 |
148 | def create_instances_from_document(self, all_documents, document_index, max_seq_length, short_seq_prob,
149 | masked_lm_prob, max_predictions_per_seq, vocab_words, rng):
150 | """Creates `TrainingInstance`s for a single document."""
151 | document = all_documents[document_index]
152 |
153 | # Account for [CLS], [SEP], [SEP]
154 | max_num_tokens = max_seq_length - 3
155 |
156 | # We *usually* want to fill up the entire sequence since we are padding
157 | # to `max_seq_length` anyways, so short sequences are generally wasted
158 | # computation. However, we *sometimes*
159 | # (i.e., short_seq_prob == 0.1 == 10% of the time) want to use shorter
160 | # sequences to minimize the mismatch between pre-training and fine-tuning.
161 | # The `target_seq_length` is just a rough target however, whereas
162 | # `max_seq_length` is a hard limit.
163 | target_seq_length = max_num_tokens
164 | if rng.random() < short_seq_prob:
165 | target_seq_length = rng.randint(2, max_num_tokens)
166 |
167 | # We DON'T just concatenate all of the tokens from a document into a long
168 | # sequence and choose an arbitrary split point because this would make the
169 | # next sentence prediction task too easy. Instead, we split the input into
170 | # segments "A" and "B" based on the actual "sentences" provided by the user
171 | # input.
172 | instances = []
173 | current_chunk = []
174 | current_length = 0
175 | i = 0
176 |
177 | while i < len(document):
178 | segment = document[i]
179 | current_chunk.append(segment)
180 | current_length += len(segment)
181 | if i == len(document) - 1 or current_length >= target_seq_length:
182 | if current_chunk:
183 | # `a_end` is how many segments from `current_chunk` go into the `A`
184 | # (first) sentence.
185 | a_end = 1
186 | if len(current_chunk) >= 2:
187 | a_end = rng.randint(1, len(current_chunk) - 1)
188 |
189 | tokens_a = []
190 | for j in range(a_end):
191 | tokens_a.extend(current_chunk[j])
192 |
193 | tokens_b = []
194 | # Random next
195 | is_random_next = False
196 | if len(current_chunk) == 1 or rng.random() < 0.5:
197 | is_random_next = True
198 | target_b_length = target_seq_length - len(tokens_a)
199 |
200 | # This should rarely go for more than one iteration for large
201 | # corpora. However, just to be careful, we try to make sure that
202 | # the random document is not the same as the document
203 | # we're processing.
204 | for _ in range(10):
205 | random_document_index = rng.randint(0, len(all_documents) - 1)
206 | if random_document_index != document_index:
207 | break
208 |
209 | random_document = all_documents[random_document_index]
210 | random_start = rng.randint(0, len(random_document) - 1)
211 | for j in range(random_start, len(random_document)):
212 | tokens_b.extend(random_document[j])
213 | if len(tokens_b) >= target_b_length:
214 | break
215 | # We didn't actually use these segments so we "put them back" so
216 | # they don't go to waste.
217 | num_unused_segments = len(current_chunk) - a_end
218 | i -= num_unused_segments
219 | # Actual next
220 | else:
221 | is_random_next = False
222 | for j in range(a_end, len(current_chunk)):
223 | tokens_b.extend(current_chunk[j])
224 | self.truncate_seq_pair(tokens_a, tokens_b, max_num_tokens, rng)
225 |
226 | assert len(tokens_a) >= 1
227 | assert len(tokens_b) >= 1
228 |
229 | tokens = []
230 | segment_ids = []
231 | tokens.append("[CLS]")
232 | segment_ids.append(0)
233 | for token in tokens_a:
234 | tokens.append(token)
235 | segment_ids.append(0)
236 |
237 | tokens.append("[SEP]")
238 | segment_ids.append(0)
239 |
240 | for token in tokens_b:
241 | tokens.append(token)
242 | segment_ids.append(1)
243 | tokens.append("[SEP]")
244 | segment_ids.append(1)
245 |
246 | (tokens, masked_lm_positions, masked_lm_labels) = self.create_masked_lm_predictions(
247 | tokens, masked_lm_prob, max_predictions_per_seq, vocab_words, rng)
248 | instance = TrainingInstance(
249 | tokens=tokens,
250 | segment_ids=segment_ids,
251 | is_random_next=is_random_next,
252 | masked_lm_positions=masked_lm_positions,
253 | masked_lm_labels=masked_lm_labels)
254 | instances.append(instance)
255 |
256 | current_chunk = []
257 | current_length = 0
258 | i += 1
259 |
260 | return instances
261 |
262 | def create_masked_lm_predictions(self, tokens, masked_lm_prob, max_predictions_per_seq, vocab_words, rng):
263 | """Creates the predictions for the masked LM objective."""
264 |
265 | cand_indexes = []
266 | for (i, token) in enumerate(tokens):
267 | if token == "[CLS]" or token == "[SEP]":
268 | continue
269 | # Whole Word Masking means that if we mask all of the wordpieces
270 | # corresponding to an original word. When a word has been split into
271 | # WordPieces, the first token does not have any marker and any subsequence
272 | # tokens are prefixed with ##. So whenever we see the ## token, we
273 | # append it to the previous set of word indexes.
274 | #
275 | # Note that Whole Word Masking does *not* change the training code
276 | # at all -- we still predict each WordPiece independently, softmaxed
277 | # over the entire vocabulary.
278 | if (self.args.do_whole_word_mask and len(cand_indexes) >= 1 and
279 | token.startswith("##")):
280 | cand_indexes[-1].append(i)
281 | else:
282 | cand_indexes.append([i])
283 |
284 | rng.shuffle(cand_indexes)
285 |
286 | output_tokens = list(tokens)
287 |
288 | num_to_predict = min(max_predictions_per_seq,
289 | max(1, int(round(len(tokens) * masked_lm_prob))))
290 |
291 | masked_lms = []
292 | covered_indexes = set()
293 | for index_set in cand_indexes:
294 | if len(masked_lms) >= num_to_predict:
295 | break
296 | # If adding a whole-word mask would exceed the maximum number of
297 | # predictions, then just skip this candidate.
298 | if len(masked_lms) + len(index_set) > num_to_predict:
299 | continue
300 | is_any_index_covered = False
301 | for index in index_set:
302 | if index in covered_indexes:
303 | is_any_index_covered = True
304 | break
305 | if is_any_index_covered:
306 | continue
307 | for index in index_set:
308 | covered_indexes.add(index)
309 |
310 | masked_token = None
311 | # 80% of the time, replace with [MASK]
312 | if rng.random() < 0.8:
313 | masked_token = "[MASK]"
314 | else:
315 | # 10% of the time, keep original
316 | if rng.random() < 0.5:
317 | masked_token = tokens[index]
318 | # 10% of the time, replace with random word
319 | else:
320 | masked_token = vocab_words[rng.randint(0, len(vocab_words) - 1)]
321 |
322 | output_tokens[index] = masked_token
323 |
324 | masked_lms.append(MaskedLmInstance(index=index, label=tokens[index]))
325 |
326 | assert len(masked_lms) <= num_to_predict
327 | masked_lms = sorted(masked_lms, key=lambda x: x.index)
328 |
329 | masked_lm_positions = []
330 | masked_lm_labels = []
331 | for p in masked_lms:
332 | masked_lm_positions.append(p.index)
333 | masked_lm_labels.append(p.label)
334 |
335 | return (output_tokens, masked_lm_positions, masked_lm_labels)
336 |
337 | def truncate_seq_pair(self, tokens_a, tokens_b, max_num_tokens, rng):
338 | """Truncates a pair of sequences to a maximum sequence length."""
339 | while True:
340 | total_length = len(tokens_a) + len(tokens_b)
341 | if total_length <= max_num_tokens:
342 | break
343 |
344 | trunc_tokens = tokens_a if len(tokens_a) > len(tokens_b) else tokens_b
345 | assert len(trunc_tokens) >= 1
346 |
347 | # We want to sometimes truncate from the front and sometimes from the
348 | # back to add more randomness and avoid biases.
349 | if rng.random() < 0.5:
350 | del trunc_tokens[0]
351 | else:
352 | trunc_tokens.pop()
353 |
354 | def instance_to_example_feature(self, instances, max_seq_length, max_predictions_per_seq):
355 | for instance in instances:
356 | input_ids = self._bert_tokenizer.convert_tokens_to_ids(instance.tokens)
357 | input_mask = [1] * len(input_ids)
358 | segment_ids = list(instance.segment_ids)
359 |
360 | assert len(input_ids) <= max_seq_length
361 |
362 | while len(input_ids) < max_seq_length:
363 | input_ids.append(0)
364 | input_mask.append(0)
365 | segment_ids.append(0)
366 |
367 | assert len(input_ids) == max_seq_length
368 | assert len(input_mask) == max_seq_length
369 | assert len(segment_ids) == max_seq_length
370 |
371 | masked_lm_positions = list(instance.masked_lm_positions)
372 | masked_lm_ids = self._bert_tokenizer.convert_tokens_to_ids(instance.masked_lm_labels)
373 | masked_lm_weights = [1.0] * len(masked_lm_ids)
374 |
375 | while len(masked_lm_positions) < max_predictions_per_seq:
376 | masked_lm_positions.append(0)
377 | masked_lm_ids.append(0)
378 | masked_lm_weights.append(0.0)
379 |
380 | next_sentence_label = 1 if instance.is_random_next else 0
381 |
382 | self.all_doc_feat_dict["input_ids"].append(input_ids)
383 | self.all_doc_feat_dict["attention_mask"].append(input_mask)
384 | self.all_doc_feat_dict["token_type_ids"].append(segment_ids)
385 | self.all_doc_feat_dict["masked_lm_positions"].append(masked_lm_positions)
386 | self.all_doc_feat_dict["masked_lm_ids"].append(masked_lm_ids)
387 | self.all_doc_feat_dict["next_sentence_labels"].append([next_sentence_label])
388 |
389 | if __name__ == "__main__":
390 | arg_parser = argparse.ArgumentParser(description="Bert / Create Pretraining Data")
391 | arg_parser.add_argument("--input_file", dest="input_file", type=str,
392 | default="./data/ubuntu_corpus_v1/ubuntu_post_training.txt",
393 | help="Input raw text file (or comma-separated list of files).")
394 | arg_parser.add_argument("--output_file", dest="output_file", type=str,
395 | default="./data/ubuntu_corpus_v1/ubuntu_post_training.hdf5",
396 | help="Output example pkl.")
397 | arg_parser.add_argument("--do_lower_case", dest="do_lower_case", type=bool, default=True,
398 | help="Whether to lower case the input text. Should be True for uncased.")
399 | arg_parser.add_argument("--do_whole_word_mask", dest="do_whole_word_mask", type=bool, default=True,
400 | help="Whether to use whole word masking rather than per-WordPiece masking.")
401 | arg_parser.add_argument("--max_seq_length", dest="max_seq_length", type=int, default=512,
402 | help="Maximum sequence length.")
403 | arg_parser.add_argument("--max_predictions_per_seq", dest="max_predictions_per_seq", type=int, default=70,
404 | help="Maximum number of masked LM predictions per sequence.")
405 | arg_parser.add_argument("--random_seed", dest="random_seed", type=int, default=12345,
406 | help="Random seed for data generation.")
407 | arg_parser.add_argument("--dupe_factor", dest="dupe_factor", type=int, default=1,
408 | help="Number of times to duplicate the input data (with different masks).")
409 | arg_parser.add_argument("--masked_lm_prob", dest="masked_lm_prob", type=float, default=0.15,
410 | help="Masked LM probability.")
411 | arg_parser.add_argument("--short_seq_prob", dest="short_seq_prob", type=float, default=0.1,
412 | help="Probability of creating sequences which are shorter than the maximum length.")
413 | arg_parser.add_argument("--special_tok", dest="special_tok", type=str, default="[EOT]",
414 | help="Special Token.")
415 | args = arg_parser.parse_args()
416 |
417 | create_data = CreateBertPretrainingData(args)
418 |
419 | rng = random.Random(args.random_seed)
420 | create_data.create_training_instances(
421 | args.input_file, args.max_seq_length, args.dupe_factor,
422 | args.short_seq_prob, args.masked_lm_prob, args.max_predictions_per_seq, rng, args.special_tok)
423 |
--------------------------------------------------------------------------------
/models/pretrained_common/file_utils.py:
--------------------------------------------------------------------------------
1 | """
2 | Utilities for working with the local dataset cache.
3 | This file is adapted from the AllenNLP library at https://github.com/allenai/allennlp
4 | Copyright by the AllenNLP authors.
5 | """
6 |
7 | import fnmatch
8 | import json
9 | import logging
10 | import os
11 | import shutil
12 | import sys
13 | import tarfile
14 | import tempfile
15 | from contextlib import contextmanager
16 | from functools import partial, wraps
17 | from hashlib import sha256
18 | from typing import Optional
19 | from urllib.parse import urlparse
20 | from zipfile import ZipFile, is_zipfile
21 |
22 | import requests
23 | from filelock import FileLock
24 | from tqdm.auto import tqdm
25 |
26 | from . import __version__
27 |
28 |
29 | logger = logging.getLogger(__name__) # pylint: disable=invalid-name
30 |
31 | try:
32 | USE_TF = os.environ.get("USE_TF", "AUTO").upper()
33 | USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper()
34 | if USE_TORCH in ("1", "ON", "YES", "AUTO") and USE_TF not in ("1", "ON", "YES"):
35 | import torch
36 |
37 | _torch_available = True # pylint: disable=invalid-name
38 | logger.info("PyTorch version {} available.".format(torch.__version__))
39 | else:
40 | logger.info("Disabling PyTorch because USE_TF is set")
41 | _torch_available = False
42 | except ImportError:
43 | _torch_available = False # pylint: disable=invalid-name
44 |
45 | try:
46 | USE_TF = os.environ.get("USE_TF", "AUTO").upper()
47 | USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper()
48 |
49 | if USE_TF in ("1", "ON", "YES", "AUTO") and USE_TORCH not in ("1", "ON", "YES"):
50 | import tensorflow as tf
51 |
52 | assert hasattr(tf, "__version__") and int(tf.__version__[0]) >= 2
53 | _tf_available = True # pylint: disable=invalid-name
54 | logger.info("TensorFlow version {} available.".format(tf.__version__))
55 | else:
56 | logger.info("Disabling Tensorflow because USE_TORCH is set")
57 | _tf_available = False
58 | except (ImportError, AssertionError):
59 | _tf_available = False # pylint: disable=invalid-name
60 |
61 | try:
62 | from torch.hub import _get_torch_home
63 |
64 | torch_cache_home = _get_torch_home()
65 | except ImportError:
66 | torch_cache_home = os.path.expanduser(
67 | os.getenv("TORCH_HOME", os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "torch"))
68 | )
69 | default_cache_path = os.path.join(torch_cache_home, "transformers")
70 |
71 | try:
72 | from pathlib import Path
73 |
74 | PYTORCH_PRETRAINED_BERT_CACHE = Path(
75 | os.getenv("PYTORCH_TRANSFORMERS_CACHE", os.getenv("PYTORCH_PRETRAINED_BERT_CACHE", default_cache_path))
76 | )
77 | except (AttributeError, ImportError):
78 | PYTORCH_PRETRAINED_BERT_CACHE = os.getenv(
79 | "PYTORCH_TRANSFORMERS_CACHE", os.getenv("PYTORCH_PRETRAINED_BERT_CACHE", default_cache_path)
80 | )
81 |
82 | PYTORCH_TRANSFORMERS_CACHE = PYTORCH_PRETRAINED_BERT_CACHE # Kept for backward compatibility
83 | TRANSFORMERS_CACHE = PYTORCH_PRETRAINED_BERT_CACHE # Kept for backward compatibility
84 |
85 | WEIGHTS_NAME = "pytorch_model.bin"
86 | TF2_WEIGHTS_NAME = "tf_model.h5"
87 | TF_WEIGHTS_NAME = "model.ckpt"
88 | CONFIG_NAME = "config.json"
89 | MODEL_CARD_NAME = "modelcard.json"
90 |
91 |
92 | MULTIPLE_CHOICE_DUMMY_INPUTS = [[[0], [1]], [[0], [1]]]
93 | DUMMY_INPUTS = [[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]]
94 | DUMMY_MASK = [[1, 1, 1, 1, 1], [1, 1, 1, 0, 0], [0, 0, 0, 1, 1]]
95 |
96 | S3_BUCKET_PREFIX = "https://s3.amazonaws.com/models.huggingface.co/bert"
97 | CLOUDFRONT_DISTRIB_PREFIX = "https://d2ws9o8vfrpkyk.cloudfront.net"
98 |
99 |
100 | def is_torch_available():
101 | return _torch_available
102 |
103 |
104 | def is_tf_available():
105 | return _tf_available
106 |
107 |
108 | def add_start_docstrings(*docstr):
109 | def docstring_decorator(fn):
110 | fn.__doc__ = "".join(docstr) + (fn.__doc__ if fn.__doc__ is not None else "")
111 | return fn
112 |
113 | return docstring_decorator
114 |
115 |
116 | def add_start_docstrings_to_callable(*docstr):
117 | def docstring_decorator(fn):
118 | class_name = ":class:`~transformers.{}`".format(fn.__qualname__.split(".")[0])
119 | intro = " The {} forward method, overrides the :func:`__call__` special method.".format(class_name)
120 | note = r"""
121 |
122 | .. note::
123 | Although the recipe for forward pass needs to be defined within
124 | this function, one should call the :class:`Module` instance afterwards
125 | instead of this since the former takes care of running the
126 | pre and post processing steps while the latter silently ignores them.
127 | """
128 | fn.__doc__ = intro + note + "".join(docstr) + (fn.__doc__ if fn.__doc__ is not None else "")
129 | return fn
130 |
131 | return docstring_decorator
132 |
133 |
134 | def add_end_docstrings(*docstr):
135 | def docstring_decorator(fn):
136 | fn.__doc__ = fn.__doc__ + "".join(docstr)
137 | return fn
138 |
139 | return docstring_decorator
140 |
141 |
142 | def is_remote_url(url_or_filename):
143 | parsed = urlparse(url_or_filename)
144 | return parsed.scheme in ("http", "https")
145 |
146 |
147 | def hf_bucket_url(identifier, postfix=None, cdn=False) -> str:
148 | endpoint = CLOUDFRONT_DISTRIB_PREFIX if cdn else S3_BUCKET_PREFIX
149 | if postfix is None:
150 | return "/".join((endpoint, identifier))
151 | else:
152 | return "/".join((endpoint, identifier, postfix))
153 |
154 |
155 | def url_to_filename(url, etag=None):
156 | """
157 | Convert `url` into a hashed filename in a repeatable way.
158 | If `etag` is specified, append its hash to the url's, delimited
159 | by a period.
160 | If the url ends with .h5 (Keras HDF5 weights) adds '.h5' to the name
161 | so that TF 2.0 can identify it as a HDF5 file
162 | (see https://github.com/tensorflow/tensorflow/blob/00fad90125b18b80fe054de1055770cfb8fe4ba3/tensorflow/python/keras/engine/network.py#L1380)
163 | """
164 | url_bytes = url.encode("utf-8")
165 | url_hash = sha256(url_bytes)
166 | filename = url_hash.hexdigest()
167 |
168 | if etag:
169 | etag_bytes = etag.encode("utf-8")
170 | etag_hash = sha256(etag_bytes)
171 | filename += "." + etag_hash.hexdigest()
172 |
173 | if url.endswith(".h5"):
174 | filename += ".h5"
175 |
176 | return filename
177 |
178 |
179 | def filename_to_url(filename, cache_dir=None):
180 | """
181 | Return the url and etag (which may be ``None``) stored for `filename`.
182 | Raise ``EnvironmentError`` if `filename` or its stored metadata do not exist.
183 | """
184 | if cache_dir is None:
185 | cache_dir = TRANSFORMERS_CACHE
186 | if isinstance(cache_dir, Path):
187 | cache_dir = str(cache_dir)
188 |
189 | cache_path = os.path.join(cache_dir, filename)
190 | if not os.path.exists(cache_path):
191 | raise EnvironmentError("file {} not found".format(cache_path))
192 |
193 | meta_path = cache_path + ".json"
194 | if not os.path.exists(meta_path):
195 | raise EnvironmentError("file {} not found".format(meta_path))
196 |
197 | with open(meta_path, encoding="utf-8") as meta_file:
198 | metadata = json.load(meta_file)
199 | url = metadata["url"]
200 | etag = metadata["etag"]
201 |
202 | return url, etag
203 |
204 |
205 | def cached_path(
206 | url_or_filename,
207 | cache_dir=None,
208 | force_download=False,
209 | proxies=None,
210 | resume_download=False,
211 | user_agent=None,
212 | extract_compressed_file=False,
213 | force_extract=False,
214 | local_files_only=False,
215 | ) -> Optional[str]:
216 | """
217 | Given something that might be a URL (or might be a local path),
218 | determine which. If it's a URL, download the file and cache it, and
219 | return the path to the cached file. If it's already a local path,
220 | make sure the file exists and then return the path.
221 | Args:
222 | cache_dir: specify a cache directory to save the file to (overwrite the default cache dir).
223 | force_download: if True, re-dowload the file even if it's already cached in the cache dir.
224 | resume_download: if True, resume the download if incompletly recieved file is found.
225 | user_agent: Optional string or dict that will be appended to the user-agent on remote requests.
226 | extract_compressed_file: if True and the path point to a zip or tar file, extract the compressed
227 | file in a folder along the archive.
228 | force_extract: if True when extract_compressed_file is True and the archive was already extracted,
229 | re-extract the archive and overide the folder where it was extracted.
230 |
231 | Return:
232 | None in case of non-recoverable file (non-existent or inaccessible url + no cache on disk).
233 | Local path (string) otherwise
234 | """
235 | if cache_dir is None:
236 | cache_dir = TRANSFORMERS_CACHE
237 | if isinstance(url_or_filename, Path):
238 | url_or_filename = str(url_or_filename)
239 | if isinstance(cache_dir, Path):
240 | cache_dir = str(cache_dir)
241 |
242 | if is_remote_url(url_or_filename):
243 | # URL, so get it from the cache (downloading if necessary)
244 | output_path = get_from_cache(
245 | url_or_filename,
246 | cache_dir=cache_dir,
247 | force_download=force_download,
248 | proxies=proxies,
249 | resume_download=resume_download,
250 | user_agent=user_agent,
251 | local_files_only=local_files_only,
252 | )
253 | elif os.path.exists(url_or_filename):
254 | # File, and it exists.
255 | output_path = url_or_filename
256 | elif urlparse(url_or_filename).scheme == "":
257 | # File, but it doesn't exist.
258 | raise EnvironmentError("file {} not found".format(url_or_filename))
259 | else:
260 | # Something unknown
261 | raise ValueError("unable to parse {} as a URL or as a local path".format(url_or_filename))
262 |
263 | if extract_compressed_file:
264 | if not is_zipfile(output_path) and not tarfile.is_tarfile(output_path):
265 | return output_path
266 |
267 | # Path where we extract compressed archives
268 | # We avoid '.' in dir name and add "-extracted" at the end: "./model.zip" => "./model-zip-extracted/"
269 | output_dir, output_file = os.path.split(output_path)
270 | output_extract_dir_name = output_file.replace(".", "-") + "-extracted"
271 | output_path_extracted = os.path.join(output_dir, output_extract_dir_name)
272 |
273 | if os.path.isdir(output_path_extracted) and os.listdir(output_path_extracted) and not force_extract:
274 | return output_path_extracted
275 |
276 | # Prevent parallel extractions
277 | lock_path = output_path + ".lock"
278 | with FileLock(lock_path):
279 | shutil.rmtree(output_path_extracted, ignore_errors=True)
280 | os.makedirs(output_path_extracted)
281 | if is_zipfile(output_path):
282 | with ZipFile(output_path, "r") as zip_file:
283 | zip_file.extractall(output_path_extracted)
284 | zip_file.close()
285 | elif tarfile.is_tarfile(output_path):
286 | tar_file = tarfile.open(output_path)
287 | tar_file.extractall(output_path_extracted)
288 | tar_file.close()
289 | else:
290 | raise EnvironmentError("Archive format of {} could not be identified".format(output_path))
291 |
292 | return output_path_extracted
293 |
294 | return output_path
295 |
296 |
297 | def http_get(url, temp_file, proxies=None, resume_size=0, user_agent=None):
298 | ua = "transformers/{}; python/{}".format(__version__, sys.version.split()[0])
299 | if is_torch_available():
300 | ua += "; torch/{}".format(torch.__version__)
301 | if is_tf_available():
302 | ua += "; tensorflow/{}".format(tf.__version__)
303 | if isinstance(user_agent, dict):
304 | ua += "; " + "; ".join("{}/{}".format(k, v) for k, v in user_agent.items())
305 | elif isinstance(user_agent, str):
306 | ua += "; " + user_agent
307 | headers = {"user-agent": ua}
308 | if resume_size > 0:
309 | headers["Range"] = "bytes=%d-" % (resume_size,)
310 | response = requests.get(url, stream=True, proxies=proxies, headers=headers)
311 | if response.status_code == 416: # Range not satisfiable
312 | return
313 | content_length = response.headers.get("Content-Length")
314 | total = resume_size + int(content_length) if content_length is not None else None
315 | progress = tqdm(
316 | unit="B",
317 | unit_scale=True,
318 | total=total,
319 | initial=resume_size,
320 | desc="Downloading",
321 | disable=bool(logger.getEffectiveLevel() == logging.NOTSET),
322 | )
323 | for chunk in response.iter_content(chunk_size=1024):
324 | if chunk: # filter out keep-alive new chunks
325 | progress.update(len(chunk))
326 | temp_file.write(chunk)
327 | progress.close()
328 |
329 |
330 | def get_from_cache(
331 | url,
332 | cache_dir=None,
333 | force_download=False,
334 | proxies=None,
335 | etag_timeout=10,
336 | resume_download=False,
337 | user_agent=None,
338 | local_files_only=False,
339 | ) -> Optional[str]:
340 | """
341 | Given a URL, look for the corresponding file in the local cache.
342 | If it's not there, download it. Then return the path to the cached file.
343 |
344 | Return:
345 | None in case of non-recoverable file (non-existent or inaccessible url + no cache on disk).
346 | Local path (string) otherwise
347 | """
348 | if cache_dir is None:
349 | cache_dir = TRANSFORMERS_CACHE
350 | if isinstance(cache_dir, Path):
351 | cache_dir = str(cache_dir)
352 |
353 | os.makedirs(cache_dir, exist_ok=True)
354 |
355 | etag = None
356 | if not local_files_only:
357 | try:
358 | response = requests.head(url, allow_redirects=True, proxies=proxies, timeout=etag_timeout)
359 | if response.status_code == 200:
360 | etag = response.headers.get("ETag")
361 | except (EnvironmentError, requests.exceptions.Timeout):
362 | # etag is already None
363 | pass
364 |
365 | filename = url_to_filename(url, etag)
366 |
367 | # get cache path to put the file
368 | cache_path = os.path.join(cache_dir, filename)
369 |
370 | # etag is None = we don't have a connection, or url doesn't exist, or is otherwise inaccessible.
371 | # try to get the last downloaded one
372 | if etag is None:
373 | if os.path.exists(cache_path):
374 | return cache_path
375 | else:
376 | matching_files = [
377 | file
378 | for file in fnmatch.filter(os.listdir(cache_dir), filename + ".*")
379 | if not file.endswith(".json") and not file.endswith(".lock")
380 | ]
381 | if len(matching_files) > 0:
382 | return os.path.join(cache_dir, matching_files[-1])
383 | else:
384 | # If files cannot be found and local_files_only=True,
385 | # the models might've been found if local_files_only=False
386 | # Notify the user about that
387 | if local_files_only:
388 | raise ValueError(
389 | "Cannot find the requested files in the cached path and outgoing traffic has been"
390 | " disabled. To enable model look-ups and downloads online, set 'local_files_only'"
391 | " to False."
392 | )
393 | return None
394 |
395 | # From now on, etag is not None.
396 | if os.path.exists(cache_path) and not force_download:
397 | return cache_path
398 |
399 | # Prevent parallel downloads of the same file with a lock.
400 | lock_path = cache_path + ".lock"
401 | with FileLock(lock_path):
402 |
403 | # If the download just completed while the lock was activated.
404 | if os.path.exists(cache_path) and not force_download:
405 | # Even if returning early like here, the lock will be released.
406 | return cache_path
407 |
408 | if resume_download:
409 | incomplete_path = cache_path + ".incomplete"
410 |
411 | @contextmanager
412 | def _resumable_file_manager():
413 | with open(incomplete_path, "a+b") as f:
414 | yield f
415 |
416 | temp_file_manager = _resumable_file_manager
417 | if os.path.exists(incomplete_path):
418 | resume_size = os.stat(incomplete_path).st_size
419 | else:
420 | resume_size = 0
421 | else:
422 | temp_file_manager = partial(tempfile.NamedTemporaryFile, dir=cache_dir, delete=False)
423 | resume_size = 0
424 |
425 | # Download to temporary file, then copy to cache dir once finished.
426 | # Otherwise you get corrupt cache entries if the download gets interrupted.
427 | with temp_file_manager() as temp_file:
428 | logger.info("%s not found in cache or force_download set to True, downloading to %s", url, temp_file.name)
429 |
430 | http_get(url, temp_file, proxies=proxies, resume_size=resume_size, user_agent=user_agent)
431 |
432 | logger.info("storing %s in cache at %s", url, cache_path)
433 | os.replace(temp_file.name, cache_path)
434 |
435 | logger.info("creating metadata file for %s", cache_path)
436 | meta = {"url": url, "etag": etag}
437 | meta_path = cache_path + ".json"
438 | with open(meta_path, "w") as meta_file:
439 | json.dump(meta, meta_file)
440 |
441 | return cache_path
442 |
443 |
444 | class cached_property(property):
445 | """
446 | Descriptor that mimics @property but caches output in member variable.
447 |
448 | From tensorflow_datasets
449 |
450 | Built-in in functools from Python 3.8.
451 | """
452 |
453 | def __get__(self, obj, objtype=None):
454 | # See docs.python.org/3/howto/descriptor.html#properties
455 | if obj is None:
456 | return self
457 | if self.fget is None:
458 | raise AttributeError("unreadable attribute")
459 | attr = "__cached_" + self.fget.__name__
460 | cached = getattr(obj, attr, None)
461 | if cached is None:
462 | cached = self.fget(obj)
463 | setattr(obj, attr, cached)
464 | return cached
465 |
466 |
467 | def torch_required(func):
468 | # Chose a different decorator name than in tests so it's clear they are not the same.
469 | @wraps(func)
470 | def wrapper(*args, **kwargs):
471 | if is_torch_available():
472 | return func(*args, **kwargs)
473 | else:
474 | raise ImportError(f"Method `{func.__name__}` requires PyTorch.")
475 |
476 | return wrapper
477 |
478 |
479 | def tf_required(func):
480 | # Chose a different decorator name than in tests so it's clear they are not the same.
481 | @wraps(func)
482 | def wrapper(*args, **kwargs):
483 | if is_tf_available():
484 | return func(*args, **kwargs)
485 | else:
486 | raise ImportError(f"Method `{func.__name__}` requires TF.")
487 |
488 | return wrapper
489 |
--------------------------------------------------------------------------------
/models/bert/tokenization_bert.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """Tokenization classes."""
16 |
17 | from __future__ import absolute_import, division, print_function, unicode_literals
18 |
19 | import collections
20 | import logging
21 | import os
22 | import unicodedata
23 | from io import open
24 |
25 | from models.pretrained_common.tokenization_utils import PreTrainedTokenizer
26 |
27 | logger = logging.getLogger(__name__)
28 |
29 | VOCAB_FILES_NAMES = {'vocab_file': 'vocab.txt'}
30 |
31 | PRETRAINED_VOCAB_FILES_MAP = {
32 | 'vocab_file':
33 | {
34 | 'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt",
35 | 'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-vocab.txt",
36 | 'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-vocab.txt",
37 | 'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-vocab.txt",
38 | 'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-vocab.txt",
39 | 'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-vocab.txt",
40 | 'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-vocab.txt",
41 | 'bert-base-german-cased': "https://int-deepset-models-bert.s3.eu-central-1.amazonaws.com/pytorch/bert-base-german-cased-vocab.txt",
42 | 'bert-large-uncased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-vocab.txt",
43 | 'bert-large-cased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-vocab.txt",
44 | 'bert-large-uncased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-finetuned-squad-vocab.txt",
45 | 'bert-large-cased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad-vocab.txt",
46 | 'bert-base-cased-finetuned-mrpc': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-vocab.txt",
47 | }
48 | }
49 |
50 | PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
51 | 'bert-base-uncased': 512,
52 | 'bert-large-uncased': 512,
53 | 'bert-base-cased': 512,
54 | 'bert-large-cased': 512,
55 | 'bert-base-multilingual-uncased': 512,
56 | 'bert-base-multilingual-cased': 512,
57 | 'bert-base-chinese': 512,
58 | 'bert-base-german-cased': 512,
59 | 'bert-large-uncased-whole-word-masking': 512,
60 | 'bert-large-cased-whole-word-masking': 512,
61 | 'bert-large-uncased-whole-word-masking-finetuned-squad': 512,
62 | 'bert-large-cased-whole-word-masking-finetuned-squad': 512,
63 | 'bert-base-cased-finetuned-mrpc': 512,
64 | }
65 |
66 | def load_vocab(vocab_file):
67 | """Loads a vocabulary file into a dictionary."""
68 | vocab = collections.OrderedDict()
69 | with open(vocab_file, "r", encoding="utf-8") as reader:
70 | tokens = reader.readlines()
71 | for index, token in enumerate(tokens):
72 | token = token.rstrip('\n')
73 | vocab[token] = index
74 | return vocab
75 |
76 |
77 | def whitespace_tokenize(text):
78 | """Runs basic whitespace cleaning and splitting on a piece of text."""
79 | text = text.strip()
80 | if not text:
81 | return []
82 | tokens = text.split()
83 | return tokens
84 |
85 |
86 | class BertTokenizer(PreTrainedTokenizer):
87 | r"""
88 | Constructs a BertTokenizer.
89 | :class:`~pytorch_pretrained_bert.BertTokenizer` runs end-to-end tokenization: punctuation splitting + wordpiece
90 | Args:
91 | vocab_file: Path to a one-wordpiece-per-line vocabulary file
92 | do_lower_case: Whether to lower case the input. Only has an effect when do_wordpiece_only=False
93 | do_basic_tokenize: Whether to do basic tokenization before wordpiece.
94 | max_len: An artificial maximum length to truncate tokenized sequences to; Effective maximum length is always the
95 | minimum of this value (if specified) and the underlying BERT model's sequence length.
96 | never_split: List of tokens which will never be split during tokenization. Only has an effect when
97 | do_wordpiece_only=False
98 | """
99 |
100 | vocab_files_names = VOCAB_FILES_NAMES
101 | pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
102 | max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
103 |
104 | def __init__(self, vocab_file, do_lower_case=True, do_basic_tokenize=True, never_split=None,
105 | unk_token="[UNK]", sep_token="[SEP]", pad_token="[PAD]", cls_token="[CLS]",
106 | mask_token="[MASK]", tokenize_chinese_chars=True, **kwargs):
107 | """Constructs a BertTokenizer.
108 | Args:
109 | **vocab_file**: Path to a one-wordpiece-per-line vocabulary file
110 | **do_lower_case**: (`optional`) boolean (default True)
111 | Whether to lower case the input
112 | Only has an effect when do_basic_tokenize=True
113 | **do_basic_tokenize**: (`optional`) boolean (default True)
114 | Whether to do basic tokenization before wordpiece.
115 | **never_split**: (`optional`) list of string
116 | List of tokens which will never be split during tokenization.
117 | Only has an effect when do_basic_tokenize=True
118 | **tokenize_chinese_chars**: (`optional`) boolean (default True)
119 | Whether to tokenize Chinese characters.
120 | This should likely be desactivated for Japanese:
121 | see: https://github.com/huggingface/pytorch-pretrained-BERT/issues/328
122 | """
123 | super(BertTokenizer, self).__init__(unk_token=unk_token, sep_token=sep_token,
124 | pad_token=pad_token, cls_token=cls_token,
125 | mask_token=mask_token, **kwargs)
126 |
127 | if not os.path.isfile(vocab_file):
128 | raise ValueError(
129 | "Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained "
130 | "model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`".format(vocab_file))
131 | print(vocab_file)
132 | self.vocab = load_vocab(vocab_file)
133 | self.ids_to_tokens = collections.OrderedDict(
134 | [(ids, tok) for tok, ids in self.vocab.items()])
135 | self.do_basic_tokenize = do_basic_tokenize
136 | if do_basic_tokenize:
137 | self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case,
138 | never_split=never_split,
139 | tokenize_chinese_chars=tokenize_chinese_chars)
140 | self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab, unk_token=self.unk_token)
141 |
142 | @property
143 | def vocab_size(self):
144 | return len(self.vocab)
145 |
146 | def _tokenize(self, text):
147 | split_tokens = []
148 | if self.do_basic_tokenize:
149 | for token in self.basic_tokenizer.tokenize(text, never_split=self.all_special_tokens):
150 | for sub_token in self.wordpiece_tokenizer.tokenize(token):
151 | split_tokens.append(sub_token)
152 | else:
153 | split_tokens = self.wordpiece_tokenizer.tokenize(text)
154 | return split_tokens
155 |
156 | def _convert_token_to_id(self, token):
157 | """ Converts a token (str/unicode) in an id using the vocab. """
158 | return self.vocab.get(token, self.vocab.get(self.unk_token))
159 |
160 | def _convert_id_to_token(self, index):
161 | """Converts an index (integer) in a token (string/unicode) using the vocab."""
162 | return self.ids_to_tokens.get(index, self.unk_token)
163 |
164 | def convert_tokens_to_string(self, tokens):
165 | """ Converts a sequence of tokens (string) in a single string. """
166 | out_string = ' '.join(tokens).replace(' ##', '').strip()
167 | return out_string
168 |
169 | def save_vocabulary(self, vocab_path):
170 | """Save the tokenizer vocabulary to a directory or file."""
171 | index = 0
172 | if os.path.isdir(vocab_path):
173 | vocab_file = os.path.join(vocab_path, VOCAB_FILES_NAMES['vocab_file'])
174 | with open(vocab_file, "w", encoding="utf-8") as writer:
175 | for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]):
176 | if index != token_index:
177 | logger.warning("Saving vocabulary to {}: vocabulary indices are not consecutive."
178 | " Please check that the vocabulary is not corrupted!".format(vocab_file))
179 | index = token_index
180 | writer.write(token + u'\n')
181 | index += 1
182 | return (vocab_file,)
183 |
184 | @classmethod
185 | def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
186 | """ Instantiate a BertTokenizer from pre-trained vocabulary files.
187 | """
188 | if pretrained_model_name_or_path in PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES:
189 | if '-cased' in pretrained_model_name_or_path and kwargs.get('do_lower_case', True):
190 | logger.warning("The pre-trained model you are loading is a cased model but you have not set "
191 | "`do_lower_case` to False. We are setting `do_lower_case=False` for you but "
192 | "you may want to check this behavior.")
193 | kwargs['do_lower_case'] = False
194 | elif '-cased' not in pretrained_model_name_or_path and not kwargs.get('do_lower_case', True):
195 | logger.warning("The pre-trained model you are loading is an uncased model but you have set "
196 | "`do_lower_case` to False. We are setting `do_lower_case=True` for you "
197 | "but you may want to check this behavior.")
198 | kwargs['do_lower_case'] = True
199 |
200 | return super(BertTokenizer, cls)._from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
201 |
202 |
203 | class BasicTokenizer(object):
204 | """Runs basic tokenization (punctuation splitting, lower casing, etc.)."""
205 |
206 | def __init__(self, do_lower_case=True, never_split=None, tokenize_chinese_chars=True):
207 | """ Constructs a BasicTokenizer.
208 | Args:
209 | **do_lower_case**: Whether to lower case the input.
210 | **never_split**: (`optional`) list of str
211 | Kept for backward compatibility purposes.
212 | Now implemented directly at the base class level (see :func:`PreTrainedTokenizer.tokenize`)
213 | List of token not to split.
214 | **tokenize_chinese_chars**: (`optional`) boolean (default True)
215 | Whether to tokenize Chinese characters.
216 | This should likely be desactivated for Japanese:
217 | see: https://github.com/huggingface/pytorch-pretrained-BERT/issues/328
218 | """
219 | if never_split is None:
220 | never_split = []
221 | self.do_lower_case = do_lower_case
222 | self.never_split = never_split
223 | self.tokenize_chinese_chars = tokenize_chinese_chars
224 |
225 | def tokenize(self, text, never_split=None):
226 | """ Basic Tokenization of a piece of text.
227 | Split on "white spaces" only, for sub-word tokenization, see WordPieceTokenizer.
228 | Args:
229 | **never_split**: (`optional`) list of str
230 | Kept for backward compatibility purposes.
231 | Now implemented directly at the base class level (see :func:`PreTrainedTokenizer.tokenize`)
232 | List of token not to split.
233 | """
234 | never_split = self.never_split + (never_split if never_split is not None else [])
235 | text = self._clean_text(text)
236 | # This was added on November 1st, 2018 for the multilingual and Chinese
237 | # models. This is also applied to the English models now, but it doesn't
238 | # matter since the English models were not trained on any Chinese data
239 | # and generally don't have any Chinese data in them (there are Chinese
240 | # characters in the vocabulary because Wikipedia does have some Chinese
241 | # words in the English Wikipedia.).
242 | if self.tokenize_chinese_chars:
243 | text = self._tokenize_chinese_chars(text)
244 | orig_tokens = whitespace_tokenize(text)
245 | split_tokens = []
246 | for token in orig_tokens:
247 | if self.do_lower_case and token not in never_split:
248 | token = token.lower()
249 | token = self._run_strip_accents(token)
250 | split_tokens.extend(self._run_split_on_punc(token))
251 |
252 | output_tokens = whitespace_tokenize(" ".join(split_tokens))
253 | return output_tokens
254 |
255 | def _run_strip_accents(self, text):
256 | """Strips accents from a piece of text."""
257 | text = unicodedata.normalize("NFD", text)
258 | output = []
259 | for char in text:
260 | cat = unicodedata.category(char)
261 | if cat == "Mn":
262 | continue
263 | output.append(char)
264 | return "".join(output)
265 |
266 | def _run_split_on_punc(self, text, never_split=None):
267 | """Splits punctuation on a piece of text."""
268 | if never_split is not None and text in never_split:
269 | return [text]
270 | chars = list(text)
271 | i = 0
272 | start_new_word = True
273 | output = []
274 | while i < len(chars):
275 | char = chars[i]
276 | if _is_punctuation(char):
277 | output.append([char])
278 | start_new_word = True
279 | else:
280 | if start_new_word:
281 | output.append([])
282 | start_new_word = False
283 | output[-1].append(char)
284 | i += 1
285 |
286 | return ["".join(x) for x in output]
287 |
288 | def _tokenize_chinese_chars(self, text):
289 | """Adds whitespace around any CJK character."""
290 | output = []
291 | for char in text:
292 | cp = ord(char)
293 | if self._is_chinese_char(cp):
294 | output.append(" ")
295 | output.append(char)
296 | output.append(" ")
297 | else:
298 | output.append(char)
299 | return "".join(output)
300 |
301 | def _is_chinese_char(self, cp):
302 | """Checks whether CP is the codepoint of a CJK character."""
303 | # This defines a "chinese character" as anything in the CJK Unicode block:
304 | # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
305 | #
306 | # Note that the CJK Unicode block is NOT all Japanese and Korean characters,
307 | # despite its name. The modern Korean Hangul alphabet is a different block,
308 | # as is Japanese Hiragana and Katakana. Those alphabets are used to write
309 | # space-separated words, so they are not treated specially and handled
310 | # like the all of the other languages.
311 | if ((cp >= 0x4E00 and cp <= 0x9FFF) or #
312 | (cp >= 0x3400 and cp <= 0x4DBF) or #
313 | (cp >= 0x20000 and cp <= 0x2A6DF) or #
314 | (cp >= 0x2A700 and cp <= 0x2B73F) or #
315 | (cp >= 0x2B740 and cp <= 0x2B81F) or #
316 | (cp >= 0x2B820 and cp <= 0x2CEAF) or
317 | (cp >= 0xF900 and cp <= 0xFAFF) or #
318 | (cp >= 0x2F800 and cp <= 0x2FA1F)): #
319 | return True
320 |
321 | return False
322 |
323 | def _clean_text(self, text):
324 | """Performs invalid character removal and whitespace cleanup on text."""
325 | output = []
326 | for char in text:
327 | cp = ord(char)
328 | if cp == 0 or cp == 0xfffd or _is_control(char):
329 | continue
330 | if _is_whitespace(char):
331 | output.append(" ")
332 | else:
333 | output.append(char)
334 | return "".join(output)
335 |
336 |
337 | class WordpieceTokenizer(object):
338 | """Runs WordPiece tokenization."""
339 |
340 | def __init__(self, vocab, unk_token, max_input_chars_per_word=100):
341 | self.vocab = vocab
342 | self.unk_token = unk_token
343 | self.max_input_chars_per_word = max_input_chars_per_word
344 |
345 | def tokenize(self, text):
346 | """Tokenizes a piece of text into its word pieces.
347 | This uses a greedy longest-match-first algorithm to perform tokenization
348 | using the given vocabulary.
349 | For example:
350 | input = "unaffable"
351 | output = ["un", "##aff", "##able"]
352 | Args:
353 | text: A single token or whitespace separated tokens. This should have
354 | already been passed through `BasicTokenizer`.
355 | Returns:
356 | A list of wordpiece tokens.
357 | """
358 |
359 | output_tokens = []
360 | for token in whitespace_tokenize(text):
361 | chars = list(token)
362 | if len(chars) > self.max_input_chars_per_word:
363 | output_tokens.append(self.unk_token)
364 | continue
365 |
366 | is_bad = False
367 | start = 0
368 | sub_tokens = []
369 | while start < len(chars):
370 | end = len(chars)
371 | cur_substr = None
372 | while start < end:
373 | substr = "".join(chars[start:end])
374 | if start > 0:
375 | substr = "##" + substr
376 | if substr in self.vocab:
377 | cur_substr = substr
378 | break
379 | end -= 1
380 | if cur_substr is None:
381 | is_bad = True
382 | break
383 | sub_tokens.append(cur_substr)
384 | start = end
385 |
386 | if is_bad:
387 | output_tokens.append(self.unk_token)
388 | else:
389 | output_tokens.extend(sub_tokens)
390 | return output_tokens
391 |
392 |
393 | def _is_whitespace(char):
394 | """Checks whether `chars` is a whitespace character."""
395 | # \t, \n, and \r are technically contorl characters but we treat them
396 | # as whitespace since they are generally considered as such.
397 | if char == " " or char == "\t" or char == "\n" or char == "\r":
398 | return True
399 | cat = unicodedata.category(char)
400 | if cat == "Zs":
401 | return True
402 | return False
403 |
404 |
405 | def _is_control(char):
406 | """Checks whether `chars` is a control character."""
407 | # These are technically control characters but we count them as whitespace
408 | # characters.
409 | if char == "\t" or char == "\n" or char == "\r":
410 | return False
411 | cat = unicodedata.category(char)
412 | if cat.startswith("C"):
413 | return True
414 | return False
415 |
416 |
417 | def _is_punctuation(char):
418 | """Checks whether `chars` is a punctuation character."""
419 | cp = ord(char)
420 | # We treat all non-letter/number ASCII as punctuation.
421 | # Characters such as "^", "$", and "`" are not in the Unicode
422 | # Punctuation class but we treat them as punctuation anyways, for
423 | # consistency.
424 | if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or
425 | (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)):
426 | return True
427 | cat = unicodedata.category(char)
428 | if cat.startswith("P"):
429 | return True
430 | return False
--------------------------------------------------------------------------------
/models/pretrained_common/tokenization_utils.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Open AI Team Authors and The HuggingFace Inc. team.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """Tokenization classes for OpenAI GPT."""
16 | from __future__ import (absolute_import, division, print_function,
17 | unicode_literals)
18 |
19 | import logging
20 | import os
21 | import json
22 | import six
23 | from io import open
24 |
25 | from models.pretrained_common.file_utils import cached_path
26 |
27 | logger = logging.getLogger(__name__)
28 |
29 | SPECIAL_TOKENS_MAP_FILE = 'special_tokens_map.json'
30 | ADDED_TOKENS_FILE = 'added_tokens.json'
31 |
32 | class PreTrainedTokenizer(object):
33 | """ An abstract class to handle dowloading and loading pretrained tokenizers and adding tokens to the vocabulary.
34 |
35 | Derived class can set up a few special tokens to be used in common scripts and internals:
36 | bos_token, eos_token, EOP_TOKEN, EOD_TOKEN, unk_token, sep_token, pad_token, cls_token, mask_token
37 | additional_special_tokens = []
38 |
39 | We defined an added_tokens_encoder to add new tokens to the vocabulary without having to handle the
40 | specific vocabulary augmentation methods of the various underlying dictionnary structures (BPE, sentencepiece...).
41 | """
42 | vocab_files_names = {}
43 | pretrained_vocab_files_map = {}
44 | max_model_input_sizes = {}
45 |
46 | SPECIAL_TOKENS_ATTRIBUTES = ["bos_token", "eos_token", "unk_token", "sep_token",
47 | "pad_token", "cls_token", "mask_token",
48 | "additional_special_tokens"]
49 |
50 | @property
51 | def bos_token(self):
52 | if self._bos_token is None:
53 | logger.error("Using bos_token, but it is not set yet.")
54 | return self._bos_token
55 |
56 | @property
57 | def eos_token(self):
58 | if self._eos_token is None:
59 | logger.error("Using eos_token, but it is not set yet.")
60 | return self._eos_token
61 |
62 | @property
63 | def unk_token(self):
64 | if self._unk_token is None:
65 | logger.error("Using unk_token, but it is not set yet.")
66 | return self._unk_token
67 |
68 | @property
69 | def sep_token(self):
70 | if self._sep_token is None:
71 | logger.error("Using sep_token, but it is not set yet.")
72 | return self._sep_token
73 |
74 | @property
75 | def pad_token(self):
76 | if self._pad_token is None:
77 | logger.error("Using pad_token, but it is not set yet.")
78 | return self._pad_token
79 |
80 | @property
81 | def cls_token(self):
82 | if self._cls_token is None:
83 | logger.error("Using cls_token, but it is not set yet.")
84 | return self._cls_token
85 |
86 | @property
87 | def mask_token(self):
88 | if self._mask_token is None:
89 | logger.error("Using mask_token, but it is not set yet.")
90 | return self._mask_token
91 |
92 | @property
93 | def additional_special_tokens(self):
94 | if self._additional_special_tokens is None:
95 | logger.error("Using additional_special_tokens, but it is not set yet.")
96 | return self._additional_special_tokens
97 |
98 | @bos_token.setter
99 | def bos_token(self, value):
100 | self._bos_token = value
101 |
102 | @eos_token.setter
103 | def eos_token(self, value):
104 | self._eos_token = value
105 |
106 | @unk_token.setter
107 | def unk_token(self, value):
108 | self._unk_token = value
109 |
110 | @sep_token.setter
111 | def sep_token(self, value):
112 | self._sep_token = value
113 |
114 | @pad_token.setter
115 | def pad_token(self, value):
116 | self._pad_token = value
117 |
118 | @cls_token.setter
119 | def cls_token(self, value):
120 | self._cls_token = value
121 |
122 | @mask_token.setter
123 | def mask_token(self, value):
124 | self._mask_token = value
125 |
126 | @additional_special_tokens.setter
127 | def additional_special_tokens(self, value):
128 | self._additional_special_tokens = value
129 |
130 | def __init__(self, max_len=None, **kwargs):
131 | self._bos_token = None
132 | self._eos_token = None
133 | self._unk_token = None
134 | self._sep_token = None
135 | self._pad_token = None
136 | self._cls_token = None
137 | self._mask_token = None
138 | self._additional_special_tokens = []
139 |
140 | self.max_len = max_len if max_len is not None else int(1e12)
141 | self.added_tokens_encoder = {}
142 | self.added_tokens_decoder = {}
143 |
144 | for key, value in kwargs.items():
145 | if key in self.SPECIAL_TOKENS_ATTRIBUTES:
146 | setattr(self, key, value)
147 |
148 |
149 | @classmethod
150 | def from_pretrained(cls, *inputs, **kwargs):
151 | return cls._from_pretrained(*inputs, **kwargs)
152 |
153 |
154 | @classmethod
155 | def _from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs):
156 | """
157 | Instantiate a PreTrainedTokenizer from pre-trained vocabulary files.
158 | Download and cache the vocabulary files if needed.
159 | """
160 | s3_models = list(cls.max_model_input_sizes.keys())
161 | vocab_files = {}
162 | if pretrained_model_name_or_path in s3_models:
163 | # Get the vocabulary from AWS S3 bucket
164 | for file_id, map_list in cls.pretrained_vocab_files_map.items():
165 | vocab_files[file_id] = map_list[pretrained_model_name_or_path]
166 | else:
167 | # Get the vocabulary from local files
168 | logger.info(
169 | "Model name '{}' not found in model shortcut name list ({}). "
170 | "Assuming '{}' is a path or url to a directory containing tokenizer files.".format(
171 | pretrained_model_name_or_path, ', '.join(s3_models),
172 | pretrained_model_name_or_path))
173 |
174 | # Look for the tokenizer main vocabulary files
175 | for file_id, file_name in cls.vocab_files_names.items():
176 | if os.path.isdir(pretrained_model_name_or_path):
177 | # If a directory is provided we look for the standard filenames
178 | full_file_name = os.path.join(pretrained_model_name_or_path, file_name)
179 | else:
180 | # If a path to a file is provided we use it (will only work for non-BPE tokenizer using a single vocabulary file)
181 | full_file_name = pretrained_model_name_or_path
182 | if not os.path.exists(full_file_name):
183 | logger.info("Didn't find file {}. We won't load it.".format(full_file_name))
184 | full_file_name = None
185 | vocab_files[file_id] = full_file_name
186 |
187 | # Look for the additional tokens files
188 | all_vocab_files_names = {'added_tokens_file': ADDED_TOKENS_FILE,
189 | 'special_tokens_map_file': SPECIAL_TOKENS_MAP_FILE}
190 |
191 | # If a path to a file was provided, get the parent directory
192 | saved_directory = pretrained_model_name_or_path
193 | if os.path.exists(saved_directory) and not os.path.isdir(saved_directory):
194 | saved_directory = os.path.dirname(saved_directory)
195 |
196 | for file_id, file_name in all_vocab_files_names.items():
197 | full_file_name = os.path.join(saved_directory, file_name)
198 | if not os.path.exists(full_file_name):
199 | logger.info("Didn't find file {}. We won't load it.".format(full_file_name))
200 | full_file_name = None
201 | vocab_files[file_id] = full_file_name
202 |
203 | if all(full_file_name is None for full_file_name in vocab_files.values()):
204 | logger.error(
205 | "Model name '{}' was not found in model name list ({}). "
206 | "We assumed '{}' was a path or url but couldn't find tokenizer files"
207 | "at this path or url.".format(
208 | pretrained_model_name_or_path, ', '.join(s3_models),
209 | pretrained_model_name_or_path, ))
210 | return None
211 |
212 | # Get files from url, cache, or disk depending on the case
213 | try:
214 | resolved_vocab_files = {}
215 | for file_id, file_path in vocab_files.items():
216 | if file_path is None:
217 | resolved_vocab_files[file_id] = None
218 | else:
219 | resolved_vocab_files[file_id] = cached_path(file_path, cache_dir=cache_dir)
220 | except EnvironmentError:
221 | if pretrained_model_name_or_path in s3_models:
222 | logger.error("Couldn't reach server to download vocabulary.")
223 | else:
224 | logger.error(
225 | "Model name '{}' was not found in model name list ({}). "
226 | "We assumed '{}' was a path or url but couldn't find files {} "
227 | "at this path or url.".format(
228 | pretrained_model_name_or_path, ', '.join(s3_models),
229 | pretrained_model_name_or_path, str(vocab_files.keys())))
230 | return None
231 |
232 | for file_id, file_path in vocab_files.items():
233 | if file_path == resolved_vocab_files[file_id]:
234 | logger.info("loading file {}".format(file_path))
235 | else:
236 | logger.info("loading file {} from cache at {}".format(
237 | file_path, resolved_vocab_files[file_id]))
238 |
239 | # Set max length if needed
240 | if pretrained_model_name_or_path in cls.max_model_input_sizes:
241 | # if we're using a pretrained model, ensure the tokenizer
242 | # wont index sequences longer than the number of positional embeddings
243 | max_len = cls.max_model_input_sizes[pretrained_model_name_or_path]
244 | if max_len is not None and isinstance(max_len, (int, float)):
245 | kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len)
246 |
247 | # Merge resolved_vocab_files arguments in kwargs.
248 | added_tokens_file = resolved_vocab_files.pop('added_tokens_file', None)
249 | special_tokens_map_file = resolved_vocab_files.pop('special_tokens_map_file', None)
250 | for args_name, file_path in resolved_vocab_files.items():
251 | if args_name not in kwargs:
252 | kwargs[args_name] = file_path
253 | if special_tokens_map_file is not None:
254 | special_tokens_map = json.load(open(special_tokens_map_file, encoding="utf-8"))
255 | for key, value in special_tokens_map.items():
256 | if key not in kwargs:
257 | kwargs[key] = value
258 |
259 | # Instantiate tokenizer.
260 | tokenizer = cls(*inputs, **kwargs)
261 |
262 | # Add supplementary tokens.
263 | if added_tokens_file is not None:
264 | added_tok_encoder = json.load(open(added_tokens_file, encoding="utf-8"))
265 | added_tok_decoder = {v:k for k, v in added_tok_encoder.items()}
266 | tokenizer.added_tokens_encoder.update(added_tok_encoder)
267 | tokenizer.added_tokens_decoder.update(added_tok_decoder)
268 |
269 | return tokenizer
270 |
271 |
272 | def save_pretrained(self, save_directory):
273 | """ Save the tokenizer vocabulary files (with added tokens) and the
274 | special-tokens-to-class-attributes-mapping to a directory, so that it
275 | can be re-loaded using the `from_pretrained(save_directory)` class method.
276 | """
277 | if not os.path.isdir(save_directory):
278 | logger.error("Saving directory ({}) should be a directory".format(save_directory))
279 | return
280 |
281 | special_tokens_map_file = os.path.join(save_directory, SPECIAL_TOKENS_MAP_FILE)
282 | added_tokens_file = os.path.join(save_directory, ADDED_TOKENS_FILE)
283 |
284 | with open(special_tokens_map_file, 'w', encoding='utf-8') as f:
285 | f.write(json.dumps(self.special_tokens_map, ensure_ascii=False))
286 |
287 | with open(added_tokens_file, 'w', encoding='utf-8') as f:
288 | if self.added_tokens_encoder:
289 | out_str = json.dumps(self.added_tokens_encoder, ensure_ascii=False)
290 | else:
291 | out_str = u"{}"
292 | f.write(out_str)
293 |
294 | vocab_files = self.save_vocabulary(save_directory)
295 |
296 | return vocab_files + (special_tokens_map_file, added_tokens_file)
297 |
298 |
299 | def save_vocabulary(self, save_directory):
300 | """ Save the tokenizer vocabulary to a directory. This method doesn't save added tokens
301 | and special token mappings.
302 |
303 | Please use `save_pretrained()` to save the full Tokenizer state so that it can be
304 | reloaded using the `from_pretrained(save_directory)` class method.
305 | """
306 | raise NotImplementedError
307 |
308 |
309 | def vocab_size(self):
310 | raise NotImplementedError
311 |
312 |
313 | def __len__(self):
314 | return self.vocab_size + len(self.added_tokens_encoder)
315 |
316 |
317 | def add_tokens(self, new_tokens):
318 | """ Add a list of new tokens to the tokenizer class. If the new tokens are not in the
319 | vocabulary, they are added to the added_tokens_encoder with indices starting from
320 | the last index of the current vocabulary.
321 |
322 | Returns:
323 | Number of tokens added to the vocabulary which can be used to correspondingly
324 | increase the size of the associated model embedding matrices.
325 | """
326 | if not new_tokens:
327 | return 0
328 |
329 | to_add_tokens = []
330 | for token in new_tokens:
331 | if self.convert_tokens_to_ids(token) == self.convert_tokens_to_ids(self.unk_token):
332 | to_add_tokens.append(token)
333 | logger.info("Adding %s to the vocabulary", token)
334 |
335 | added_tok_encoder = dict((tok, len(self) + i) for i, tok in enumerate(to_add_tokens))
336 | added_tok_decoder = {v:k for k, v in added_tok_encoder.items()}
337 | self.added_tokens_encoder.update(added_tok_encoder)
338 | self.added_tokens_decoder.update(added_tok_decoder)
339 |
340 | return len(to_add_tokens)
341 |
342 |
343 | def add_special_tokens(self, special_tokens_dict):
344 | """ Add a dictionnary of special tokens (eos, pad, cls...) to the encoder and link them
345 | to class attributes. If the special tokens are not in the vocabulary, they are added
346 | to it and indexed starting from the last index of the current vocabulary.
347 |
348 | Returns:
349 | Number of tokens added to the vocabulary which can be used to correspondingly
350 | increase the size of the associated model embedding matrices.
351 | """
352 | if not special_tokens_dict:
353 | return 0
354 |
355 | added_special_tokens = self.add_tokens(special_tokens_dict.values())
356 | for key, value in special_tokens_dict.items():
357 | logger.info("Assigning %s to the %s key of the tokenizer", value, key)
358 | setattr(self, key, value)
359 |
360 | return added_special_tokens
361 |
362 |
363 | def tokenize(self, text, **kwargs):
364 | """ Converts a string in a sequence of tokens (string), using the tokenizer.
365 | Split in words for word-based vocabulary or sub-words for sub-word-based
366 | vocabularies (BPE/SentencePieces/WordPieces).
367 |
368 | Take care of added tokens.
369 | """
370 | def split_on_tokens(tok_list, text):
371 | if not text:
372 | return []
373 | if not tok_list:
374 | return self._tokenize(text, **kwargs)
375 | tok = tok_list[0]
376 | split_text = text.split(tok)
377 | return sum((split_on_tokens(tok_list[1:], sub_text.strip()) + [tok] \
378 | for sub_text in split_text), [])[:-1]
379 |
380 | added_tokens = list(self.added_tokens_encoder.keys()) + self.all_special_tokens
381 | tokenized_text = split_on_tokens(added_tokens, text)
382 | return tokenized_text
383 |
384 | def _tokenize(self, text, **kwargs):
385 | """ Converts a string in a sequence of tokens (string), using the tokenizer.
386 | Split in words for word-based vocabulary or sub-words for sub-word-based
387 | vocabularies (BPE/SentencePieces/WordPieces).
388 |
389 | Don't take care of added tokens.
390 | """
391 | raise NotImplementedError
392 |
393 | def convert_tokens_to_ids(self, tokens):
394 | """ Converts a single token or a sequence of tokens (str/unicode) in a integer id
395 | (resp.) a sequence of ids, using the vocabulary.
396 | """
397 | if isinstance(tokens, str) or (six.PY2 and isinstance(tokens, unicode)):
398 | return self._convert_token_to_id_with_added_voc(tokens)
399 |
400 | ids = []
401 | for token in tokens:
402 | ids.append(self._convert_token_to_id_with_added_voc(token))
403 | if len(ids) > self.max_len:
404 | logger.warning("Token indices sequence length is longer than the specified maximum sequence length "
405 | "for this model ({} > {}). Running this sequence through the model will result in "
406 | "indexing errors".format(len(ids), self.max_len))
407 | return ids
408 |
409 | def _convert_token_to_id_with_added_voc(self, token):
410 | if token in self.added_tokens_encoder:
411 | return self.added_tokens_encoder[token]
412 | return self._convert_token_to_id(token)
413 |
414 | def _convert_token_to_id(self, token):
415 | raise NotImplementedError
416 |
417 |
418 | def encode(self, text):
419 | """ Converts a string in a sequence of ids (integer), using the tokenizer and vocabulary.
420 | same as self.convert_tokens_to_ids(self.tokenize(text)).
421 | """
422 | return self.convert_tokens_to_ids(self.tokenize(text))
423 |
424 |
425 | def convert_ids_to_tokens(self, ids, skip_special_tokens=False):
426 | """ Converts a single index or a sequence of indices (integers) in a token "
427 | (resp.) a sequence of tokens (str/unicode), using the vocabulary and added tokens.
428 |
429 | Args:
430 | skip_special_tokens: Don't decode special tokens (self.all_special_tokens). Default: False
431 | """
432 | if isinstance(ids, int):
433 | if ids in self.added_tokens_decoder:
434 | return self.added_tokens_decoder[ids]
435 | else:
436 | return self._convert_id_to_token(ids)
437 | tokens = []
438 | for index in ids:
439 | if index in self.all_special_ids and skip_special_tokens:
440 | continue
441 | if index in self.added_tokens_decoder:
442 | tokens.append(self.added_tokens_decoder[index])
443 | else:
444 | tokens.append(self._convert_id_to_token(index))
445 | return tokens
446 |
447 | def _convert_id_to_token(self, index):
448 | raise NotImplementedError
449 |
450 | def convert_tokens_to_string(self, tokens):
451 | """ Converts a sequence of tokens (string) in a single string.
452 | The most simple way to do it is ' '.join(self.convert_ids_to_tokens(token_ids))
453 | but we often want to remove sub-word tokenization artifacts at the same time.
454 | """
455 | return ' '.join(self.convert_ids_to_tokens(tokens))
456 |
457 | def decode(self, token_ids, skip_special_tokens=False, clean_up_tokenization_spaces=True):
458 | """ Converts a sequence of ids (integer) in a string, using the tokenizer and vocabulary
459 | with options to remove special tokens and clean up tokenization spaces.
460 | """
461 | filtered_tokens = self.convert_ids_to_tokens(token_ids, skip_special_tokens=skip_special_tokens)
462 | text = self.convert_tokens_to_string(filtered_tokens)
463 | if clean_up_tokenization_spaces:
464 | text = clean_up_tokenization(text)
465 | return text
466 |
467 | @property
468 | def special_tokens_map(self):
469 | """ A dictionary mapping special token class attribute (cls_token, unk_token...) to their
470 | values ('', ''...)
471 | """
472 | set_attr = {}
473 | for attr in self.SPECIAL_TOKENS_ATTRIBUTES:
474 | attr_value = getattr(self, "_" + attr)
475 | if attr_value:
476 | set_attr[attr] = attr_value
477 | return set_attr
478 |
479 | @property
480 | def all_special_tokens(self):
481 | """ List all the special tokens ('', ''...) mapped to class attributes
482 | (cls_token, unk_token...).
483 | """
484 | all_toks = []
485 | set_attr = self.special_tokens_map
486 | for attr_value in set_attr.values():
487 | all_toks = all_toks + (attr_value if isinstance(attr_value, (list, tuple)) else [attr_value])
488 | all_toks = list(set(all_toks))
489 | return all_toks
490 |
491 | @property
492 | def all_special_ids(self):
493 | """ List the vocabulary indices of the special tokens ('', ''...) mapped to
494 | class attributes (cls_token, unk_token...).
495 | """
496 | all_toks = self.all_special_tokens
497 | all_ids = list(self.convert_tokens_to_ids(t) for t in all_toks)
498 | return all_ids
499 |
500 |
501 |
502 | def clean_up_tokenization(out_string):
503 | out_string = out_string.replace(' .', '.').replace(' ?', '?').replace(' !', '!').replace(' ,', ','
504 | ).replace(" ' ", "'").replace(" n't", "n't").replace(" 'm", "'m").replace(" do not", " don't"
505 | ).replace(" 's", "'s").replace(" 've", "'ve").replace(" 're", "'re")
506 | return out_string
507 |
--------------------------------------------------------------------------------
/models/pretrained_common/modeling_utils.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | """PyTorch BERT model."""
17 |
18 | from __future__ import (absolute_import, division, print_function,
19 | unicode_literals)
20 |
21 | import copy
22 | import json
23 | from collections import OrderedDict
24 | import logging
25 | import os
26 | from io import open
27 |
28 | import six
29 | import torch
30 | from torch import nn
31 | from torch.nn import CrossEntropyLoss
32 | from torch.nn import functional as F
33 |
34 | from .configuration_utils import PretrainedConfig
35 | from .file_utils import cached_path, WEIGHTS_NAME, TF_WEIGHTS_NAME
36 |
37 | logger = logging.getLogger(__name__)
38 |
39 |
40 | try:
41 | from torch.nn import Identity
42 | except ImportError:
43 | # Older PyTorch compatibility
44 | class Identity(nn.Module):
45 | r"""A placeholder identity operator that is argument-insensitive.
46 | """
47 | def __init__(self, *args, **kwargs):
48 | super(Identity, self).__init__()
49 |
50 | def forward(self, input):
51 | return input
52 |
53 | class PreTrainedModel(nn.Module):
54 | r""" Base class for all models.
55 | :class:`~pytorch_transformers.PreTrainedModel` takes care of storing the configuration of the models and handles methods for loading/downloading/saving models
56 | as well as a few methods commons to all models to (i) resize the input embeddings and (ii) prune heads in the self-attention heads.
57 | Class attributes (overridden by derived classes):
58 | - ``config_class``: a class derived from :class:`~pytorch_transformers.PretrainedConfig` to use as configuration class for this model architecture.
59 | - ``pretrained_model_archive_map``: a python ``dict`` of with `short-cut-names` (string) as keys and `url` (string) of associated pretrained weights as values.
60 | - ``load_tf_weights``: a python ``method`` for loading a TensorFlow checkpoint in a PyTorch model, taking as arguments:
61 | - ``model``: an instance of the relevant subclass of :class:`~pytorch_transformers.PreTrainedModel`,
62 | - ``config``: an instance of the relevant subclass of :class:`~pytorch_transformers.PretrainedConfig`,
63 | - ``path``: a path (string) to the TensorFlow checkpoint.
64 | - ``base_model_prefix``: a string indicating the attribute associated to the base model in derived classes of the same architecture adding modules on top of the base model.
65 | """
66 | config_class = None
67 | pretrained_model_archive_map = {}
68 | load_tf_weights = lambda model, config, path: None
69 | base_model_prefix = ""
70 |
71 | def __init__(self, config, *inputs, **kwargs):
72 | super(PreTrainedModel, self).__init__()
73 | if not isinstance(config, PretrainedConfig):
74 | raise ValueError(
75 | "Parameter config in `{}(config)` should be an instance of class `PretrainedConfig`. "
76 | "To create a model from a pretrained model use "
77 | "`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
78 | self.__class__.__name__, self.__class__.__name__
79 | ))
80 | # Save config in model
81 | self.config = config
82 |
83 | def _get_resized_embeddings(self, old_embeddings, new_num_tokens=None):
84 | """ Build a resized Embedding Module from a provided token Embedding Module.
85 | Increasing the size will add newly initialized vectors at the end
86 | Reducing the size will remove vectors from the end
87 | Args:
88 | new_num_tokens: (`optional`) int
89 | New number of tokens in the embedding matrix.
90 | Increasing the size will add newly initialized vectors at the end
91 | Reducing the size will remove vectors from the end
92 | If not provided or None: return the provided token Embedding Module.
93 | Return: ``torch.nn.Embeddings``
94 | Pointer to the resized Embedding Module or the old Embedding Module if new_num_tokens is None
95 | """
96 | if new_num_tokens is None:
97 | return old_embeddings
98 |
99 | old_num_tokens, old_embedding_dim = old_embeddings.weight.size()
100 | if old_num_tokens == new_num_tokens:
101 | return old_embeddings
102 |
103 | # Build new embeddings
104 | new_embeddings = nn.Embedding(new_num_tokens, old_embedding_dim)
105 | new_embeddings.to(old_embeddings.weight.device)
106 |
107 | # initialize all new embeddings (in particular added tokens)
108 | self._init_weights(new_embeddings)
109 |
110 | # Copy word embeddings from the previous weights
111 | num_tokens_to_copy = min(old_num_tokens, new_num_tokens)
112 | new_embeddings.weight.data[:num_tokens_to_copy, :] = old_embeddings.weight.data[:num_tokens_to_copy, :]
113 |
114 | return new_embeddings
115 |
116 | def _tie_or_clone_weights(self, first_module, second_module):
117 | """ Tie or clone module weights depending of weither we are using TorchScript or not
118 | """
119 | if self.config.torchscript:
120 | first_module.weight = nn.Parameter(second_module.weight.clone())
121 | else:
122 | first_module.weight = second_module.weight
123 |
124 | if hasattr(first_module, 'bias') and first_module.bias is not None:
125 | first_module.bias.data = torch.nn.functional.pad(
126 | first_module.bias.data,
127 | (0, first_module.weight.shape[0] - first_module.bias.shape[0]),
128 | 'constant',
129 | 0
130 | )
131 |
132 | def resize_token_embeddings(self, new_num_tokens=None):
133 | """ Resize input token embeddings matrix of the model if new_num_tokens != config.vocab_size.
134 | Take care of tying weights embeddings afterwards if the model class has a `tie_weights()` method.
135 | Arguments:
136 | new_num_tokens: (`optional`) int:
137 | New number of tokens in the embedding matrix. Increasing the size will add newly initialized vectors at the end. Reducing the size will remove vectors from the end.
138 | If not provided or None: does nothing and just returns a pointer to the input tokens ``torch.nn.Embeddings`` Module of the model.
139 | Return: ``torch.nn.Embeddings``
140 | Pointer to the input tokens Embeddings Module of the model
141 | """
142 | base_model = getattr(self, self.base_model_prefix, self) # get the base model if needed
143 | model_embeds = base_model._resize_token_embeddings(new_num_tokens) # word_embeds
144 | if new_num_tokens is None:
145 | return model_embeds
146 |
147 | # Update base model and current model config
148 | self.config.vocab_size = new_num_tokens
149 | base_model.vocab_size = new_num_tokens
150 |
151 | # Tie weights again if needed
152 | if hasattr(self, 'tie_weights'):
153 | self.tie_weights()
154 |
155 | return model_embeds
156 |
157 | def init_weights(self):
158 | """ Initialize and prunes weights if needed. """
159 | # Initialize weights
160 | self.apply(self._init_weights)
161 |
162 | # Prune heads if needed
163 | if self.config.pruned_heads:
164 | self.prune_heads(self.config.pruned_heads)
165 |
166 | def prune_heads(self, heads_to_prune):
167 | """ Prunes heads of the base model.
168 | Arguments:
169 | heads_to_prune: dict with keys being selected layer indices (`int`) and associated values being the list of heads to prune in said layer (list of `int`).
170 | E.g. {1: [0, 2], 2: [2, 3]} will prune heads 0 and 2 on layer 1 and heads 2 and 3 on layer 2.
171 | """
172 | base_model = getattr(self, self.base_model_prefix, self) # get the base model if needed
173 |
174 | # save new sets of pruned heads as union of previously stored pruned heads and newly pruned heads
175 | for layer, heads in heads_to_prune.items():
176 | union_heads = set(self.config.pruned_heads.get(layer, [])) | set(heads)
177 | self.config.pruned_heads[layer] = list(union_heads) # Unfortunately we have to store it as list for JSON
178 |
179 | base_model._prune_heads(heads_to_prune)
180 |
181 | def save_pretrained(self, save_directory):
182 | """ Save a model and its configuration file to a directory, so that it
183 | can be re-loaded using the `:func:`~pytorch_transformers.PreTrainedModel.from_pretrained`` class method.
184 | """
185 | assert os.path.isdir(save_directory), "Saving path should be a directory where the model and configuration can be saved"
186 |
187 | # Only save the model it-self if we are using distributed training
188 | model_to_save = self.module if hasattr(self, 'module') else self
189 |
190 | # Save configuration file
191 | model_to_save.config.save_pretrained(save_directory)
192 |
193 | # If we save using the predefined names, we can load using `from_pretrained`
194 | output_model_file = os.path.join(save_directory, WEIGHTS_NAME)
195 |
196 | torch.save(model_to_save.state_dict(), output_model_file)
197 |
198 | @classmethod
199 | def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
200 | r"""Instantiate a pretrained pytorch model from a pre-trained model configuration.
201 | The model is set in evaluation mode by default using ``model.eval()`` (Dropout modules are deactivated)
202 | To train the model, you should first set it back in training mode with ``model.train()``
203 | The warning ``Weights from XXX not initialized from pretrained model`` means that the weights of XXX do not come pre-trained with the rest of the model.
204 | It is up to you to train those weights with a downstream fine-tuning task.
205 | The warning ``Weights from XXX not used in YYY`` means that the layer XXX is not used by YYY, therefore those weights are discarded.
206 | Parameters:
207 | pretrained_model_name_or_path: either:
208 | - a string with the `shortcut name` of a pre-trained model to load from cache or download, e.g.: ``bert-base-uncased``.
209 | - a path to a `directory` containing model weights saved using :func:`~pytorch_transformers.PreTrainedModel.save_pretrained`, e.g.: ``./my_model_directory/``.
210 | - a path or url to a `tensorflow index checkpoint file` (e.g. `./tf_model/model.ckpt.index`). In this case, ``from_tf`` should be set to True and a configuration object should be provided as ``config`` argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
211 | model_args: (`optional`) Sequence of positional arguments:
212 | All remaning positional arguments will be passed to the underlying model's ``__init__`` method
213 | config: (`optional`) instance of a class derived from :class:`~pytorch_transformers.PretrainedConfig`:
214 | Configuration for the model to use instead of an automatically loaded configuation. Configuration can be automatically loaded when:
215 | - the model is a model provided by the library (loaded with the ``shortcut-name`` string of a pretrained model), or
216 | - the model was saved using :func:`~pytorch_transformers.PreTrainedModel.save_pretrained` and is reloaded by suppling the save directory.
217 | - the model is loaded by suppling a local directory as ``pretrained_model_name_or_path`` and a configuration JSON file named `config.json` is found in the directory.
218 | state_dict: (`optional`) dict:
219 | an optional state dictionnary for the model to use instead of a state dictionary loaded from saved weights file.
220 | This option can be used if you want to create a model from a pretrained configuration but load your own weights.
221 | In this case though, you should check if using :func:`~pytorch_transformers.PreTrainedModel.save_pretrained` and :func:`~pytorch_transformers.PreTrainedModel.from_pretrained` is not a simpler option.
222 | cache_dir: (`optional`) string:
223 | Path to a directory in which a downloaded pre-trained model
224 | configuration should be cached if the standard cache should not be used.
225 | force_download: (`optional`) boolean, default False:
226 | Force to (re-)download the model weights and configuration files and override the cached versions if they exists.
227 | proxies: (`optional`) dict, default None:
228 | A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
229 | The proxies are used on each request.
230 | output_loading_info: (`optional`) boolean:
231 | Set to ``True`` to also return a dictionnary containing missing keys, unexpected keys and error messages.
232 | kwargs: (`optional`) Remaining dictionary of keyword arguments:
233 | Can be used to update the configuration object (after it being loaded) and initiate the model. (e.g. ``output_attention=True``). Behave differently depending on whether a `config` is provided or automatically loaded:
234 | - If a configuration is provided with ``config``, ``**kwargs`` will be directly passed to the underlying model's ``__init__`` method (we assume all relevant updates to the configuration have already been done)
235 | - If a configuration is not provided, ``kwargs`` will be first passed to the configuration class initialization function (:func:`~pytorch_transformers.PretrainedConfig.from_pretrained`). Each key of ``kwargs`` that corresponds to a configuration attribute will be used to override said attribute with the supplied ``kwargs`` value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model's ``__init__`` function.
236 | Examples::
237 | model = BertModel.from_pretrained('bert-base-uncased') # Download model and configuration from S3 and cache.
238 | model = BertModel.from_pretrained('./test/saved_model/') # E.g. model was saved using `save_pretrained('./test/saved_model/')`
239 | model = BertModel.from_pretrained('bert-base-uncased', output_attention=True) # Update configuration during loading
240 | assert model.config.output_attention == True
241 | # Loading from a TF checkpoint file instead of a PyTorch model (slower)
242 | config = BertConfig.from_json_file('./tf_model/my_tf_model_config.json')
243 | model = BertModel.from_pretrained('./tf_model/my_tf_checkpoint.ckpt.index', from_tf=True, config=config)
244 | """
245 | config = kwargs.pop('config', None)
246 | state_dict = kwargs.pop('state_dict', None)
247 | cache_dir = kwargs.pop('cache_dir', None)
248 | from_tf = kwargs.pop('from_tf', False)
249 | force_download = kwargs.pop('force_download', False)
250 | proxies = kwargs.pop('proxies', None)
251 | output_loading_info = kwargs.pop('output_loading_info', False)
252 |
253 | # Load config
254 | if config is None:
255 | config, model_kwargs = cls.config_class.from_pretrained(
256 | pretrained_model_name_or_path, *model_args,
257 | cache_dir=cache_dir, return_unused_kwargs=True,
258 | force_download=force_download,
259 | **kwargs
260 | )
261 | else:
262 | model_kwargs = kwargs
263 |
264 | # Load model
265 | if pretrained_model_name_or_path in cls.pretrained_model_archive_map:
266 | archive_file = cls.pretrained_model_archive_map[pretrained_model_name_or_path]
267 | elif os.path.isdir(pretrained_model_name_or_path):
268 | if from_tf:
269 | # Directly load from a TensorFlow checkpoint
270 | archive_file = os.path.join(pretrained_model_name_or_path, TF_WEIGHTS_NAME + ".index")
271 | else:
272 | archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
273 | else:
274 | if from_tf:
275 | # Directly load from a TensorFlow checkpoint
276 | archive_file = pretrained_model_name_or_path + ".index"
277 | else:
278 | archive_file = pretrained_model_name_or_path
279 | # redirect to the cache, if necessary
280 | try:
281 | resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir, force_download=force_download, proxies=proxies)
282 | except EnvironmentError as e:
283 | if pretrained_model_name_or_path in cls.pretrained_model_archive_map:
284 | logger.error(
285 | "Couldn't reach server at '{}' to download pretrained weights.".format(
286 | archive_file))
287 | else:
288 | logger.error(
289 | "Model name '{}' was not found in model name list ({}). "
290 | "We assumed '{}' was a path or url but couldn't find any file "
291 | "associated to this path or url.".format(
292 | pretrained_model_name_or_path,
293 | ', '.join(cls.pretrained_model_archive_map.keys()),
294 | archive_file))
295 | raise e
296 | if resolved_archive_file == archive_file:
297 | logger.info("loading weights file {}".format(archive_file))
298 | else:
299 | logger.info("loading weights file {} from cache at {}".format(
300 | archive_file, resolved_archive_file))
301 |
302 | # Instantiate model.
303 | model = cls(config, *model_args, **model_kwargs)
304 |
305 | if state_dict is None and not from_tf:
306 | state_dict = torch.load(resolved_archive_file, map_location='cpu')
307 | # taesun -> domaing post training model
308 | # state_dict.keys() : "model", "optimizer"
309 | if resolved_archive_file.endswith('pth'):
310 | for state_key in state_dict["model"].keys():
311 | if state_key.startswith("_bert_model.bert."):
312 | state_dict[state_key[len("_bert_model.bert."):]] = state_dict["model"][state_key]
313 |
314 | if from_tf:
315 | # Directly load from a TensorFlow checkpoint
316 | return cls.load_tf_weights(model, config, resolved_archive_file[:-6]) # Remove the '.index'
317 |
318 | # Convert old format to new format if needed from a PyTorch state_dict
319 | old_keys = []
320 | new_keys = []
321 | for key in state_dict.keys():
322 | new_key = None
323 | if 'gamma' in key:
324 | new_key = key.replace('gamma', 'weight')
325 | if 'beta' in key:
326 | new_key = key.replace('beta', 'bias')
327 | if new_key:
328 | old_keys.append(key)
329 | new_keys.append(new_key)
330 | for old_key, new_key in zip(old_keys, new_keys):
331 | state_dict[new_key] = state_dict.pop(old_key)
332 |
333 | # Load from a PyTorch state_dict
334 | missing_keys = []
335 | unexpected_keys = []
336 | error_msgs = []
337 | # copy state_dict so _load_from_state_dict can modify it
338 | metadata = getattr(state_dict, '_metadata', None)
339 | state_dict = state_dict.copy()
340 | if metadata is not None:
341 | state_dict._metadata = metadata
342 |
343 | def load(module, prefix=''):
344 | local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
345 | module._load_from_state_dict(
346 | state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
347 | for name, child in module._modules.items():
348 | if child is not None:
349 | load(child, prefix + name + '.')
350 |
351 | # Make sure we are able to load base models as well as derived models (with heads)
352 | start_prefix = ''
353 | model_to_load = model
354 | if not hasattr(model, cls.base_model_prefix) and any(s.startswith(cls.base_model_prefix) for s in state_dict.keys()):
355 | start_prefix = cls.base_model_prefix + '.'
356 | if hasattr(model, cls.base_model_prefix) and not any(s.startswith(cls.base_model_prefix) for s in state_dict.keys()):
357 | model_to_load = getattr(model, cls.base_model_prefix)
358 |
359 | load(model_to_load, prefix=start_prefix)
360 | if len(missing_keys) > 0:
361 | logger.info("Weights of {} not initialized from pretrained model: {}".format(
362 | model.__class__.__name__, missing_keys))
363 | if len(unexpected_keys) > 0:
364 | logger.info("Weights from pretrained model not used in {}: {}".format(
365 | model.__class__.__name__, unexpected_keys))
366 | if len(error_msgs) > 0:
367 | raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
368 | model.__class__.__name__, "\n\t".join(error_msgs)))
369 |
370 | if hasattr(model, 'tie_weights'):
371 | model.tie_weights() # make sure word embedding weights are still tied
372 |
373 | # Set model in evaluation mode to desactivate DropOut modules by default
374 | model.eval()
375 |
376 | if output_loading_info:
377 | loading_info = {"missing_keys": missing_keys, "unexpected_keys": unexpected_keys, "error_msgs": error_msgs}
378 | return model, loading_info
379 |
380 | return model
381 |
382 |
383 | class Conv1D(nn.Module):
384 | def __init__(self, nf, nx):
385 | """ Conv1D layer as defined by Radford et al. for OpenAI GPT (and also used in GPT-2)
386 | Basically works like a Linear layer but the weights are transposed
387 | """
388 | super(Conv1D, self).__init__()
389 | self.nf = nf
390 | w = torch.empty(nx, nf)
391 | nn.init.normal_(w, std=0.02)
392 | self.weight = nn.Parameter(w)
393 | self.bias = nn.Parameter(torch.zeros(nf))
394 |
395 | def forward(self, x):
396 | size_out = x.size()[:-1] + (self.nf,)
397 | x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight)
398 | x = x.view(*size_out)
399 | return x
400 |
401 |
402 | class PoolerStartLogits(nn.Module):
403 | """ Compute SQuAD start_logits from sequence hidden states. """
404 | def __init__(self, config):
405 | super(PoolerStartLogits, self).__init__()
406 | self.dense = nn.Linear(config.hidden_size, 1)
407 |
408 | def forward(self, hidden_states, p_mask=None):
409 | """ Args:
410 | **p_mask**: (`optional`) ``torch.FloatTensor`` of shape `(batch_size, seq_len)`
411 | invalid position mask such as query and special symbols (PAD, SEP, CLS)
412 | 1.0 means token should be masked.
413 | """
414 | x = self.dense(hidden_states).squeeze(-1)
415 |
416 | if p_mask is not None:
417 | x = x * (1 - p_mask) - 1e30 * p_mask
418 |
419 | return x
420 |
421 |
422 | class PoolerEndLogits(nn.Module):
423 | """ Compute SQuAD end_logits from sequence hidden states and start token hidden state.
424 | """
425 | def __init__(self, config):
426 | super(PoolerEndLogits, self).__init__()
427 | self.dense_0 = nn.Linear(config.hidden_size * 2, config.hidden_size)
428 | self.activation = nn.Tanh()
429 | self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
430 | self.dense_1 = nn.Linear(config.hidden_size, 1)
431 |
432 | def forward(self, hidden_states, start_states=None, start_positions=None, p_mask=None):
433 | """ Args:
434 | One of ``start_states``, ``start_positions`` should be not None.
435 | If both are set, ``start_positions`` overrides ``start_states``.
436 | **start_states**: ``torch.LongTensor`` of shape identical to hidden_states
437 | hidden states of the first tokens for the labeled span.
438 | **start_positions**: ``torch.LongTensor`` of shape ``(batch_size,)``
439 | position of the first token for the labeled span:
440 | **p_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, seq_len)``
441 | Mask of invalid position such as query and special symbols (PAD, SEP, CLS)
442 | 1.0 means token should be masked.
443 | """
444 | assert start_states is not None or start_positions is not None, "One of start_states, start_positions should be not None"
445 | if start_positions is not None:
446 | slen, hsz = hidden_states.shape[-2:]
447 | start_positions = start_positions[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz)
448 | start_states = hidden_states.gather(-2, start_positions) # shape (bsz, 1, hsz)
449 | start_states = start_states.expand(-1, slen, -1) # shape (bsz, slen, hsz)
450 |
451 | x = self.dense_0(torch.cat([hidden_states, start_states], dim=-1))
452 | x = self.activation(x)
453 | x = self.LayerNorm(x)
454 | x = self.dense_1(x).squeeze(-1)
455 |
456 | if p_mask is not None:
457 | x = x * (1 - p_mask) - 1e30 * p_mask
458 |
459 | return x
460 |
461 |
462 | class PoolerAnswerClass(nn.Module):
463 | """ Compute SQuAD 2.0 answer class from classification and start tokens hidden states. """
464 | def __init__(self, config):
465 | super(PoolerAnswerClass, self).__init__()
466 | self.dense_0 = nn.Linear(config.hidden_size * 2, config.hidden_size)
467 | self.activation = nn.Tanh()
468 | self.dense_1 = nn.Linear(config.hidden_size, 1, bias=False)
469 |
470 | def forward(self, hidden_states, start_states=None, start_positions=None, cls_index=None):
471 | """
472 | Args:
473 | One of ``start_states``, ``start_positions`` should be not None.
474 | If both are set, ``start_positions`` overrides ``start_states``.
475 | **start_states**: ``torch.LongTensor`` of shape identical to ``hidden_states``.
476 | hidden states of the first tokens for the labeled span.
477 | **start_positions**: ``torch.LongTensor`` of shape ``(batch_size,)``
478 | position of the first token for the labeled span.
479 | **cls_index**: torch.LongTensor of shape ``(batch_size,)``
480 | position of the CLS token. If None, take the last token.
481 | note(Original repo):
482 | no dependency on end_feature so that we can obtain one single `cls_logits`
483 | for each sample
484 | """
485 | hsz = hidden_states.shape[-1]
486 | assert start_states is not None or start_positions is not None, "One of start_states, start_positions should be not None"
487 | if start_positions is not None:
488 | start_positions = start_positions[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz)
489 | start_states = hidden_states.gather(-2, start_positions).squeeze(-2) # shape (bsz, hsz)
490 |
491 | if cls_index is not None:
492 | cls_index = cls_index[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz)
493 | cls_token_state = hidden_states.gather(-2, cls_index).squeeze(-2) # shape (bsz, hsz)
494 | else:
495 | cls_token_state = hidden_states[:, -1, :] # shape (bsz, hsz)
496 |
497 | x = self.dense_0(torch.cat([start_states, cls_token_state], dim=-1))
498 | x = self.activation(x)
499 | x = self.dense_1(x).squeeze(-1)
500 |
501 | return x
502 |
503 |
504 | class SQuADHead(nn.Module):
505 | r""" A SQuAD head inspired by XLNet.
506 | Parameters:
507 | config (:class:`~pytorch_transformers.XLNetConfig`): Model configuration class with all the parameters of the model.
508 | Inputs:
509 | **hidden_states**: ``torch.FloatTensor`` of shape ``(batch_size, seq_len, hidden_size)``
510 | hidden states of sequence tokens
511 | **start_positions**: ``torch.LongTensor`` of shape ``(batch_size,)``
512 | position of the first token for the labeled span.
513 | **end_positions**: ``torch.LongTensor`` of shape ``(batch_size,)``
514 | position of the last token for the labeled span.
515 | **cls_index**: torch.LongTensor of shape ``(batch_size,)``
516 | position of the CLS token. If None, take the last token.
517 | **is_impossible**: ``torch.LongTensor`` of shape ``(batch_size,)``
518 | Whether the question has a possible answer in the paragraph or not.
519 | **p_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, seq_len)``
520 | Mask of invalid position such as query and special symbols (PAD, SEP, CLS)
521 | 1.0 means token should be masked.
522 | Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
523 | **loss**: (`optional`, returned if both ``start_positions`` and ``end_positions`` are provided) ``torch.FloatTensor`` of shape ``(1,)``:
524 | Classification loss as the sum of start token, end token (and is_impossible if provided) classification losses.
525 | **start_top_log_probs**: (`optional`, returned if ``start_positions`` or ``end_positions`` is not provided)
526 | ``torch.FloatTensor`` of shape ``(batch_size, config.start_n_top)``
527 | Log probabilities for the top config.start_n_top start token possibilities (beam-search).
528 | **start_top_index**: (`optional`, returned if ``start_positions`` or ``end_positions`` is not provided)
529 | ``torch.LongTensor`` of shape ``(batch_size, config.start_n_top)``
530 | Indices for the top config.start_n_top start token possibilities (beam-search).
531 | **end_top_log_probs**: (`optional`, returned if ``start_positions`` or ``end_positions`` is not provided)
532 | ``torch.FloatTensor`` of shape ``(batch_size, config.start_n_top * config.end_n_top)``
533 | Log probabilities for the top ``config.start_n_top * config.end_n_top`` end token possibilities (beam-search).
534 | **end_top_index**: (`optional`, returned if ``start_positions`` or ``end_positions`` is not provided)
535 | ``torch.LongTensor`` of shape ``(batch_size, config.start_n_top * config.end_n_top)``
536 | Indices for the top ``config.start_n_top * config.end_n_top`` end token possibilities (beam-search).
537 | **cls_logits**: (`optional`, returned if ``start_positions`` or ``end_positions`` is not provided)
538 | ``torch.FloatTensor`` of shape ``(batch_size,)``
539 | Log probabilities for the ``is_impossible`` label of the answers.
540 | """
541 | def __init__(self, config):
542 | super(SQuADHead, self).__init__()
543 | self.start_n_top = config.start_n_top
544 | self.end_n_top = config.end_n_top
545 |
546 | self.start_logits = PoolerStartLogits(config)
547 | self.end_logits = PoolerEndLogits(config)
548 | self.answer_class = PoolerAnswerClass(config)
549 |
550 | def forward(self, hidden_states, start_positions=None, end_positions=None,
551 | cls_index=None, is_impossible=None, p_mask=None):
552 | outputs = ()
553 |
554 | start_logits = self.start_logits(hidden_states, p_mask=p_mask)
555 |
556 | if start_positions is not None and end_positions is not None:
557 | # If we are on multi-GPU, let's remove the dimension added by batch splitting
558 | for x in (start_positions, end_positions, cls_index, is_impossible):
559 | if x is not None and x.dim() > 1:
560 | x.squeeze_(-1)
561 |
562 | # during training, compute the end logits based on the ground truth of the start position
563 | end_logits = self.end_logits(hidden_states, start_positions=start_positions, p_mask=p_mask)
564 |
565 | loss_fct = CrossEntropyLoss()
566 | start_loss = loss_fct(start_logits, start_positions)
567 | end_loss = loss_fct(end_logits, end_positions)
568 | total_loss = (start_loss + end_loss) / 2
569 |
570 | if cls_index is not None and is_impossible is not None:
571 | # Predict answerability from the representation of CLS and START
572 | cls_logits = self.answer_class(hidden_states, start_positions=start_positions, cls_index=cls_index)
573 | loss_fct_cls = nn.BCEWithLogitsLoss()
574 | cls_loss = loss_fct_cls(cls_logits, is_impossible)
575 |
576 | # note(zhiliny): by default multiply the loss by 0.5 so that the scale is comparable to start_loss and end_loss
577 | total_loss += cls_loss * 0.5
578 |
579 | outputs = (total_loss,) + outputs
580 |
581 | else:
582 | # during inference, compute the end logits based on beam search
583 | bsz, slen, hsz = hidden_states.size()
584 | start_log_probs = F.softmax(start_logits, dim=-1) # shape (bsz, slen)
585 |
586 | start_top_log_probs, start_top_index = torch.topk(start_log_probs, self.start_n_top, dim=-1) # shape (bsz, start_n_top)
587 | start_top_index_exp = start_top_index.unsqueeze(-1).expand(-1, -1, hsz) # shape (bsz, start_n_top, hsz)
588 | start_states = torch.gather(hidden_states, -2, start_top_index_exp) # shape (bsz, start_n_top, hsz)
589 | start_states = start_states.unsqueeze(1).expand(-1, slen, -1, -1) # shape (bsz, slen, start_n_top, hsz)
590 |
591 | hidden_states_expanded = hidden_states.unsqueeze(2).expand_as(start_states) # shape (bsz, slen, start_n_top, hsz)
592 | p_mask = p_mask.unsqueeze(-1) if p_mask is not None else None
593 | end_logits = self.end_logits(hidden_states_expanded, start_states=start_states, p_mask=p_mask)
594 | end_log_probs = F.softmax(end_logits, dim=1) # shape (bsz, slen, start_n_top)
595 |
596 | end_top_log_probs, end_top_index = torch.topk(end_log_probs, self.end_n_top, dim=1) # shape (bsz, end_n_top, start_n_top)
597 | end_top_log_probs = end_top_log_probs.view(-1, self.start_n_top * self.end_n_top)
598 | end_top_index = end_top_index.view(-1, self.start_n_top * self.end_n_top)
599 |
600 | start_states = torch.einsum("blh,bl->bh", hidden_states, start_log_probs)
601 | cls_logits = self.answer_class(hidden_states, start_states=start_states, cls_index=cls_index)
602 |
603 | outputs = (start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits) + outputs
604 |
605 | # return start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits
606 | # or (if labels are provided) (total_loss,)
607 | return outputs
608 |
609 |
610 | class SequenceSummary(nn.Module):
611 | r""" Compute a single vector summary of a sequence hidden states according to various possibilities:
612 | Args of the config class:
613 | summary_type:
614 | - 'last' => [default] take the last token hidden state (like XLNet)
615 | - 'first' => take the first token hidden state (like Bert)
616 | - 'mean' => take the mean of all tokens hidden states
617 | - 'cls_index' => supply a Tensor of classification token position (GPT/GPT-2)
618 | - 'attn' => Not implemented now, use multi-head attention
619 | summary_use_proj: Add a projection after the vector extraction
620 | summary_proj_to_labels: If True, the projection outputs to config.num_labels classes (otherwise to hidden_size). Default: False.
621 | summary_activation: 'tanh' => add a tanh activation to the output, Other => no activation. Default
622 | summary_first_dropout: Add a dropout before the projection and activation
623 | summary_last_dropout: Add a dropout after the projection and activation
624 | """
625 | def __init__(self, config):
626 | super(SequenceSummary, self).__init__()
627 |
628 | self.summary_type = config.summary_type if hasattr(config, 'summary_use_proj') else 'last'
629 | if self.summary_type == 'attn':
630 | # We should use a standard multi-head attention module with absolute positional embedding for that.
631 | # Cf. https://github.com/zihangdai/xlnet/blob/master/modeling.py#L253-L276
632 | # We can probably just use the multi-head attention module of PyTorch >=1.1.0
633 | raise NotImplementedError
634 |
635 | self.summary = Identity()
636 | if hasattr(config, 'summary_use_proj') and config.summary_use_proj:
637 | if hasattr(config, 'summary_proj_to_labels') and config.summary_proj_to_labels and config.num_labels > 0:
638 | num_classes = config.num_labels
639 | else:
640 | num_classes = config.hidden_size
641 | self.summary = nn.Linear(config.hidden_size, num_classes)
642 |
643 | self.activation = Identity()
644 | if hasattr(config, 'summary_activation') and config.summary_activation == 'tanh':
645 | self.activation = nn.Tanh()
646 |
647 | self.first_dropout = Identity()
648 | if hasattr(config, 'summary_first_dropout') and config.summary_first_dropout > 0:
649 | self.first_dropout = nn.Dropout(config.summary_first_dropout)
650 |
651 | self.last_dropout = Identity()
652 | if hasattr(config, 'summary_last_dropout') and config.summary_last_dropout > 0:
653 | self.last_dropout = nn.Dropout(config.summary_last_dropout)
654 |
655 | def forward(self, hidden_states, cls_index=None):
656 | """ hidden_states: float Tensor in shape [bsz, seq_len, hidden_size], the hidden-states of the last layer.
657 | cls_index: [optional] position of the classification token if summary_type == 'cls_index',
658 | shape (bsz,) or more generally (bsz, ...) where ... are optional leading dimensions of hidden_states.
659 | if summary_type == 'cls_index' and cls_index is None:
660 | we take the last token of the sequence as classification token
661 | """
662 | if self.summary_type == 'last':
663 | output = hidden_states[:, -1]
664 | elif self.summary_type == 'first':
665 | output = hidden_states[:, 0]
666 | elif self.summary_type == 'mean':
667 | output = hidden_states.mean(dim=1)
668 | elif self.summary_type == 'cls_index':
669 | if cls_index is None:
670 | cls_index = torch.full_like(hidden_states[..., :1, :], hidden_states.shape[-2]-1, dtype=torch.long)
671 | else:
672 | cls_index = cls_index.unsqueeze(-1).unsqueeze(-1)
673 | cls_index = cls_index.expand((-1,) * (cls_index.dim()-1) + (hidden_states.size(-1),))
674 | # shape of cls_index: (bsz, XX, 1, hidden_size) where XX are optional leading dim of hidden_states
675 | output = hidden_states.gather(-2, cls_index).squeeze(-2) # shape (bsz, XX, hidden_size)
676 | elif self.summary_type == 'attn':
677 | raise NotImplementedError
678 |
679 | output = self.first_dropout(output)
680 | output = self.summary(output)
681 | output = self.activation(output)
682 | output = self.last_dropout(output)
683 |
684 | return output
685 |
686 |
687 | def prune_linear_layer(layer, index, dim=0):
688 | """ Prune a linear layer (a model parameters) to keep only entries in index.
689 | Return the pruned layer as a new layer with requires_grad=True.
690 | Used to remove heads.
691 | """
692 | index = index.to(layer.weight.device)
693 | W = layer.weight.index_select(dim, index).clone().detach()
694 | if layer.bias is not None:
695 | if dim == 1:
696 | b = layer.bias.clone().detach()
697 | else:
698 | b = layer.bias[index].clone().detach()
699 | new_size = list(layer.weight.size())
700 | new_size[dim] = len(index)
701 | new_layer = nn.Linear(new_size[1], new_size[0], bias=layer.bias is not None).to(layer.weight.device)
702 | new_layer.weight.requires_grad = False
703 | new_layer.weight.copy_(W.contiguous())
704 | new_layer.weight.requires_grad = True
705 | if layer.bias is not None:
706 | new_layer.bias.requires_grad = False
707 | new_layer.bias.copy_(b.contiguous())
708 | new_layer.bias.requires_grad = True
709 | return new_layer
710 |
711 |
712 | def prune_conv1d_layer(layer, index, dim=1):
713 | """ Prune a Conv1D layer (a model parameters) to keep only entries in index.
714 | A Conv1D work as a Linear layer (see e.g. BERT) but the weights are transposed.
715 | Return the pruned layer as a new layer with requires_grad=True.
716 | Used to remove heads.
717 | """
718 | index = index.to(layer.weight.device)
719 | W = layer.weight.index_select(dim, index).clone().detach()
720 | if dim == 0:
721 | b = layer.bias.clone().detach()
722 | else:
723 | b = layer.bias[index].clone().detach()
724 | new_size = list(layer.weight.size())
725 | new_size[dim] = len(index)
726 | new_layer = Conv1D(new_size[1], new_size[0]).to(layer.weight.device)
727 | new_layer.weight.requires_grad = False
728 | new_layer.weight.copy_(W.contiguous())
729 | new_layer.weight.requires_grad = True
730 | new_layer.bias.requires_grad = False
731 | new_layer.bias.copy_(b.contiguous())
732 | new_layer.bias.requires_grad = True
733 | return new_layer
734 |
735 | def prune_layer(layer, index, dim=None):
736 | """ Prune a Conv1D or nn.Linear layer (a model parameters) to keep only entries in index.
737 | Return the pruned layer as a new layer with requires_grad=True.
738 | Used to remove heads.
739 | """
740 | if isinstance(layer, nn.Linear):
741 | return prune_linear_layer(layer, index, dim=0 if dim is None else dim)
742 | elif isinstance(layer, Conv1D):
743 | return prune_conv1d_layer(layer, index, dim=1 if dim is None else dim)
744 | else:
745 | raise ValueError("Can't prune layer of class {}".format(layer.__class__))
--------------------------------------------------------------------------------