├── convert_bert_pytorch_and_tf ├── README.md ├── convert_bert_original_tf_checkpoint_to_pytorch.py └── convert_bert_pytorch_checkpoint_to_original_tf.py ├── README.md ├── bert_pytorch_source_code ├── README.md ├── optimization.py ├── tokenization_word.py ├── modeling.py └── run_classifier_word.py └── utils ├── distance_3d_line.py └── common_function.py /convert_bert_pytorch_and_tf/README.md: -------------------------------------------------------------------------------- 1 | Reference: [transformers](https://github.com/huggingface/transformers) of huggingface -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## CodeShare 2 | 3 | The repository is used to record some useful and reusable codes, which will be updated continuously. 4 | 5 | Some of the codes come from different open source repositories, and I will indicate where they come from and follow the license. -------------------------------------------------------------------------------- /bert_pytorch_source_code/README.md: -------------------------------------------------------------------------------- 1 | Reference : https://github.com/xieyufei1993/Bert-Pytorch-Chinese-TextClassification 2 | 3 | And some chinese comments are added based on the above code. 4 | 5 | # Bert-Pytorch-Chinese-TextClassification 6 | Pytorch Bert Finetune in Chinese Text Classification 7 | 8 | ### Step 1 9 | 10 | Download the pretrained TensorFlow model:[chinese_L-12_H-768_A-12](https://storage.googleapis.com/bert_models/2018_11_03/chinese_L-12_H-768_A-12.zip) 11 | 12 | ### Step 2 13 | 14 | Change the TensorFlow Pretrained Model into Pytorch 15 | 16 | ```shell 17 | cd convert_tf_to_pytorch 18 | ``` 19 | 20 | ```shell 21 | export BERT_BASE_DIR=/workspace/mnt/group/ocr/xieyufei/bert-tf-chinese/chinese_L-12_H-768_A-12 22 | 23 | python3 convert_tf_checkpoint_to_pytorch.py \ 24 | --tf_checkpoint_path $BERT_BASE_DIR/bert_model.ckpt \ 25 | --bert_config_file $BERT_BASE_DIR/bert_config.json \ 26 | --pytorch_dump_path $BERT_BASE_DIR/pytorch_model.bin 27 | ``` 28 | 29 | ### Step 3 30 | 31 | Download the Chinese News DataSet:[Train](https://pan.baidu.com/s/15rkzx-YRbP5XRNeapzYWLw) for 5w and [Dev](https://pan.baidu.com/s/1HuYTacgAQFqGAJ8FYXNqOw) for 5k 32 | 33 | ### Step 4 34 | 35 | Just Train and Test 36 | 37 | ```shell 38 | cd src 39 | ``` 40 | 41 | ```shell 42 | export GLUE_DIR=/workspace/mnt/group/ocr/xieyufei/bert-tf-chinese/glue_data 43 | export BERT_BASE_DIR=/workspace/mnt/group/ocr/xieyufei/bert-tf-chinese/chinese_L-12_H-768_A-12/ 44 | export BERT_PYTORCH_DIR=/workspace/mnt/group/ocr/xieyufei/bert-tf-chinese/chinese_L-12_H-768_A-12/ 45 | 46 | python3 run_classifier_word.py \ 47 | --task_name NEWS \ 48 | --do_train \ 49 | --do_eval \ 50 | --data_dir $GLUE_DIR/SouGou/ \ 51 | --vocab_file $BERT_BASE_DIR/vocab.txt \ 52 | --bert_config_file $BERT_BASE_DIR/bert_config.json \ 53 | --init_checkpoint $BERT_PYTORCH_DIR/pytorch_model.bin \ 54 | --max_seq_length 256 \ 55 | --train_batch_size 24 \ 56 | --learning_rate 2e-5 \ 57 | --num_train_epochs 50.0 \ 58 | --output_dir ./newsAll_output/ \ 59 | --local_rank 3 60 | ``` 61 | 62 | The result of one epoch: 63 | 64 | ``` 65 | eval_accuracy = 0.9742 66 | eval_loss = 0.10202122390270234 67 | global_step = 2084 68 | loss = 0.15899521649851786 69 | ``` 70 | 71 | 72 | 73 | -------------------------------------------------------------------------------- /utils/distance_3d_line.py: -------------------------------------------------------------------------------- 1 | # Author:hichenway 2 | # 功能:计算三维空间异面直线的距离和垂足 3 | # Corresponding Blog: https://blog.csdn.net/songyunli1111 4 | 5 | # 参考1:https://www.jianshu.com/p/34a7c4e1f3f5 6 | # 参考2:https://www.cnblogs.com/mazhenyu/p/7154449.html 7 | import math 8 | import numpy as np 9 | def distace(point1, point2): 10 | num = sum([(point1[i]-point2[i])**2 for i in range(len(point1))]) 11 | return math.sqrt(num) 12 | 13 | def distace_3d(p1, p2, q1, q2): 14 | # 叉乘方法计算三维直线距离 15 | 16 | # p1,p2为L1上的两个节点,q1,q2为L2上的两个节点 17 | v1 = np.array([p2[i] - p1[i] for i in range(len(p1))]) 18 | v2 = np.array([q2[i] - q1[i] for i in range(len(q1))]) 19 | 20 | # 叉乘公式:(a1,a2,a3) X (b1,b2,b3)=(a2b3-a3b2,a3b1-a1b3,a1b2-a2b1) 21 | chacheng = np.array([v1[1]*v2[2]-v1[2]*v2[1],v1[2]*v2[0]-v1[0]*v2[2],v1[0]*v2[1]-v1[1]*v2[0]]) 22 | temp = np.array([p1[i] - q1[i] for i in range(len(p1))]) 23 | dis_res = abs(sum(chacheng*temp)) / math.sqrt(sum([chacheng[i]**2 for i in range(len(chacheng))])) 24 | return dis_res 25 | 26 | def cross(p1, p2, q1, q2): 27 | v1 = np.array([p2[i] - p1[i] for i in range(len(p1))]) 28 | v2 = np.array([q2[i] - q1[i] for i in range(len(q1))]) 29 | # l1 = p1 + t1 * v1 30 | # l2 = q1 + t2 * v2 31 | a = sum(v1*v2) 32 | b = sum(v1*v1) 33 | c = sum(v2*v2) 34 | d = sum(np.array([q1[i] - p1[i] for i in range(len(p1))])*v1) 35 | e = sum(np.array([q1[i] - p1[i] for i in range(len(p1))])*v2) 36 | isParallel = False 37 | if a==0: # 对应两直线垂直 38 | t1 = d/b 39 | t2 = -e/c 40 | elif abs(a*a - b*c) > 0.001: # 普通情况,这里因为浮点数的原因不要用等于0 41 | t1 = (a * e - c * d) / (a * a - b * c) 42 | t2 = b * t1 / a - d / a 43 | else: # 两直线平行,垂足有无数对,通过在任一一条直线上随便指定一个点,另一条直线上的垂足也就随之确定 44 | isParallel = True 45 | t1 = 0 46 | t2 = - d / a 47 | 48 | point1 = [p1[i] + t1 * v1[i] for i in range(len(p1)) ] 49 | point2 = [q1[i] + t2 * v2[i] for i in range(len(q1)) ] 50 | dis = distace(point1, point2) 51 | 52 | return point1, point2, dis, isParallel 53 | 54 | if __name__ == "__main__": 55 | p1 = (2,0,0) 56 | p2 = (0,2,0) 57 | q1 = (2,0,2) 58 | q2 = (0,2,2) 59 | res = distace_3d(p1, p2, q1, q2) 60 | point1, point2, dis, isParallel = cross(p1, p2, q1, q2) 61 | print(dis, res) -------------------------------------------------------------------------------- /convert_bert_pytorch_and_tf/convert_bert_original_tf_checkpoint_to_pytorch.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Convert BERT checkpoint.""" 16 | 17 | 18 | import argparse 19 | import logging 20 | 21 | import torch 22 | 23 | from transformers import BertConfig, BertForPreTraining, load_tf_weights_in_bert 24 | 25 | 26 | logging.basicConfig(level=logging.INFO) 27 | 28 | 29 | def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytorch_dump_path): 30 | # Initialise PyTorch model 31 | config = BertConfig.from_json_file(bert_config_file) 32 | print("Building PyTorch model from configuration: {}".format(str(config))) 33 | model = BertForPreTraining(config) 34 | 35 | # Load weights from tf checkpoint 36 | load_tf_weights_in_bert(model, config, tf_checkpoint_path) 37 | 38 | # Save pytorch-model 39 | print("Save PyTorch model to {}".format(pytorch_dump_path)) 40 | torch.save(model.state_dict(), pytorch_dump_path) 41 | 42 | 43 | if __name__ == "__main__": 44 | parser = argparse.ArgumentParser() 45 | # Required parameters 46 | parser.add_argument( 47 | "--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path." 48 | ) 49 | parser.add_argument( 50 | "--bert_config_file", 51 | default=None, 52 | type=str, 53 | required=True, 54 | help="The config json file corresponding to the pre-trained BERT model. \n" 55 | "This specifies the model architecture.", 56 | ) 57 | parser.add_argument( 58 | "--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model." 59 | ) 60 | args = parser.parse_args() 61 | convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, args.bert_config_file, args.pytorch_dump_path) -------------------------------------------------------------------------------- /convert_bert_pytorch_and_tf/convert_bert_pytorch_checkpoint_to_original_tf.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Convert Huggingface Pytorch checkpoint to Tensorflow checkpoint.""" 17 | 18 | import argparse 19 | import os 20 | 21 | import numpy as np 22 | import tensorflow as tf 23 | import torch 24 | 25 | from transformers import BertModel 26 | 27 | 28 | def convert_pytorch_checkpoint_to_tf(model: BertModel, ckpt_dir: str, model_name: str): 29 | 30 | """ 31 | :param model:BertModel Pytorch model instance to be converted 32 | :param ckpt_dir: Tensorflow model directory 33 | :param model_name: model name 34 | :return: 35 | Currently supported HF models: 36 | Y BertModel 37 | N BertForMaskedLM 38 | N BertForPreTraining 39 | N BertForMultipleChoice 40 | N BertForNextSentencePrediction 41 | N BertForSequenceClassification 42 | N BertForQuestionAnswering 43 | """ 44 | 45 | tensors_to_transpose = ("dense.weight", "attention.self.query", "attention.self.key", "attention.self.value") 46 | 47 | var_map = ( 48 | ("layer.", "layer_"), 49 | ("word_embeddings.weight", "word_embeddings"), 50 | ("position_embeddings.weight", "position_embeddings"), 51 | ("token_type_embeddings.weight", "token_type_embeddings"), 52 | (".", "/"), 53 | ("LayerNorm/weight", "LayerNorm/gamma"), 54 | ("LayerNorm/bias", "LayerNorm/beta"), 55 | ("weight", "kernel"), 56 | ) 57 | 58 | if not os.path.isdir(ckpt_dir): 59 | os.makedirs(ckpt_dir) 60 | 61 | state_dict = model.state_dict() 62 | 63 | def to_tf_var_name(name: str): 64 | for patt, repl in iter(var_map): 65 | name = name.replace(patt, repl) 66 | return "bert/{}".format(name) 67 | 68 | def create_tf_var(tensor: np.ndarray, name: str, session: tf.Session): 69 | tf_dtype = tf.dtypes.as_dtype(tensor.dtype) 70 | tf_var = tf.get_variable(dtype=tf_dtype, shape=tensor.shape, name=name, initializer=tf.zeros_initializer()) 71 | session.run(tf.variables_initializer([tf_var])) 72 | session.run(tf_var) 73 | return tf_var 74 | 75 | tf.reset_default_graph() 76 | with tf.Session() as session: 77 | for var_name in state_dict: 78 | tf_name = to_tf_var_name(var_name) 79 | torch_tensor = state_dict[var_name].numpy() 80 | if any([x in var_name for x in tensors_to_transpose]): 81 | torch_tensor = torch_tensor.T 82 | tf_var = create_tf_var(tensor=torch_tensor, name=tf_name, session=session) 83 | tf.keras.backend.set_value(tf_var, torch_tensor) 84 | tf_weight = session.run(tf_var) 85 | print("Successfully created {}: {}".format(tf_name, np.allclose(tf_weight, torch_tensor))) 86 | 87 | saver = tf.train.Saver(tf.trainable_variables()) 88 | saver.save(session, os.path.join(ckpt_dir, model_name.replace("-", "_") + ".ckpt")) 89 | 90 | 91 | def main(raw_args=None): 92 | parser = argparse.ArgumentParser() 93 | parser.add_argument("--model_name", type=str, required=True, help="model name e.g. bert-base-uncased") 94 | parser.add_argument( 95 | "--cache_dir", type=str, default=None, required=False, help="Directory containing pytorch model" 96 | ) 97 | parser.add_argument("--pytorch_model_path", type=str, required=True, help="/path/to/.bin") 98 | parser.add_argument("--tf_cache_dir", type=str, required=True, help="Directory in which to save tensorflow model") 99 | args = parser.parse_args(raw_args) 100 | 101 | model = BertModel.from_pretrained( 102 | pretrained_model_name_or_path=args.model_name, 103 | state_dict=torch.load(args.pytorch_model_path), 104 | cache_dir=args.cache_dir, 105 | ) 106 | 107 | convert_pytorch_checkpoint_to_tf(model=model, ckpt_dir=args.tf_cache_dir, model_name=args.model_name) 108 | 109 | 110 | if __name__ == "__main__": 111 | main() -------------------------------------------------------------------------------- /utils/common_function.py: -------------------------------------------------------------------------------- 1 | # Author:hichenway 2 | # 该脚本包括了常用的文件或数据处理的功能函数,方便查找和应用,减少重复造轮子 3 | 4 | import os 5 | import time 6 | 7 | 8 | # 时间计算装饰器 9 | def deco_time(is_deco = True): 10 | # 可通过 is_deco 参数去设置是否使用该装饰器 11 | if is_deco: 12 | def _deco_time(func): 13 | def time_spent(*args, **kwargs): 14 | start_time = time.time() 15 | result = func(*args, **kwargs) 16 | end_time = time.time() 17 | spent_time = (end_time - start_time) * 1000 18 | print("The spend time is: %f ms"%spent_time) # 这里也可以通过日志去记录 19 | return result 20 | return time_spent 21 | else: 22 | def _deco_time(func): 23 | return func 24 | return _deco_time 25 | 26 | 27 | # 装饰器测试用例 28 | class Test: 29 | @deco_time(True) 30 | def test1(self, n): 31 | sum_num = 0 32 | for i in range(n): 33 | sum_num += 1 34 | return sum_num 35 | 36 | 37 | def set_seed(args): 38 | """ 39 | 随机数种子设置,保证实验可重复性 40 | :param args: 41 | :return: 42 | """ 43 | import random 44 | import numpy as np 45 | import torch 46 | random.seed(args.seed) 47 | np.random.seed(args.seed) 48 | torch.manual_seed(args.seed) 49 | if args.n_gpu > 0: 50 | torch.cuda.manual_seed_all(args.seed) 51 | 52 | 53 | import time 54 | # 另一个计时装饰器 55 | def timing(f): 56 | """Decorator for timing functions 57 | Usage: 58 | @timing 59 | def function(a): 60 | pass 61 | """ 62 | def wrapper(*args, **kwargs): 63 | start = time.time() 64 | result = f(*args, **kwargs) 65 | end = time.time() 66 | print('function:%r took: %2.2f sec' % (f.__name__, end - start)) 67 | return result 68 | return wrapper 69 | 70 | 71 | # 把一个文件夹下的所有文件移到另一个文件夹下 72 | import shutil 73 | def move_file(file_path, to_path): 74 | ''' 75 | :param file_path: 要移动的文件路径 76 | :param to_path: 要移动到的文件路径 77 | :return: None 78 | ''' 79 | file_list = os.listdir(file_path) 80 | for file_name in file_list: 81 | file = file_path + "/" + file_name 82 | shutil.move(file, to_path) 83 | 84 | 85 | # 在保存文件时加上当前时间 86 | def file_add_time(file_name): 87 | cur_time = time.strftime("%Y-%m-%d_%H_%M", time.localtime()) 88 | name, form = file_name.split('.') 89 | return '.'.join([name+cur_time,form]) 90 | 91 | 92 | def plot_text_length(data): 93 | """ 94 | 输出文本列表的长度直方图,方便分析和取padding 的 max_length 95 | :param data: 文本数据列表,如:["我想吃饭","我想吃东西"...] 96 | :return: 97 | """ 98 | import matplotlib.pyplot as plt 99 | from collections import Counter 100 | data_len = [len(sentence) for sentence in data] 101 | len_dict = Counter(data_len) 102 | X = list(len_dict.keys()) 103 | Y = list(len_dict.values()) 104 | plt.bar(X, Y) 105 | plt.show() 106 | 107 | 108 | # 删除句末的符号 109 | # string 还有很多格式形式,参见: https://docs.python.org/2/library/string.html 110 | # string.printable 代表 '0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ!"#$%&\'()*+,-./:;<=>?@[\\]^_`{|}~ \t\n\r\x0b\x0c' 111 | # string.printable 没有中文符号 112 | import string 113 | symbol = ",。、!(),./!~·`" # 要去除的句尾标点,不包括问号,可加上 114 | def drop_symbol(seq): 115 | seq = seq.strip() 116 | seq = seq.rstrip(symbol) 117 | seq = seq.rstrip(string.printable) 118 | return seq 119 | 120 | 121 | # 删除两个特定字符之间的内容,比如微博话题:#新冠疫情#, 一般可以用于文本的清洗 122 | # 这两个特定的字符不限于单个字符,像这样的也可以:--> 123 | # 不同模式匹配下的示例:"#除夕夜#万家灯火通明鞭炮齐鸣也是极好的#春节放鞭炮#",start_char和 end_char都是 '#' 124 | # 最大模式是贪婪匹配,会匹配整句,最小模式下仅匹配:#除夕夜# 125 | def delete_special_two_chars_inner(start_char, end_char, content): 126 | import re 127 | # 最大模式匹配: 128 | pattern = re.compile(r'({})(.*)({})'.format(start_char, end_char)) 129 | 130 | # 最小模式匹配: 131 | # pattern = re.compile(r'({})(.*?)({})'.format(start_char, end_char)) 132 | 133 | return pattern.sub(r'', content, count=1) #这里的count还可以设置替换次数 134 | 135 | # 如果想保留start_char和end_char,则用: 136 | # return pattern.sub(r'{}{}'.format(start_char, end_char),content) 137 | 138 | 139 | # 文本清理函数 140 | def clean_function(seq): 141 | import re 142 | # 去除超链接 143 | pattern_url = re.compile(r'http://[a-zA-Z0-9.?/&=:]*',re.S) 144 | seq = pattern_url.sub(r'',seq) 145 | 146 | # 去除空格 147 | seq = seq.replace(" ", "") 148 | return seq 149 | 150 | 151 | if __name__ == "__main__": 152 | file_name = "file.csv" 153 | new_fime_name = file_add_time(file_name) 154 | 155 | a=Test() 156 | sum_num = a.test1(10000) 157 | print(sum_num) 158 | 159 | test = "【史上最强悍促销】闹得沸沸扬扬" 160 | print(delete_special_two_chars_inner('【', '】',test)) -------------------------------------------------------------------------------- /bert_pytorch_source_code/optimization.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """PyTorch optimization for BERT model.""" 16 | 17 | import math 18 | import torch 19 | from torch.optim import Optimizer 20 | from torch.nn.utils import clip_grad_norm_ 21 | 22 | def warmup_cosine(x, warmup=0.002): 23 | if x < warmup: 24 | return x/warmup 25 | return 0.5 * (1.0 + torch.cos(math.pi * x)) 26 | 27 | def warmup_constant(x, warmup=0.002): 28 | if x < warmup: 29 | return x/warmup 30 | return 1.0 31 | 32 | def warmup_linear(x, warmup=0.002): 33 | if x < warmup: 34 | return x/warmup 35 | return 1.0 - x 36 | 37 | SCHEDULES = { 38 | 'warmup_cosine':warmup_cosine, 39 | 'warmup_constant':warmup_constant, 40 | 'warmup_linear':warmup_linear, 41 | } 42 | 43 | 44 | class BERTAdam(Optimizer): 45 | """Implements BERT version of Adam algorithm with weight decay fix (and no ). 46 | Params: 47 | lr: learning rate 48 | warmup: portion of t_total for the warmup, -1 means no warmup. Default: -1 49 | t_total: total number of training steps for the learning 50 | rate schedule, -1 means constant learning rate. Default: -1 51 | schedule: schedule to use for the warmup (see above). Default: 'warmup_linear' 52 | b1: Adams b1. Default: 0.9 53 | b2: Adams b2. Default: 0.999 54 | e: Adams epsilon. Default: 1e-6 55 | weight_decay_rate: Weight decay. Default: 0.01 56 | max_grad_norm: Maximum norm for the gradients (-1 means no clipping). Default: 1.0 57 | """ 58 | def __init__(self, params, lr, warmup=-1, t_total=-1, schedule='warmup_linear', 59 | b1=0.9, b2=0.999, e=1e-6, weight_decay_rate=0.01, 60 | max_grad_norm=1.0): 61 | if not lr >= 0.0: 62 | raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr)) 63 | if schedule not in SCHEDULES: 64 | raise ValueError("Invalid schedule parameter: {}".format(schedule)) 65 | if not 0.0 <= warmup < 1.0 and not warmup == -1: 66 | raise ValueError("Invalid warmup: {} - should be in [0.0, 1.0[ or -1".format(warmup)) 67 | if not 0.0 <= b1 < 1.0: 68 | raise ValueError("Invalid b1 parameter: {} - should be in [0.0, 1.0[".format(b1)) 69 | if not 0.0 <= b2 < 1.0: 70 | raise ValueError("Invalid b2 parameter: {} - should be in [0.0, 1.0[".format(b2)) 71 | if not e >= 0.0: 72 | raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(e)) 73 | defaults = dict(lr=lr, schedule=schedule, warmup=warmup, t_total=t_total, 74 | b1=b1, b2=b2, e=e, weight_decay_rate=weight_decay_rate, 75 | max_grad_norm=max_grad_norm) 76 | super(BERTAdam, self).__init__(params, defaults) 77 | 78 | def get_lr(self): 79 | lr = [] 80 | for group in self.param_groups: 81 | for p in group['params']: 82 | state = self.state[p] 83 | if len(state) == 0: 84 | return [0] 85 | if group['t_total'] != -1: 86 | schedule_fct = SCHEDULES[group['schedule']] 87 | lr_scheduled = group['lr'] * schedule_fct(state['step']/group['t_total'], group['warmup']) 88 | else: 89 | lr_scheduled = group['lr'] 90 | lr.append(lr_scheduled) 91 | return lr 92 | 93 | def to(self, device): 94 | """ Move the optimizer state to a specified device""" 95 | for state in self.state.values(): 96 | state['exp_avg'].to(device) 97 | state['exp_avg_sq'].to(device) 98 | 99 | def initialize_step(self, initial_step): 100 | """Initialize state with a defined step (but we don't have stored averaged). 101 | Arguments: 102 | initial_step (int): Initial step number. 103 | """ 104 | for group in self.param_groups: 105 | for p in group['params']: 106 | state = self.state[p] 107 | # State initialization 108 | state['step'] = initial_step 109 | # Exponential moving average of gradient values 110 | state['exp_avg'] = torch.zeros_like(p.data) 111 | # Exponential moving average of squared gradient values 112 | state['exp_avg_sq'] = torch.zeros_like(p.data) 113 | 114 | def step(self, closure=None): 115 | """Performs a single optimization step. 116 | Arguments: 117 | closure (callable, optional): A closure that reevaluates the model 118 | and returns the loss. 119 | """ 120 | loss = None 121 | if closure is not None: 122 | loss = closure() 123 | 124 | for group in self.param_groups: 125 | for p in group['params']: 126 | if p.grad is None: 127 | continue 128 | grad = p.grad.data 129 | if grad.is_sparse: 130 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') 131 | 132 | state = self.state[p] 133 | 134 | # State initialization 135 | if len(state) == 0: 136 | state['step'] = 0 137 | # Exponential moving average of gradient values 138 | state['next_m'] = torch.zeros_like(p.data) 139 | # Exponential moving average of squared gradient values 140 | state['next_v'] = torch.zeros_like(p.data) 141 | 142 | next_m, next_v = state['next_m'], state['next_v'] 143 | beta1, beta2 = group['b1'], group['b2'] 144 | 145 | # Add grad clipping 146 | if group['max_grad_norm'] > 0: 147 | clip_grad_norm_(p, group['max_grad_norm']) 148 | 149 | # Decay the first and second moment running average coefficient 150 | # In-place operations to update the averages at the same time 151 | next_m.mul_(beta1).add_(1 - beta1, grad) 152 | next_v.mul_(beta2).addcmul_(1 - beta2, grad, grad) 153 | update = next_m / (next_v.sqrt() + group['e']) 154 | 155 | # Just adding the square of the weights to the loss function is *not* 156 | # the correct way of using L2 regularization/weight decay with Adam, 157 | # since that will interact with the m and v parameters in strange ways. 158 | # 159 | # Instead we want ot decay the weights in a manner that doesn't interact 160 | # with the m/v parameters. This is equivalent to adding the square 161 | # of the weights to the loss with plain (non-momentum) SGD. 162 | if group['weight_decay_rate'] > 0.0: 163 | update += group['weight_decay_rate'] * p.data 164 | 165 | if group['t_total'] != -1: 166 | schedule_fct = SCHEDULES[group['schedule']] 167 | lr_scheduled = group['lr'] * schedule_fct(state['step']/group['t_total'], group['warmup']) 168 | else: 169 | lr_scheduled = group['lr'] 170 | 171 | update_with_lr = lr_scheduled * update 172 | p.data.add_(-update_with_lr) 173 | 174 | state['step'] += 1 175 | 176 | # step_size = lr_scheduled * math.sqrt(bias_correction2) / bias_correction1 177 | # bias_correction1 = 1 - beta1 ** state['step'] 178 | # bias_correction2 = 1 - beta2 ** state['step'] 179 | 180 | return loss 181 | -------------------------------------------------------------------------------- /bert_pytorch_source_code/tokenization_word.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Tokenization classes.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import collections 22 | import unicodedata 23 | import six 24 | 25 | 26 | def convert_to_unicode(text): 27 | """Converts `text` to Unicode (if it's not already), assuming utf-8 input.""" 28 | if six.PY3: 29 | if isinstance(text, str): 30 | return text 31 | elif isinstance(text, bytes): 32 | return text.decode("utf-8", "ignore") 33 | else: 34 | raise ValueError("Unsupported string type: %s" % (type(text))) 35 | elif six.PY2: 36 | if isinstance(text, str): 37 | return text.decode("utf-8", "ignore") 38 | elif isinstance(text, unicode): 39 | return text 40 | else: 41 | raise ValueError("Unsupported string type: %s" % (type(text))) 42 | else: 43 | raise ValueError("Not running on Python2 or Python 3?") 44 | 45 | 46 | def printable_text(text): 47 | """Returns text encoded in a way suitable for print or `tf.logging`.""" 48 | 49 | # These functions want `str` for both Python2 and Python3, but in one case 50 | # it's a Unicode string and in the other it's a byte string. 51 | if six.PY3: 52 | if isinstance(text, str): 53 | return text 54 | elif isinstance(text, bytes): 55 | return text.decode("utf-8", "ignore") 56 | else: 57 | raise ValueError("Unsupported string type: %s" % (type(text))) 58 | elif six.PY2: 59 | if isinstance(text, str): 60 | return text 61 | elif isinstance(text, unicode): 62 | return text.encode("utf-8") 63 | else: 64 | raise ValueError("Unsupported string type: %s" % (type(text))) 65 | else: 66 | raise ValueError("Not running on Python2 or Python 3?") 67 | 68 | 69 | def load_vocab(vocab_file): 70 | """Loads a vocabulary file into a dictionary.""" 71 | vocab = collections.OrderedDict() 72 | 73 | index_vocab = collections.OrderedDict() 74 | index = 0 75 | with open(vocab_file, "rb") as reader: 76 | while True: 77 | tmp = reader.readline() 78 | token = convert_to_unicode(tmp) 79 | 80 | if not token: 81 | break 82 | 83 | # file_out.write("%d\t%s\n" %(index,token)) 84 | token = token.strip() 85 | vocab[token] = index 86 | index_vocab[index] = token 87 | index += 1 88 | 89 | return vocab, index_vocab 90 | 91 | 92 | def convert_tokens_to_ids(vocab, tokens): 93 | """Converts a sequence of tokens into ids using the vocab.""" 94 | ids = [] 95 | for token in tokens: 96 | ids.append(vocab[token]) 97 | return ids 98 | 99 | 100 | def whitespace_tokenize(text): 101 | """Runs basic whitespace cleaning and splitting on a peice of text.""" 102 | text = text.strip() 103 | if not text: 104 | return [] 105 | tokens = text.split() 106 | return tokens 107 | 108 | 109 | class FullTokenizer(object): 110 | """Runs end-to-end tokenziation.""" 111 | 112 | def __init__(self, vocab_file, do_lower_case=True): 113 | self.vocab, self.index_vocab = load_vocab(vocab_file) 114 | self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case) 115 | self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab) 116 | 117 | def tokenize(self, text): 118 | split_tokens = [] 119 | for token in self.basic_tokenizer.tokenize(text): 120 | for sub_token in self.wordpiece_tokenizer.tokenize(token): 121 | split_tokens.append(sub_token) 122 | 123 | return split_tokens 124 | 125 | def convert_tokens_to_ids(self, tokens): 126 | return convert_tokens_to_ids(self.vocab, tokens) 127 | 128 | 129 | class BasicTokenizer(object): 130 | """Runs basic tokenization (punctuation splitting, lower casing, etc.).""" 131 | 132 | def __init__(self, do_lower_case=True): 133 | """Constructs a BasicTokenizer. 134 | 135 | Args: 136 | do_lower_case: Whether to lower case the input. 137 | """ 138 | self.do_lower_case = do_lower_case 139 | 140 | def tokenize(self, text): 141 | """Tokenizes a piece of text.""" 142 | text = convert_to_unicode(text) 143 | text = self._clean_text(text) 144 | # This was added on November 1st, 2018 for the multilingual and Chinese 145 | # models. This is also applied to the English models now, but it doesn't 146 | # matter since the English models were not trained on any Chinese data 147 | # and generally don't have any Chinese data in them (there are Chinese 148 | # characters in the vocabulary because Wikipedia does have some Chinese 149 | # words in the English Wikipedia.). 150 | text = self._tokenize_chinese_chars(text) 151 | orig_tokens = whitespace_tokenize(text) 152 | split_tokens = [] 153 | for token in orig_tokens: 154 | if self.do_lower_case: 155 | token = token.lower() 156 | token = self._run_strip_accents(token) 157 | split_tokens.extend(self._run_split_on_punc(token)) 158 | 159 | output_tokens = whitespace_tokenize(" ".join(split_tokens)) 160 | return output_tokens 161 | 162 | def _run_strip_accents(self, text): 163 | """Strips accents from a piece of text.""" 164 | text = unicodedata.normalize("NFD", text) 165 | output = [] 166 | for char in text: 167 | cat = unicodedata.category(char) 168 | if cat == "Mn": 169 | continue 170 | output.append(char) 171 | return "".join(output) 172 | 173 | def _run_split_on_punc(self, text): 174 | """Splits punctuation on a piece of text.""" 175 | chars = list(text) 176 | i = 0 177 | start_new_word = True 178 | output = [] 179 | while i < len(chars): 180 | char = chars[i] 181 | if _is_punctuation(char): 182 | output.append([char]) 183 | start_new_word = True 184 | else: 185 | if start_new_word: 186 | output.append([]) 187 | start_new_word = False 188 | output[-1].append(char) 189 | i += 1 190 | 191 | return ["".join(x) for x in output] 192 | 193 | def _tokenize_chinese_chars(self, text): 194 | """Adds whitespace around any CJK character.""" 195 | output = [] 196 | for char in text: 197 | cp = ord(char) 198 | if self._is_chinese_char(cp): 199 | output.append(" ") 200 | output.append(char) 201 | output.append(" ") 202 | else: 203 | output.append(char) 204 | return "".join(output) 205 | 206 | def _is_chinese_char(self, cp): 207 | """Checks whether CP is the codepoint of a CJK character.""" 208 | # This defines a "chinese character" as anything in the CJK Unicode block: 209 | # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) 210 | # 211 | # Note that the CJK Unicode block is NOT all Japanese and Korean characters, 212 | # despite its name. The modern Korean Hangul alphabet is a different block, 213 | # as is Japanese Hiragana and Katakana. Those alphabets are used to write 214 | # space-separated words, so they are not treated specially and handled 215 | # like the all of the other languages. 216 | if ((cp >= 0x4E00 and cp <= 0x9FFF) or # 217 | (cp >= 0x3400 and cp <= 0x4DBF) or # 218 | (cp >= 0x20000 and cp <= 0x2A6DF) or # 219 | (cp >= 0x2A700 and cp <= 0x2B73F) or # 220 | (cp >= 0x2B740 and cp <= 0x2B81F) or # 221 | (cp >= 0x2B820 and cp <= 0x2CEAF) or 222 | (cp >= 0xF900 and cp <= 0xFAFF) or # 223 | (cp >= 0x2F800 and cp <= 0x2FA1F)): # 224 | return True 225 | 226 | return False 227 | 228 | def _clean_text(self, text): 229 | """Performs invalid character removal and whitespace cleanup on text.""" 230 | output = [] 231 | for char in text: 232 | cp = ord(char) 233 | if cp == 0 or cp == 0xfffd or _is_control(char): 234 | continue 235 | if _is_whitespace(char): 236 | output.append(" ") 237 | else: 238 | output.append(char) 239 | return "".join(output) 240 | 241 | 242 | class WordpieceTokenizer(object): 243 | """Runs WordPiece tokenization.""" 244 | 245 | def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=100): 246 | self.vocab = vocab 247 | self.unk_token = unk_token 248 | self.max_input_chars_per_word = max_input_chars_per_word 249 | 250 | def tokenize(self, text): 251 | """Tokenizes a piece of text into its word pieces. 252 | 253 | This uses a greedy longest-match-first algorithm to perform tokenization 254 | using the given vocabulary. 255 | 256 | For example: 257 | input = "unaffable" 258 | output = ["un", "##aff", "##able"] 259 | 260 | Args: 261 | text: A single token or whitespace separated tokens. This should have 262 | already been passed through `BasicTokenizer. 263 | 264 | Returns: 265 | A list of wordpiece tokens. 266 | """ 267 | 268 | text = convert_to_unicode(text) 269 | 270 | output_tokens = [] 271 | for token in whitespace_tokenize(text): 272 | chars = list(token) 273 | if len(chars) > self.max_input_chars_per_word: 274 | output_tokens.append(self.unk_token) 275 | continue 276 | 277 | is_bad = False 278 | start = 0 279 | sub_tokens = [] 280 | while start < len(chars): 281 | end = len(chars) 282 | cur_substr = None 283 | while start < end: 284 | substr = "".join(chars[start:end]) 285 | if start > 0: 286 | substr = "##" + substr 287 | if substr in self.vocab: 288 | cur_substr = substr 289 | break 290 | end -= 1 291 | if cur_substr is None: 292 | is_bad = True 293 | break 294 | sub_tokens.append(cur_substr) 295 | start = end 296 | 297 | if is_bad: 298 | output_tokens.append(self.unk_token) 299 | else: 300 | output_tokens.extend(sub_tokens) 301 | return output_tokens 302 | 303 | 304 | def _is_whitespace(char): 305 | """Checks whether `chars` is a whitespace character.""" 306 | # \t, \n, and \r are technically contorl characters but we treat them 307 | # as whitespace since they are generally considered as such. 308 | if char == " " or char == "\t" or char == "\n" or char == "\r": 309 | return True 310 | cat = unicodedata.category(char) 311 | if cat == "Zs": 312 | return True 313 | return False 314 | 315 | 316 | def _is_control(char): 317 | """Checks whether `chars` is a control character.""" 318 | # These are technically control characters but we count them as whitespace 319 | # characters. 320 | if char == "\t" or char == "\n" or char == "\r": 321 | return False 322 | cat = unicodedata.category(char) 323 | if cat.startswith("C"): 324 | return True 325 | return False 326 | 327 | 328 | def _is_punctuation(char): 329 | """Checks whether `chars` is a punctuation character.""" 330 | cp = ord(char) 331 | # We treat all non-letter/number ASCII as punctuation. 332 | # Characters such as "^", "$", and "`" are not in the Unicode 333 | # Punctuation class but we treat them as punctuation anyways, for 334 | # consistency. 335 | if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or 336 | (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)): 337 | return True 338 | cat = unicodedata.category(char) 339 | if cat.startswith("P"): 340 | return True 341 | return False 342 | -------------------------------------------------------------------------------- /bert_pytorch_source_code/modeling.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """PyTorch BERT model.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import copy 22 | import json 23 | import math 24 | import six 25 | import torch 26 | import torch.nn as nn 27 | from torch.nn import CrossEntropyLoss 28 | 29 | def gelu(x): 30 | """Implementation of the gelu activation function. 31 | For information: OpenAI GPT's gelu is slightly different (and gives slightly different results): 32 | 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) 33 | """ 34 | return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) 35 | 36 | 37 | class BertConfig(object): 38 | """Configuration class to store the configuration of a `BertModel`. 39 | """ 40 | def __init__(self, 41 | vocab_size, 42 | hidden_size=768, 43 | num_hidden_layers=12, 44 | num_attention_heads=12, 45 | intermediate_size=3072, 46 | hidden_act="gelu", 47 | hidden_dropout_prob=0.1, 48 | attention_probs_dropout_prob=0.1, 49 | max_position_embeddings=512, 50 | type_vocab_size=16, 51 | initializer_range=0.02): 52 | """Constructs BertConfig. 53 | 54 | Args: 55 | vocab_size: Vocabulary size of `inputs_ids` in `BertModel`. 56 | hidden_size: Size of the encoder layers and the pooler layer. 57 | num_hidden_layers: Number of hidden layers in the Transformer encoder. 58 | num_attention_heads: Number of attention heads for each attention layer in 59 | the Transformer encoder. 60 | intermediate_size: The size of the "intermediate" (i.e., feed-forward) 61 | layer in the Transformer encoder. 62 | hidden_act: The non-linear activation function (function or string) in the 63 | encoder and pooler. 64 | hidden_dropout_prob: The dropout probabilitiy for all fully connected 65 | layers in the embeddings, encoder, and pooler. 66 | attention_probs_dropout_prob: The dropout ratio for the attention 67 | probabilities. 68 | max_position_embeddings: The maximum sequence length that this model might 69 | ever be used with. Typically set this to something large just in case 70 | (e.g., 512 or 1024 or 2048). 71 | type_vocab_size: The vocabulary size of the `token_type_ids` passed into 72 | `BertModel`. 73 | initializer_range: The sttdev of the truncated_normal_initializer for 74 | initializing all weight matrices. 75 | """ 76 | self.vocab_size = vocab_size 77 | self.hidden_size = hidden_size 78 | self.num_hidden_layers = num_hidden_layers 79 | self.num_attention_heads = num_attention_heads 80 | self.hidden_act = hidden_act 81 | self.intermediate_size = intermediate_size 82 | self.hidden_dropout_prob = hidden_dropout_prob 83 | self.attention_probs_dropout_prob = attention_probs_dropout_prob 84 | self.max_position_embeddings = max_position_embeddings 85 | self.type_vocab_size = type_vocab_size 86 | self.initializer_range = initializer_range 87 | 88 | @classmethod 89 | def from_dict(cls, json_object): 90 | """Constructs a `BertConfig` from a Python dictionary of parameters.""" 91 | config = BertConfig(vocab_size=None) 92 | for (key, value) in six.iteritems(json_object): 93 | config.__dict__[key] = value 94 | return config 95 | 96 | @classmethod 97 | def from_json_file(cls, json_file): 98 | """Constructs a `BertConfig` from a json file of parameters.""" 99 | with open(json_file, "r") as reader: 100 | text = reader.read() 101 | return cls.from_dict(json.loads(text)) 102 | 103 | def to_dict(self): 104 | """Serializes this instance to a Python dictionary.""" 105 | output = copy.deepcopy(self.__dict__) 106 | return output 107 | 108 | def to_json_string(self): 109 | """Serializes this instance to a JSON string.""" 110 | return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n" 111 | 112 | 113 | class BERTLayerNorm(nn.Module): 114 | def __init__(self, config, variance_epsilon=1e-12): 115 | """Construct a layernorm module in the TF style (epsilon inside the square root). 116 | """ 117 | super(BERTLayerNorm, self).__init__() 118 | self.gamma = nn.Parameter(torch.ones(config.hidden_size)) 119 | self.beta = nn.Parameter(torch.zeros(config.hidden_size)) 120 | self.variance_epsilon = variance_epsilon # 一个很小的常数,防止除0 121 | 122 | def forward(self, x): 123 | u = x.mean(-1, keepdim=True) # LN是对最后一个维度做Norm 124 | s = (x - u).pow(2).mean(-1, keepdim=True) 125 | x = (x - u) / torch.sqrt(s + self.variance_epsilon) 126 | return self.gamma * x + self.beta 127 | 128 | class BERTEmbeddings(nn.Module): 129 | def __init__(self, config): 130 | super(BERTEmbeddings, self).__init__() 131 | """Construct the embedding module from word, position and token_type embeddings. 132 | 三种embedding都是可学习的,输入为单句是,只需要给input_ids,双句时才要给token_type_ids,position_ids都不用给 133 | """ 134 | self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size) 135 | self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) 136 | self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) 137 | 138 | # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load 139 | # any TensorFlow checkpoint file 140 | self.LayerNorm = BERTLayerNorm(config) 141 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 142 | 143 | def forward(self, input_ids, token_type_ids=None): 144 | seq_length = input_ids.size(1) 145 | position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) 146 | position_ids = position_ids.unsqueeze(0).expand_as(input_ids) 147 | if token_type_ids is None: 148 | token_type_ids = torch.zeros_like(input_ids) 149 | 150 | words_embeddings = self.word_embeddings(input_ids) 151 | position_embeddings = self.position_embeddings(position_ids) 152 | token_type_embeddings = self.token_type_embeddings(token_type_ids) 153 | 154 | embeddings = words_embeddings + position_embeddings + token_type_embeddings 155 | embeddings = self.LayerNorm(embeddings) 156 | embeddings = self.dropout(embeddings) 157 | return embeddings 158 | 159 | 160 | class BERTSelfAttention(nn.Module): 161 | """self attention 是Bert的精髓,但它维度的变化确实比较复杂""" 162 | def __init__(self, config): 163 | super(BERTSelfAttention, self).__init__() 164 | if config.hidden_size % config.num_attention_heads != 0: 165 | raise ValueError( 166 | "The hidden size (%d) is not a multiple of the number of attention " 167 | "heads (%d)" % (config.hidden_size, config.num_attention_heads)) 168 | self.num_attention_heads = config.num_attention_heads 169 | self.attention_head_size = int(config.hidden_size / config.num_attention_heads) 170 | self.all_head_size = self.num_attention_heads * self.attention_head_size 171 | # 注意:这里的 all_head_size 就等于config.hidden_size,应该是一种简化,或者是为了从embedding到最后输出维度都保持一致 172 | # 这样使得多个attention头合起来维度还是config.hidden_size 173 | # 而attention_head_size就是每个attention头的维度,这个维度其实是可以人为指定的,但实现起来代码会比较麻烦 174 | 175 | self.query = nn.Linear(config.hidden_size, self.all_head_size) 176 | self.key = nn.Linear(config.hidden_size, self.all_head_size) 177 | self.value = nn.Linear(config.hidden_size, self.all_head_size) 178 | 179 | self.dropout = nn.Dropout(config.attention_probs_dropout_prob) 180 | 181 | def transpose_for_scores(self, x): 182 | # shape of x: batch_size * seq_length * hidden_size 183 | # 这个操作是把hidden_size分解为 self.num_attention_heads * self.attention_head_size 184 | # 然后再交换 seq_length维度 和 num_attention_heads维度 185 | # 为什么要做这一步:因为attention是要对query中的每个字和key中的每个字做点积,即是在 seq_length 维度上 186 | # query和key的点积是 [seq_length * attention_head_size] * [attention_head_size * seq_length]=[seq_length * seq_length] 187 | new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) # 这里是一个维度拼接:(1,2)+(4,5) -> (1, 2, 4, 5) 188 | x = x.view(*new_x_shape) 189 | return x.permute(0, 2, 1, 3) 190 | 191 | def forward(self, hidden_states, attention_mask): 192 | # shape of hidden_states: batch_size * seq_length * hidden_size 193 | # shape of mixed_*_layer: batch_size * seq_length * hidden_size 194 | mixed_query_layer = self.query(hidden_states) 195 | mixed_key_layer = self.key(hidden_states) 196 | mixed_value_layer = self.value(hidden_states) 197 | 198 | # shape of query_layer: batch_size * num_attention_heads * seq_length * attention_head_size 199 | query_layer = self.transpose_for_scores(mixed_query_layer) 200 | key_layer = self.transpose_for_scores(mixed_key_layer) 201 | value_layer = self.transpose_for_scores(mixed_value_layer) 202 | 203 | # Take the dot product between "query" and "key" to get the raw attention scores. 204 | # shape of attention_scores: batch_size * num_attention_heads * seq_length * seq_length 205 | attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) 206 | attention_scores /= math.sqrt(self.attention_head_size) 207 | 208 | # Apply the attention mask is (precomputed for all layers in BertModel forward() function) 209 | # shape of attention_mask: batch_size * 1 * 1 * seq_length. 它可以自动广播到和attention_scores一样的维度 210 | # 我们初始输入的attention_mask是:batch_size * seq_length,做了两次unsqueeze之后得到当前的attention_mask 211 | attention_scores = attention_scores + attention_mask 212 | 213 | # Normalize the attention scores to probabilities. Softmax 不改变维度 214 | # shape of attention_scores: batch_size * num_attention_heads * seq_length * seq_length 215 | attention_probs = nn.Softmax(dim=-1)(attention_scores) 216 | attention_probs = self.dropout(attention_probs) 217 | 218 | # shape of value_layer: batch_size * num_attention_heads * seq_length * attention_head_size 219 | # shape of first context_layer: batch_size * num_attention_heads * seq_length * attention_head_size 220 | # shape of second context_layer: batch_size * seq_length * num_attention_heads * attention_head_size 221 | # context_layer 维度恢复到:batch_size * seq_length * hidden_size 222 | context_layer = torch.matmul(attention_probs, value_layer) 223 | context_layer = context_layer.permute(0, 2, 1, 3).contiguous() 224 | new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) 225 | context_layer = context_layer.view(*new_context_layer_shape) 226 | return context_layer 227 | 228 | 229 | class BERTSelfOutput(nn.Module): 230 | """BERTSelfAttention 之后还有一个 feed forward,dropout,add and norm """ 231 | def __init__(self, config): 232 | super(BERTSelfOutput, self).__init__() 233 | self.dense = nn.Linear(config.hidden_size, config.hidden_size) 234 | self.LayerNorm = BERTLayerNorm(config) 235 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 236 | 237 | def forward(self, hidden_states, input_tensor): 238 | hidden_states = self.dense(hidden_states) 239 | hidden_states = self.dropout(hidden_states) 240 | hidden_states = self.LayerNorm(hidden_states + input_tensor) # skip connection 在这里 241 | return hidden_states 242 | 243 | 244 | class BERTAttention(nn.Module): 245 | """一个BERT block中的前面部分""" 246 | def __init__(self, config): 247 | super(BERTAttention, self).__init__() 248 | self.self = BERTSelfAttention(config) 249 | self.output = BERTSelfOutput(config) 250 | 251 | def forward(self, input_tensor, attention_mask): 252 | self_output = self.self(input_tensor, attention_mask) 253 | attention_output = self.output(self_output, input_tensor) 254 | return attention_output 255 | 256 | 257 | class BERTIntermediate(nn.Module): 258 | """BERT模型中唯一用到了激活函数的地方, BERTIntermediate只是在中间扩充了一下维度,在BERTOutput中又转回去了""" 259 | def __init__(self, config): 260 | super(BERTIntermediate, self).__init__() 261 | self.dense = nn.Linear(config.hidden_size, config.intermediate_size) 262 | self.intermediate_act_fn = gelu 263 | 264 | def forward(self, hidden_states): 265 | hidden_states = self.dense(hidden_states) 266 | hidden_states = self.intermediate_act_fn(hidden_states) 267 | return hidden_states 268 | 269 | 270 | class BERTOutput(nn.Module): 271 | """一个BERT block中的后面部分,和前面的BERTSelfOutput几乎相同""" 272 | def __init__(self, config): 273 | super(BERTOutput, self).__init__() 274 | self.dense = nn.Linear(config.intermediate_size, config.hidden_size) 275 | self.LayerNorm = BERTLayerNorm(config) 276 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 277 | 278 | def forward(self, hidden_states, input_tensor): 279 | hidden_states = self.dense(hidden_states) 280 | hidden_states = self.dropout(hidden_states) 281 | hidden_states = self.LayerNorm(hidden_states + input_tensor) 282 | return hidden_states 283 | 284 | 285 | class BERTLayer(nn.Module): 286 | """一个BERT block包括三个部分:BERTAttention, BERTIntermediate, BERTOutput""" 287 | def __init__(self, config): 288 | super(BERTLayer, self).__init__() 289 | self.attention = BERTAttention(config) 290 | self.intermediate = BERTIntermediate(config) 291 | self.output = BERTOutput(config) 292 | 293 | def forward(self, hidden_states, attention_mask): 294 | attention_output = self.attention(hidden_states, attention_mask) 295 | intermediate_output = self.intermediate(attention_output) 296 | layer_output = self.output(intermediate_output, attention_output) 297 | return layer_output 298 | 299 | 300 | class BERTEncoder(nn.Module): 301 | """12 个BERT block, 中间一定要用copy.deepcopy,否则指代的会是同一个block""" 302 | def __init__(self, config): 303 | super(BERTEncoder, self).__init__() 304 | layer = BERTLayer(config) 305 | self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.num_hidden_layers)]) 306 | 307 | def forward(self, hidden_states, attention_mask): 308 | all_encoder_layers = [] 309 | for layer_module in self.layer: 310 | hidden_states = layer_module(hidden_states, attention_mask) 311 | all_encoder_layers.append(hidden_states) 312 | return all_encoder_layers # 记录了第一层到最后一层,所有time_step的输出 313 | 314 | 315 | class BERTPooler(nn.Module): 316 | """取的[CLS]位输出做分类""" 317 | def __init__(self, config): 318 | super(BERTPooler, self).__init__() 319 | self.dense = nn.Linear(config.hidden_size, config.hidden_size) 320 | self.activation = nn.Tanh() 321 | 322 | def forward(self, hidden_states): 323 | # We "pool" the model by simply taking the hidden state corresponding 324 | # to the first token. 325 | first_token_tensor = hidden_states[:, 0] 326 | pooled_output = self.dense(first_token_tensor) 327 | pooled_output = self.activation(pooled_output) 328 | return pooled_output 329 | 330 | 331 | class BertModel(nn.Module): 332 | """BERT model ("Bidirectional Embedding Representations from a Transformer"). 333 | 334 | Example usage: 335 | ```python 336 | # Already been converted into WordPiece token ids 337 | input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) 338 | input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) 339 | token_type_ids = torch.LongTensor([[0, 0, 1], [0, 2, 0]]) 340 | 341 | config = modeling.BertConfig(vocab_size=32000, hidden_size=512, 342 | num_hidden_layers=8, num_attention_heads=6, intermediate_size=1024) 343 | 344 | model = modeling.BertModel(config=config) 345 | all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask) 346 | ``` 347 | """ 348 | def __init__(self, config: BertConfig): 349 | """Constructor for BertModel. 350 | 351 | Args: 352 | config: `BertConfig` instance. 353 | """ 354 | super(BertModel, self).__init__() 355 | self.embeddings = BERTEmbeddings(config) 356 | self.encoder = BERTEncoder(config) 357 | self.pooler = BERTPooler(config) 358 | 359 | def forward(self, input_ids, token_type_ids=None, attention_mask=None): 360 | if attention_mask is None: 361 | attention_mask = torch.ones_like(input_ids) 362 | if token_type_ids is None: 363 | token_type_ids = torch.zeros_like(input_ids) 364 | 365 | # We create a 3D attention mask from a 2D tensor mask. 366 | # Sizes are [batch_size, 1, 1, to_seq_length] 367 | # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] 368 | # this attention mask is more simple than the triangular masking of causal attention 369 | # used in OpenAI GPT, we just need to prepare the broadcast dimension here. 370 | extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) 371 | 372 | # Since attention_mask is 1.0 for positions we want to attend and 0.0 for 373 | # masked positions, this operation will create a tensor which is 0.0 for 374 | # positions we want to attend and -10000.0 for masked positions. 375 | # Since we are adding it to the raw scores before the softmax, this is 376 | # effectively the same as removing these entirely. 377 | extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility 378 | extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 379 | 380 | embedding_output = self.embeddings(input_ids, token_type_ids) 381 | all_encoder_layers = self.encoder(embedding_output, extended_attention_mask) 382 | sequence_output = all_encoder_layers[-1] 383 | pooled_output = self.pooler(sequence_output) 384 | return all_encoder_layers, pooled_output 385 | 386 | class BertForSequenceClassification(nn.Module): 387 | """BERT model for classification. 388 | This module is composed of the BERT model with a linear layer on top of 389 | the pooled output. 390 | 391 | Example usage: 392 | ```python 393 | # Already been converted into WordPiece token ids 394 | input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) 395 | input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) 396 | token_type_ids = torch.LongTensor([[0, 0, 1], [0, 2, 0]]) 397 | 398 | config = BertConfig(vocab_size=32000, hidden_size=512, 399 | num_hidden_layers=8, num_attention_heads=6, intermediate_size=1024) 400 | 401 | num_labels = 2 402 | 403 | model = BertForSequenceClassification(config, num_labels) 404 | logits = model(input_ids, token_type_ids, input_mask) 405 | ``` 406 | """ 407 | def __init__(self, config, num_labels): 408 | super(BertForSequenceClassification, self).__init__() 409 | self.bert = BertModel(config) 410 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 411 | self.classifier = nn.Linear(config.hidden_size, num_labels) 412 | 413 | def init_weights(module): 414 | if isinstance(module, (nn.Linear, nn.Embedding)): 415 | # Slightly different from the TF version which uses truncated_normal for initialization 416 | # cf https://github.com/pytorch/pytorch/pull/5617 417 | module.weight.data.normal_(mean=0.0, std=config.initializer_range) 418 | elif isinstance(module, BERTLayerNorm): 419 | module.beta.data.normal_(mean=0.0, std=config.initializer_range) 420 | module.gamma.data.normal_(mean=0.0, std=config.initializer_range) 421 | if isinstance(module, nn.Linear): 422 | module.bias.data.zero_() 423 | self.apply(init_weights) 424 | 425 | def forward(self, input_ids, token_type_ids, attention_mask, labels=None): 426 | _, pooled_output = self.bert(input_ids, token_type_ids, attention_mask) 427 | pooled_output = self.dropout(pooled_output) 428 | logits = self.classifier(pooled_output) 429 | 430 | if labels is not None: 431 | loss_fct = CrossEntropyLoss() 432 | loss = loss_fct(logits, labels) 433 | return loss, logits 434 | else: 435 | return logits 436 | 437 | class BertForQuestionAnswering(nn.Module): 438 | """BERT model for Question Answering (span extraction). 439 | This module is composed of the BERT model with a linear layer on top of 440 | the sequence output that computes start_logits and end_logits 441 | 442 | Example usage: 443 | ```python 444 | # Already been converted into WordPiece token ids 445 | input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) 446 | input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) 447 | token_type_ids = torch.LongTensor([[0, 0, 1], [0, 2, 0]]) 448 | 449 | config = BertConfig(vocab_size=32000, hidden_size=512, 450 | num_hidden_layers=8, num_attention_heads=6, intermediate_size=1024) 451 | 452 | model = BertForQuestionAnswering(config) 453 | start_logits, end_logits = model(input_ids, token_type_ids, input_mask) 454 | ``` 455 | """ 456 | def __init__(self, config): 457 | super(BertForQuestionAnswering, self).__init__() 458 | self.bert = BertModel(config) 459 | # TODO check with Google if it's normal there is no dropout on the token classifier of SQuAD in the TF version 460 | # self.dropout = nn.Dropout(config.hidden_dropout_prob) 461 | self.qa_outputs = nn.Linear(config.hidden_size, 2) 462 | 463 | def init_weights(module): 464 | if isinstance(module, (nn.Linear, nn.Embedding)): 465 | # Slightly different from the TF version which uses truncated_normal for initialization 466 | # cf https://github.com/pytorch/pytorch/pull/5617 467 | module.weight.data.normal_(mean=0.0, std=config.initializer_range) 468 | elif isinstance(module, BERTLayerNorm): 469 | module.beta.data.normal_(mean=0.0, std=config.initializer_range) 470 | module.gamma.data.normal_(mean=0.0, std=config.initializer_range) 471 | if isinstance(module, nn.Linear): 472 | module.bias.data.zero_() 473 | self.apply(init_weights) 474 | 475 | def forward(self, input_ids, token_type_ids, attention_mask, start_positions=None, end_positions=None): 476 | all_encoder_layers, _ = self.bert(input_ids, token_type_ids, attention_mask) 477 | sequence_output = all_encoder_layers[-1] 478 | logits = self.qa_outputs(sequence_output) 479 | start_logits, end_logits = logits.split(1, dim=-1) 480 | start_logits = start_logits.squeeze(-1) 481 | end_logits = end_logits.squeeze(-1) 482 | 483 | if start_positions is not None and end_positions is not None: 484 | # If we are on multi-GPU, split add a dimension 485 | if len(start_positions.size()) > 1: 486 | start_positions = start_positions.squeeze(-1) 487 | if len(end_positions.size()) > 1: 488 | end_positions = end_positions.squeeze(-1) 489 | # sometimes the start/end positions are outside our model inputs, we ignore these terms 490 | ignored_index = start_logits.size(1) 491 | start_positions.clamp_(0, ignored_index) 492 | end_positions.clamp_(0, ignored_index) 493 | 494 | loss_fct = CrossEntropyLoss(ignore_index=ignored_index) 495 | start_loss = loss_fct(start_logits, start_positions) 496 | end_loss = loss_fct(end_logits, end_positions) 497 | total_loss = (start_loss + end_loss) / 2 498 | return total_loss 499 | else: 500 | return start_logits, end_logits 501 | 502 | if __name__=="__main__": 503 | input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) 504 | input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) 505 | token_type_ids = torch.LongTensor([[0, 0, 1], [0, 2, 0]]) 506 | 507 | config = BertConfig(vocab_size=32000, hidden_size=512, 508 | num_hidden_layers=8, num_attention_heads=8, intermediate_size=1024) 509 | 510 | model = BertModel(config=config) 511 | all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask) 512 | print(pooled_output) -------------------------------------------------------------------------------- /bert_pytorch_source_code/run_classifier_word.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """BERT finetuning runner.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import csv 22 | import os 23 | import logging 24 | import argparse 25 | import random 26 | from tqdm import tqdm, trange 27 | 28 | import numpy as np 29 | import torch 30 | from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler 31 | from torch.utils.data.distributed import DistributedSampler 32 | 33 | import tokenization_word as tokenization 34 | from modeling import BertConfig, BertForSequenceClassification 35 | from optimization import BERTAdam 36 | 37 | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', 38 | datefmt='%m/%d/%Y %H:%M:%S', 39 | level=logging.INFO) 40 | logger = logging.getLogger(__name__) 41 | 42 | 43 | class InputExample(object): 44 | """A single training/test example for simple sequence classification.""" 45 | 46 | def __init__(self, guid, text_a, text_b=None, label=None): 47 | """Constructs a InputExample. 48 | 49 | Args: 50 | guid: Unique id for the example. 51 | text_a: string. The untokenized text of the first sequence. For single 52 | sequence tasks, only this sequence must be specified. 53 | text_b: (Optional) string. The untokenized text of the second sequence. 54 | Only must be specified for sequence pair tasks. 55 | label: (Optional) string. The label of the example. This should be 56 | specified for train and dev examples, but not for test examples. 57 | """ 58 | self.guid = guid 59 | self.text_a = text_a 60 | self.text_b = text_b 61 | self.label = label 62 | 63 | 64 | class InputFeatures(object): 65 | """A single set of features of data.""" 66 | 67 | def __init__(self, input_ids, input_mask, segment_ids, label_id): 68 | self.input_ids = input_ids 69 | self.input_mask = input_mask 70 | self.segment_ids = segment_ids 71 | self.label_id = label_id 72 | 73 | 74 | class DataProcessor(object): 75 | """Base class for data converters for sequence classification data sets.""" 76 | 77 | def get_train_examples(self, data_dir): 78 | """Gets a collection of `InputExample`s for the train set.""" 79 | raise NotImplementedError() 80 | 81 | def get_dev_examples(self, data_dir): 82 | """Gets a collection of `InputExample`s for the dev set.""" 83 | raise NotImplementedError() 84 | 85 | def get_labels(self): 86 | """Gets the list of labels for this data set.""" 87 | raise NotImplementedError() 88 | 89 | @classmethod 90 | def _read_tsv(cls, input_file, quotechar=None): 91 | """Reads a tab separated value file.""" 92 | file_in = open(input_file, "rb") 93 | lines = [] 94 | for line in file_in: 95 | lines.append(line.decode("utf-8").split("\t")) 96 | return lines 97 | 98 | 99 | class NewsProcessor(DataProcessor): 100 | """Processor for the MRPC data set (GLUE version).""" 101 | 102 | def __init__(self): 103 | self.labels = set() 104 | 105 | def get_train_examples(self, data_dir): 106 | """See base class.""" 107 | logger.info("LOOKING AT {}".format(os.path.join(data_dir, "train.tsv"))) 108 | return self._create_examples( 109 | self._read_tsv(os.path.join(data_dir, "train.tsv")), "train") 110 | 111 | def get_dev_examples(self, data_dir): 112 | """See base class.""" 113 | return self._create_examples( 114 | self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev") 115 | 116 | def get_labels(self): 117 | """See base class.""" 118 | return list(self.labels) 119 | 120 | def _create_examples(self, lines, set_type): 121 | """Creates examples for the training and dev sets.""" 122 | examples = [] 123 | for (i, line) in enumerate(lines): 124 | guid = "%s-%s" % (set_type, i) 125 | text_a = tokenization.convert_to_unicode(line[1]) 126 | label = tokenization.convert_to_unicode(line[0]) 127 | self.labels.add(label) 128 | examples.append( 129 | InputExample(guid=guid, text_a=text_a, text_b=None, label=label)) 130 | 131 | return examples 132 | 133 | 134 | class MrpcProcessor(DataProcessor): 135 | """Processor for the MRPC data set (GLUE version).""" 136 | 137 | def get_train_examples(self, data_dir): 138 | """See base class.""" 139 | logger.info("LOOKING AT {}".format(os.path.join(data_dir, "train.tsv"))) 140 | return self._create_examples( 141 | self._read_tsv(os.path.join(data_dir, "train.tsv")), "train") 142 | 143 | def get_dev_examples(self, data_dir): 144 | """See base class.""" 145 | return self._create_examples( 146 | self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev") 147 | 148 | def get_labels(self): 149 | """See base class.""" 150 | return ["0", "1"] 151 | 152 | def _create_examples(self, lines, set_type): 153 | """Creates examples for the training and dev sets.""" 154 | examples = [] 155 | for (i, line) in enumerate(lines): 156 | if i == 0: 157 | continue 158 | guid = "%s-%s" % (set_type, i) 159 | text_a = tokenization.convert_to_unicode(line[3]) 160 | text_b = tokenization.convert_to_unicode(line[4]) 161 | label = tokenization.convert_to_unicode(line[0]) 162 | examples.append( 163 | InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) 164 | return examples 165 | 166 | 167 | class MnliProcessor(DataProcessor): 168 | """Processor for the MultiNLI data set (GLUE version).""" 169 | 170 | def get_train_examples(self, data_dir): 171 | """See base class.""" 172 | return self._create_examples( 173 | self._read_tsv(os.path.join(data_dir, "train.tsv")), "train") 174 | 175 | def get_dev_examples(self, data_dir): 176 | """See base class.""" 177 | return self._create_examples( 178 | self._read_tsv(os.path.join(data_dir, "dev_matched.tsv")), 179 | "dev_matched") 180 | 181 | def get_labels(self): 182 | """See base class.""" 183 | return ["contradiction", "entailment", "neutral"] 184 | 185 | def _create_examples(self, lines, set_type): 186 | """Creates examples for the training and dev sets.""" 187 | examples = [] 188 | for (i, line) in enumerate(lines): 189 | if i == 0: 190 | continue 191 | guid = "%s-%s" % (set_type, tokenization.convert_to_unicode(line[0])) 192 | text_a = tokenization.convert_to_unicode(line[8]) 193 | text_b = tokenization.convert_to_unicode(line[9]) 194 | label = tokenization.convert_to_unicode(line[-1]) 195 | examples.append( 196 | InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) 197 | return examples 198 | 199 | 200 | class ColaProcessor(DataProcessor): 201 | """Processor for the CoLA data set (GLUE version).""" 202 | 203 | def get_train_examples(self, data_dir): 204 | """See base class.""" 205 | return self._create_examples( 206 | self._read_tsv(os.path.join(data_dir, "train.tsv")), "train") 207 | 208 | def get_dev_examples(self, data_dir): 209 | """See base class.""" 210 | return self._create_examples( 211 | self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev") 212 | 213 | def get_labels(self): 214 | """See base class.""" 215 | return ["0", "1"] 216 | 217 | def _create_examples(self, lines, set_type): 218 | """Creates examples for the training and dev sets.""" 219 | examples = [] 220 | for (i, line) in enumerate(lines): 221 | guid = "%s-%s" % (set_type, i) 222 | text_a = tokenization.convert_to_unicode(line[3]) 223 | label = tokenization.convert_to_unicode(line[1]) 224 | examples.append( 225 | InputExample(guid=guid, text_a=text_a, text_b=None, label=label)) 226 | return examples 227 | 228 | 229 | def convert_examples_to_features(examples, label_list, max_seq_length, tokenizer): 230 | """Loads a data file into a list of `InputBatch`s.""" 231 | 232 | label_map = {} 233 | for (i, label) in enumerate(label_list): 234 | label_map[label] = i 235 | features = [] 236 | for (ex_index, example) in enumerate(examples): 237 | tokens_a = tokenizer.tokenize(example.text_a) 238 | 239 | tokens_b = None 240 | if example.text_b: 241 | tokens_b = tokenizer.tokenize(example.text_b) 242 | 243 | if tokens_b: 244 | # Modifies `tokens_a` and `tokens_b` in place so that the total 245 | # length is less than the specified length. 246 | # Account for [CLS], [SEP], [SEP] with "- 3" 247 | _truncate_seq_pair(tokens_a, tokens_b, max_seq_length - 3) 248 | else: 249 | # Account for [CLS] and [SEP] with "- 2" 250 | if len(tokens_a) > max_seq_length - 2: 251 | tokens_a = tokens_a[0:(max_seq_length - 2)] 252 | 253 | # The convention in BERT is: 254 | # (a) For sequence pairs: 255 | # tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP] 256 | # type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1 257 | # (b) For single sequences: 258 | # tokens: [CLS] the dog is hairy . [SEP] 259 | # type_ids: 0 0 0 0 0 0 0 260 | # 261 | # Where "type_ids" are used to indicate whether this is the first 262 | # sequence or the second sequence. The embedding vectors for `type=0` and 263 | # `type=1` were learned during pre-training and are added to the wordpiece 264 | # embedding vector (and position vector). This is not *strictly* necessary 265 | # since the [SEP] token unambigiously separates the sequences, but it makes 266 | # it easier for the model to learn the concept of sequences. 267 | # 268 | # For classification tasks, the first vector (corresponding to [CLS]) is 269 | # used as as the "sentence vector". Note that this only makes sense because 270 | # the entire model is fine-tuned. 271 | tokens = [] 272 | segment_ids = [] 273 | tokens.append("[CLS]") 274 | segment_ids.append(0) 275 | for token in tokens_a: 276 | tokens.append(token) 277 | segment_ids.append(0) 278 | tokens.append("[SEP]") 279 | segment_ids.append(0) 280 | 281 | if tokens_b: 282 | for token in tokens_b: 283 | tokens.append(token) 284 | segment_ids.append(1) 285 | tokens.append("[SEP]") 286 | segment_ids.append(1) 287 | 288 | input_ids = tokenizer.convert_tokens_to_ids(tokens) 289 | 290 | # The mask has 1 for real tokens and 0 for padding tokens. Only real 291 | # tokens are attended to. 292 | input_mask = [1] * len(input_ids) 293 | 294 | # Zero-pad up to the sequence length. 295 | while len(input_ids) < max_seq_length: 296 | input_ids.append(0) 297 | input_mask.append(0) 298 | segment_ids.append(0) 299 | 300 | assert len(input_ids) == max_seq_length 301 | assert len(input_mask) == max_seq_length 302 | assert len(segment_ids) == max_seq_length 303 | 304 | label_id = label_map[example.label] 305 | if ex_index < 5: 306 | logger.info("*** Example ***") 307 | logger.info("guid: %s" % (example.guid)) 308 | logger.info("tokens: %s" % " ".join( 309 | [tokenization.printable_text(x) for x in tokens])) 310 | logger.info("input_ids: %s" % " ".join([str(x) for x in input_ids])) 311 | logger.info("input_mask: %s" % " ".join([str(x) for x in input_mask])) 312 | logger.info( 313 | "segment_ids: %s" % " ".join([str(x) for x in segment_ids])) 314 | logger.info("label: %s (id = %d)" % (example.label, label_id)) 315 | 316 | features.append( 317 | InputFeatures( 318 | input_ids=input_ids, 319 | input_mask=input_mask, 320 | segment_ids=segment_ids, 321 | label_id=label_id)) 322 | return features 323 | 324 | 325 | def _truncate_seq_pair(tokens_a, tokens_b, max_length): 326 | """Truncates a sequence pair in place to the maximum length.""" 327 | 328 | # This is a simple heuristic which will always truncate the longer sequence 329 | # one token at a time. This makes more sense than truncating an equal percent 330 | # of tokens from each, since if one sequence is very short then each token 331 | # that's truncated likely contains more information than a longer sequence. 332 | while True: 333 | total_length = len(tokens_a) + len(tokens_b) 334 | if total_length <= max_length: 335 | break 336 | if len(tokens_a) > len(tokens_b): 337 | tokens_a.pop() 338 | else: 339 | tokens_b.pop() 340 | 341 | 342 | def accuracy(out, labels): 343 | outputs = np.argmax(out, axis=1) 344 | return np.sum(outputs == labels) 345 | 346 | 347 | def copy_optimizer_params_to_model(named_params_model, named_params_optimizer): 348 | """ Utility function for optimize_on_cpu and 16-bits training. 349 | Copy the parameters optimized on CPU/RAM back to the model on GPU 350 | """ 351 | for (name_opti, param_opti), (name_model, param_model) in zip(named_params_optimizer, named_params_model): 352 | if name_opti != name_model: 353 | logger.error("name_opti != name_model: {} {}".format(name_opti, name_model)) 354 | raise ValueError 355 | param_model.data.copy_(param_opti.data) 356 | 357 | 358 | def set_optimizer_params_grad(named_params_optimizer, named_params_model, test_nan=False): 359 | """ Utility function for optimize_on_cpu and 16-bits training. 360 | Copy the gradient of the GPU parameters to the CPU/RAMM copy of the model 361 | """ 362 | is_nan = False 363 | for (name_opti, param_opti), (name_model, param_model) in zip(named_params_optimizer, named_params_model): 364 | if name_opti != name_model: 365 | logger.error("name_opti != name_model: {} {}".format(name_opti, name_model)) 366 | raise ValueError 367 | if test_nan and torch.isnan(param_model.grad).sum() > 0: 368 | is_nan = True 369 | if param_opti.grad is None: 370 | param_opti.grad = torch.nn.Parameter(param_opti.data.new().resize_(*param_opti.data.size())) 371 | param_opti.grad.data.copy_(param_model.grad.data) 372 | return is_nan 373 | 374 | 375 | def main(): 376 | parser = argparse.ArgumentParser() 377 | 378 | ## Required parameters 379 | parser.add_argument("--data_dir", 380 | default=None, 381 | type=str, 382 | required=True, 383 | help="The input data dir. Should contain the .tsv files (or other data files) for the task.") 384 | parser.add_argument("--bert_config_file", 385 | default=None, 386 | type=str, 387 | required=True, 388 | help="The config json file corresponding to the pre-trained BERT model. \n" 389 | "This specifies the model architecture.") 390 | parser.add_argument("--task_name", 391 | default=None, 392 | type=str, 393 | required=True, 394 | help="The name of the task to train.") 395 | parser.add_argument("--vocab_file", 396 | default=None, 397 | type=str, 398 | required=True, 399 | help="The vocabulary file that the BERT model was trained on.") 400 | parser.add_argument("--output_dir", 401 | default=None, 402 | type=str, 403 | required=True, 404 | help="The output directory where the model checkpoints will be written.") 405 | 406 | ## Other parameters 407 | parser.add_argument("--init_checkpoint", 408 | default=None, 409 | type=str, 410 | help="Initial checkpoint (usually from a pre-trained BERT model).") 411 | parser.add_argument("--do_lower_case", 412 | default=False, 413 | action='store_true', 414 | help="Whether to lower case the input text. True for uncased models, False for cased models.") 415 | parser.add_argument("--max_seq_length", 416 | default=128, 417 | type=int, 418 | help="The maximum total input sequence length after WordPiece tokenization. \n" 419 | "Sequences longer than this will be truncated, and sequences shorter \n" 420 | "than this will be padded.") 421 | parser.add_argument("--do_train", 422 | default=False, 423 | action='store_true', 424 | help="Whether to run training.") 425 | parser.add_argument("--do_eval", 426 | default=False, 427 | action='store_true', 428 | help="Whether to run eval on the dev set.") 429 | parser.add_argument("--train_batch_size", 430 | default=32, 431 | type=int, 432 | help="Total batch size for training.") 433 | parser.add_argument("--eval_batch_size", 434 | default=8, 435 | type=int, 436 | help="Total batch size for eval.") 437 | parser.add_argument("--learning_rate", 438 | default=5e-5, 439 | type=float, 440 | help="The initial learning rate for Adam.") 441 | parser.add_argument("--num_train_epochs", 442 | default=3.0, 443 | type=float, 444 | help="Total number of training epochs to perform.") 445 | parser.add_argument("--warmup_proportion", 446 | default=0.1, 447 | type=float, 448 | help="Proportion of training to perform linear learning rate warmup for. " 449 | "E.g., 0.1 = 10%% of training.") 450 | parser.add_argument("--save_checkpoints_steps", 451 | default=1000, 452 | type=int, 453 | help="How often to save the model checkpoint.") 454 | parser.add_argument("--no_cuda", 455 | default=False, 456 | action='store_true', 457 | help="Whether not to use CUDA when available") 458 | parser.add_argument("--local_rank", 459 | type=int, 460 | default=-1, 461 | help="local_rank for distributed training on gpus") 462 | parser.add_argument('--seed', 463 | type=int, 464 | default=42, 465 | help="random seed for initialization") 466 | parser.add_argument('--gradient_accumulation_steps', 467 | type=int, 468 | default=1, 469 | help="Number of updates steps to accumualte before performing a backward/update pass.") 470 | parser.add_argument('--optimize_on_cpu', 471 | default=False, 472 | action='store_true', 473 | help="Whether to perform optimization and keep the optimizer averages on CPU") 474 | parser.add_argument('--fp16', 475 | default=False, 476 | action='store_true', 477 | help="Whether to use 16-bit float precision instead of 32-bit") 478 | parser.add_argument('--loss_scale', 479 | type=float, default=128, 480 | help='Loss scaling, positive power of 2 values can improve fp16 convergence.') 481 | 482 | args = parser.parse_args() 483 | 484 | processors = { 485 | "cola": ColaProcessor, 486 | "mnli": MnliProcessor, 487 | "mrpc": MrpcProcessor, 488 | "news": NewsProcessor, 489 | } 490 | 491 | if args.local_rank == -1 or args.no_cuda: 492 | device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") 493 | n_gpu = torch.cuda.device_count() 494 | else: 495 | device = torch.device("cuda", args.local_rank) 496 | n_gpu = 1 497 | # Initializes the distributed backend which will take care of sychronizing nodes/GPUs 498 | # torch.distributed.init_process_group(backend='nccl') 499 | if args.fp16: 500 | logger.info("16-bits training currently not supported in distributed training") 501 | args.fp16 = False # (see https://github.com/pytorch/pytorch/pull/13496) 502 | logger.info("device %s n_gpu %d distributed training %r", device, n_gpu, bool(args.local_rank != -1)) 503 | 504 | if args.gradient_accumulation_steps < 1: 505 | raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format( 506 | args.gradient_accumulation_steps)) 507 | 508 | args.train_batch_size = int(args.train_batch_size / args.gradient_accumulation_steps) 509 | 510 | random.seed(args.seed) 511 | np.random.seed(args.seed) 512 | torch.manual_seed(args.seed) 513 | if n_gpu > 0: 514 | torch.cuda.manual_seed_all(args.seed) 515 | 516 | if not args.do_train and not args.do_eval: 517 | raise ValueError("At least one of `do_train` or `do_eval` must be True.") 518 | 519 | bert_config = BertConfig.from_json_file(args.bert_config_file) 520 | 521 | if args.max_seq_length > bert_config.max_position_embeddings: 522 | raise ValueError( 523 | "Cannot use sequence length {} because the BERT model was only trained up to sequence length {}".format( 524 | args.max_seq_length, bert_config.max_position_embeddings)) 525 | 526 | if os.path.exists(args.output_dir) and os.listdir(args.output_dir): 527 | raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir)) 528 | os.makedirs(args.output_dir, exist_ok=True) 529 | 530 | task_name = args.task_name.lower() 531 | 532 | if task_name not in processors: 533 | raise ValueError("Task not found: %s" % (task_name)) 534 | 535 | processor = processors[task_name]() 536 | 537 | tokenizer = tokenization.FullTokenizer( 538 | vocab_file=args.vocab_file, do_lower_case=args.do_lower_case) 539 | 540 | train_examples = None 541 | num_train_steps = None 542 | if args.do_train: 543 | train_examples = processor.get_train_examples(args.data_dir) 544 | num_train_steps = int( 545 | len(train_examples) / args.train_batch_size / args.gradient_accumulation_steps * args.num_train_epochs) 546 | 547 | label_list = processor.get_labels() 548 | 549 | print("label_list.size:%d\n" % (len(label_list))) 550 | 551 | # Prepare model 552 | model = BertForSequenceClassification(bert_config, len(label_list)) 553 | if args.init_checkpoint is not None: 554 | model.bert.load_state_dict(torch.load(args.init_checkpoint, map_location='cpu')) 555 | if args.fp16: 556 | model.half() 557 | model.to(device) 558 | # if args.local_rank != -1: 559 | # model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], 560 | # output_device=args.local_rank) 561 | # elif n_gpu > 1: 562 | # model = torch.nn.DataParallel(model) 563 | 564 | # Prepare optimizer 565 | if args.fp16: 566 | param_optimizer = [(n, param.clone().detach().to('cpu').float().requires_grad_()) \ 567 | for n, param in model.named_parameters()] 568 | elif args.optimize_on_cpu: 569 | param_optimizer = [(n, param.clone().detach().to('cpu').requires_grad_()) \ 570 | for n, param in model.named_parameters()] 571 | else: 572 | param_optimizer = list(model.named_parameters()) 573 | no_decay = ['bias', 'gamma', 'beta'] 574 | optimizer_grouped_parameters = [ 575 | {'params': [p for n, p in param_optimizer if n not in no_decay], 'weight_decay_rate': 0.01}, 576 | {'params': [p for n, p in param_optimizer if n in no_decay], 'weight_decay_rate': 0.0} 577 | ] 578 | optimizer = BERTAdam(optimizer_grouped_parameters, 579 | lr=args.learning_rate, 580 | warmup=args.warmup_proportion, 581 | t_total=num_train_steps) 582 | 583 | global_step = 0 584 | if args.do_train: 585 | train_features = convert_examples_to_features( 586 | train_examples, label_list, args.max_seq_length, tokenizer) 587 | logger.info("***** Running training *****") 588 | logger.info(" Num examples = %d", len(train_examples)) 589 | logger.info(" Batch size = %d", args.train_batch_size) 590 | logger.info(" Num steps = %d", num_train_steps) 591 | all_input_ids = torch.tensor([f.input_ids for f in train_features], dtype=torch.long) 592 | all_input_mask = torch.tensor([f.input_mask for f in train_features], dtype=torch.long) 593 | all_segment_ids = torch.tensor([f.segment_ids for f in train_features], dtype=torch.long) 594 | all_label_ids = torch.tensor([f.label_id for f in train_features], dtype=torch.long) 595 | train_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids) 596 | if args.local_rank == -1: 597 | train_sampler = RandomSampler(train_data) 598 | else: 599 | 600 | train_sampler = RandomSampler(train_data) 601 | # train_sampler = DistributedSampler(train_data) 602 | train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size) 603 | 604 | model.train() 605 | for _ in trange(int(args.num_train_epochs), desc="Epoch"): 606 | tr_loss = 0 607 | nb_tr_examples, nb_tr_steps = 0, 0 608 | for step, batch in enumerate(tqdm(train_dataloader, desc="Iteration")): 609 | batch = tuple(t.to(device) for t in batch) 610 | input_ids, input_mask, segment_ids, label_ids = batch 611 | loss, _ = model(input_ids, segment_ids, input_mask, label_ids) 612 | if n_gpu > 1: 613 | loss = loss.mean() # mean() to average on multi-gpu. 614 | if args.fp16 and args.loss_scale != 1.0: 615 | # rescale loss for fp16 training 616 | # see https://docs.nvidia.com/deeplearning/sdk/mixed-precision-training/index.html 617 | loss = loss * args.loss_scale 618 | if args.gradient_accumulation_steps > 1: 619 | loss = loss / args.gradient_accumulation_steps 620 | loss.backward() 621 | tr_loss += loss.item() 622 | nb_tr_examples += input_ids.size(0) 623 | nb_tr_steps += 1 624 | if (step + 1) % args.gradient_accumulation_steps == 0: 625 | if args.fp16 or args.optimize_on_cpu: 626 | if args.fp16 and args.loss_scale != 1.0: 627 | # scale down gradients for fp16 training 628 | for param in model.parameters(): 629 | param.grad.data = param.grad.data / args.loss_scale 630 | is_nan = set_optimizer_params_grad(param_optimizer, model.named_parameters(), test_nan=True) 631 | if is_nan: 632 | logger.info("FP16 TRAINING: Nan in gradients, reducing loss scaling") 633 | args.loss_scale = args.loss_scale / 2 634 | model.zero_grad() 635 | continue 636 | optimizer.step() 637 | copy_optimizer_params_to_model(model.named_parameters(), param_optimizer) 638 | else: 639 | optimizer.step() 640 | model.zero_grad() 641 | global_step += 1 642 | 643 | if args.do_eval: 644 | eval_examples = processor.get_dev_examples(args.data_dir) 645 | eval_features = convert_examples_to_features( 646 | eval_examples, label_list, args.max_seq_length, tokenizer) 647 | logger.info("***** Running evaluation *****") 648 | logger.info(" Num examples = %d", len(eval_examples)) 649 | logger.info(" Batch size = %d", args.eval_batch_size) 650 | all_input_ids = torch.tensor([f.input_ids for f in eval_features], dtype=torch.long) 651 | all_input_mask = torch.tensor([f.input_mask for f in eval_features], dtype=torch.long) 652 | all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long) 653 | all_label_ids = torch.tensor([f.label_id for f in eval_features], dtype=torch.long) 654 | eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids) 655 | if args.local_rank == -1: 656 | eval_sampler = SequentialSampler(eval_data) 657 | else: 658 | 659 | eval_sampler = SequentialSampler(eval_data) 660 | # eval_sampler = DistributedSampler(eval_data) 661 | eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.eval_batch_size) 662 | 663 | model.eval() 664 | eval_loss, eval_accuracy = 0, 0 665 | nb_eval_steps, nb_eval_examples = 0, 0 666 | for input_ids, input_mask, segment_ids, label_ids in eval_dataloader: 667 | input_ids = input_ids.to(device) 668 | input_mask = input_mask.to(device) 669 | segment_ids = segment_ids.to(device) 670 | label_ids = label_ids.to(device) 671 | 672 | with torch.no_grad(): 673 | tmp_eval_loss, logits = model(input_ids, segment_ids, input_mask, label_ids) 674 | 675 | logits = logits.detach().cpu().numpy() 676 | label_ids = label_ids.to('cpu').numpy() 677 | tmp_eval_accuracy = accuracy(logits, label_ids) 678 | 679 | eval_loss += tmp_eval_loss.mean().item() 680 | eval_accuracy += tmp_eval_accuracy 681 | 682 | nb_eval_examples += input_ids.size(0) 683 | nb_eval_steps += 1 684 | 685 | eval_loss = eval_loss / nb_eval_steps 686 | eval_accuracy = eval_accuracy / nb_eval_examples 687 | 688 | result = {'eval_loss': eval_loss, 689 | 'eval_accuracy': eval_accuracy, 690 | 'global_step': global_step, 691 | 'loss': tr_loss / nb_tr_steps} 692 | 693 | output_eval_file = os.path.join(args.output_dir, "eval_results.txt") 694 | with open(output_eval_file, "w") as writer: 695 | logger.info("***** Eval results *****") 696 | for key in sorted(result.keys()): 697 | logger.info(" %s = %s", key, str(result[key])) 698 | writer.write("%s = %s\n" % (key, str(result[key]))) 699 | 700 | 701 | if __name__ == "__main__": 702 | main() 703 | --------------------------------------------------------------------------------