├── README.md ├── data ├── dev.tsv ├── test └── train.tsv ├── download_glue_data.py ├── modeling.py ├── optimization.py ├── run.sh ├── run_classifier_word.py └── tokenization_word.py /README.md: -------------------------------------------------------------------------------- 1 | # bert-Chinese-classification-task 2 | bert中文分类实践 3 | 4 | 在run_classifier_word.py中添加NewsProcessor,即新闻的预处理读入部分 \ 5 | 在main方法中添加news类型数据处理label \ 6 | processors = { \ 7 | "cola": ColaProcessor,\ 8 | "mnli": MnliProcessor,\ 9 | "mrpc": MrpcProcessor,\ 10 | "news": NewsProcessor,\ 11 | } 12 | 13 | download_glue_data.py 提供glue_data下面其他的bert论文公测glue数据下载 14 | 15 | data目录下是news数据的样例 16 | 17 | export GLUE_DIR=/search/odin/bert/extract_code/glue_data \ 18 | export BERT_BASE_DIR=/search/odin/bert/chinese_L-12_H-768_A-12/ \ 19 | export BERT_PYTORCH_DIR=/search/odin/bert/chinese_L-12_H-768_A-12/ 20 | 21 | python run_classifier_word.py \ 22 | --task_name NEWS \ 23 | --do_train \ 24 | --do_eval \ 25 | --data_dir $GLUE_DIR/NewsAll/ \ 26 | --vocab_file $BERT_BASE_DIR/vocab.txt \ 27 | --bert_config_file $BERT_BASE_DIR/bert_config.json \ 28 | --init_checkpoint $BERT_PYTORCH_DIR/pytorch_model.bin \ 29 | --max_seq_length 256 \ 30 | --train_batch_size 32 \ 31 | --learning_rate 2e-5 \ 32 | --num_train_epochs 3.0 \ 33 | --output_dir ./newsAll_output/ \ 34 | --local_rank 3 35 | 36 | 中文分类任务实践 37 | 38 | 实验中对中文34个topic进行实践(包括:时政,娱乐,体育等),在对run_classifier.py代码中的预处理环节需要加入NewsProcessor模块,及类似于MrpcProcessor,但是需要对中文的编码进行适当修改,训练数据与测试数据按照4:1进行切割,数据量约80万,单卡GPU资源,训练时间18小时,acc为92.8% 39 | 40 | eval_accuracy = 0.9281581998809113 41 | 42 | eval_loss = 0.2222444740207354 43 | 44 | global_step = 59826 45 | 46 | loss = 0.14488934577978746 47 | -------------------------------------------------------------------------------- /data/dev.tsv: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NLPScott/bert-Chinese-classification-task/4106887959424f3c19a9f068f92f7d5729844fc1/data/dev.tsv -------------------------------------------------------------------------------- /data/test: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /data/train.tsv: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NLPScott/bert-Chinese-classification-task/4106887959424f3c19a9f068f92f7d5729844fc1/data/train.tsv -------------------------------------------------------------------------------- /download_glue_data.py: -------------------------------------------------------------------------------- 1 | ''' Script for downloading all GLUE data. 2 | 3 | Note: for legal reasons, we are unable to host MRPC. 4 | You can either use the version hosted by the SentEval team, which is already tokenized, 5 | or you can download the original data from (https://download.microsoft.com/download/D/4/6/D46FF87A-F6B9-4252-AA8B-3604ED519838/MSRParaphraseCorpus.msi) and extract the data from it manually. 6 | For Windows users, you can run the .msi file. For Mac and Linux users, consider an external library such as 'cabextract' (see below for an example). 7 | You should then rename and place specific files in a folder (see below for an example). 8 | 9 | mkdir MRPC 10 | cabextract MSRParaphraseCorpus.msi -d MRPC 11 | cat MRPC/_2DEC3DBE877E4DB192D17C0256E90F1D | tr -d $'\r' > MRPC/msr_paraphrase_train.txt 12 | cat MRPC/_D7B391F9EAFF4B1B8BCE8F21B20B1B61 | tr -d $'\r' > MRPC/msr_paraphrase_test.txt 13 | rm MRPC/_* 14 | rm MSRParaphraseCorpus.msi 15 | ''' 16 | 17 | import os 18 | import sys 19 | import shutil 20 | import argparse 21 | import tempfile 22 | import urllib.request 23 | import zipfile 24 | 25 | TASKS = ["CoLA", "SST", "MRPC", "QQP", "STS", "MNLI", "SNLI", "QNLI", "RTE", "WNLI", "diagnostic"] 26 | TASK2PATH = {"CoLA":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FCoLA.zip?alt=media&token=46d5e637-3411-4188-bc44-5809b5bfb5f4', 27 | "SST":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FSST-2.zip?alt=media&token=aabc5f6b-e466-44a2-b9b4-cf6337f84ac8', 28 | "MRPC":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2Fmrpc_dev_ids.tsv?alt=media&token=ec5c0836-31d5-48f4-b431-7480817f1adc', 29 | "QQP":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FQQP.zip?alt=media&token=700c6acf-160d-4d89-81d1-de4191d02cb5', 30 | "STS":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FSTS-B.zip?alt=media&token=bddb94a7-8706-4e0d-a694-1109e12273b5', 31 | "MNLI":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FMNLI.zip?alt=media&token=50329ea1-e339-40e2-809c-10c40afff3ce', 32 | "SNLI":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FSNLI.zip?alt=media&token=4afcfbb2-ff0c-4b2d-a09a-dbf07926f4df', 33 | "QNLI":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FQNLI.zip?alt=media&token=c24cad61-f2df-4f04-9ab6-aa576fa829d0', 34 | "RTE":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FRTE.zip?alt=media&token=5efa7e85-a0bb-4f19-8ea2-9e1840f077fb', 35 | "WNLI":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FWNLI.zip?alt=media&token=068ad0a0-ded7-4bd7-99a5-5e00222e0faf', 36 | "diagnostic":'https://storage.googleapis.com/mtl-sentence-representations.appspot.com/tsvsWithoutLabels%2FAX.tsv?GoogleAccessId=firebase-adminsdk-0khhl@mtl-sentence-representations.iam.gserviceaccount.com&Expires=2498860800&Signature=DuQ2CSPt2Yfre0C%2BiISrVYrIFaZH1Lc7hBVZDD4ZyR7fZYOMNOUGpi8QxBmTNOrNPjR3z1cggo7WXFfrgECP6FBJSsURv8Ybrue8Ypt%2FTPxbuJ0Xc2FhDi%2BarnecCBFO77RSbfuz%2Bs95hRrYhTnByqu3U%2FYZPaj3tZt5QdfpH2IUROY8LiBXoXS46LE%2FgOQc%2FKN%2BA9SoscRDYsnxHfG0IjXGwHN%2Bf88q6hOmAxeNPx6moDulUF6XMUAaXCSFU%2BnRO2RDL9CapWxj%2BDl7syNyHhB7987hZ80B%2FwFkQ3MEs8auvt5XW1%2Bd4aCU7ytgM69r8JDCwibfhZxpaa4gd50QXQ%3D%3D'} 37 | 38 | MRPC_TRAIN = 'https://s3.amazonaws.com/senteval/senteval_data/msr_paraphrase_train.txt' 39 | MRPC_TEST = 'https://s3.amazonaws.com/senteval/senteval_data/msr_paraphrase_test.txt' 40 | 41 | def download_and_extract(task, data_dir): 42 | print("Downloading and extracting %s..." % task) 43 | data_file = "%s.zip" % task 44 | urllib.request.urlretrieve(TASK2PATH[task], data_file) 45 | with zipfile.ZipFile(data_file) as zip_ref: 46 | zip_ref.extractall(data_dir) 47 | os.remove(data_file) 48 | print("\tCompleted!") 49 | 50 | def format_mrpc(data_dir, path_to_data): 51 | print("Processing MRPC...") 52 | mrpc_dir = os.path.join(data_dir, "MRPC") 53 | if not os.path.isdir(mrpc_dir): 54 | os.mkdir(mrpc_dir) 55 | if path_to_data: 56 | mrpc_train_file = os.path.join(path_to_data, "msr_paraphrase_train.txt") 57 | mrpc_test_file = os.path.join(path_to_data, "msr_paraphrase_test.txt") 58 | else: 59 | mrpc_train_file = os.path.join(mrpc_dir, "msr_paraphrase_train.txt") 60 | mrpc_test_file = os.path.join(mrpc_dir, "msr_paraphrase_test.txt") 61 | urllib.request.urlretrieve(MRPC_TRAIN, mrpc_train_file) 62 | urllib.request.urlretrieve(MRPC_TEST, mrpc_test_file) 63 | assert os.path.isfile(mrpc_train_file), "Train data not found at %s" % mrpc_train_file 64 | assert os.path.isfile(mrpc_test_file), "Test data not found at %s" % mrpc_test_file 65 | urllib.request.urlretrieve(TASK2PATH["MRPC"], os.path.join(mrpc_dir, "dev_ids.tsv")) 66 | 67 | dev_ids = [] 68 | with open(os.path.join(mrpc_dir, "dev_ids.tsv")) as ids_fh: 69 | for row in ids_fh: 70 | dev_ids.append(row.strip().split('\t')) 71 | 72 | with open(mrpc_train_file) as data_fh, \ 73 | open(os.path.join(mrpc_dir, "train.tsv"), 'w') as train_fh, \ 74 | open(os.path.join(mrpc_dir, "dev.tsv"), 'w') as dev_fh: 75 | header = data_fh.readline() 76 | train_fh.write(header) 77 | dev_fh.write(header) 78 | for row in data_fh: 79 | label, id1, id2, s1, s2 = row.strip().split('\t') 80 | if [id1, id2] in dev_ids: 81 | dev_fh.write("%s\t%s\t%s\t%s\t%s\n" % (label, id1, id2, s1, s2)) 82 | else: 83 | train_fh.write("%s\t%s\t%s\t%s\t%s\n" % (label, id1, id2, s1, s2)) 84 | 85 | with open(mrpc_test_file) as data_fh, \ 86 | open(os.path.join(mrpc_dir, "test.tsv"), 'w') as test_fh: 87 | header = data_fh.readline() 88 | test_fh.write("index\t#1 ID\t#2 ID\t#1 String\t#2 String\n") 89 | for idx, row in enumerate(data_fh): 90 | label, id1, id2, s1, s2 = row.strip().split('\t') 91 | test_fh.write("%d\t%s\t%s\t%s\t%s\n" % (idx, id1, id2, s1, s2)) 92 | print("\tCompleted!") 93 | 94 | def download_diagnostic(data_dir): 95 | print("Downloading and extracting diagnostic...") 96 | if not os.path.isdir(os.path.join(data_dir, "diagnostic")): 97 | os.mkdir(os.path.join(data_dir, "diagnostic")) 98 | data_file = os.path.join(data_dir, "diagnostic", "diagnostic.tsv") 99 | urllib.request.urlretrieve(TASK2PATH["diagnostic"], data_file) 100 | print("\tCompleted!") 101 | return 102 | 103 | def get_tasks(task_names): 104 | task_names = task_names.split(',') 105 | if "all" in task_names: 106 | tasks = TASKS 107 | else: 108 | tasks = [] 109 | for task_name in task_names: 110 | assert task_name in TASKS, "Task %s not found!" % task_name 111 | tasks.append(task_name) 112 | return tasks 113 | 114 | def main(arguments): 115 | parser = argparse.ArgumentParser() 116 | parser.add_argument('--data_dir', help='directory to save data to', type=str, default='glue_data') 117 | parser.add_argument('--tasks', help='tasks to download data for as a comma separated string', 118 | type=str, default='all') 119 | parser.add_argument('--path_to_mrpc', help='path to directory containing extracted MRPC data, msr_paraphrase_train.txt and msr_paraphrase_text.txt', 120 | type=str, default='') 121 | args = parser.parse_args(arguments) 122 | 123 | if not os.path.isdir(args.data_dir): 124 | os.mkdir(args.data_dir) 125 | tasks = get_tasks(args.tasks) 126 | 127 | for task in tasks: 128 | if task == 'MRPC': 129 | format_mrpc(args.data_dir, args.path_to_mrpc) 130 | elif task == 'diagnostic': 131 | download_diagnostic(args.data_dir) 132 | else: 133 | download_and_extract(task, args.data_dir) 134 | 135 | 136 | if __name__ == '__main__': 137 | sys.exit(main(sys.argv[1:])) -------------------------------------------------------------------------------- /modeling.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """PyTorch BERT model.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import copy 22 | import json 23 | import math 24 | import six 25 | import torch 26 | import torch.nn as nn 27 | from torch.nn import CrossEntropyLoss 28 | 29 | def gelu(x): 30 | """Implementation of the gelu activation function. 31 | For information: OpenAI GPT's gelu is slightly different (and gives slightly different results): 32 | 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) 33 | """ 34 | return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) 35 | 36 | 37 | class BertConfig(object): 38 | """Configuration class to store the configuration of a `BertModel`. 39 | """ 40 | def __init__(self, 41 | vocab_size, 42 | hidden_size=768, 43 | num_hidden_layers=12, 44 | num_attention_heads=12, 45 | intermediate_size=3072, 46 | hidden_act="gelu", 47 | hidden_dropout_prob=0.1, 48 | attention_probs_dropout_prob=0.1, 49 | max_position_embeddings=512, 50 | type_vocab_size=16, 51 | initializer_range=0.02): 52 | """Constructs BertConfig. 53 | 54 | Args: 55 | vocab_size: Vocabulary size of `inputs_ids` in `BertModel`. 56 | hidden_size: Size of the encoder layers and the pooler layer. 57 | num_hidden_layers: Number of hidden layers in the Transformer encoder. 58 | num_attention_heads: Number of attention heads for each attention layer in 59 | the Transformer encoder. 60 | intermediate_size: The size of the "intermediate" (i.e., feed-forward) 61 | layer in the Transformer encoder. 62 | hidden_act: The non-linear activation function (function or string) in the 63 | encoder and pooler. 64 | hidden_dropout_prob: The dropout probabilitiy for all fully connected 65 | layers in the embeddings, encoder, and pooler. 66 | attention_probs_dropout_prob: The dropout ratio for the attention 67 | probabilities. 68 | max_position_embeddings: The maximum sequence length that this model might 69 | ever be used with. Typically set this to something large just in case 70 | (e.g., 512 or 1024 or 2048). 71 | type_vocab_size: The vocabulary size of the `token_type_ids` passed into 72 | `BertModel`. 73 | initializer_range: The sttdev of the truncated_normal_initializer for 74 | initializing all weight matrices. 75 | """ 76 | self.vocab_size = vocab_size 77 | self.hidden_size = hidden_size 78 | self.num_hidden_layers = num_hidden_layers 79 | self.num_attention_heads = num_attention_heads 80 | self.hidden_act = hidden_act 81 | self.intermediate_size = intermediate_size 82 | self.hidden_dropout_prob = hidden_dropout_prob 83 | self.attention_probs_dropout_prob = attention_probs_dropout_prob 84 | self.max_position_embeddings = max_position_embeddings 85 | self.type_vocab_size = type_vocab_size 86 | self.initializer_range = initializer_range 87 | 88 | @classmethod 89 | def from_dict(cls, json_object): 90 | """Constructs a `BertConfig` from a Python dictionary of parameters.""" 91 | config = BertConfig(vocab_size=None) 92 | for (key, value) in six.iteritems(json_object): 93 | config.__dict__[key] = value 94 | return config 95 | 96 | @classmethod 97 | def from_json_file(cls, json_file): 98 | """Constructs a `BertConfig` from a json file of parameters.""" 99 | with open(json_file, "r") as reader: 100 | text = reader.read() 101 | return cls.from_dict(json.loads(text)) 102 | 103 | def to_dict(self): 104 | """Serializes this instance to a Python dictionary.""" 105 | output = copy.deepcopy(self.__dict__) 106 | return output 107 | 108 | def to_json_string(self): 109 | """Serializes this instance to a JSON string.""" 110 | return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n" 111 | 112 | 113 | class BERTLayerNorm(nn.Module): 114 | def __init__(self, config, variance_epsilon=1e-12): 115 | """Construct a layernorm module in the TF style (epsilon inside the square root). 116 | """ 117 | super(BERTLayerNorm, self).__init__() 118 | self.gamma = nn.Parameter(torch.ones(config.hidden_size)) 119 | self.beta = nn.Parameter(torch.zeros(config.hidden_size)) 120 | self.variance_epsilon = variance_epsilon 121 | 122 | def forward(self, x): 123 | u = x.mean(-1, keepdim=True) 124 | s = (x - u).pow(2).mean(-1, keepdim=True) 125 | x = (x - u) / torch.sqrt(s + self.variance_epsilon) 126 | return self.gamma * x + self.beta 127 | 128 | class BERTEmbeddings(nn.Module): 129 | def __init__(self, config): 130 | super(BERTEmbeddings, self).__init__() 131 | """Construct the embedding module from word, position and token_type embeddings. 132 | """ 133 | self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size) 134 | self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) 135 | self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) 136 | 137 | # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load 138 | # any TensorFlow checkpoint file 139 | self.LayerNorm = BERTLayerNorm(config) 140 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 141 | 142 | def forward(self, input_ids, token_type_ids=None): 143 | seq_length = input_ids.size(1) 144 | position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) 145 | position_ids = position_ids.unsqueeze(0).expand_as(input_ids) 146 | if token_type_ids is None: 147 | token_type_ids = torch.zeros_like(input_ids) 148 | 149 | words_embeddings = self.word_embeddings(input_ids) 150 | position_embeddings = self.position_embeddings(position_ids) 151 | token_type_embeddings = self.token_type_embeddings(token_type_ids) 152 | 153 | embeddings = words_embeddings + position_embeddings + token_type_embeddings 154 | embeddings = self.LayerNorm(embeddings) 155 | embeddings = self.dropout(embeddings) 156 | return embeddings 157 | 158 | 159 | class BERTSelfAttention(nn.Module): 160 | def __init__(self, config): 161 | super(BERTSelfAttention, self).__init__() 162 | if config.hidden_size % config.num_attention_heads != 0: 163 | raise ValueError( 164 | "The hidden size (%d) is not a multiple of the number of attention " 165 | "heads (%d)" % (config.hidden_size, config.num_attention_heads)) 166 | self.num_attention_heads = config.num_attention_heads 167 | self.attention_head_size = int(config.hidden_size / config.num_attention_heads) 168 | self.all_head_size = self.num_attention_heads * self.attention_head_size 169 | 170 | self.query = nn.Linear(config.hidden_size, self.all_head_size) 171 | self.key = nn.Linear(config.hidden_size, self.all_head_size) 172 | self.value = nn.Linear(config.hidden_size, self.all_head_size) 173 | 174 | self.dropout = nn.Dropout(config.attention_probs_dropout_prob) 175 | 176 | def transpose_for_scores(self, x): 177 | new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) 178 | x = x.view(*new_x_shape) 179 | return x.permute(0, 2, 1, 3) 180 | 181 | def forward(self, hidden_states, attention_mask): 182 | mixed_query_layer = self.query(hidden_states) 183 | mixed_key_layer = self.key(hidden_states) 184 | mixed_value_layer = self.value(hidden_states) 185 | 186 | query_layer = self.transpose_for_scores(mixed_query_layer) 187 | key_layer = self.transpose_for_scores(mixed_key_layer) 188 | value_layer = self.transpose_for_scores(mixed_value_layer) 189 | 190 | # Take the dot product between "query" and "key" to get the raw attention scores. 191 | attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) 192 | attention_scores = attention_scores / math.sqrt(self.attention_head_size) 193 | # Apply the attention mask is (precomputed for all layers in BertModel forward() function) 194 | attention_scores = attention_scores + attention_mask 195 | 196 | # Normalize the attention scores to probabilities. 197 | attention_probs = nn.Softmax(dim=-1)(attention_scores) 198 | 199 | # This is actually dropping out entire tokens to attend to, which might 200 | # seem a bit unusual, but is taken from the original Transformer paper. 201 | attention_probs = self.dropout(attention_probs) 202 | 203 | context_layer = torch.matmul(attention_probs, value_layer) 204 | context_layer = context_layer.permute(0, 2, 1, 3).contiguous() 205 | new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) 206 | context_layer = context_layer.view(*new_context_layer_shape) 207 | return context_layer 208 | 209 | 210 | class BERTSelfOutput(nn.Module): 211 | def __init__(self, config): 212 | super(BERTSelfOutput, self).__init__() 213 | self.dense = nn.Linear(config.hidden_size, config.hidden_size) 214 | self.LayerNorm = BERTLayerNorm(config) 215 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 216 | 217 | def forward(self, hidden_states, input_tensor): 218 | hidden_states = self.dense(hidden_states) 219 | hidden_states = self.dropout(hidden_states) 220 | hidden_states = self.LayerNorm(hidden_states + input_tensor) 221 | return hidden_states 222 | 223 | 224 | class BERTAttention(nn.Module): 225 | def __init__(self, config): 226 | super(BERTAttention, self).__init__() 227 | self.self = BERTSelfAttention(config) 228 | self.output = BERTSelfOutput(config) 229 | 230 | def forward(self, input_tensor, attention_mask): 231 | self_output = self.self(input_tensor, attention_mask) 232 | attention_output = self.output(self_output, input_tensor) 233 | return attention_output 234 | 235 | 236 | class BERTIntermediate(nn.Module): 237 | def __init__(self, config): 238 | super(BERTIntermediate, self).__init__() 239 | self.dense = nn.Linear(config.hidden_size, config.intermediate_size) 240 | self.intermediate_act_fn = gelu 241 | 242 | def forward(self, hidden_states): 243 | hidden_states = self.dense(hidden_states) 244 | hidden_states = self.intermediate_act_fn(hidden_states) 245 | return hidden_states 246 | 247 | 248 | class BERTOutput(nn.Module): 249 | def __init__(self, config): 250 | super(BERTOutput, self).__init__() 251 | self.dense = nn.Linear(config.intermediate_size, config.hidden_size) 252 | self.LayerNorm = BERTLayerNorm(config) 253 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 254 | 255 | def forward(self, hidden_states, input_tensor): 256 | hidden_states = self.dense(hidden_states) 257 | hidden_states = self.dropout(hidden_states) 258 | hidden_states = self.LayerNorm(hidden_states + input_tensor) 259 | return hidden_states 260 | 261 | 262 | class BERTLayer(nn.Module): 263 | def __init__(self, config): 264 | super(BERTLayer, self).__init__() 265 | self.attention = BERTAttention(config) 266 | self.intermediate = BERTIntermediate(config) 267 | self.output = BERTOutput(config) 268 | 269 | def forward(self, hidden_states, attention_mask): 270 | attention_output = self.attention(hidden_states, attention_mask) 271 | intermediate_output = self.intermediate(attention_output) 272 | layer_output = self.output(intermediate_output, attention_output) 273 | return layer_output 274 | 275 | 276 | class BERTEncoder(nn.Module): 277 | def __init__(self, config): 278 | super(BERTEncoder, self).__init__() 279 | layer = BERTLayer(config) 280 | self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.num_hidden_layers)]) 281 | 282 | def forward(self, hidden_states, attention_mask): 283 | all_encoder_layers = [] 284 | for layer_module in self.layer: 285 | hidden_states = layer_module(hidden_states, attention_mask) 286 | all_encoder_layers.append(hidden_states) 287 | return all_encoder_layers 288 | 289 | 290 | class BERTPooler(nn.Module): 291 | def __init__(self, config): 292 | super(BERTPooler, self).__init__() 293 | self.dense = nn.Linear(config.hidden_size, config.hidden_size) 294 | self.activation = nn.Tanh() 295 | 296 | def forward(self, hidden_states): 297 | # We "pool" the model by simply taking the hidden state corresponding 298 | # to the first token. 299 | first_token_tensor = hidden_states[:, 0] 300 | pooled_output = self.dense(first_token_tensor) 301 | pooled_output = self.activation(pooled_output) 302 | return pooled_output 303 | 304 | 305 | class BertModel(nn.Module): 306 | """BERT model ("Bidirectional Embedding Representations from a Transformer"). 307 | 308 | Example usage: 309 | ```python 310 | # Already been converted into WordPiece token ids 311 | input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) 312 | input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) 313 | token_type_ids = torch.LongTensor([[0, 0, 1], [0, 2, 0]]) 314 | 315 | config = modeling.BertConfig(vocab_size=32000, hidden_size=512, 316 | num_hidden_layers=8, num_attention_heads=6, intermediate_size=1024) 317 | 318 | model = modeling.BertModel(config=config) 319 | all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask) 320 | ``` 321 | """ 322 | def __init__(self, config: BertConfig): 323 | """Constructor for BertModel. 324 | 325 | Args: 326 | config: `BertConfig` instance. 327 | """ 328 | super(BertModel, self).__init__() 329 | self.embeddings = BERTEmbeddings(config) 330 | self.encoder = BERTEncoder(config) 331 | self.pooler = BERTPooler(config) 332 | 333 | def forward(self, input_ids, token_type_ids=None, attention_mask=None): 334 | if attention_mask is None: 335 | attention_mask = torch.ones_like(input_ids) 336 | if token_type_ids is None: 337 | token_type_ids = torch.zeros_like(input_ids) 338 | 339 | # We create a 3D attention mask from a 2D tensor mask. 340 | # Sizes are [batch_size, 1, 1, to_seq_length] 341 | # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] 342 | # this attention mask is more simple than the triangular masking of causal attention 343 | # used in OpenAI GPT, we just need to prepare the broadcast dimension here. 344 | extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) 345 | 346 | # Since attention_mask is 1.0 for positions we want to attend and 0.0 for 347 | # masked positions, this operation will create a tensor which is 0.0 for 348 | # positions we want to attend and -10000.0 for masked positions. 349 | # Since we are adding it to the raw scores before the softmax, this is 350 | # effectively the same as removing these entirely. 351 | extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility 352 | extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 353 | 354 | embedding_output = self.embeddings(input_ids, token_type_ids) 355 | all_encoder_layers = self.encoder(embedding_output, extended_attention_mask) 356 | sequence_output = all_encoder_layers[-1] 357 | pooled_output = self.pooler(sequence_output) 358 | return all_encoder_layers, pooled_output 359 | 360 | class BertForSequenceClassification(nn.Module): 361 | """BERT model for classification. 362 | This module is composed of the BERT model with a linear layer on top of 363 | the pooled output. 364 | 365 | Example usage: 366 | ```python 367 | # Already been converted into WordPiece token ids 368 | input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) 369 | input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) 370 | token_type_ids = torch.LongTensor([[0, 0, 1], [0, 2, 0]]) 371 | 372 | config = BertConfig(vocab_size=32000, hidden_size=512, 373 | num_hidden_layers=8, num_attention_heads=6, intermediate_size=1024) 374 | 375 | num_labels = 2 376 | 377 | model = BertForSequenceClassification(config, num_labels) 378 | logits = model(input_ids, token_type_ids, input_mask) 379 | ``` 380 | """ 381 | def __init__(self, config, num_labels): 382 | super(BertForSequenceClassification, self).__init__() 383 | self.bert = BertModel(config) 384 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 385 | self.classifier = nn.Linear(config.hidden_size, num_labels) 386 | 387 | def init_weights(module): 388 | if isinstance(module, (nn.Linear, nn.Embedding)): 389 | # Slightly different from the TF version which uses truncated_normal for initialization 390 | # cf https://github.com/pytorch/pytorch/pull/5617 391 | module.weight.data.normal_(mean=0.0, std=config.initializer_range) 392 | elif isinstance(module, BERTLayerNorm): 393 | module.beta.data.normal_(mean=0.0, std=config.initializer_range) 394 | module.gamma.data.normal_(mean=0.0, std=config.initializer_range) 395 | if isinstance(module, nn.Linear): 396 | module.bias.data.zero_() 397 | self.apply(init_weights) 398 | 399 | def forward(self, input_ids, token_type_ids, attention_mask, labels=None): 400 | _, pooled_output = self.bert(input_ids, token_type_ids, attention_mask) 401 | pooled_output = self.dropout(pooled_output) 402 | logits = self.classifier(pooled_output) 403 | 404 | if labels is not None: 405 | loss_fct = CrossEntropyLoss() 406 | loss = loss_fct(logits, labels) 407 | return loss, logits 408 | else: 409 | return logits 410 | 411 | class BertForQuestionAnswering(nn.Module): 412 | """BERT model for Question Answering (span extraction). 413 | This module is composed of the BERT model with a linear layer on top of 414 | the sequence output that computes start_logits and end_logits 415 | 416 | Example usage: 417 | ```python 418 | # Already been converted into WordPiece token ids 419 | input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) 420 | input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) 421 | token_type_ids = torch.LongTensor([[0, 0, 1], [0, 2, 0]]) 422 | 423 | config = BertConfig(vocab_size=32000, hidden_size=512, 424 | num_hidden_layers=8, num_attention_heads=6, intermediate_size=1024) 425 | 426 | model = BertForQuestionAnswering(config) 427 | start_logits, end_logits = model(input_ids, token_type_ids, input_mask) 428 | ``` 429 | """ 430 | def __init__(self, config): 431 | super(BertForQuestionAnswering, self).__init__() 432 | self.bert = BertModel(config) 433 | # TODO check with Google if it's normal there is no dropout on the token classifier of SQuAD in the TF version 434 | # self.dropout = nn.Dropout(config.hidden_dropout_prob) 435 | self.qa_outputs = nn.Linear(config.hidden_size, 2) 436 | 437 | def init_weights(module): 438 | if isinstance(module, (nn.Linear, nn.Embedding)): 439 | # Slightly different from the TF version which uses truncated_normal for initialization 440 | # cf https://github.com/pytorch/pytorch/pull/5617 441 | module.weight.data.normal_(mean=0.0, std=config.initializer_range) 442 | elif isinstance(module, BERTLayerNorm): 443 | module.beta.data.normal_(mean=0.0, std=config.initializer_range) 444 | module.gamma.data.normal_(mean=0.0, std=config.initializer_range) 445 | if isinstance(module, nn.Linear): 446 | module.bias.data.zero_() 447 | self.apply(init_weights) 448 | 449 | def forward(self, input_ids, token_type_ids, attention_mask, start_positions=None, end_positions=None): 450 | all_encoder_layers, _ = self.bert(input_ids, token_type_ids, attention_mask) 451 | sequence_output = all_encoder_layers[-1] 452 | logits = self.qa_outputs(sequence_output) 453 | start_logits, end_logits = logits.split(1, dim=-1) 454 | start_logits = start_logits.squeeze(-1) 455 | end_logits = end_logits.squeeze(-1) 456 | 457 | if start_positions is not None and end_positions is not None: 458 | # If we are on multi-GPU, split add a dimension 459 | if len(start_positions.size()) > 1: 460 | start_positions = start_positions.squeeze(-1) 461 | if len(end_positions.size()) > 1: 462 | end_positions = end_positions.squeeze(-1) 463 | # sometimes the start/end positions are outside our model inputs, we ignore these terms 464 | ignored_index = start_logits.size(1) 465 | start_positions.clamp_(0, ignored_index) 466 | end_positions.clamp_(0, ignored_index) 467 | 468 | loss_fct = CrossEntropyLoss(ignore_index=ignored_index) 469 | start_loss = loss_fct(start_logits, start_positions) 470 | end_loss = loss_fct(end_logits, end_positions) 471 | total_loss = (start_loss + end_loss) / 2 472 | return total_loss 473 | else: 474 | return start_logits, end_logits 475 | -------------------------------------------------------------------------------- /optimization.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """PyTorch optimization for BERT model.""" 16 | 17 | import math 18 | import torch 19 | from torch.optim import Optimizer 20 | from torch.nn.utils import clip_grad_norm_ 21 | 22 | def warmup_cosine(x, warmup=0.002): 23 | if x < warmup: 24 | return x/warmup 25 | return 0.5 * (1.0 + torch.cos(math.pi * x)) 26 | 27 | def warmup_constant(x, warmup=0.002): 28 | if x < warmup: 29 | return x/warmup 30 | return 1.0 31 | 32 | def warmup_linear(x, warmup=0.002): 33 | if x < warmup: 34 | return x/warmup 35 | return 1.0 - x 36 | 37 | SCHEDULES = { 38 | 'warmup_cosine':warmup_cosine, 39 | 'warmup_constant':warmup_constant, 40 | 'warmup_linear':warmup_linear, 41 | } 42 | 43 | 44 | class BERTAdam(Optimizer): 45 | """Implements BERT version of Adam algorithm with weight decay fix (and no ). 46 | Params: 47 | lr: learning rate 48 | warmup: portion of t_total for the warmup, -1 means no warmup. Default: -1 49 | t_total: total number of training steps for the learning 50 | rate schedule, -1 means constant learning rate. Default: -1 51 | schedule: schedule to use for the warmup (see above). Default: 'warmup_linear' 52 | b1: Adams b1. Default: 0.9 53 | b2: Adams b2. Default: 0.999 54 | e: Adams epsilon. Default: 1e-6 55 | weight_decay_rate: Weight decay. Default: 0.01 56 | max_grad_norm: Maximum norm for the gradients (-1 means no clipping). Default: 1.0 57 | """ 58 | def __init__(self, params, lr, warmup=-1, t_total=-1, schedule='warmup_linear', 59 | b1=0.9, b2=0.999, e=1e-6, weight_decay_rate=0.01, 60 | max_grad_norm=1.0): 61 | if not lr >= 0.0: 62 | raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr)) 63 | if schedule not in SCHEDULES: 64 | raise ValueError("Invalid schedule parameter: {}".format(schedule)) 65 | if not 0.0 <= warmup < 1.0 and not warmup == -1: 66 | raise ValueError("Invalid warmup: {} - should be in [0.0, 1.0[ or -1".format(warmup)) 67 | if not 0.0 <= b1 < 1.0: 68 | raise ValueError("Invalid b1 parameter: {} - should be in [0.0, 1.0[".format(b1)) 69 | if not 0.0 <= b2 < 1.0: 70 | raise ValueError("Invalid b2 parameter: {} - should be in [0.0, 1.0[".format(b2)) 71 | if not e >= 0.0: 72 | raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(e)) 73 | defaults = dict(lr=lr, schedule=schedule, warmup=warmup, t_total=t_total, 74 | b1=b1, b2=b2, e=e, weight_decay_rate=weight_decay_rate, 75 | max_grad_norm=max_grad_norm) 76 | super(BERTAdam, self).__init__(params, defaults) 77 | 78 | def get_lr(self): 79 | lr = [] 80 | for group in self.param_groups: 81 | for p in group['params']: 82 | state = self.state[p] 83 | if len(state) == 0: 84 | return [0] 85 | if group['t_total'] != -1: 86 | schedule_fct = SCHEDULES[group['schedule']] 87 | lr_scheduled = group['lr'] * schedule_fct(state['step']/group['t_total'], group['warmup']) 88 | else: 89 | lr_scheduled = group['lr'] 90 | lr.append(lr_scheduled) 91 | return lr 92 | 93 | def to(self, device): 94 | """ Move the optimizer state to a specified device""" 95 | for state in self.state.values(): 96 | state['exp_avg'].to(device) 97 | state['exp_avg_sq'].to(device) 98 | 99 | def initialize_step(self, initial_step): 100 | """Initialize state with a defined step (but we don't have stored averaged). 101 | Arguments: 102 | initial_step (int): Initial step number. 103 | """ 104 | for group in self.param_groups: 105 | for p in group['params']: 106 | state = self.state[p] 107 | # State initialization 108 | state['step'] = initial_step 109 | # Exponential moving average of gradient values 110 | state['exp_avg'] = torch.zeros_like(p.data) 111 | # Exponential moving average of squared gradient values 112 | state['exp_avg_sq'] = torch.zeros_like(p.data) 113 | 114 | def step(self, closure=None): 115 | """Performs a single optimization step. 116 | Arguments: 117 | closure (callable, optional): A closure that reevaluates the model 118 | and returns the loss. 119 | """ 120 | loss = None 121 | if closure is not None: 122 | loss = closure() 123 | 124 | for group in self.param_groups: 125 | for p in group['params']: 126 | if p.grad is None: 127 | continue 128 | grad = p.grad.data 129 | if grad.is_sparse: 130 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') 131 | 132 | state = self.state[p] 133 | 134 | # State initialization 135 | if len(state) == 0: 136 | state['step'] = 0 137 | # Exponential moving average of gradient values 138 | state['next_m'] = torch.zeros_like(p.data) 139 | # Exponential moving average of squared gradient values 140 | state['next_v'] = torch.zeros_like(p.data) 141 | 142 | next_m, next_v = state['next_m'], state['next_v'] 143 | beta1, beta2 = group['b1'], group['b2'] 144 | 145 | # Add grad clipping 146 | if group['max_grad_norm'] > 0: 147 | clip_grad_norm_(p, group['max_grad_norm']) 148 | 149 | # Decay the first and second moment running average coefficient 150 | # In-place operations to update the averages at the same time 151 | next_m.mul_(beta1).add_(1 - beta1, grad) 152 | next_v.mul_(beta2).addcmul_(1 - beta2, grad, grad) 153 | update = next_m / (next_v.sqrt() + group['e']) 154 | 155 | # Just adding the square of the weights to the loss function is *not* 156 | # the correct way of using L2 regularization/weight decay with Adam, 157 | # since that will interact with the m and v parameters in strange ways. 158 | # 159 | # Instead we want ot decay the weights in a manner that doesn't interact 160 | # with the m/v parameters. This is equivalent to adding the square 161 | # of the weights to the loss with plain (non-momentum) SGD. 162 | if group['weight_decay_rate'] > 0.0: 163 | update += group['weight_decay_rate'] * p.data 164 | 165 | if group['t_total'] != -1: 166 | schedule_fct = SCHEDULES[group['schedule']] 167 | lr_scheduled = group['lr'] * schedule_fct(state['step']/group['t_total'], group['warmup']) 168 | else: 169 | lr_scheduled = group['lr'] 170 | 171 | update_with_lr = lr_scheduled * update 172 | p.data.add_(-update_with_lr) 173 | 174 | state['step'] += 1 175 | 176 | # step_size = lr_scheduled * math.sqrt(bias_correction2) / bias_correction1 177 | # bias_correction1 = 1 - beta1 ** state['step'] 178 | # bias_correction2 = 1 - beta2 ** state['step'] 179 | 180 | return loss 181 | -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | 2 | 3 | export GLUE_DIR=/search/odin/wuyonggang/bert/extract_code/glue_data 4 | export BERT_BASE_DIR=/search/odin/wuyonggang/bert/chinese_L-12_H-768_A-12/ 5 | export BERT_PYTORCH_DIR=/search/odin/wuyonggang/bert/chinese_L-12_H-768_A-12/ 6 | 7 | python run_classifier_word.py \ 8 | --task_name NEWS \ 9 | --do_train \ 10 | --do_eval \ 11 | --data_dir $GLUE_DIR/NewsAll/ \ 12 | --vocab_file $BERT_BASE_DIR/vocab.txt \ 13 | --bert_config_file $BERT_BASE_DIR/bert_config.json \ 14 | --init_checkpoint $BERT_PYTORCH_DIR/pytorch_model.bin \ 15 | --max_seq_length 256 \ 16 | --train_batch_size 32 \ 17 | --learning_rate 2e-5 \ 18 | --num_train_epochs 3.0 \ 19 | --output_dir ./newsAll_output/ \ 20 | --local_rank 3 21 | 22 | #python run_classifier_word.py \ 23 | # --task_name NEWS \ 24 | # --do_train \ 25 | # --do_eval \ 26 | # --data_dir $GLUE_DIR/News/ \ 27 | # --vocab_file $BERT_BASE_DIR/vocab.txt \ 28 | # --bert_config_file $BERT_BASE_DIR/bert_config.json \ 29 | # --init_checkpoint $BERT_PYTORCH_DIR/pytorch_model.bin \ 30 | # --max_seq_length 128 \ 31 | # --train_batch_size 32 \ 32 | # --learning_rate 2e-5 \ 33 | # --num_train_epochs 3.0 \ 34 | # --output_dir ./news_output/ \ 35 | # --local_rank 2 36 | 37 | #python run_classifier.py \ 38 | # --task_name MRPC \ 39 | # --do_train \ 40 | # --do_eval \ 41 | # --do_lower_case \ 42 | # --data_dir $GLUE_DIR/MRPC/ \ 43 | # --vocab_file $BERT_BASE_DIR/vocab.txt \ 44 | # --bert_config_file $BERT_BASE_DIR/bert_config.json \ 45 | # --init_checkpoint $BERT_PYTORCH_DIR/pytorch_model.bin \ 46 | # --max_seq_length 128 \ 47 | # --train_batch_size 32 \ 48 | # --learning_rate 2e-5 \ 49 | # --num_train_epochs 3.0 \ 50 | # --output_dir ./mrpc_output/ 51 | -------------------------------------------------------------------------------- /run_classifier_word.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """BERT finetuning runner.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import csv 22 | import os 23 | import logging 24 | import argparse 25 | import random 26 | from tqdm import tqdm, trange 27 | 28 | import numpy as np 29 | import torch 30 | from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler 31 | from torch.utils.data.distributed import DistributedSampler 32 | 33 | import tokenization_word as tokenization 34 | from modeling import BertConfig, BertForSequenceClassification 35 | from optimization import BERTAdam 36 | 37 | 38 | logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', 39 | datefmt = '%m/%d/%Y %H:%M:%S', 40 | level = logging.INFO) 41 | logger = logging.getLogger(__name__) 42 | 43 | 44 | class InputExample(object): 45 | """A single training/test example for simple sequence classification.""" 46 | 47 | def __init__(self, guid, text_a, text_b=None, label=None): 48 | """Constructs a InputExample. 49 | 50 | Args: 51 | guid: Unique id for the example. 52 | text_a: string. The untokenized text of the first sequence. For single 53 | sequence tasks, only this sequence must be specified. 54 | text_b: (Optional) string. The untokenized text of the second sequence. 55 | Only must be specified for sequence pair tasks. 56 | label: (Optional) string. The label of the example. This should be 57 | specified for train and dev examples, but not for test examples. 58 | """ 59 | self.guid = guid 60 | self.text_a = text_a 61 | self.text_b = text_b 62 | self.label = label 63 | 64 | 65 | class InputFeatures(object): 66 | """A single set of features of data.""" 67 | 68 | def __init__(self, input_ids, input_mask, segment_ids, label_id): 69 | self.input_ids = input_ids 70 | self.input_mask = input_mask 71 | self.segment_ids = segment_ids 72 | self.label_id = label_id 73 | 74 | 75 | class DataProcessor(object): 76 | """Base class for data converters for sequence classification data sets.""" 77 | 78 | def get_train_examples(self, data_dir): 79 | """Gets a collection of `InputExample`s for the train set.""" 80 | raise NotImplementedError() 81 | 82 | def get_dev_examples(self, data_dir): 83 | """Gets a collection of `InputExample`s for the dev set.""" 84 | raise NotImplementedError() 85 | 86 | def get_labels(self): 87 | """Gets the list of labels for this data set.""" 88 | raise NotImplementedError() 89 | 90 | @classmethod 91 | def _read_tsv(cls, input_file, quotechar=None): 92 | """Reads a tab separated value file.""" 93 | file_in = open(input_file, "rb") 94 | lines = [] 95 | for line in file_in: 96 | lines.append(line.decode("gbk").split("\t")) 97 | return lines 98 | 99 | class NewsProcessor(DataProcessor): 100 | """Processor for the MRPC data set (GLUE version).""" 101 | 102 | def __init__(self): 103 | self.labels = set() 104 | 105 | def get_train_examples(self, data_dir): 106 | """See base class.""" 107 | logger.info("LOOKING AT {}".format(os.path.join(data_dir, "train.tsv"))) 108 | return self._create_examples( 109 | self._read_tsv(os.path.join(data_dir, "train.tsv")), "train") 110 | 111 | def get_dev_examples(self, data_dir): 112 | """See base class.""" 113 | return self._create_examples( 114 | self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev") 115 | 116 | def get_labels(self): 117 | """See base class.""" 118 | return list(self.labels) 119 | 120 | def _create_examples(self, lines, set_type): 121 | """Creates examples for the training and dev sets.""" 122 | examples = [] 123 | for (i, line) in enumerate(lines): 124 | guid = "%s-%s" % (set_type, i) 125 | text_a = tokenization.convert_to_unicode(line[1]) 126 | label = tokenization.convert_to_unicode(line[0]) 127 | self.labels.add(label) 128 | examples.append( 129 | InputExample(guid=guid, text_a=text_a, text_b=None, label=label)) 130 | 131 | 132 | return examples 133 | 134 | class MrpcProcessor(DataProcessor): 135 | """Processor for the MRPC data set (GLUE version).""" 136 | 137 | def get_train_examples(self, data_dir): 138 | """See base class.""" 139 | logger.info("LOOKING AT {}".format(os.path.join(data_dir, "train.tsv"))) 140 | return self._create_examples( 141 | self._read_tsv(os.path.join(data_dir, "train.tsv")), "train") 142 | 143 | def get_dev_examples(self, data_dir): 144 | """See base class.""" 145 | return self._create_examples( 146 | self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev") 147 | 148 | def get_labels(self): 149 | """See base class.""" 150 | return ["0", "1"] 151 | 152 | def _create_examples(self, lines, set_type): 153 | """Creates examples for the training and dev sets.""" 154 | examples = [] 155 | for (i, line) in enumerate(lines): 156 | if i == 0: 157 | continue 158 | guid = "%s-%s" % (set_type, i) 159 | text_a = tokenization.convert_to_unicode(line[3]) 160 | text_b = tokenization.convert_to_unicode(line[4]) 161 | label = tokenization.convert_to_unicode(line[0]) 162 | examples.append( 163 | InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) 164 | return examples 165 | 166 | class MnliProcessor(DataProcessor): 167 | """Processor for the MultiNLI data set (GLUE version).""" 168 | 169 | def get_train_examples(self, data_dir): 170 | """See base class.""" 171 | return self._create_examples( 172 | self._read_tsv(os.path.join(data_dir, "train.tsv")), "train") 173 | 174 | def get_dev_examples(self, data_dir): 175 | """See base class.""" 176 | return self._create_examples( 177 | self._read_tsv(os.path.join(data_dir, "dev_matched.tsv")), 178 | "dev_matched") 179 | 180 | def get_labels(self): 181 | """See base class.""" 182 | return ["contradiction", "entailment", "neutral"] 183 | 184 | def _create_examples(self, lines, set_type): 185 | """Creates examples for the training and dev sets.""" 186 | examples = [] 187 | for (i, line) in enumerate(lines): 188 | if i == 0: 189 | continue 190 | guid = "%s-%s" % (set_type, tokenization.convert_to_unicode(line[0])) 191 | text_a = tokenization.convert_to_unicode(line[8]) 192 | text_b = tokenization.convert_to_unicode(line[9]) 193 | label = tokenization.convert_to_unicode(line[-1]) 194 | examples.append( 195 | InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) 196 | return examples 197 | 198 | 199 | class ColaProcessor(DataProcessor): 200 | """Processor for the CoLA data set (GLUE version).""" 201 | 202 | def get_train_examples(self, data_dir): 203 | """See base class.""" 204 | return self._create_examples( 205 | self._read_tsv(os.path.join(data_dir, "train.tsv")), "train") 206 | 207 | def get_dev_examples(self, data_dir): 208 | """See base class.""" 209 | return self._create_examples( 210 | self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev") 211 | 212 | def get_labels(self): 213 | """See base class.""" 214 | return ["0", "1"] 215 | 216 | def _create_examples(self, lines, set_type): 217 | """Creates examples for the training and dev sets.""" 218 | examples = [] 219 | for (i, line) in enumerate(lines): 220 | guid = "%s-%s" % (set_type, i) 221 | text_a = tokenization.convert_to_unicode(line[3]) 222 | label = tokenization.convert_to_unicode(line[1]) 223 | examples.append( 224 | InputExample(guid=guid, text_a=text_a, text_b=None, label=label)) 225 | return examples 226 | 227 | 228 | def convert_examples_to_features(examples, label_list, max_seq_length, tokenizer): 229 | """Loads a data file into a list of `InputBatch`s.""" 230 | 231 | label_map = {} 232 | for (i, label) in enumerate(label_list): 233 | label_map[label] = i 234 | features = [] 235 | for (ex_index, example) in enumerate(examples): 236 | tokens_a = tokenizer.tokenize(example.text_a) 237 | 238 | tokens_b = None 239 | if example.text_b: 240 | tokens_b = tokenizer.tokenize(example.text_b) 241 | 242 | if tokens_b: 243 | # Modifies `tokens_a` and `tokens_b` in place so that the total 244 | # length is less than the specified length. 245 | # Account for [CLS], [SEP], [SEP] with "- 3" 246 | _truncate_seq_pair(tokens_a, tokens_b, max_seq_length - 3) 247 | else: 248 | # Account for [CLS] and [SEP] with "- 2" 249 | if len(tokens_a) > max_seq_length - 2: 250 | tokens_a = tokens_a[0:(max_seq_length - 2)] 251 | 252 | # The convention in BERT is: 253 | # (a) For sequence pairs: 254 | # tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP] 255 | # type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1 256 | # (b) For single sequences: 257 | # tokens: [CLS] the dog is hairy . [SEP] 258 | # type_ids: 0 0 0 0 0 0 0 259 | # 260 | # Where "type_ids" are used to indicate whether this is the first 261 | # sequence or the second sequence. The embedding vectors for `type=0` and 262 | # `type=1` were learned during pre-training and are added to the wordpiece 263 | # embedding vector (and position vector). This is not *strictly* necessary 264 | # since the [SEP] token unambigiously separates the sequences, but it makes 265 | # it easier for the model to learn the concept of sequences. 266 | # 267 | # For classification tasks, the first vector (corresponding to [CLS]) is 268 | # used as as the "sentence vector". Note that this only makes sense because 269 | # the entire model is fine-tuned. 270 | tokens = [] 271 | segment_ids = [] 272 | tokens.append("[CLS]") 273 | segment_ids.append(0) 274 | for token in tokens_a: 275 | tokens.append(token) 276 | segment_ids.append(0) 277 | tokens.append("[SEP]") 278 | segment_ids.append(0) 279 | 280 | if tokens_b: 281 | for token in tokens_b: 282 | tokens.append(token) 283 | segment_ids.append(1) 284 | tokens.append("[SEP]") 285 | segment_ids.append(1) 286 | 287 | input_ids = tokenizer.convert_tokens_to_ids(tokens) 288 | 289 | # The mask has 1 for real tokens and 0 for padding tokens. Only real 290 | # tokens are attended to. 291 | input_mask = [1] * len(input_ids) 292 | 293 | # Zero-pad up to the sequence length. 294 | while len(input_ids) < max_seq_length: 295 | input_ids.append(0) 296 | input_mask.append(0) 297 | segment_ids.append(0) 298 | 299 | assert len(input_ids) == max_seq_length 300 | assert len(input_mask) == max_seq_length 301 | assert len(segment_ids) == max_seq_length 302 | 303 | label_id = label_map[example.label] 304 | if ex_index < 5: 305 | logger.info("*** Example ***") 306 | logger.info("guid: %s" % (example.guid)) 307 | logger.info("tokens: %s" % " ".join( 308 | [tokenization.printable_text(x) for x in tokens])) 309 | logger.info("input_ids: %s" % " ".join([str(x) for x in input_ids])) 310 | logger.info("input_mask: %s" % " ".join([str(x) for x in input_mask])) 311 | logger.info( 312 | "segment_ids: %s" % " ".join([str(x) for x in segment_ids])) 313 | logger.info("label: %s (id = %d)" % (example.label, label_id)) 314 | 315 | features.append( 316 | InputFeatures( 317 | input_ids=input_ids, 318 | input_mask=input_mask, 319 | segment_ids=segment_ids, 320 | label_id=label_id)) 321 | return features 322 | 323 | 324 | def _truncate_seq_pair(tokens_a, tokens_b, max_length): 325 | """Truncates a sequence pair in place to the maximum length.""" 326 | 327 | # This is a simple heuristic which will always truncate the longer sequence 328 | # one token at a time. This makes more sense than truncating an equal percent 329 | # of tokens from each, since if one sequence is very short then each token 330 | # that's truncated likely contains more information than a longer sequence. 331 | while True: 332 | total_length = len(tokens_a) + len(tokens_b) 333 | if total_length <= max_length: 334 | break 335 | if len(tokens_a) > len(tokens_b): 336 | tokens_a.pop() 337 | else: 338 | tokens_b.pop() 339 | 340 | def accuracy(out, labels): 341 | outputs = np.argmax(out, axis=1) 342 | return np.sum(outputs==labels) 343 | 344 | def copy_optimizer_params_to_model(named_params_model, named_params_optimizer): 345 | """ Utility function for optimize_on_cpu and 16-bits training. 346 | Copy the parameters optimized on CPU/RAM back to the model on GPU 347 | """ 348 | for (name_opti, param_opti), (name_model, param_model) in zip(named_params_optimizer, named_params_model): 349 | if name_opti != name_model: 350 | logger.error("name_opti != name_model: {} {}".format(name_opti, name_model)) 351 | raise ValueError 352 | param_model.data.copy_(param_opti.data) 353 | 354 | def set_optimizer_params_grad(named_params_optimizer, named_params_model, test_nan=False): 355 | """ Utility function for optimize_on_cpu and 16-bits training. 356 | Copy the gradient of the GPU parameters to the CPU/RAMM copy of the model 357 | """ 358 | is_nan = False 359 | for (name_opti, param_opti), (name_model, param_model) in zip(named_params_optimizer, named_params_model): 360 | if name_opti != name_model: 361 | logger.error("name_opti != name_model: {} {}".format(name_opti, name_model)) 362 | raise ValueError 363 | if test_nan and torch.isnan(param_model.grad).sum() > 0: 364 | is_nan = True 365 | if param_opti.grad is None: 366 | param_opti.grad = torch.nn.Parameter(param_opti.data.new().resize_(*param_opti.data.size())) 367 | param_opti.grad.data.copy_(param_model.grad.data) 368 | return is_nan 369 | 370 | def main(): 371 | parser = argparse.ArgumentParser() 372 | 373 | ## Required parameters 374 | parser.add_argument("--data_dir", 375 | default=None, 376 | type=str, 377 | required=True, 378 | help="The input data dir. Should contain the .tsv files (or other data files) for the task.") 379 | parser.add_argument("--bert_config_file", 380 | default=None, 381 | type=str, 382 | required=True, 383 | help="The config json file corresponding to the pre-trained BERT model. \n" 384 | "This specifies the model architecture.") 385 | parser.add_argument("--task_name", 386 | default=None, 387 | type=str, 388 | required=True, 389 | help="The name of the task to train.") 390 | parser.add_argument("--vocab_file", 391 | default=None, 392 | type=str, 393 | required=True, 394 | help="The vocabulary file that the BERT model was trained on.") 395 | parser.add_argument("--output_dir", 396 | default=None, 397 | type=str, 398 | required=True, 399 | help="The output directory where the model checkpoints will be written.") 400 | 401 | ## Other parameters 402 | parser.add_argument("--init_checkpoint", 403 | default=None, 404 | type=str, 405 | help="Initial checkpoint (usually from a pre-trained BERT model).") 406 | parser.add_argument("--do_lower_case", 407 | default=False, 408 | action='store_true', 409 | help="Whether to lower case the input text. True for uncased models, False for cased models.") 410 | parser.add_argument("--max_seq_length", 411 | default=128, 412 | type=int, 413 | help="The maximum total input sequence length after WordPiece tokenization. \n" 414 | "Sequences longer than this will be truncated, and sequences shorter \n" 415 | "than this will be padded.") 416 | parser.add_argument("--do_train", 417 | default=False, 418 | action='store_true', 419 | help="Whether to run training.") 420 | parser.add_argument("--do_eval", 421 | default=False, 422 | action='store_true', 423 | help="Whether to run eval on the dev set.") 424 | parser.add_argument("--train_batch_size", 425 | default=32, 426 | type=int, 427 | help="Total batch size for training.") 428 | parser.add_argument("--eval_batch_size", 429 | default=8, 430 | type=int, 431 | help="Total batch size for eval.") 432 | parser.add_argument("--learning_rate", 433 | default=5e-5, 434 | type=float, 435 | help="The initial learning rate for Adam.") 436 | parser.add_argument("--num_train_epochs", 437 | default=3.0, 438 | type=float, 439 | help="Total number of training epochs to perform.") 440 | parser.add_argument("--warmup_proportion", 441 | default=0.1, 442 | type=float, 443 | help="Proportion of training to perform linear learning rate warmup for. " 444 | "E.g., 0.1 = 10%% of training.") 445 | parser.add_argument("--save_checkpoints_steps", 446 | default=1000, 447 | type=int, 448 | help="How often to save the model checkpoint.") 449 | parser.add_argument("--no_cuda", 450 | default=False, 451 | action='store_true', 452 | help="Whether not to use CUDA when available") 453 | parser.add_argument("--local_rank", 454 | type=int, 455 | default=-1, 456 | help="local_rank for distributed training on gpus") 457 | parser.add_argument('--seed', 458 | type=int, 459 | default=42, 460 | help="random seed for initialization") 461 | parser.add_argument('--gradient_accumulation_steps', 462 | type=int, 463 | default=1, 464 | help="Number of updates steps to accumualte before performing a backward/update pass.") 465 | parser.add_argument('--optimize_on_cpu', 466 | default=False, 467 | action='store_true', 468 | help="Whether to perform optimization and keep the optimizer averages on CPU") 469 | parser.add_argument('--fp16', 470 | default=False, 471 | action='store_true', 472 | help="Whether to use 16-bit float precision instead of 32-bit") 473 | parser.add_argument('--loss_scale', 474 | type=float, default=128, 475 | help='Loss scaling, positive power of 2 values can improve fp16 convergence.') 476 | 477 | args = parser.parse_args() 478 | 479 | processors = { 480 | "cola": ColaProcessor, 481 | "mnli": MnliProcessor, 482 | "mrpc": MrpcProcessor, 483 | "news": NewsProcessor, 484 | } 485 | 486 | if args.local_rank == -1 or args.no_cuda: 487 | device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") 488 | n_gpu = torch.cuda.device_count() 489 | else: 490 | device = torch.device("cuda", args.local_rank) 491 | n_gpu = 1 492 | # Initializes the distributed backend which will take care of sychronizing nodes/GPUs 493 | # torch.distributed.init_process_group(backend='nccl') 494 | if args.fp16: 495 | logger.info("16-bits training currently not supported in distributed training") 496 | args.fp16 = False # (see https://github.com/pytorch/pytorch/pull/13496) 497 | logger.info("device %s n_gpu %d distributed training %r", device, n_gpu, bool(args.local_rank != -1)) 498 | 499 | if args.gradient_accumulation_steps < 1: 500 | raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format( 501 | args.gradient_accumulation_steps)) 502 | 503 | args.train_batch_size = int(args.train_batch_size / args.gradient_accumulation_steps) 504 | 505 | random.seed(args.seed) 506 | np.random.seed(args.seed) 507 | torch.manual_seed(args.seed) 508 | if n_gpu > 0: 509 | torch.cuda.manual_seed_all(args.seed) 510 | 511 | if not args.do_train and not args.do_eval: 512 | raise ValueError("At least one of `do_train` or `do_eval` must be True.") 513 | 514 | bert_config = BertConfig.from_json_file(args.bert_config_file) 515 | 516 | if args.max_seq_length > bert_config.max_position_embeddings: 517 | raise ValueError( 518 | "Cannot use sequence length {} because the BERT model was only trained up to sequence length {}".format( 519 | args.max_seq_length, bert_config.max_position_embeddings)) 520 | 521 | if os.path.exists(args.output_dir) and os.listdir(args.output_dir): 522 | raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir)) 523 | os.makedirs(args.output_dir, exist_ok=True) 524 | 525 | task_name = args.task_name.lower() 526 | 527 | if task_name not in processors: 528 | raise ValueError("Task not found: %s" % (task_name)) 529 | 530 | 531 | processor = processors[task_name]() 532 | 533 | tokenizer = tokenization.FullTokenizer( 534 | vocab_file=args.vocab_file, do_lower_case=args.do_lower_case) 535 | 536 | train_examples = None 537 | num_train_steps = None 538 | if args.do_train: 539 | train_examples = processor.get_train_examples(args.data_dir) 540 | num_train_steps = int( 541 | len(train_examples) / args.train_batch_size / args.gradient_accumulation_steps * args.num_train_epochs) 542 | 543 | label_list = processor.get_labels() 544 | 545 | print("label_list.size:%d\n" %(len(label_list))) 546 | 547 | # Prepare model 548 | model = BertForSequenceClassification(bert_config, len(label_list)) 549 | if args.init_checkpoint is not None: 550 | model.bert.load_state_dict(torch.load(args.init_checkpoint, map_location='cpu')) 551 | if args.fp16: 552 | model.half() 553 | model.to(device) 554 | #if args.local_rank != -1: 555 | #model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], 556 | # output_device=args.local_rank) 557 | #elif n_gpu > 1: 558 | # model = torch.nn.DataParallel(model) 559 | 560 | # Prepare optimizer 561 | if args.fp16: 562 | param_optimizer = [(n, param.clone().detach().to('cpu').float().requires_grad_()) \ 563 | for n, param in model.named_parameters()] 564 | elif args.optimize_on_cpu: 565 | param_optimizer = [(n, param.clone().detach().to('cpu').requires_grad_()) \ 566 | for n, param in model.named_parameters()] 567 | else: 568 | param_optimizer = list(model.named_parameters()) 569 | no_decay = ['bias', 'gamma', 'beta'] 570 | optimizer_grouped_parameters = [ 571 | {'params': [p for n, p in param_optimizer if n not in no_decay], 'weight_decay_rate': 0.01}, 572 | {'params': [p for n, p in param_optimizer if n in no_decay], 'weight_decay_rate': 0.0} 573 | ] 574 | optimizer = BERTAdam(optimizer_grouped_parameters, 575 | lr=args.learning_rate, 576 | warmup=args.warmup_proportion, 577 | t_total=num_train_steps) 578 | 579 | global_step = 0 580 | if args.do_train: 581 | train_features = convert_examples_to_features( 582 | train_examples, label_list, args.max_seq_length, tokenizer) 583 | logger.info("***** Running training *****") 584 | logger.info(" Num examples = %d", len(train_examples)) 585 | logger.info(" Batch size = %d", args.train_batch_size) 586 | logger.info(" Num steps = %d", num_train_steps) 587 | all_input_ids = torch.tensor([f.input_ids for f in train_features], dtype=torch.long) 588 | all_input_mask = torch.tensor([f.input_mask for f in train_features], dtype=torch.long) 589 | all_segment_ids = torch.tensor([f.segment_ids for f in train_features], dtype=torch.long) 590 | all_label_ids = torch.tensor([f.label_id for f in train_features], dtype=torch.long) 591 | train_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids) 592 | if args.local_rank == -1: 593 | train_sampler = RandomSampler(train_data) 594 | else: 595 | 596 | train_sampler = RandomSampler(train_data) 597 | #train_sampler = DistributedSampler(train_data) 598 | train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size) 599 | 600 | model.train() 601 | for _ in trange(int(args.num_train_epochs), desc="Epoch"): 602 | tr_loss = 0 603 | nb_tr_examples, nb_tr_steps = 0, 0 604 | for step, batch in enumerate(tqdm(train_dataloader, desc="Iteration")): 605 | batch = tuple(t.to(device) for t in batch) 606 | input_ids, input_mask, segment_ids, label_ids = batch 607 | loss, _ = model(input_ids, segment_ids, input_mask, label_ids) 608 | if n_gpu > 1: 609 | loss = loss.mean() # mean() to average on multi-gpu. 610 | if args.fp16 and args.loss_scale != 1.0: 611 | # rescale loss for fp16 training 612 | # see https://docs.nvidia.com/deeplearning/sdk/mixed-precision-training/index.html 613 | loss = loss * args.loss_scale 614 | if args.gradient_accumulation_steps > 1: 615 | loss = loss / args.gradient_accumulation_steps 616 | loss.backward() 617 | tr_loss += loss.item() 618 | nb_tr_examples += input_ids.size(0) 619 | nb_tr_steps += 1 620 | if (step + 1) % args.gradient_accumulation_steps == 0: 621 | if args.fp16 or args.optimize_on_cpu: 622 | if args.fp16 and args.loss_scale != 1.0: 623 | # scale down gradients for fp16 training 624 | for param in model.parameters(): 625 | param.grad.data = param.grad.data / args.loss_scale 626 | is_nan = set_optimizer_params_grad(param_optimizer, model.named_parameters(), test_nan=True) 627 | if is_nan: 628 | logger.info("FP16 TRAINING: Nan in gradients, reducing loss scaling") 629 | args.loss_scale = args.loss_scale / 2 630 | model.zero_grad() 631 | continue 632 | optimizer.step() 633 | copy_optimizer_params_to_model(model.named_parameters(), param_optimizer) 634 | else: 635 | optimizer.step() 636 | model.zero_grad() 637 | global_step += 1 638 | 639 | if args.do_eval: 640 | eval_examples = processor.get_dev_examples(args.data_dir) 641 | eval_features = convert_examples_to_features( 642 | eval_examples, label_list, args.max_seq_length, tokenizer) 643 | logger.info("***** Running evaluation *****") 644 | logger.info(" Num examples = %d", len(eval_examples)) 645 | logger.info(" Batch size = %d", args.eval_batch_size) 646 | all_input_ids = torch.tensor([f.input_ids for f in eval_features], dtype=torch.long) 647 | all_input_mask = torch.tensor([f.input_mask for f in eval_features], dtype=torch.long) 648 | all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long) 649 | all_label_ids = torch.tensor([f.label_id for f in eval_features], dtype=torch.long) 650 | eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids) 651 | if args.local_rank == -1: 652 | eval_sampler = SequentialSampler(eval_data) 653 | else: 654 | 655 | eval_sampler = SequentialSampler(eval_data) 656 | #eval_sampler = DistributedSampler(eval_data) 657 | eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.eval_batch_size) 658 | 659 | model.eval() 660 | eval_loss, eval_accuracy = 0, 0 661 | nb_eval_steps, nb_eval_examples = 0, 0 662 | for input_ids, input_mask, segment_ids, label_ids in eval_dataloader: 663 | input_ids = input_ids.to(device) 664 | input_mask = input_mask.to(device) 665 | segment_ids = segment_ids.to(device) 666 | label_ids = label_ids.to(device) 667 | 668 | with torch.no_grad(): 669 | tmp_eval_loss, logits = model(input_ids, segment_ids, input_mask, label_ids) 670 | 671 | logits = logits.detach().cpu().numpy() 672 | label_ids = label_ids.to('cpu').numpy() 673 | tmp_eval_accuracy = accuracy(logits, label_ids) 674 | 675 | eval_loss += tmp_eval_loss.mean().item() 676 | eval_accuracy += tmp_eval_accuracy 677 | 678 | nb_eval_examples += input_ids.size(0) 679 | nb_eval_steps += 1 680 | 681 | eval_loss = eval_loss / nb_eval_steps 682 | eval_accuracy = eval_accuracy / nb_eval_examples 683 | 684 | result = {'eval_loss': eval_loss, 685 | 'eval_accuracy': eval_accuracy, 686 | 'global_step': global_step, 687 | 'loss': tr_loss/nb_tr_steps} 688 | 689 | output_eval_file = os.path.join(args.output_dir, "eval_results.txt") 690 | with open(output_eval_file, "w") as writer: 691 | logger.info("***** Eval results *****") 692 | for key in sorted(result.keys()): 693 | logger.info(" %s = %s", key, str(result[key])) 694 | writer.write("%s = %s\n" % (key, str(result[key]))) 695 | 696 | if __name__ == "__main__": 697 | main() 698 | -------------------------------------------------------------------------------- /tokenization_word.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Tokenization classes.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import collections 22 | import unicodedata 23 | import six 24 | 25 | def convert_to_unicode(text): 26 | """Converts `text` to Unicode (if it's not already), assuming utf-8 input.""" 27 | if six.PY3: 28 | if isinstance(text, str): 29 | return text 30 | elif isinstance(text, bytes): 31 | return text.decode("utf-8", "ignore") 32 | else: 33 | raise ValueError("Unsupported string type: %s" % (type(text))) 34 | elif six.PY2: 35 | if isinstance(text, str): 36 | return text.decode("utf-8", "ignore") 37 | elif isinstance(text, unicode): 38 | return text 39 | else: 40 | raise ValueError("Unsupported string type: %s" % (type(text))) 41 | else: 42 | raise ValueError("Not running on Python2 or Python 3?") 43 | 44 | 45 | def printable_text(text): 46 | """Returns text encoded in a way suitable for print or `tf.logging`.""" 47 | 48 | # These functions want `str` for both Python2 and Python3, but in one case 49 | # it's a Unicode string and in the other it's a byte string. 50 | if six.PY3: 51 | if isinstance(text, str): 52 | return text 53 | elif isinstance(text, bytes): 54 | return text.decode("utf-8", "ignore") 55 | else: 56 | raise ValueError("Unsupported string type: %s" % (type(text))) 57 | elif six.PY2: 58 | if isinstance(text, str): 59 | return text 60 | elif isinstance(text, unicode): 61 | return text.encode("utf-8") 62 | else: 63 | raise ValueError("Unsupported string type: %s" % (type(text))) 64 | else: 65 | raise ValueError("Not running on Python2 or Python 3?") 66 | 67 | 68 | def load_vocab(vocab_file): 69 | """Loads a vocabulary file into a dictionary.""" 70 | vocab = collections.OrderedDict() 71 | 72 | index_vocab = collections.OrderedDict() 73 | index = 0 74 | with open(vocab_file, "rb") as reader: 75 | while True: 76 | tmp = reader.readline() 77 | token = convert_to_unicode(tmp) 78 | 79 | 80 | if not token: 81 | break 82 | 83 | #file_out.write("%d\t%s\n" %(index,token)) 84 | token = token.strip() 85 | vocab[token] = index 86 | index_vocab[index]=token 87 | index += 1 88 | 89 | 90 | return vocab,index_vocab 91 | 92 | 93 | def convert_tokens_to_ids(vocab, tokens): 94 | """Converts a sequence of tokens into ids using the vocab.""" 95 | ids = [] 96 | for token in tokens: 97 | ids.append(vocab[token]) 98 | return ids 99 | 100 | 101 | def whitespace_tokenize(text): 102 | """Runs basic whitespace cleaning and splitting on a peice of text.""" 103 | text = text.strip() 104 | if not text: 105 | return [] 106 | tokens = text.split() 107 | return tokens 108 | 109 | 110 | class FullTokenizer(object): 111 | """Runs end-to-end tokenziation.""" 112 | 113 | def __init__(self, vocab_file, do_lower_case=True): 114 | self.vocab,self.index_vocab = load_vocab(vocab_file) 115 | self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case) 116 | self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab) 117 | 118 | def tokenize(self, text): 119 | split_tokens = [] 120 | for token in self.basic_tokenizer.tokenize(text): 121 | for sub_token in self.wordpiece_tokenizer.tokenize(token): 122 | split_tokens.append(sub_token) 123 | 124 | return split_tokens 125 | 126 | def convert_tokens_to_ids(self, tokens): 127 | return convert_tokens_to_ids(self.vocab, tokens) 128 | 129 | 130 | class BasicTokenizer(object): 131 | """Runs basic tokenization (punctuation splitting, lower casing, etc.).""" 132 | 133 | def __init__(self, do_lower_case=True): 134 | """Constructs a BasicTokenizer. 135 | 136 | Args: 137 | do_lower_case: Whether to lower case the input. 138 | """ 139 | self.do_lower_case = do_lower_case 140 | 141 | def tokenize(self, text): 142 | """Tokenizes a piece of text.""" 143 | text = convert_to_unicode(text) 144 | text = self._clean_text(text) 145 | # This was added on November 1st, 2018 for the multilingual and Chinese 146 | # models. This is also applied to the English models now, but it doesn't 147 | # matter since the English models were not trained on any Chinese data 148 | # and generally don't have any Chinese data in them (there are Chinese 149 | # characters in the vocabulary because Wikipedia does have some Chinese 150 | # words in the English Wikipedia.). 151 | text = self._tokenize_chinese_chars(text) 152 | orig_tokens = whitespace_tokenize(text) 153 | split_tokens = [] 154 | for token in orig_tokens: 155 | if self.do_lower_case: 156 | token = token.lower() 157 | token = self._run_strip_accents(token) 158 | split_tokens.extend(self._run_split_on_punc(token)) 159 | 160 | output_tokens = whitespace_tokenize(" ".join(split_tokens)) 161 | return output_tokens 162 | 163 | def _run_strip_accents(self, text): 164 | """Strips accents from a piece of text.""" 165 | text = unicodedata.normalize("NFD", text) 166 | output = [] 167 | for char in text: 168 | cat = unicodedata.category(char) 169 | if cat == "Mn": 170 | continue 171 | output.append(char) 172 | return "".join(output) 173 | 174 | def _run_split_on_punc(self, text): 175 | """Splits punctuation on a piece of text.""" 176 | chars = list(text) 177 | i = 0 178 | start_new_word = True 179 | output = [] 180 | while i < len(chars): 181 | char = chars[i] 182 | if _is_punctuation(char): 183 | output.append([char]) 184 | start_new_word = True 185 | else: 186 | if start_new_word: 187 | output.append([]) 188 | start_new_word = False 189 | output[-1].append(char) 190 | i += 1 191 | 192 | return ["".join(x) for x in output] 193 | 194 | def _tokenize_chinese_chars(self, text): 195 | """Adds whitespace around any CJK character.""" 196 | output = [] 197 | for char in text: 198 | cp = ord(char) 199 | if self._is_chinese_char(cp): 200 | output.append(" ") 201 | output.append(char) 202 | output.append(" ") 203 | else: 204 | output.append(char) 205 | return "".join(output) 206 | 207 | def _is_chinese_char(self, cp): 208 | """Checks whether CP is the codepoint of a CJK character.""" 209 | # This defines a "chinese character" as anything in the CJK Unicode block: 210 | # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) 211 | # 212 | # Note that the CJK Unicode block is NOT all Japanese and Korean characters, 213 | # despite its name. The modern Korean Hangul alphabet is a different block, 214 | # as is Japanese Hiragana and Katakana. Those alphabets are used to write 215 | # space-separated words, so they are not treated specially and handled 216 | # like the all of the other languages. 217 | if ((cp >= 0x4E00 and cp <= 0x9FFF) or # 218 | (cp >= 0x3400 and cp <= 0x4DBF) or # 219 | (cp >= 0x20000 and cp <= 0x2A6DF) or # 220 | (cp >= 0x2A700 and cp <= 0x2B73F) or # 221 | (cp >= 0x2B740 and cp <= 0x2B81F) or # 222 | (cp >= 0x2B820 and cp <= 0x2CEAF) or 223 | (cp >= 0xF900 and cp <= 0xFAFF) or # 224 | (cp >= 0x2F800 and cp <= 0x2FA1F)): # 225 | return True 226 | 227 | return False 228 | 229 | def _clean_text(self, text): 230 | """Performs invalid character removal and whitespace cleanup on text.""" 231 | output = [] 232 | for char in text: 233 | cp = ord(char) 234 | if cp == 0 or cp == 0xfffd or _is_control(char): 235 | continue 236 | if _is_whitespace(char): 237 | output.append(" ") 238 | else: 239 | output.append(char) 240 | return "".join(output) 241 | 242 | 243 | class WordpieceTokenizer(object): 244 | """Runs WordPiece tokenization.""" 245 | 246 | def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=100): 247 | self.vocab = vocab 248 | self.unk_token = unk_token 249 | self.max_input_chars_per_word = max_input_chars_per_word 250 | 251 | def tokenize(self, text): 252 | """Tokenizes a piece of text into its word pieces. 253 | 254 | This uses a greedy longest-match-first algorithm to perform tokenization 255 | using the given vocabulary. 256 | 257 | For example: 258 | input = "unaffable" 259 | output = ["un", "##aff", "##able"] 260 | 261 | Args: 262 | text: A single token or whitespace separated tokens. This should have 263 | already been passed through `BasicTokenizer. 264 | 265 | Returns: 266 | A list of wordpiece tokens. 267 | """ 268 | 269 | text = convert_to_unicode(text) 270 | 271 | output_tokens = [] 272 | for token in whitespace_tokenize(text): 273 | chars = list(token) 274 | if len(chars) > self.max_input_chars_per_word: 275 | output_tokens.append(self.unk_token) 276 | continue 277 | 278 | is_bad = False 279 | start = 0 280 | sub_tokens = [] 281 | while start < len(chars): 282 | end = len(chars) 283 | cur_substr = None 284 | while start < end: 285 | substr = "".join(chars[start:end]) 286 | if start > 0: 287 | substr = "##" + substr 288 | if substr in self.vocab: 289 | cur_substr = substr 290 | break 291 | end -= 1 292 | if cur_substr is None: 293 | is_bad = True 294 | break 295 | sub_tokens.append(cur_substr) 296 | start = end 297 | 298 | if is_bad: 299 | output_tokens.append(self.unk_token) 300 | else: 301 | output_tokens.extend(sub_tokens) 302 | return output_tokens 303 | 304 | 305 | def _is_whitespace(char): 306 | """Checks whether `chars` is a whitespace character.""" 307 | # \t, \n, and \r are technically contorl characters but we treat them 308 | # as whitespace since they are generally considered as such. 309 | if char == " " or char == "\t" or char == "\n" or char == "\r": 310 | return True 311 | cat = unicodedata.category(char) 312 | if cat == "Zs": 313 | return True 314 | return False 315 | 316 | 317 | def _is_control(char): 318 | """Checks whether `chars` is a control character.""" 319 | # These are technically control characters but we count them as whitespace 320 | # characters. 321 | if char == "\t" or char == "\n" or char == "\r": 322 | return False 323 | cat = unicodedata.category(char) 324 | if cat.startswith("C"): 325 | return True 326 | return False 327 | 328 | 329 | def _is_punctuation(char): 330 | """Checks whether `chars` is a punctuation character.""" 331 | cp = ord(char) 332 | # We treat all non-letter/number ASCII as punctuation. 333 | # Characters such as "^", "$", and "`" are not in the Unicode 334 | # Punctuation class but we treat them as punctuation anyways, for 335 | # consistency. 336 | if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or 337 | (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)): 338 | return True 339 | cat = unicodedata.category(char) 340 | if cat.startswith("P"): 341 | return True 342 | return False 343 | --------------------------------------------------------------------------------