├── README.md ├── __init__.py ├── convert_ernie_to_pytorch.py ├── fine_tune_ernie.py └── pyernie ├── __init__.py ├── callback ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── lrscheduler.cpython-36.pyc │ ├── modelcheckpoint.cpython-36.pyc │ ├── progressbar.cpython-36.pyc │ └── trainingmonitor.cpython-36.pyc ├── earlystopping.py ├── lrscheduler.py ├── modelcheckpoint.py ├── optimizater.py ├── progressbar.py ├── trainingmonitor.py └── writetensorboard.py ├── config ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-36.pyc │ └── basic_config.cpython-36.pyc └── basic_config.py ├── dataset ├── __init__.py ├── processed │ └── __init__.py └── raw │ └── __init__.py ├── io ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── data_transformer.cpython-36.pyc │ └── dataset.cpython-36.pyc ├── data_transformer.py └── dataset.py ├── model ├── __init__.py ├── ernie │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ ├── file_utils.cpython-36.pyc │ │ ├── modeling.cpython-36.pyc │ │ ├── optimization.cpython-36.pyc │ │ └── tokenization.cpython-36.pyc │ ├── file_utils.py │ ├── modeling.py │ ├── optimization.py │ └── tokenization.py ├── nn │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ └── ernie_fine.cpython-36.pyc │ └── ernie_fine.py └── pretrain │ ├── __init__.py │ └── ernie_base │ └── __init__.py ├── output ├── __init__.py ├── checkpoints │ └── __init__.py ├── embedding │ └── __init__.py ├── feature │ └── __init__.py ├── figure │ └── __init__.py ├── log │ └── __init__.py └── result │ └── __init__.py ├── preprocessing ├── __init__.py ├── augmentation.py └── preprocessor.py ├── test └── __init__.py ├── train ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── losses.cpython-36.pyc │ ├── metrics.cpython-36.pyc │ └── trainer.cpython-36.pyc ├── losses.py ├── metrics.py ├── train_utils.py └── trainer.py └── utils ├── __init__.py ├── __pycache__ ├── __init__.cpython-36.pyc ├── logginger.cpython-36.pyc └── utils.cpython-36.pyc ├── logginger.py └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | # ERNIE text classification by PyTorch 2 | 3 | This repo contains a PyTorch implementation of a pretrained ERNIE model for text classification. 4 | 5 | arxiv: https://arxiv.org/abs/1904.09223v1 6 | 7 | ## Structure of the code 8 | 9 | At the root of the project, you will see: 10 | 11 | ```text 12 | ├── pyernie 13 | | └── callback 14 | | | └── lrscheduler.py   15 | | | └── trainingmonitor.py  16 | | | └── ... 17 | | └── config 18 | | | └── basic_config.py #a configuration file for storing model parameters 19 | | └── dataset    20 | | └── io     21 | | | └── dataset.py   22 | | | └── data_transformer.py   23 | | └── model 24 | | | └── nn  25 | | | └── pretrain  26 | | └── output #save the ouput of model 27 | | └── preprocessing #text preprocessing 28 | | └── train #used for training a model 29 | | | └── trainer.py 30 | | | └── ... 31 | | └── utils # a set of utility functions 32 | ├── convert_ernie_to_pytorch.py 33 | ├── fine_tune_ernie.py 34 | ``` 35 | ## Dependencies 36 | 37 | - csv 38 | - tqdm 39 | - numpy 40 | - pickle 41 | - scikit-learn 42 | - PyTorch 1.0 43 | - matplotlib 44 | - tensorboardX 45 | - Tensorflow (to be able to run TensorboardX) 46 | 47 | ## How to use the code 48 | 49 | you need download pretrained ERNIE model 50 | 51 | 1. Download the pretrained ERNIE model from [baiduPan](https://pan.baidu.com/s/1BQlwbc9PZjAoVB7Kfq_Ihg) {password: uwds} and place it into the `/pyernie/model/pretrain` directory. 52 | 53 | 2. prepare Chinese raw data(example,news data), you can modify the `io.data_transformer.py` to adapt your data. 54 | 55 | 3. Modify configuration information in `pyernie/config/basic_config.py`(the path of data,...). 56 | 57 | 4. run `fine_tune_ernie.py`. 58 | 59 | ## Fine-tuning result 60 | 61 | ### training 62 | 63 | Epoch: 4 - loss: 0.0136 - f1: 0.9967 - valid_loss: 0.0761 - valid_f1: 0.9798 64 | 65 | ### train classify_report 66 | 67 | | label | precision | recall | f1-score | support | 68 | | :---------: | :-------: | :----: | :------: | :-----: | 69 | | 财经 | 0.99 | 0.99 | 0.99 | 3500 | 70 | | 体育 | 1.00 | 1.00 | 1.00 | 3500 | 71 | | 娱乐 | 1.00 | 1.00 | 1.00 | 3500 | 72 | | 家居 | 1.00 | 1.00 | 1.00 | 3500 | 73 | | 房产 | 0.99 | 0.99 | 0.99 | 3500 | 74 | | 教育 | 1.00 | 0.99 | 1.00 | 3500 | 75 | | 时尚 | 1.00 | 1.00 | 1.00 | 3500 | 76 | | 时政 | 1.00 | 1.00 | 1.00 | 3500 | 77 | | 游戏 | 1.00 | 1.00 | 1.00 | 3500 | 78 | | 科技 | 0.99 | 1.00 | 1.00 | 3500 | 79 | | avg / total | 1.00 | 1.00 | 1.00 | 35000 | 80 | 81 | ### valid classify_report 82 | 83 | | label | precision | recall | f1-score | support | 84 | | :---------: | :-------: | :----: | :------: | :-----: | 85 | | 财经 | 0.97 | 0.96 | 0.96 | 1500 | 86 | | 体育 | 1.00 | 1.00 | 1.00 | 1500 | 87 | | 娱乐 | 0.99 | 0.99 | 0.99 | 1500 | 88 | | 家居 | 0.99 | 0.99 | 0.99 | 1500 | 89 | | 房产 | 0.96 | 0.96 | 0.96 | 1500 | 90 | | 教育 | 0.98 | 0.98 | 0.98 | 1500 | 91 | | 时尚 | 0.99 | 0.99 | 0.99 | 1500 | 92 | | 时政 | 0.97 | 0.98 | 0.98 | 1500 | 93 | | 游戏 | 0.99 | 0.99 | 0.99 | 1500 | 94 | | 科技 | 0.97 | 0.97 | 0.97 | 1500 | 95 | | avg / total | 0.98 | 0.98 | 0.98 | 15000 | 96 | 97 | ### training figure 98 | 99 | ![]( https://lonepatient-1257945978.cos.ap-chengdu.myqcloud.com/20190519002915.png) 100 | 101 | ## Tips 102 | 103 | - When converting the tensorflow checkpoint into the pytorch, it's expected to choice the "bert_model.ckpt", instead of "bert_model.ckpt.index", as the input file. Otherwise, you will see that the model can learn nothing and give almost same random outputs for any inputs. This means, in fact, you have not loaded the true ckpt for your model 104 | - When using multiple GPUs, the non-tensor calculations, such as accuracy and f1_score, are not supported by DataParallel instance 105 | - As recommanded by Jocob in his paper https://arxiv.org/pdf/1810.04805.pdf, in fine-tuning tasks, the hyperparameters are expected to set as following: **Batch_size**: 16 or 32, **learning_rate**: 5e-5 or 2e-5 or 3e-5, **num_train_epoch**: 3 or 4 106 | - The pretrained model has a limit for the sentence of input that its length should is not larger than 512, the max position embedding dim. The data flows into the model as: Raw_data -> WordPieces -> Model. Note that the length of wordPieces is generally larger than that of raw_data, so a safe max length of raw_data is at ~128 - 256 107 | - Upon testing, we found that fine-tuning all layers could get much better results than those of only fine-tuning the last classfier layer. The latter is actually a feature-based way 108 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | #encoding:utf-8 -------------------------------------------------------------------------------- /convert_ernie_to_pytorch.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # encoding: utf-8 3 | import collections 4 | import os 5 | import sys 6 | import numpy as np 7 | import argparse 8 | import paddle.fluid as fluid 9 | import torch 10 | import json 11 | 12 | if not os.path.exists('LARK'): 13 | os.system('git clone https://github.com/PaddlePaddle/LARK.git') 14 | sys.path = ['./LARK/ERNIE'] + sys.path 15 | try: 16 | from model.ernie import ErnieConfig 17 | from finetune.classifier import create_model 18 | except: 19 | raise Exception('Place clone ERNIE first') 20 | 21 | 22 | def if_exist(var): 23 | return os.path.exists(os.path.join(args.init_pretraining_params, var.name)) 24 | 25 | 26 | def build_weight_map(): 27 | weight_map = collections.OrderedDict({ 28 | 'word_embedding': 'bert.embeddings.word_embeddings.weight', 29 | 'pos_embedding': 'bert.embeddings.position_embeddings.weight', 30 | 'sent_embedding': 'bert.embeddings.token_type_embeddings.weight', 31 | 'pre_encoder_layer_norm_scale': 'bert.embeddings.LayerNorm.gamma', 32 | 'pre_encoder_layer_norm_bias': 'bert.embeddings.LayerNorm.beta', 33 | }) 34 | 35 | def add_w_and_b(ernie_pre, pytroch_pre): 36 | weight_map[ernie_pre + ".w_0"] = pytroch_pre + ".weight" 37 | weight_map[ernie_pre + ".b_0"] = pytroch_pre + ".bias" 38 | 39 | def add_one_encoder_layer(layer_number): 40 | # attention 41 | add_w_and_b(f"encoder_layer_{layer_number}_multi_head_att_query_fc", 42 | f"bert.encoder.layer.{layer_number}.attention.self.query") 43 | add_w_and_b(f"encoder_layer_{layer_number}_multi_head_att_key_fc", 44 | f"bert.encoder.layer.{layer_number}.attention.self.key") 45 | add_w_and_b(f"encoder_layer_{layer_number}_multi_head_att_value_fc", 46 | f"bert.encoder.layer.{layer_number}.attention.self.value") 47 | add_w_and_b(f"encoder_layer_{layer_number}_multi_head_att_output_fc", 48 | f"bert.encoder.layer.{layer_number}.attention.output.dense") 49 | weight_map[f"encoder_layer_{layer_number}_post_att_layer_norm_bias"] = \ 50 | f"bert.encoder.layer.{layer_number}.attention.output.LayerNorm.bias" 51 | weight_map[f"encoder_layer_{layer_number}_post_att_layer_norm_scale"] = \ 52 | f"bert.encoder.layer.{layer_number}.attention.output.LayerNorm.weight" 53 | # intermediate 54 | add_w_and_b(f"encoder_layer_{layer_number}_ffn_fc_0", f"bert.encoder.layer.{layer_number}.intermediate.dense") 55 | # output 56 | add_w_and_b(f"encoder_layer_{layer_number}_ffn_fc_1", f"bert.encoder.layer.{layer_number}.output.dense") 57 | weight_map[f"encoder_layer_{layer_number}_post_ffn_layer_norm_bias"] = \ 58 | f"bert.encoder.layer.{layer_number}.output.LayerNorm.bias" 59 | weight_map[f"encoder_layer_{layer_number}_post_ffn_layer_norm_scale"] = \ 60 | f"bert.encoder.layer.{layer_number}.output.LayerNorm.weight" 61 | 62 | for i in range(12): 63 | add_one_encoder_layer(i) 64 | add_w_and_b('pooled_fc', 'bert.pooler.dense') 65 | return weight_map 66 | 67 | 68 | def extract_weights(args): 69 | # add ERNIR to environment 70 | print('extract weights start'.center(60, '=')) 71 | startup_prog = fluid.Program() 72 | test_prog = fluid.Program() 73 | place = fluid.CPUPlace() 74 | exe = fluid.Executor(place) 75 | exe.run(startup_prog) 76 | args.max_seq_len = 512 77 | args.use_fp16 = False 78 | args.num_labels = 2 79 | args.loss_scaling = 1.0 80 | print('model config:') 81 | ernie_config = ErnieConfig(args.ernie_config_path) 82 | ernie_config.print_config() 83 | with fluid.program_guard(test_prog, startup_prog): 84 | with fluid.unique_name.guard(): 85 | _, _ = create_model( 86 | args, 87 | pyreader_name='train', 88 | ernie_config=ernie_config) 89 | fluid.io.load_vars(exe, args.init_pretraining_params, main_program=test_prog, predicate=if_exist) 90 | state_dict = collections.OrderedDict() 91 | weight_map = build_weight_map() 92 | for ernie_name, pytorch_name in weight_map.items(): 93 | fluid_tensor = fluid.global_scope().find_var(ernie_name).get_tensor() 94 | fluid_array = np.array(fluid_tensor, dtype=np.float32) 95 | if 'w_0' in ernie_name: 96 | fluid_array = fluid_array.transpose() 97 | state_dict[pytorch_name] = fluid_array 98 | print(f'{ernie_name} -> {pytorch_name} {fluid_array.shape}') 99 | print('extract weights done!'.center(60, '=')) 100 | return state_dict 101 | 102 | 103 | def save_model(state_dict, dump_path): 104 | print('save model start'.center(60, '=')) 105 | if not os.path.exists(dump_path): 106 | os.makedirs(dump_path) 107 | # save model 108 | for key in state_dict: 109 | state_dict[key] = torch.FloatTensor(state_dict[key]) 110 | torch.save(state_dict, os.path.join(dump_path, "pytorch_model.bin")) 111 | print('finish save model') 112 | # save config 113 | ernie_config = ErnieConfig(args.ernie_config_path)._config_dict 114 | # set layer_norm_eps, more detail see: https://github.com/PaddlePaddle/LARK/issues/75 115 | ernie_config['layer_norm_eps'] = 1e-5 116 | with open(os.path.join(dump_path, "bert_config.json"), 'wt', encoding='utf-8') as f: 117 | json.dump(ernie_config, f, indent=4) 118 | print('finish save config') 119 | # save vocab.txt 120 | vocab_f = open(os.path.join(dump_path, "vocab.txt"), "wt", encoding='utf-8') 121 | with open("./LARK/ERNIE/config/vocab.txt", "rt", encoding='utf-8') as f: 122 | for line in f: 123 | data = line.strip().split("\t") 124 | vocab_f.writelines(data[0] + "\n") 125 | vocab_f.close() 126 | print('finish save vocab') 127 | print('save model done!'.center(60, '=')) 128 | 129 | 130 | if __name__ == "__main__": 131 | parser = argparse.ArgumentParser() 132 | parser.add_argument("--init_pretraining_params", default='./ERNIE_stable-1.0.1.tar/params', type=str, help=".") 133 | parser.add_argument("--ernie_config_path", default='./ERNIE_stable-1.0.1.tar/ernie_config.json', type=str, help=".") 134 | parser.add_argument("--output_dir", default='./ERNIE', type=str, help=".") 135 | args = parser.parse_args() 136 | state_dict = extract_weights(args) 137 | save_model(state_dict, args.output_dir) 138 | -------------------------------------------------------------------------------- /fine_tune_ernie.py: -------------------------------------------------------------------------------- 1 | #encoding:utf-8 2 | import torch 3 | import warnings 4 | from pyernie.train.metrics import F1Score 5 | from pyernie.train.losses import CrossEntropy 6 | from pyernie.train.trainer import Trainer 7 | from torch.utils.data import DataLoader 8 | from pyernie.io.dataset import CreateDataset 9 | from pyernie.utils.logginger import init_logger 10 | from pyernie.utils.utils import seed_everything 11 | from pyernie.callback.lrscheduler import BertLR 12 | from pyernie.model.nn.ernie_fine import ErnieFine 13 | from pyernie.io.data_transformer import DataTransformer 14 | from pyernie.train.metrics import ClassReport,Accuracy 15 | from pyernie.config.basic_config import configs as config 16 | from pyernie.callback.modelcheckpoint import ModelCheckpoint 17 | from pyernie.callback.trainingmonitor import TrainingMonitor 18 | from pyernie.model.ernie.optimization import BertAdam 19 | warnings.filterwarnings("ignore") 20 | 21 | # 主函数 22 | def main(): 23 | # **************************** 基础信息 *********************** 24 | logger = init_logger(log_name=config['arch'], log_dir=config['log_dir']) 25 | logger.info("seed is %d"%config['seed']) 26 | device = 'cuda:%d' % config['n_gpu'][0] if len(config['n_gpu']) else 'cpu' 27 | seed_everything(seed=config['seed'],device=device) 28 | logger.info('starting load data from disk') 29 | config['id_to_label'] = {v:k for k,v in config['label_to_id'].items()} 30 | target_names = [config['id_to_label'][x] for x in range(len(config['label_to_id']))] 31 | # **************************** 数据生成 *********************** 32 | # data_transformer = DataTransformer(logger = logger, 33 | # label_to_id = config['label_to_id'], 34 | # train_file = config['train_file_path'], 35 | # valid_file = config['valid_file_path'], 36 | # valid_size = config['valid_size'], 37 | # seed = config['seed'], 38 | # shuffle = True, 39 | # skip_header = False, 40 | # preprocess = None, 41 | # raw_data_path=config['raw_data_path']) 42 | # 读取数据集以及数据划分 43 | # data_transformer.read_data() 44 | # train 45 | train_dataset = CreateDataset(data_path = config['train_file_path'], 46 | vocab_path = config['vocab_path'], 47 | max_seq_len = config['max_seq_len'], 48 | seed = config['seed'], 49 | example_type = 'train') 50 | # valid 51 | valid_dataset = CreateDataset( 52 | data_path = config['valid_file_path'], 53 | vocab_path = config['vocab_path'], 54 | max_seq_len = config['max_seq_len'], 55 | seed = config['seed'], 56 | example_type = 'valid' 57 | ) 58 | #加载训练数据集 59 | train_loader = DataLoader(dataset = train_dataset, 60 | batch_size = config['batch_size'], 61 | num_workers = config['num_workers'], 62 | shuffle = True, 63 | drop_last = False, 64 | pin_memory = False) 65 | # 验证数据集 66 | valid_loader = DataLoader(dataset = valid_dataset, 67 | batch_size = config['batch_size'], 68 | num_workers = config['num_workers'], 69 | shuffle = False, 70 | drop_last = False, 71 | pin_memory = False) 72 | 73 | # **************************** 模型 *********************** 74 | logger.info("initializing model") 75 | model = ErnieFine.from_pretrained(config['ernie_model_dir'], 76 | num_classes = len(config['label_to_id'])) 77 | 78 | # ************************** 优化器 ************************* 79 | param_optimizer = list(model.named_parameters()) 80 | no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] 81 | optimizer_grouped_parameters = [ 82 | {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01}, 83 | {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} 84 | ] 85 | num_train_steps = int( 86 | len(train_dataset.examples) / config['batch_size'] / config['gradient_accumulation_steps'] * config['epochs']) 87 | # t_total: total number of training steps for the learning rate schedule 88 | # warmup: portion of t_total for the warmup 89 | optimizer = BertAdam(optimizer_grouped_parameters, 90 | lr = config['learning_rate'], 91 | warmup = config['warmup_proportion'], 92 | t_total = num_train_steps) 93 | 94 | # **************************** callbacks *********************** 95 | logger.info("initializing callbacks") 96 | # 模型保存 97 | model_checkpoint = ModelCheckpoint(checkpoint_dir = config['checkpoint_dir'], 98 | mode = config['mode'], 99 | monitor = config['monitor'], 100 | save_best_only = config['save_best_only'], 101 | arch = config['arch'], 102 | logger = logger) 103 | # 监控训练过程 104 | train_monitor = TrainingMonitor(file_dir = config['figure_dir'],arch = config['arch']) 105 | # 学习率机制 106 | lr_scheduler = BertLR(optimizer=optimizer, 107 | learning_rate = config['learning_rate'], 108 | t_total = num_train_steps, 109 | warmup = config['warmup_proportion']) 110 | 111 | # **************************** training model *********************** 112 | logger.info('training model....') 113 | train_configs = { 114 | 'model': model, 115 | 'logger': logger, 116 | 'optimizer': optimizer, 117 | 'resume': config['resume'], 118 | 'epochs': config['epochs'], 119 | 'n_gpu': config['n_gpu'], 120 | 'gradient_accumulation_steps': config['gradient_accumulation_steps'], 121 | 'epoch_metrics':[F1Score(average='macro',task_type='multiclass'), 122 | ClassReport(target_names=target_names)], 123 | 'batch_metrics':[Accuracy(topK=1)], 124 | 'criterion': CrossEntropy(), 125 | 'model_checkpoint': model_checkpoint, 126 | 'training_monitor': train_monitor, 127 | 'lr_scheduler': lr_scheduler, 128 | 'early_stopping': None, 129 | 'verbose': 1 130 | } 131 | trainer = Trainer(train_configs=train_configs) 132 | # 拟合模型 133 | trainer.train(train_data = train_loader,valid_data=valid_loader) 134 | # 释放显存 135 | if len(config['n_gpu']) > 0: 136 | torch.cuda.empty_cache() 137 | 138 | if __name__ == '__main__': 139 | main() 140 | -------------------------------------------------------------------------------- /pyernie/__init__.py: -------------------------------------------------------------------------------- 1 | #encoding:utf-8 -------------------------------------------------------------------------------- /pyernie/callback/__init__.py: -------------------------------------------------------------------------------- 1 | #encoding:utf-8 -------------------------------------------------------------------------------- /pyernie/callback/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lonePatient/ERNIE-text-classification-pytorch/01d93c594b0d102ea0a8c1d310a162d6fe5a5328/pyernie/callback/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /pyernie/callback/__pycache__/lrscheduler.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lonePatient/ERNIE-text-classification-pytorch/01d93c594b0d102ea0a8c1d310a162d6fe5a5328/pyernie/callback/__pycache__/lrscheduler.cpython-36.pyc -------------------------------------------------------------------------------- /pyernie/callback/__pycache__/modelcheckpoint.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lonePatient/ERNIE-text-classification-pytorch/01d93c594b0d102ea0a8c1d310a162d6fe5a5328/pyernie/callback/__pycache__/modelcheckpoint.cpython-36.pyc -------------------------------------------------------------------------------- /pyernie/callback/__pycache__/progressbar.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lonePatient/ERNIE-text-classification-pytorch/01d93c594b0d102ea0a8c1d310a162d6fe5a5328/pyernie/callback/__pycache__/progressbar.cpython-36.pyc -------------------------------------------------------------------------------- /pyernie/callback/__pycache__/trainingmonitor.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lonePatient/ERNIE-text-classification-pytorch/01d93c594b0d102ea0a8c1d310a162d6fe5a5328/pyernie/callback/__pycache__/trainingmonitor.cpython-36.pyc -------------------------------------------------------------------------------- /pyernie/callback/earlystopping.py: -------------------------------------------------------------------------------- 1 | #encoding:utf-8 2 | import numpy as np 3 | 4 | class EarlyStopping(object): 5 | ''' 6 | early stopping 功能 7 | # Arguments 8 | min_delta: 最小变化 9 | patience: 多少个epoch未提高,就停止训练 10 | verbose: 信息大于,默认打印信息 11 | mode: 计算模式 12 | monitor: 计算指标 13 | baseline: 基线 14 | ''' 15 | def __init__(self, 16 | min_delta = 0, 17 | patience = 10, 18 | verbose = 1, 19 | mode = 'min', 20 | monitor = 'loss', 21 | logger = None, 22 | baseline = None): 23 | 24 | self.baseline = baseline 25 | self.patience = patience 26 | self.verbose = verbose 27 | self.min_delta = min_delta 28 | self.monitor = monitor 29 | self.logger = logger 30 | 31 | assert mode in ['min','max'] 32 | 33 | if mode == 'min': 34 | self.monitor_op = np.less 35 | elif mode == 'max': 36 | self.monitor_op = np.greater 37 | if self.monitor_op == np.greater: 38 | self.min_delta *= 1 39 | else: 40 | self.min_delta *= -1 41 | self.reset() 42 | 43 | def reset(self): 44 | # Allow instances to be re-used 45 | self.wait = 0 46 | self.stop_training = False 47 | if self.baseline is not None: 48 | self.best = self.baseline 49 | else: 50 | self.best = np.Inf if self.monitor_op == np.less else -np.Inf 51 | 52 | def epoch_step(self,current): 53 | if self.monitor_op(current - self.min_delta, self.best): 54 | self.best = current 55 | self.wait = 0 56 | else: 57 | self.wait += 1 58 | if self.wait >= self.patience: 59 | if self.verbose >0: 60 | self.logger.info(f"{self.patience} epochs with no improvement after which training will be stopped") 61 | self.stop_training = True 62 | -------------------------------------------------------------------------------- /pyernie/callback/lrscheduler.py: -------------------------------------------------------------------------------- 1 | #encoding:utf-8 2 | import math 3 | import numpy as np 4 | import warnings 5 | from torch.optim.optimizer import Optimizer 6 | 7 | 8 | __all__ = ['CustomDecayLR', 9 | 'BertLR', 10 | 'CyclicLR', 11 | 'ReduceLROnPlateau', 12 | 'ReduceLRWDOnPlateau', 13 | 'CosineLRWithRestarts', 14 | ] 15 | 16 | class CustomDecayLR(object): 17 | ''' 18 | 自定义学习率变化机制 19 | Example: 20 | >>> scheduler = CustomDecayLR(optimizer) 21 | >>> for epoch in range(100): 22 | >>> scheduler.epoch_step() 23 | >>> train(...) 24 | >>> ... 25 | >>> optimizer.zero_grad() 26 | >>> loss.backward() 27 | >>> optimizer.step() 28 | >>> validate(...) 29 | ''' 30 | def __init__(self,optimizer,lr): 31 | self.optimizer = optimizer 32 | self.lr = lr 33 | 34 | def epoch_step(self,epoch): 35 | lr = self.lr 36 | if epoch > 12: 37 | lr = lr / 1000 38 | elif epoch > 8: 39 | lr = lr / 100 40 | elif epoch > 4: 41 | lr = lr / 10 42 | for param_group in self.optimizer.param_groups: 43 | param_group['lr'] = lr 44 | 45 | class BertLR(object): 46 | ''' 47 | Bert模型内定的学习率变化机制 48 | Example: 49 | >>> scheduler = BertLR(optimizer) 50 | >>> for epoch in range(100): 51 | >>> scheduler.step() 52 | >>> train(...) 53 | >>> ... 54 | >>> optimizer.zero_grad() 55 | >>> loss.backward() 56 | >>> optimizer.step() 57 | >>> scheduler.batch_step() 58 | >>> validate(...) 59 | ''' 60 | def __init__(self,optimizer,learning_rate,t_total,warmup): 61 | self.learning_rate = learning_rate 62 | self.optimizer = optimizer 63 | self.t_total = t_total 64 | self.warmup = warmup 65 | 66 | # 线性预热方式 67 | def warmup_linear(self,x, warmup=0.002): 68 | if x < warmup: 69 | return x / warmup 70 | return 1.0 - x 71 | 72 | def batch_step(self,training_step): 73 | lr_this_step = self.learning_rate * self.warmup_linear(training_step / self.t_total,self.warmup) 74 | for param_group in self.optimizer.param_groups: 75 | param_group['lr'] = lr_this_step 76 | 77 | class CyclicLR(object): 78 | ''' 79 | Cyclical learning rates for training neural networks 80 | Example: 81 | >>> scheduler = CyclicLR(optimizer) 82 | >>> for epoch in range(100): 83 | >>> scheduler.step() 84 | >>> train(...) 85 | >>> ... 86 | >>> optimizer.zero_grad() 87 | >>> loss.backward() 88 | >>> optimizer.step() 89 | >>> scheduler.batch_step() 90 | >>> validate(...) 91 | ''' 92 | def __init__(self, optimizer, base_lr=1e-3, max_lr=6e-3, 93 | step_size=2000, mode='triangular', gamma=1., 94 | scale_fn=None, scale_mode='cycle', last_batch_iteration=-1): 95 | 96 | if not isinstance(optimizer, Optimizer): 97 | raise TypeError('{} is not an Optimizer'.format( 98 | type(optimizer).__name__)) 99 | 100 | self.optimizer = optimizer 101 | 102 | if isinstance(base_lr, list) or isinstance(base_lr, tuple): 103 | if len(base_lr) != len(optimizer.param_groups): 104 | raise ValueError("expected {} base_lr, got {}".format( 105 | len(optimizer.param_groups), len(base_lr))) 106 | self.base_lrs = list(base_lr) 107 | else: 108 | self.base_lrs = [base_lr] * len(optimizer.param_groups) 109 | 110 | if isinstance(max_lr, list) or isinstance(max_lr, tuple): 111 | if len(max_lr) != len(optimizer.param_groups): 112 | raise ValueError("expected {} max_lr, got {}".format( 113 | len(optimizer.param_groups), len(max_lr))) 114 | self.max_lrs = list(max_lr) 115 | else: 116 | self.max_lrs = [max_lr] * len(optimizer.param_groups) 117 | 118 | self.step_size = step_size 119 | 120 | if mode not in ['triangular', 'triangular2', 'exp_range'] \ 121 | and scale_fn is None: 122 | raise ValueError('mode is invalid and scale_fn is None') 123 | 124 | self.mode = mode 125 | self.gamma = gamma 126 | 127 | if scale_fn is None: 128 | if self.mode == 'triangular': 129 | self.scale_fn = self._triangular_scale_fn 130 | self.scale_mode = 'cycle' 131 | elif self.mode == 'triangular2': 132 | self.scale_fn = self._triangular2_scale_fn 133 | self.scale_mode = 'cycle' 134 | elif self.mode == 'exp_range': 135 | self.scale_fn = self._exp_range_scale_fn 136 | self.scale_mode = 'iterations' 137 | else: 138 | self.scale_fn = scale_fn 139 | self.scale_mode = scale_mode 140 | 141 | self.batch_step(last_batch_iteration + 1) 142 | self.last_batch_iteration = last_batch_iteration 143 | 144 | def _triangular_scale_fn(self, x): 145 | return 1. 146 | 147 | def _triangular2_scale_fn(self, x): 148 | return 1 / (2. ** (x - 1)) 149 | 150 | def _exp_range_scale_fn(self, x): 151 | return self.gamma**(x) 152 | 153 | def get_lr(self): 154 | step_size = float(self.step_size) 155 | cycle = np.floor(1 + self.last_batch_iteration / (2 * step_size)) 156 | x = np.abs(self.last_batch_iteration / step_size - 2 * cycle + 1) 157 | 158 | lrs = [] 159 | param_lrs = zip(self.optimizer.param_groups, self.base_lrs, self.max_lrs) 160 | for param_group, base_lr, max_lr in param_lrs: 161 | base_height = (max_lr - base_lr) * np.maximum(0, (1 - x)) 162 | if self.scale_mode == 'cycle': 163 | lr = base_lr + base_height * self.scale_fn(cycle) 164 | else: 165 | lr = base_lr + base_height * self.scale_fn(self.last_batch_iteration) 166 | lrs.append(lr) 167 | return lrs 168 | 169 | def batch_step(self, batch_iteration=None): 170 | if batch_iteration is None: 171 | batch_iteration = self.last_batch_iteration + 1 172 | self.last_batch_iteration = batch_iteration 173 | for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()): 174 | param_group['lr'] = lr 175 | 176 | class ReduceLROnPlateau(object): 177 | """Reduce learning rate when a metric has stopped improving. 178 | Models often benefit from reducing the learning rate by a factor 179 | of 2-10 once learning stagnates. This scheduler reads a metrics 180 | quantity and if no improvement is seen for a 'patience' number 181 | of epochs, the learning rate is reduced. 182 | 183 | Args: 184 | factor: factor by which the learning rate will 185 | be reduced. new_lr = lr * factor 186 | patience: number of epochs with no improvement 187 | after which learning rate will be reduced. 188 | verbose: int. 0: quiet, 1: update messages. 189 | mode: one of {min, max}. In `min` mode, 190 | lr will be reduced when the quantity 191 | monitored has stopped decreasing; in `max` 192 | mode it will be reduced when the quantity 193 | monitored has stopped increasing. 194 | epsilon: threshold for measuring the new optimum, 195 | to only focus on significant changes. 196 | cooldown: number of epochs to wait before resuming 197 | normal operation after lr has been reduced. 198 | min_lr: lower bound on the learning rate. 199 | 200 | 201 | Example: 202 | >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9) 203 | >>> scheduler = ReduceLROnPlateau(optimizer, 'min') 204 | >>> for epoch in range(10): 205 | >>> train(...) 206 | >>> val_acc, val_loss = validate(...) 207 | >>> scheduler.epoch_step(val_loss, epoch) 208 | """ 209 | 210 | def __init__(self, optimizer, mode='min', factor=0.1, patience=10, 211 | verbose=0, epsilon=1e-4, cooldown=0, min_lr=0,eps=1e-8): 212 | 213 | super(ReduceLROnPlateau, self).__init__() 214 | assert isinstance(optimizer, Optimizer) 215 | if factor >= 1.0: 216 | raise ValueError('ReduceLROnPlateau ' 217 | 'does not support a factor >= 1.0.') 218 | self.factor = factor 219 | self.min_lr = min_lr 220 | self.epsilon = epsilon 221 | self.patience = patience 222 | self.verbose = verbose 223 | self.cooldown = cooldown 224 | self.cooldown_counter = 0 # Cooldown counter. 225 | self.monitor_op = None 226 | self.wait = 0 227 | self.best = 0 228 | self.mode = mode 229 | self.optimizer = optimizer 230 | self.eps = eps 231 | self._reset() 232 | 233 | def _reset(self): 234 | """Resets wait counter and cooldown counter. 235 | """ 236 | if self.mode not in ['min', 'max']: 237 | raise RuntimeError('Learning Rate Plateau Reducing mode %s is unknown!') 238 | if self.mode == 'min': 239 | self.monitor_op = lambda a, b: np.less(a, b - self.epsilon) 240 | self.best = np.Inf 241 | else: 242 | self.monitor_op = lambda a, b: np.greater(a, b + self.epsilon) 243 | self.best = -np.Inf 244 | self.cooldown_counter = 0 245 | self.wait = 0 246 | 247 | def reset(self): 248 | self._reset() 249 | 250 | def epoch_step(self, metrics, epoch): 251 | current = metrics 252 | if current is None: 253 | warnings.warn('Learning Rate Plateau Reducing requires metrics available!', RuntimeWarning) 254 | else: 255 | if self.in_cooldown(): 256 | self.cooldown_counter -= 1 257 | self.wait = 0 258 | 259 | if self.monitor_op(current, self.best): 260 | self.best = current 261 | self.wait = 0 262 | elif not self.in_cooldown(): 263 | if self.wait >= self.patience: 264 | for param_group in self.optimizer.param_groups: 265 | old_lr = float(param_group['lr']) 266 | if old_lr > self.min_lr + self.eps: 267 | new_lr = old_lr * self.factor 268 | new_lr = max(new_lr, self.min_lr) 269 | param_group['lr'] = new_lr 270 | if self.verbose > 0: 271 | print('\nEpoch %05d: reducing learning rate to %s.' % (epoch, new_lr)) 272 | self.cooldown_counter = self.cooldown 273 | self.wait = 0 274 | self.wait += 1 275 | 276 | def in_cooldown(self): 277 | return self.cooldown_counter > 0 278 | 279 | class ReduceLRWDOnPlateau(ReduceLROnPlateau): 280 | """Reduce learning rate and weight decay when a metric has stopped 281 | improving. Models often benefit from reducing the learning rate by 282 | a factor of 2-10 once learning stagnates. This scheduler reads a metric 283 | quantity and if no improvement is seen for a 'patience' number 284 | of epochs, the learning rate and weight decay factor is reduced for 285 | optimizers that implement the the weight decay method from the paper 286 | `Fixing Weight Decay Regularization in Adam`_. 287 | 288 | .. _Fixing Weight Decay Regularization in Adam: 289 | https://arxiv.org/abs/1711.05101 290 | for AdamW or SGDW 291 | Example: 292 | >>> optimizer = AdamW(model.parameters(), lr=0.1, weight_decay=1e-3) 293 | >>> scheduler = ReduceLRWDOnPlateau(optimizer, 'min') 294 | >>> for epoch in range(10): 295 | >>> train(...) 296 | >>> val_loss = validate(...) 297 | >>> # Note that step should be called after validate() 298 | >>> scheduler.epoch_step(val_loss) 299 | """ 300 | def epoch_step(self, metrics, epoch): 301 | current = metrics 302 | if current is None: 303 | warnings.warn('Learning Rate Plateau Reducing requires metrics available!', RuntimeWarning) 304 | else: 305 | if self.in_cooldown(): 306 | self.cooldown_counter -= 1 307 | self.wait = 0 308 | 309 | if self.monitor_op(current, self.best): 310 | self.best = current 311 | self.wait = 0 312 | elif not self.in_cooldown(): 313 | if self.wait >= self.patience: 314 | for param_group in self.optimizer.param_groups: 315 | old_lr = float(param_group['lr']) 316 | if old_lr > self.min_lr + self.eps: 317 | new_lr = old_lr * self.factor 318 | new_lr = max(new_lr, self.min_lr) 319 | param_group['lr'] = new_lr 320 | if self.verbose > 0: 321 | print('\nEpoch %d: reducing learning rate to %s.' % (epoch, new_lr)) 322 | if param_group['weight_decay'] != 0: 323 | old_weight_decay = float(param_group['weight_decay']) 324 | new_weight_decay = max(old_weight_decay * self.factor, self.min_lr) 325 | if old_weight_decay > new_weight_decay + self.eps: 326 | param_group['weight_decay'] = new_weight_decay 327 | if self.verbose: 328 | print('\nEpoch {epoch}: reducing weight decay factor of group {i} to {new_weight_decay:.4e}.') 329 | self.cooldown_counter = self.cooldown 330 | self.wait = 0 331 | self.wait += 1 332 | 333 | class CosineLRWithRestarts(object): 334 | """Decays learning rate with cosine annealing, normalizes weight decay 335 | hyperparameter value, implements restarts. 336 | https://arxiv.org/abs/1711.05101 337 | 338 | Args: 339 | optimizer (Optimizer): Wrapped optimizer. 340 | batch_size: minibatch size 341 | epoch_size: training samples per epoch 342 | restart_period: epoch count in the first restart period 343 | t_mult: multiplication factor by which the next restart period will extend/shrink 344 | 345 | Example: 346 | >>> scheduler = CosineLRWithRestarts(optimizer, 32, 1024, restart_period=5, t_mult=1.2) 347 | >>> for epoch in range(100): 348 | >>> scheduler.step() 349 | >>> train(...) 350 | >>> ... 351 | >>> optimizer.zero_grad() 352 | >>> loss.backward() 353 | >>> optimizer.step() 354 | >>> scheduler.batch_step() 355 | >>> validate(...) 356 | """ 357 | 358 | def __init__(self, optimizer, batch_size, epoch_size, restart_period=100, 359 | t_mult=2, last_epoch=-1, eta_threshold=1000, verbose=False): 360 | if not isinstance(optimizer, Optimizer): 361 | raise TypeError('{} is not an Optimizer'.format( 362 | type(optimizer).__name__)) 363 | self.optimizer = optimizer 364 | if last_epoch == -1: 365 | for group in optimizer.param_groups: 366 | group.setdefault('initial_lr', group['lr']) 367 | else: 368 | for i, group in enumerate(optimizer.param_groups): 369 | if 'initial_lr' not in group: 370 | raise KeyError("param 'initial_lr' is not specified " 371 | "in param_groups[{}] when resuming an" 372 | " optimizer".format(i)) 373 | self.base_lrs = list(map(lambda group: group['initial_lr'], 374 | optimizer.param_groups)) 375 | 376 | self.last_epoch = last_epoch 377 | self.batch_size = batch_size 378 | self.iteration = 0 379 | self.epoch_size = epoch_size 380 | self.eta_threshold = eta_threshold 381 | self.t_mult = t_mult 382 | self.verbose = verbose 383 | self.base_weight_decays = list(map(lambda group: group['weight_decay'], 384 | optimizer.param_groups)) 385 | self.restart_period = restart_period 386 | self.restarts = 0 387 | self.t_epoch = -1 388 | self.batch_increments = [] 389 | self._set_batch_increment() 390 | 391 | def _schedule_eta(self): 392 | """ 393 | Threshold value could be adjusted to shrink eta_min and eta_max values. 394 | """ 395 | eta_min = 0 396 | eta_max = 1 397 | if self.restarts <= self.eta_threshold: 398 | return eta_min, eta_max 399 | else: 400 | d = self.restarts - self.eta_threshold 401 | k = d * 0.09 402 | return (eta_min + k, eta_max - k) 403 | 404 | def get_lr(self, t_cur): 405 | eta_min, eta_max = self._schedule_eta() 406 | 407 | eta_t = (eta_min + 0.5 * (eta_max - eta_min) 408 | * (1. + math.cos(math.pi * 409 | (t_cur / self.restart_period)))) 410 | 411 | weight_decay_norm_multi = math.sqrt(self.batch_size / 412 | (self.epoch_size * 413 | self.restart_period)) 414 | lrs = [base_lr * eta_t for base_lr in self.base_lrs] 415 | weight_decays = [base_weight_decay * eta_t * weight_decay_norm_multi 416 | for base_weight_decay in self.base_weight_decays] 417 | 418 | if self.t_epoch % self.restart_period < self.t_epoch: 419 | if self.verbose: 420 | print("Restart at epoch {}".format(self.last_epoch)) 421 | self.restart_period *= self.t_mult 422 | self.restarts += 1 423 | self.t_epoch = 0 424 | 425 | return zip(lrs, weight_decays) 426 | 427 | def _set_batch_increment(self): 428 | d, r = divmod(self.epoch_size, self.batch_size) 429 | batches_in_epoch = d + 2 if r > 0 else d + 1 430 | self.iteration = 0 431 | self.batch_increments = list(np.linspace(0, 1, batches_in_epoch)) 432 | 433 | def batch_step(self): 434 | self.last_epoch += 1 435 | self.t_epoch += 1 436 | self._set_batch_increment() 437 | try: 438 | t_cur = self.t_epoch + self.batch_increments[self.iteration] 439 | self.iteration += 1 440 | except (IndexError): 441 | raise RuntimeError("Epoch size and batch size used in the " 442 | "training loop and while initializing " 443 | "scheduler should be the same.") 444 | 445 | for param_group, (lr, weight_decay) in zip(self.optimizer.param_groups,self.get_lr(t_cur)): 446 | param_group['lr'] = lr 447 | param_group['weight_decay'] = weight_decay 448 | -------------------------------------------------------------------------------- /pyernie/callback/modelcheckpoint.py: -------------------------------------------------------------------------------- 1 | #encoding:utf-8 2 | import os 3 | from pathlib import Path 4 | import numpy as np 5 | import torch 6 | 7 | class ModelCheckpoint(object): 8 | ''' 9 | 模型保存,两种模式: 10 | 1. 直接保存最好模型 11 | 2. 按照epoch频率保存模型 12 | ''' 13 | def __init__(self, checkpoint_dir,monitor,logger, 14 | arch,mode='min',epoch_freq=1,best = None, 15 | save_best_only = True, 16 | ): 17 | if isinstance(checkpoint_dir,Path): 18 | self.base_path = checkpoint_dir 19 | else: 20 | self.base_path = Path(checkpoint_dir) 21 | 22 | self.arch = arch 23 | self.logger = logger 24 | self.monitor = monitor 25 | self.epoch_freq = epoch_freq 26 | self.save_best_only = save_best_only 27 | 28 | # 计算模式 29 | if mode == 'min': 30 | self.monitor_op = np.less 31 | self.best = np.Inf 32 | 33 | elif mode == 'max': 34 | self.monitor_op = np.greater 35 | self.best = -np.Inf 36 | # 这里主要重新加载模型时候 37 | #对best重新赋值 38 | if best: 39 | self.best = best 40 | 41 | if save_best_only: 42 | self.model_name = f"best_{arch}_model.pth" 43 | 44 | def epoch_step(self, state,current): 45 | # 是否保存最好模型 46 | if self.save_best_only: 47 | if self.monitor_op(current, self.best): 48 | self.logger.info(f"\nEpoch {state['epoch']}: {self.monitor} improved from {self.best:.5f} to {current:.5f}") 49 | self.best = current 50 | state['best'] = self.best 51 | best_path = self.base_path/ self.model_name 52 | torch.save(state, str(best_path)) 53 | # 每隔几个epoch保存下模型 54 | else: 55 | filename = self.base_path / f"epoch_{state['epoch']}_{state[self.monitor]}_{self.arch}_model.pth" 56 | if state['epoch'] % self.epoch_freq == 0: 57 | self.logger.info("\nEpoch %d: save model to disk."%(state['epoch'])) 58 | torch.save(state, str(filename)) 59 | -------------------------------------------------------------------------------- /pyernie/callback/optimizater.py: -------------------------------------------------------------------------------- 1 | #encofing:utf-8 2 | import math 3 | import torch 4 | from torch.optim.optimizer import Optimizer 5 | 6 | __call__ = ['SGDW','AdamW','AdaBound'] 7 | 8 | 9 | class SGDW(Optimizer): 10 | r"""Implements stochastic gradient descent (optionally with momentum) with 11 | weight decay from the paper `Fixing Weight Decay Regularization in Adam`_. 12 | 13 | Nesterov momentum is based on the formula from 14 | `On the importance of initialization and momentum in deep learning`__. 15 | 16 | Args: 17 | params (iterable): iterable of parameters to optimize or dicts defining 18 | parameter groups 19 | lr (float): learning rate 20 | momentum (float, optional): momentum factor (default: 0) 21 | weight_decay (float, optional): weight decay factor (default: 0) 22 | dampening (float, optional): dampening for momentum (default: 0) 23 | nesterov (bool, optional): enables Nesterov momentum (default: False) 24 | 25 | .. _Fixing Weight Decay Regularization in Adam: 26 | https://arxiv.org/abs/1711.05101 27 | 28 | Example: 29 | >>> model = resnet() 30 | >>> optimizer = SGDW(model.parameters(), lr=0.1, momentum=0.9,weight_decay=1e-5) 31 | """ 32 | 33 | def __init__(self, params, lr=0.1, momentum=0, dampening=0, 34 | weight_decay=0, nesterov=False): 35 | if lr < 0.0: 36 | raise ValueError("Invalid learning rate: {}".format(lr)) 37 | if momentum < 0.0: 38 | raise ValueError("Invalid momentum value: {}".format(momentum)) 39 | if weight_decay < 0.0: 40 | raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) 41 | 42 | defaults = dict(lr=lr, momentum=momentum, dampening=dampening, 43 | weight_decay=weight_decay, nesterov=nesterov) 44 | if nesterov and (momentum <= 0 or dampening != 0): 45 | raise ValueError("Nesterov momentum requires a momentum and zero dampening") 46 | super(SGDW, self).__init__(params, defaults) 47 | 48 | def __setstate__(self, state): 49 | super(SGDW, self).__setstate__(state) 50 | for group in self.param_groups: 51 | group.setdefault('nesterov', False) 52 | 53 | def step(self, closure=None): 54 | """Performs a single optimization step. 55 | 56 | Arguments: 57 | closure (callable, optional): A closure that reevaluates the model 58 | and returns the loss. 59 | """ 60 | loss = None 61 | if closure is not None: 62 | loss = closure() 63 | 64 | for group in self.param_groups: 65 | weight_decay = group['weight_decay'] 66 | momentum = group['momentum'] 67 | dampening = group['dampening'] 68 | nesterov = group['nesterov'] 69 | 70 | for p in group['params']: 71 | if p.grad is None: 72 | continue 73 | d_p = p.grad.data 74 | 75 | if momentum != 0: 76 | param_state = self.state[p] 77 | if 'momentum_buffer' not in param_state: 78 | buf = param_state['momentum_buffer'] = torch.zeros_like(p.data) 79 | buf.mul_(momentum).add_(d_p) 80 | else: 81 | buf = param_state['momentum_buffer'] 82 | buf.mul_(momentum).add_(1 - dampening, d_p) 83 | if nesterov: 84 | d_p = d_p.add(momentum, buf) 85 | else: 86 | d_p = buf 87 | 88 | if weight_decay != 0: 89 | p.data.add_(-weight_decay, p.data) 90 | 91 | p.data.add_(-group['lr'], d_p) 92 | 93 | return loss 94 | 95 | 96 | class AdamW(Optimizer): 97 | """Implements Adam algorithm. 98 | 99 | Arguments: 100 | params (iterable): iterable of parameters to optimize or dicts defining 101 | parameter groups 102 | lr (float, optional): learning rate (default: 1e-3) 103 | betas (Tuple[float, float], optional): coefficients used for computing 104 | running averages of gradient and its square (default: (0.9, 0.999)) 105 | eps (float, optional): term added to the denominator to improve 106 | numerical stability (default: 1e-8) 107 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0) 108 | amsgrad (boolean, optional): whether to use the AMSGrad variant of this 109 | algorithm from the paper `On the Convergence of Adam and Beyond`_ 110 | 111 | Example: 112 | >>> model = resnet() 113 | >>> optimizer = AdamW(model.parameters(), lr=1e-3, weight_decay=1e-5) 114 | """ 115 | 116 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, 117 | weight_decay=0, amsgrad=False): 118 | if lr < 0.0: 119 | raise ValueError("Invalid learning rate: {}".format(lr)) 120 | if not 0.0 <= betas[0] < 1.0: 121 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 122 | if not 0.0 <= betas[1] < 1.0: 123 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 124 | defaults = dict(lr=lr, betas=betas, eps=eps,weight_decay=weight_decay, amsgrad=amsgrad) 125 | #super(AdamW, self).__init__(params, defaults) 126 | super().__init__(params, defaults) 127 | 128 | def step(self, closure=None): 129 | """Performs a single optimization step. 130 | 131 | Arguments: 132 | closure (callable, optional): A closure that reevaluates the model 133 | and returns the loss. 134 | """ 135 | loss = None 136 | if closure is not None: 137 | loss = closure() 138 | 139 | for group in self.param_groups: 140 | for p in group['params']: 141 | if p.grad is None: 142 | continue 143 | grad = p.grad.data 144 | if grad.is_sparse: 145 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') 146 | amsgrad = group['amsgrad'] 147 | 148 | state = self.state[p] 149 | 150 | # State initialization 151 | if len(state) == 0: 152 | state['step'] = 0 153 | # Exponential moving average of gradient values 154 | state['exp_avg'] = torch.zeros_like(p.data) 155 | # Exponential moving average of squared gradient values 156 | state['exp_avg_sq'] = torch.zeros_like(p.data) 157 | if amsgrad: 158 | # Maintains max of all exp. moving avg. of sq. grad. values 159 | state['max_exp_avg_sq'] = torch.zeros_like(p.data) 160 | 161 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 162 | if amsgrad: 163 | max_exp_avg_sq = state['max_exp_avg_sq'] 164 | beta1, beta2 = group['betas'] 165 | 166 | state['step'] += 1 167 | 168 | # Decay the first and second moment running average coefficient 169 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 170 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 171 | if amsgrad: 172 | # Maintains the maximum of all 2nd moment running avg. till now 173 | torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq) 174 | # Use the max. for normalizing running avg. of gradient 175 | denom = max_exp_avg_sq.sqrt().add_(group['eps']) 176 | else: 177 | denom = exp_avg_sq.sqrt().add_(group['eps']) 178 | 179 | bias_correction1 = 1 - beta1 ** state['step'] 180 | bias_correction2 = 1 - beta2 ** state['step'] 181 | step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1 182 | 183 | if group['weight_decay'] != 0: 184 | decayed_weights = torch.mul(p.data, group['weight_decay']) 185 | p.data.addcdiv_(-step_size, exp_avg, denom) 186 | p.data.sub_(decayed_weights) 187 | else: 188 | p.data.addcdiv_(-step_size, exp_avg, denom) 189 | 190 | return loss 191 | 192 | class AdaBound(Optimizer): 193 | """Implements AdaBound algorithm. 194 | It has been proposed in `Adaptive Gradient Methods with Dynamic Bound of Learning Rate`_. 195 | Arguments: 196 | params (iterable): iterable of parameters to optimize or dicts defining 197 | parameter groups 198 | lr (float, optional): Adam learning rate (default: 1e-3) 199 | betas (Tuple[float, float], optional): coefficients used for computing 200 | running averages of gradient and its square (default: (0.9, 0.999)) 201 | final_lr (float, optional): final (SGD) learning rate (default: 0.1) 202 | gamma (float, optional): convergence speed of the bound functions (default: 1e-3) 203 | eps (float, optional): term added to the denominator to improve 204 | numerical stability (default: 1e-8) 205 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0) 206 | amsbound (boolean, optional): whether to use the AMSBound variant of this algorithm 207 | .. Adaptive Gradient Methods with Dynamic Bound of Learning Rate: 208 | https://openreview.net/forum?id=Bkg3g2R9FX 209 | """ 210 | 211 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), final_lr=0.1, gamma=1e-3, 212 | eps=1e-8, weight_decay=0, amsbound=False): 213 | if not 0.0 <= lr: 214 | raise ValueError("Invalid learning rate: {}".format(lr)) 215 | if not 0.0 <= eps: 216 | raise ValueError("Invalid epsilon value: {}".format(eps)) 217 | if not 0.0 <= betas[0] < 1.0: 218 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 219 | if not 0.0 <= betas[1] < 1.0: 220 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 221 | if not 0.0 <= final_lr: 222 | raise ValueError("Invalid final learning rate: {}".format(final_lr)) 223 | if not 0.0 <= gamma < 1.0: 224 | raise ValueError("Invalid gamma parameter: {}".format(gamma)) 225 | defaults = dict(lr=lr, betas=betas, final_lr=final_lr, gamma=gamma, eps=eps, 226 | weight_decay=weight_decay, amsbound=amsbound) 227 | super(AdaBound, self).__init__(params, defaults) 228 | 229 | self.base_lrs = list(map(lambda group: group['lr'], self.param_groups)) 230 | 231 | def __setstate__(self, state): 232 | super(AdaBound, self).__setstate__(state) 233 | for group in self.param_groups: 234 | group.setdefault('amsbound', False) 235 | 236 | def step(self, closure=None): 237 | """Performs a single optimization step. 238 | Arguments: 239 | closure (callable, optional): A closure that reevaluates the model 240 | and returns the loss. 241 | Examples: 242 | >>> model = resnet() 243 | >>> optimizer = adabound.AdaBound(model.parameters(), lr=1e-3, final_lr=0.1) 244 | """ 245 | loss = None 246 | if closure is not None: 247 | loss = closure() 248 | 249 | for group, base_lr in zip(self.param_groups, self.base_lrs): 250 | for p in group['params']: 251 | if p.grad is None: 252 | continue 253 | grad = p.grad.data 254 | if grad.is_sparse: 255 | raise RuntimeError( 256 | 'Adam does not support sparse gradients, please consider SparseAdam instead') 257 | amsbound = group['amsbound'] 258 | 259 | state = self.state[p] 260 | 261 | # State initialization 262 | if len(state) == 0: 263 | state['step'] = 0 264 | # Exponential moving average of gradient values 265 | state['exp_avg'] = torch.zeros_like(p.data) 266 | # Exponential moving average of squared gradient values 267 | state['exp_avg_sq'] = torch.zeros_like(p.data) 268 | if amsbound: 269 | # Maintains max of all exp. moving avg. of sq. grad. values 270 | state['max_exp_avg_sq'] = torch.zeros_like(p.data) 271 | 272 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 273 | if amsbound: 274 | max_exp_avg_sq = state['max_exp_avg_sq'] 275 | beta1, beta2 = group['betas'] 276 | 277 | state['step'] += 1 278 | 279 | if group['weight_decay'] != 0: 280 | grad = grad.add(group['weight_decay'], p.data) 281 | 282 | # Decay the first and second moment running average coefficient 283 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 284 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 285 | if amsbound: 286 | # Maintains the maximum of all 2nd moment running avg. till now 287 | torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq) 288 | # Use the max. for normalizing running avg. of gradient 289 | denom = max_exp_avg_sq.sqrt().add_(group['eps']) 290 | else: 291 | denom = exp_avg_sq.sqrt().add_(group['eps']) 292 | 293 | bias_correction1 = 1 - beta1 ** state['step'] 294 | bias_correction2 = 1 - beta2 ** state['step'] 295 | step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1 296 | 297 | # Applies bounds on actual learning rate 298 | # lr_scheduler cannot affect final_lr, this is a workaround to apply lr decay 299 | final_lr = group['final_lr'] * group['lr'] / base_lr 300 | lower_bound = final_lr * (1 - 1 / (group['gamma'] * state['step'] + 1)) 301 | upper_bound = final_lr * (1 + 1 / (group['gamma'] * state['step'])) 302 | step_size = torch.full_like(denom, step_size) 303 | step_size.div_(denom).clamp_(lower_bound, upper_bound).mul_(exp_avg) 304 | 305 | p.data.add_(-step_size) 306 | 307 | return loss -------------------------------------------------------------------------------- /pyernie/callback/progressbar.py: -------------------------------------------------------------------------------- 1 | #encoding:Utf-8 2 | class ProgressBar(object): 3 | def __init__(self,n_batch, 4 | width=30): 5 | self.width = width 6 | self.n_batch = n_batch 7 | def batch_step(self,batch_idx,info,use_time): 8 | recv_per = int(100 * (batch_idx + 1) / self.n_batch) 9 | if recv_per >= 100: 10 | recv_per = 100 11 | # 进度条模式 12 | show_bar = f"\r[{int(self.width * recv_per / 100) * '>':<{self.width}s}]{recv_per}%" 13 | # 打印信息 14 | show_info = f'\r[training] {batch_idx+1}/{self.n_batch} {show_bar} -{use_time:.1f}s/step '+\ 15 | "-".join([f' {key}: {value:.4f} ' for key,value in info.items()]) 16 | print(show_info,end='') 17 | 18 | 19 | -------------------------------------------------------------------------------- /pyernie/callback/trainingmonitor.py: -------------------------------------------------------------------------------- 1 | #encoding:utf-8 2 | import json 3 | import numpy as np 4 | from pathlib import Path 5 | import matplotlib.pyplot as plt 6 | plt.switch_backend('agg') # 防止ssh上绘图问题 7 | 8 | class TrainingMonitor(): 9 | def __init__(self, file_dir,arch,start_at=0): 10 | ''' 11 | :param startAt: 重新开始训练的epoch点 12 | ''' 13 | if isinstance(file_dir,Path): 14 | pass 15 | else: 16 | file_dir = Path(file_dir) 17 | file_dir.mkdir(parents=True, exist_ok=True) 18 | 19 | self.arch = arch 20 | self.file_dir = file_dir 21 | self.start_at = start_at 22 | self.H = {} 23 | self.json_path = file_dir / (arch+"_training_monitor.json") 24 | self.reset() 25 | 26 | def reset(self): 27 | if self.start_at > 0: 28 | # 如果jsonPath文件存在,咋加载历史训练数据 29 | if self.json_path is not None: 30 | if self.json_path.exists(): 31 | self.H = json.loads(open(str(self.json_path)).read()) 32 | for k in self.H.keys(): 33 | self.H[k] = self.H[k][:self.start_at] 34 | 35 | def epoch_step(self,logs={}): 36 | for (k, v) in logs.items(): 37 | l = self.H.get(k, []) 38 | # np.float32会报错 39 | if not isinstance(v,np.float): 40 | v = round(float(v),4) 41 | l.append(v) 42 | self.H[k] = l 43 | 44 | # 写入文件 45 | if self.json_path is not None: 46 | f = open(str(self.json_path), "w") 47 | f.write(json.dumps(self.H)) 48 | f.close() 49 | 50 | #保存train图像 51 | if len(self.H["loss"]) == 1: 52 | self.paths = {key: self.file_dir / (self.arch + f'_{key}') for key in self.H.keys()} 53 | if len(self.H["loss"]) > 1: 54 | # 指标变化曲线 55 | # 需要成对出现 56 | keys = [key for key,_ in self.H.items() if '_' not in key] 57 | for key in keys: 58 | N = np.arange(0, len(self.H[key])) 59 | plt.style.use("ggplot") 60 | plt.figure() 61 | plt.plot(N, self.H[key],label=f"train_{key}") 62 | plt.plot(N, self.H[f"valid_{key}"],label=f"valid_{key}") 63 | plt.legend() 64 | plt.xlabel("Epoch #") 65 | plt.ylabel(key) 66 | plt.title(f"Training {key} [Epoch {len(self.H[key])}]") 67 | plt.savefig(str(self.paths[key])) 68 | plt.close() 69 | 70 | 71 | -------------------------------------------------------------------------------- /pyernie/callback/writetensorboard.py: -------------------------------------------------------------------------------- 1 | #encoding:utf-8 2 | import importlib 3 | import warnings 4 | from pathlib import Path 5 | 6 | class WriterTensorboardX(): 7 | def __init__(self, writer_dir, logger, enable): 8 | self.writer = None 9 | 10 | if not isinstance(writer_dir,Path): 11 | writer_dir = Path(writer_dir) 12 | 13 | if enable: 14 | log_path = writer_dir 15 | try: 16 | self.writer = importlib.import_module('tensorboardX').SummaryWriter(log_path) 17 | except ModuleNotFoundError: 18 | message = """TensorboardX visualization is configured to use, but currently not installed on this machine. Please install the package by 'pip install tensorboardx' command or turn off the option in the 'configs.json' file.""" 19 | warnings.warn(message, UserWarning) 20 | logger.warn() 21 | self.step = 0 22 | self.mode = '' 23 | 24 | self.tensorboard_writer_ftns = ['add_scalar', 'add_scalars', 'add_image', 'add_audio', 'add_text', 'add_histogram', 'add_pr_curve', 'add_embedding'] 25 | 26 | def set_step(self, step, mode='train'): 27 | self.mode = mode 28 | self.step = step 29 | 30 | def __getattr__(self, name): 31 | """ 32 | If visualization is configured to use: 33 | return add_data() methods of tensorboard with additional information (step, tag) added. 34 | Otherwise: 35 | return blank function handle that does nothing 36 | """ 37 | if name in self.tensorboard_writer_ftns: 38 | add_data = getattr(self.writer, name, None) 39 | def wrapper(tag, data, *args, **kwargs): 40 | if add_data is not None: 41 | add_data(f'{self.mode}/{tag}', data, self.step, *args, **kwargs) 42 | return wrapper 43 | else: 44 | # default action for returning methods defined in this class, set_step() for instance. 45 | try: 46 | attr = object.__getattr__(name) 47 | except AttributeError: 48 | raise AttributeError(f"type object 'WriterTensorboardX' has no attribute '{name}'") 49 | return attr -------------------------------------------------------------------------------- /pyernie/config/__init__.py: -------------------------------------------------------------------------------- 1 | #encoding:utf-8 -------------------------------------------------------------------------------- /pyernie/config/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lonePatient/ERNIE-text-classification-pytorch/01d93c594b0d102ea0a8c1d310a162d6fe5a5328/pyernie/config/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /pyernie/config/__pycache__/basic_config.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lonePatient/ERNIE-text-classification-pytorch/01d93c594b0d102ea0a8c1d310a162d6fe5a5328/pyernie/config/__pycache__/basic_config.cpython-36.pyc -------------------------------------------------------------------------------- /pyernie/config/basic_config.py: -------------------------------------------------------------------------------- 1 | #encoding:utf-8 2 | from os import path 3 | import multiprocessing 4 | 5 | """Note: 6 | pytorch BERT 模型包含三个文件:模型、vocab.txt, bert_config.json, 有两种加载方式: 7 | (1)在线下载。这种方式下,模型和vocab会通过url的方式下载,只需将bert_model设置为 "bert_model=bert-base-chinese" 8 | 另外,还需要设置cache_dir路径,用来存储下载的文件。 9 | (2)先下载好文件。下载好的文件是tensorflow的ckpt格式的,首先要利用convert_tf_checkpoint_to_pytorch转换成pytorch格式存储 10 | 这种方式是通过本地文件夹直接加载的,要注意这时的文件命名方式。首先指定bert_model=存储模型的文件夹 11 | 第二,将vocab.txt和bert_config.json放入该目录下,并在配置文件中指定VOCAB_FILE路径。当然vocab.txt可以不和模型放在一起, 12 | 但是bert_config.json文件必须和模型文件在一起。具体可见源代码file_utils 13 | """ 14 | 15 | BASE_DIR = 'pyernie' 16 | 17 | configs = { 18 | 'arch':'ernie', 19 | 'raw_data_path': path.sep.join([BASE_DIR,'dataset/raw/cnews.train.txt']), # 总的数据,一般是将train和test何在一起构建语料库 20 | 'train_file_path': path.sep.join([BASE_DIR,'dataset/processed/train.tsv']), 21 | 'valid_file_path': path.sep.join([BASE_DIR,'dataset/processed/valid.tsv']), 22 | 23 | 'log_dir': path.sep.join([BASE_DIR, 'output/log']), # 模型运行日志 24 | 'writer_dir': path.sep.join([BASE_DIR, 'output/TSboard']),# TSboard信息保存路径 25 | 'figure_dir': path.sep.join([BASE_DIR, 'output/figure']), # 图形保存路径 26 | 'checkpoint_dir': path.sep.join([BASE_DIR, 'output/checkpoints']),# 模型保存路径 27 | 'cache_dir': path.sep.join([BASE_DIR,'model/']), 28 | 29 | 'vocab_path': path.sep.join([BASE_DIR, 'model/pretrain/ernie_base/vocab.txt']), 30 | 'ernie_config_file': path.sep.join([BASE_DIR, 'model/pretrain/ernie_base/ernie_config.json']), 31 | 'pytorch_model_path': path.sep.join([BASE_DIR, 'model/pretrain/ernie_base/pytorch_model.bin']), 32 | 'ernie_model_dir': path.sep.join([BASE_DIR, 'model/pretrain/ernie_base']), 33 | 34 | 'valid_size': 0.3, # valid数据集大小 35 | 'max_seq_len': 256, # word文本平均长度,按照覆盖95%样本的标准,取截断长度:np.percentile(list,95.0) 36 | 37 | 'batch_size': 16, # how many samples to process at once 38 | 'epochs': 5, # number of epochs to train 39 | 'start_epoch': 1, 40 | 'warmup_proportion': 0.1, # Proportion of training to perform linear learning rate warmup for. E.g., 0.1 = 10%% of training. 41 | 'gradient_accumulation_steps':1, # Number of updates steps to accumulate before performing a backward/update pass. 42 | 'learning_rate': 2e-5, 43 | 'n_gpu': [0], # GPU个数,如果只写一个数字,则表示gpu标号从0开始,并且默认使用gpu:0作为controller, 44 | # 如果以列表形式表示,即[1,3,5],则我们默认list[0]作为controller 45 | 46 | 'num_workers': multiprocessing.cpu_count(), # 线程个数 47 | 'resume':False, 48 | 'seed': 2018, 49 | 'lr_patience': 5, # number of epochs with no improvement after which learning rate will be reduced. 50 | 'mode': 'min', # one of {min, max} 51 | 'monitor': 'valid_loss', # 计算指标 52 | 'early_patience': 10, # early_stopping 53 | 'save_best_only': True, # 是否保存最好模型 54 | 'save_checkpoint_freq': 10, #保存模型频率,当save_best_only为False时候,指定才有作用 55 | 56 | 'label_to_id' : { # 标签映射 57 | "财经": 0, 58 | "体育": 1, 59 | "娱乐": 2, 60 | "家居": 3, 61 | "房产": 4, 62 | "教育": 5, 63 | "时尚": 6, 64 | "时政": 7, 65 | "游戏": 8, 66 | "科技": 9, 67 | } 68 | } 69 | -------------------------------------------------------------------------------- /pyernie/dataset/__init__.py: -------------------------------------------------------------------------------- 1 | #encoding:utf-8 -------------------------------------------------------------------------------- /pyernie/dataset/processed/__init__.py: -------------------------------------------------------------------------------- 1 | #encoding:utf-8 -------------------------------------------------------------------------------- /pyernie/dataset/raw/__init__.py: -------------------------------------------------------------------------------- 1 | #encoding:utf-8 -------------------------------------------------------------------------------- /pyernie/io/__init__.py: -------------------------------------------------------------------------------- 1 | #encoding:utf-8 -------------------------------------------------------------------------------- /pyernie/io/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lonePatient/ERNIE-text-classification-pytorch/01d93c594b0d102ea0a8c1d310a162d6fe5a5328/pyernie/io/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /pyernie/io/__pycache__/data_transformer.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lonePatient/ERNIE-text-classification-pytorch/01d93c594b0d102ea0a8c1d310a162d6fe5a5328/pyernie/io/__pycache__/data_transformer.cpython-36.pyc -------------------------------------------------------------------------------- /pyernie/io/__pycache__/dataset.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lonePatient/ERNIE-text-classification-pytorch/01d93c594b0d102ea0a8c1d310a162d6fe5a5328/pyernie/io/__pycache__/dataset.cpython-36.pyc -------------------------------------------------------------------------------- /pyernie/io/data_transformer.py: -------------------------------------------------------------------------------- 1 | #encoding:utf-8 2 | import random 3 | from tqdm import tqdm 4 | from ..utils.utils import text_write 5 | 6 | class DataTransformer(object): 7 | def __init__(self, 8 | logger, 9 | label_to_id, 10 | train_file, 11 | valid_file, 12 | valid_size, 13 | skip_header, 14 | preprocess, 15 | raw_data_path, 16 | shuffle, 17 | seed, 18 | ): 19 | self.seed = seed 20 | self.logger = logger 21 | self.valid_size = valid_size 22 | self.train_file = train_file 23 | self.valid_file = valid_file 24 | self.raw_data_path = raw_data_path 25 | self.skip_header = skip_header 26 | self.label_to_id = label_to_id 27 | self.preprocess = preprocess 28 | self.shuffle = shuffle 29 | 30 | # 将原始数据集分割成train和valid 31 | def train_val_split(self,X, y): 32 | self.logger.info('train val split') 33 | train, valid = [], [] 34 | bucket = [[] for _ in self.label_to_id] 35 | for data_x, data_y in tqdm(zip(X, y), desc='bucket'): 36 | bucket[int(data_y)].append((data_x, data_y)) 37 | del X, y 38 | for bt in tqdm(bucket, desc='split'): 39 | N = len(bt) 40 | if N == 0: 41 | continue 42 | test_size = int(N * self.valid_size) 43 | if self.shuffle: 44 | random.seed(self.seed) 45 | random.shuffle(bt) 46 | valid.extend(bt[:test_size]) 47 | train.extend(bt[test_size:]) 48 | # 混洗train数据集 49 | if self.shuffle: 50 | random.seed(self.seed) 51 | random.shuffle(train) 52 | return train, valid 53 | 54 | # 读取原始数据集 55 | def read_data(self): 56 | targets,sentences = [],[] 57 | with open(self.raw_data_path,'r') as fr: 58 | for i,line in enumerate(fr): 59 | # 如果首行为列名,则skip_header=True 60 | if i == 0 and self.skip_header: 61 | continue 62 | lines = line.strip().split('\t') 63 | target = self.label_to_id[lines[0]] 64 | sentence = str(lines[1]) 65 | # 预处理 66 | if self.preprocess: 67 | sentence = self.preprocess(sentence) 68 | if sentence: 69 | targets.append(target) 70 | sentences.append(sentence) 71 | # 保存数据 72 | if self.valid_size: 73 | train,valid = self.train_val_split(X = sentences,y = targets) 74 | text_write(filename = self.train_file,data = train) 75 | text_write(filename = self.valid_file,data = valid) 76 | 77 | 78 | -------------------------------------------------------------------------------- /pyernie/io/dataset.py: -------------------------------------------------------------------------------- 1 | #encoding:utf-8 2 | import csv 3 | import numpy as np 4 | from torch.utils.data import Dataset 5 | from ..model.ernie.tokenization import BertTokenizer 6 | 7 | class InputExample(object): 8 | def __init__(self, guid, text_a, text_b=None, label=None): 9 | """创建一个输入实例 10 | Args: 11 | guid: 每个example拥有唯一的id 12 | text_a: 第一个句子的原始文本,一般对于文本分类来说,只需要text_a 13 | text_b: 第二个句子的原始文本,在句子对的任务中才有,分类问题中为None 14 | label: example对应的标签,对于训练集和验证集应非None,测试集为None 15 | """ 16 | self.guid = guid # 该样本的唯一ID 17 | self.text_a = text_a 18 | self.text_b = text_b 19 | self.label = label 20 | 21 | class InputFeature(object): 22 | ''' 23 | 数据的feature集合 24 | ''' 25 | def __init__(self,input_ids,input_mask,segment_ids,label_id): 26 | self.input_ids = input_ids # tokens的索引 27 | self.input_mask = input_mask 28 | self.segment_ids = segment_ids 29 | self.label_id = label_id 30 | 31 | class CreateDataset(Dataset): 32 | def __init__(self,data_path,max_seq_len,vocab_path,example_type,seed): 33 | self.seed = seed 34 | self.max_seq_len = max_seq_len 35 | self.example_type = example_type 36 | self.data_path = data_path 37 | self.vocab_path = vocab_path 38 | self.reset() 39 | 40 | # 初始化 41 | def reset(self): 42 | # 加载语料库,这是pretrained Bert模型自带的 43 | self.tokenizer = BertTokenizer(vocab_file=self.vocab_path) 44 | # 构建examples 45 | self.build_examples() 46 | 47 | # 读取数据集 48 | def read_data(self,quotechar = None): 49 | ''' 50 | 默认是以tab分割的数据 51 | :param quotechar: 52 | :return: 53 | ''' 54 | lines = [] 55 | with open(self.data_path,'r',encoding='utf-8') as fr: 56 | reader = csv.reader(fr,delimiter = '\t',quotechar = quotechar) 57 | for line in reader: 58 | lines.append(line) 59 | return lines 60 | 61 | # 构建数据examples 62 | def build_examples(self): 63 | lines = self.read_data() 64 | self.examples = [] 65 | for i,line in enumerate(lines): 66 | guid = '%s-%d'%(self.example_type,i) 67 | label = line[0] 68 | text_a = line[1] 69 | example = InputExample(guid = guid,text_a = text_a,label= label) 70 | self.examples.append(example) 71 | del lines 72 | 73 | # 将example转化为feature 74 | def build_features(self,example): 75 | ''' 76 | # 对于两个句子: 77 | # tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP] 78 | # type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1 79 | 80 | # 对于单个句子: 81 | # tokens: [CLS] the dog is hairy . [SEP] 82 | # type_ids: 0 0 0 0 0 0 0 83 | # type_ids:表示是第一个句子还是第二个句子 84 | ''' 85 | #转化为token 86 | tokens_a = self.tokenizer.tokenize(example.text_a) 87 | # Account for [CLS] and [SEP] with "- 2" 88 | if len(tokens_a) > self.max_seq_len - 2: 89 | tokens_a = tokens_a[:(self.max_seq_len - 2)] 90 | # 句子首尾加入标示符 91 | tokens = ['[CLS]'] + tokens_a + ['[SEP]'] 92 | segment_ids = [0] * len(tokens) # 对应type_ids 93 | # 将词转化为语料库中对应的id 94 | input_ids = self.tokenizer.convert_tokens_to_ids(tokens) 95 | # 输入mask 96 | input_mask = [1] * len(input_ids) 97 | # padding,使用0进行填充 98 | padding = [0] * (self.max_seq_len - len(input_ids)) 99 | 100 | input_ids += padding 101 | input_mask += padding 102 | segment_ids += padding 103 | 104 | # 标签 105 | label_id = int(example.label) 106 | feature = InputFeature(input_ids = input_ids,input_mask = input_mask, 107 | segment_ids = segment_ids,label_id = label_id) 108 | return feature 109 | 110 | def _preprocess(self,index): 111 | example = self.examples[index] 112 | feature = self.build_features(example) 113 | return np.array(feature.input_ids),np.array(feature.input_mask),\ 114 | np.array(feature.segment_ids),np.array(feature.label_id) 115 | 116 | def __getitem__(self, index): 117 | return self._preprocess(index) 118 | 119 | def __len__(self): 120 | return len(self.examples) 121 | -------------------------------------------------------------------------------- /pyernie/model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lonePatient/ERNIE-text-classification-pytorch/01d93c594b0d102ea0a8c1d310a162d6fe5a5328/pyernie/model/__init__.py -------------------------------------------------------------------------------- /pyernie/model/ernie/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.4.0" 2 | from .tokenization import BertTokenizer, BasicTokenizer, WordpieceTokenizer 3 | from .modeling import (BertConfig, BertModel, BertForPreTraining, 4 | BertForMaskedLM, BertForNextSentencePrediction, 5 | BertForSequenceClassification, BertForMultipleChoice, 6 | BertForTokenClassification, BertForQuestionAnswering) 7 | from .optimization import BertAdam 8 | from .file_utils import PYTORCH_PRETRAINED_BERT_CACHE 9 | -------------------------------------------------------------------------------- /pyernie/model/ernie/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lonePatient/ERNIE-text-classification-pytorch/01d93c594b0d102ea0a8c1d310a162d6fe5a5328/pyernie/model/ernie/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /pyernie/model/ernie/__pycache__/file_utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lonePatient/ERNIE-text-classification-pytorch/01d93c594b0d102ea0a8c1d310a162d6fe5a5328/pyernie/model/ernie/__pycache__/file_utils.cpython-36.pyc -------------------------------------------------------------------------------- /pyernie/model/ernie/__pycache__/modeling.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lonePatient/ERNIE-text-classification-pytorch/01d93c594b0d102ea0a8c1d310a162d6fe5a5328/pyernie/model/ernie/__pycache__/modeling.cpython-36.pyc -------------------------------------------------------------------------------- /pyernie/model/ernie/__pycache__/optimization.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lonePatient/ERNIE-text-classification-pytorch/01d93c594b0d102ea0a8c1d310a162d6fe5a5328/pyernie/model/ernie/__pycache__/optimization.cpython-36.pyc -------------------------------------------------------------------------------- /pyernie/model/ernie/__pycache__/tokenization.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lonePatient/ERNIE-text-classification-pytorch/01d93c594b0d102ea0a8c1d310a162d6fe5a5328/pyernie/model/ernie/__pycache__/tokenization.cpython-36.pyc -------------------------------------------------------------------------------- /pyernie/model/ernie/file_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utilities for working with the local dataset cache. 3 | This file is adapted from the AllenNLP library at https://github.com/allenai/allennlp 4 | Copyright by the AllenNLP authors. 5 | """ 6 | 7 | import os 8 | import logging 9 | import shutil 10 | import tempfile 11 | import json 12 | from urllib.parse import urlparse 13 | from pathlib import Path 14 | from typing import Optional, Tuple, Union, IO, Callable, Set 15 | from hashlib import sha256 16 | from functools import wraps 17 | 18 | from tqdm import tqdm 19 | 20 | import boto3 21 | from botocore.exceptions import ClientError 22 | import requests 23 | 24 | logger = logging.getLogger(__name__) # pylint: disable=invalid-name 25 | 26 | PYTORCH_PRETRAINED_BERT_CACHE = Path(os.getenv('PYTORCH_PRETRAINED_BERT_CACHE', 27 | Path.home() / '.pytorch_pretrained_bert')) 28 | 29 | 30 | def url_to_filename(url: str, etag: str = None) -> str: 31 | """ 32 | Convert `url` into a hashed filename in a repeatable way. 33 | If `etag` is specified, append its hash to the url's, delimited 34 | by a period. 35 | """ 36 | url_bytes = url.encode('utf-8') 37 | url_hash = sha256(url_bytes) 38 | filename = url_hash.hexdigest() 39 | 40 | if etag: 41 | etag_bytes = etag.encode('utf-8') 42 | etag_hash = sha256(etag_bytes) 43 | filename += '.' + etag_hash.hexdigest() 44 | 45 | return filename 46 | 47 | 48 | def filename_to_url(filename: str, cache_dir: Union[str, Path] = None) -> Tuple[str, str]: 49 | """ 50 | Return the url and etag (which may be ``None``) stored for `filename`. 51 | Raise ``FileNotFoundError`` if `filename` or its stored metadata do not exist. 52 | """ 53 | if cache_dir is None: 54 | cache_dir = PYTORCH_PRETRAINED_BERT_CACHE 55 | if isinstance(cache_dir, Path): 56 | cache_dir = str(cache_dir) 57 | 58 | cache_path = os.path.join(cache_dir, filename) 59 | if not os.path.exists(cache_path): 60 | raise FileNotFoundError("file {} not found".format(cache_path)) 61 | 62 | meta_path = cache_path + '.json' 63 | if not os.path.exists(meta_path): 64 | raise FileNotFoundError("file {} not found".format(meta_path)) 65 | 66 | with open(meta_path) as meta_file: 67 | metadata = json.load(meta_file) 68 | url = metadata['url'] 69 | etag = metadata['etag'] 70 | 71 | return url, etag 72 | 73 | 74 | def cached_path(url_or_filename: Union[str, Path], cache_dir: Union[str, Path] = None) -> str: 75 | """ 76 | Given something that might be a URL (or might be a local path), 77 | determine which. If it's a URL, download the file and cache it, and 78 | return the path to the cached file. If it's already a local path, 79 | make sure the file exists and then return the path. 80 | """ 81 | if cache_dir is None: 82 | cache_dir = PYTORCH_PRETRAINED_BERT_CACHE 83 | if isinstance(url_or_filename, Path): 84 | url_or_filename = str(url_or_filename) 85 | if isinstance(cache_dir, Path): 86 | cache_dir = str(cache_dir) 87 | 88 | parsed = urlparse(url_or_filename) 89 | 90 | if parsed.scheme in ('http', 'https', 's3'): 91 | # URL, so get it from the cache (downloading if necessary) 92 | return get_from_cache(url_or_filename, cache_dir) 93 | elif os.path.exists(url_or_filename): 94 | # File, and it exists. 95 | return url_or_filename 96 | elif parsed.scheme == '': 97 | # File, but it doesn't exist. 98 | raise FileNotFoundError("file {} not found".format(url_or_filename)) 99 | else: 100 | # Something unknown 101 | raise ValueError("unable to parse {} as a URL or as a local path".format(url_or_filename)) 102 | 103 | 104 | def split_s3_path(url: str) -> Tuple[str, str]: 105 | """Split a full s3 path into the bucket name and path.""" 106 | parsed = urlparse(url) 107 | if not parsed.netloc or not parsed.path: 108 | raise ValueError("bad s3 path {}".format(url)) 109 | bucket_name = parsed.netloc 110 | s3_path = parsed.path 111 | # Remove '/' at beginning of path. 112 | if s3_path.startswith("/"): 113 | s3_path = s3_path[1:] 114 | return bucket_name, s3_path 115 | 116 | 117 | def s3_request(func: Callable): 118 | """ 119 | Wrapper function for s3 requests in order to create more helpful error 120 | messages. 121 | """ 122 | 123 | @wraps(func) 124 | def wrapper(url: str, *args, **kwargs): 125 | try: 126 | return func(url, *args, **kwargs) 127 | except ClientError as exc: 128 | if int(exc.response["Error"]["Code"]) == 404: 129 | raise FileNotFoundError("file {} not found".format(url)) 130 | else: 131 | raise 132 | 133 | return wrapper 134 | 135 | 136 | @s3_request 137 | def s3_etag(url: str) -> Optional[str]: 138 | """Check ETag on S3 object.""" 139 | s3_resource = boto3.resource("s3") 140 | bucket_name, s3_path = split_s3_path(url) 141 | s3_object = s3_resource.Object(bucket_name, s3_path) 142 | return s3_object.e_tag 143 | 144 | 145 | @s3_request 146 | def s3_get(url: str, temp_file: IO) -> None: 147 | """Pull a file directly from S3.""" 148 | s3_resource = boto3.resource("s3") 149 | bucket_name, s3_path = split_s3_path(url) 150 | s3_resource.Bucket(bucket_name).download_fileobj(s3_path, temp_file) 151 | 152 | 153 | def http_get(url: str, temp_file: IO) -> None: 154 | req = requests.get(url, stream=True) 155 | content_length = req.headers.get('Content-Length') 156 | total = int(content_length) if content_length is not None else None 157 | progress = tqdm(unit="B", total=total) 158 | for chunk in req.iter_content(chunk_size=1024): 159 | if chunk: # filter out keep-alive new chunks 160 | progress.update(len(chunk)) 161 | temp_file.write(chunk) 162 | progress.close() 163 | 164 | 165 | def get_from_cache(url: str, cache_dir: Union[str, Path] = None) -> str: 166 | """ 167 | Given a URL, look for the corresponding dataset in the local cache. 168 | If it's not there, download it. Then return the path to the cached file. 169 | """ 170 | if cache_dir is None: 171 | cache_dir = PYTORCH_PRETRAINED_BERT_CACHE 172 | if isinstance(cache_dir, Path): 173 | cache_dir = str(cache_dir) 174 | 175 | os.makedirs(cache_dir, exist_ok=True) 176 | 177 | # Get eTag to add to filename, if it exists. 178 | if url.startswith("s3://"): 179 | etag = s3_etag(url) 180 | else: 181 | response = requests.head(url, allow_redirects=True) 182 | if response.status_code != 200: 183 | raise IOError("HEAD request failed for url {} with status code {}" 184 | .format(url, response.status_code)) 185 | etag = response.headers.get("ETag") 186 | 187 | filename = url_to_filename(url, etag) 188 | 189 | # get cache path to put the file 190 | cache_path = os.path.join(cache_dir, filename) 191 | 192 | if not os.path.exists(cache_path): 193 | # Download to temporary file, then copy to cache dir once finished. 194 | # Otherwise you get corrupt cache entries if the download gets interrupted. 195 | with tempfile.NamedTemporaryFile() as temp_file: 196 | logger.info("%s not found in cache, downloading to %s", url, temp_file.name) 197 | 198 | # GET file object 199 | if url.startswith("s3://"): 200 | s3_get(url, temp_file) 201 | else: 202 | http_get(url, temp_file) 203 | 204 | # we are copying the file before closing it, so flush to avoid truncation 205 | temp_file.flush() 206 | # shutil.copyfileobj() starts at the current position, so go to the start 207 | temp_file.seek(0) 208 | 209 | logger.info("copying %s to cache at %s", temp_file.name, cache_path) 210 | with open(cache_path, 'wb') as cache_file: 211 | shutil.copyfileobj(temp_file, cache_file) 212 | 213 | logger.info("creating metadata file for %s", cache_path) 214 | meta = {'url': url, 'etag': etag} 215 | meta_path = cache_path + '.json' 216 | with open(meta_path, 'w') as meta_file: 217 | json.dump(meta, meta_file) 218 | 219 | logger.info("removing temp file %s", temp_file.name) 220 | 221 | return cache_path 222 | 223 | 224 | def read_set_from_file(filename: str) -> Set[str]: 225 | ''' 226 | Extract a de-duped collection (set) of text from a file. 227 | Expected file format is one item per line. 228 | ''' 229 | collection = set() 230 | with open(filename, 'r', encoding='utf-8') as file_: 231 | for line in file_: 232 | collection.add(line.rstrip()) 233 | return collection 234 | 235 | 236 | def get_file_extension(path: str, dot=True, lower: bool = True): 237 | ext = os.path.splitext(path)[1] 238 | ext = ext if dot else ext[1:] 239 | return ext.lower() if lower else ext 240 | -------------------------------------------------------------------------------- /pyernie/model/ernie/modeling.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """PyTorch BERT model.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import os 23 | import copy 24 | import json 25 | import math 26 | import logging 27 | import tarfile 28 | import tempfile 29 | import shutil 30 | 31 | import torch 32 | from torch import nn 33 | from torch.nn import CrossEntropyLoss 34 | 35 | from .file_utils import cached_path 36 | 37 | logger = logging.getLogger(__name__) 38 | 39 | PRETRAINED_MODEL_ARCHIVE_MAP = { 40 | 'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased.tar.gz", 41 | 'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased.tar.gz", 42 | 'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased.tar.gz", 43 | 'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased.tar.gz", 44 | 'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased.tar.gz", 45 | 'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased.tar.gz", 46 | 'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese.tar.gz", 47 | } 48 | CONFIG_NAME = 'bert_config.json' 49 | WEIGHTS_NAME = 'pytorch_model.bin' 50 | 51 | def gelu(x): 52 | """Implementation of the gelu activation function. 53 | For information: OpenAI GPT's gelu is slightly different (and gives slightly different results): 54 | 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) 55 | """ 56 | return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) 57 | 58 | 59 | def swish(x): 60 | return x * torch.sigmoid(x) 61 | 62 | 63 | ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish} 64 | 65 | 66 | class BertConfig(object): 67 | """Configuration class to store the configuration of a `BertModel`. 68 | """ 69 | def __init__(self, 70 | vocab_size_or_config_json_file, 71 | hidden_size=768, 72 | num_hidden_layers=12, 73 | num_attention_heads=12, 74 | intermediate_size=3072, 75 | hidden_act="gelu", 76 | hidden_dropout_prob=0.1, 77 | attention_probs_dropout_prob=0.1, 78 | max_position_embeddings=512, 79 | type_vocab_size=2, 80 | initializer_range=0.02): 81 | """Constructs BertConfig. 82 | 83 | Args: 84 | vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `BertModel`. 85 | hidden_size: Size of the encoder layers and the pooler layer. 86 | num_hidden_layers: Number of hidden layers in the Transformer encoder. 87 | num_attention_heads: Number of attention heads for each attention layer in 88 | the Transformer encoder. 89 | intermediate_size: The size of the "intermediate" (i.e., feed-forward) 90 | layer in the Transformer encoder. 91 | hidden_act: The non-linear activation function (function or string) in the 92 | encoder and pooler. If string, "gelu", "relu" and "swish" are supported. 93 | hidden_dropout_prob: The dropout probabilitiy for all fully connected 94 | layers in the embeddings, encoder, and pooler. 95 | attention_probs_dropout_prob: The dropout ratio for the attention 96 | probabilities. 97 | max_position_embeddings: The maximum sequence length that this model might 98 | ever be used with. Typically set this to something large just in case 99 | (e.g., 512 or 1024 or 2048). 100 | type_vocab_size: The vocabulary size of the `token_type_ids` passed into 101 | `BertModel`. 102 | initializer_range: The sttdev of the truncated_normal_initializer for 103 | initializing all weight matrices. 104 | """ 105 | if isinstance(vocab_size_or_config_json_file, str): 106 | with open(vocab_size_or_config_json_file, "r", encoding='utf-8') as reader: 107 | json_config = json.loads(reader.read()) 108 | for key, value in json_config.items(): 109 | self.__dict__[key] = value 110 | elif isinstance(vocab_size_or_config_json_file, int): 111 | self.vocab_size = vocab_size_or_config_json_file 112 | self.hidden_size = hidden_size 113 | self.num_hidden_layers = num_hidden_layers 114 | self.num_attention_heads = num_attention_heads 115 | self.hidden_act = hidden_act 116 | self.intermediate_size = intermediate_size 117 | self.hidden_dropout_prob = hidden_dropout_prob 118 | self.attention_probs_dropout_prob = attention_probs_dropout_prob 119 | self.max_position_embeddings = max_position_embeddings 120 | self.type_vocab_size = type_vocab_size 121 | self.initializer_range = initializer_range 122 | else: 123 | raise ValueError("First argument must be either a vocabulary size (int)" 124 | "or the path to a pretrained model config file (str)") 125 | 126 | @classmethod 127 | def from_dict(cls, json_object): 128 | """Constructs a `BertConfig` from a Python dictionary of parameters.""" 129 | config = BertConfig(vocab_size_or_config_json_file=-1) 130 | for key, value in json_object.items(): 131 | config.__dict__[key] = value 132 | return config 133 | 134 | @classmethod 135 | def from_json_file(cls, json_file): 136 | """Constructs a `BertConfig` from a json file of parameters.""" 137 | with open(json_file, "r", encoding='utf-8') as reader: 138 | text = reader.read() 139 | return cls.from_dict(json.loads(text)) 140 | 141 | def __repr__(self): 142 | return str(self.to_json_string()) 143 | 144 | def to_dict(self): 145 | """Serializes this instance to a Python dictionary.""" 146 | output = copy.deepcopy(self.__dict__) 147 | return output 148 | 149 | def to_json_string(self): 150 | """Serializes this instance to a JSON string.""" 151 | return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n" 152 | 153 | try: 154 | from apex.normalization.fused_layer_norm import FusedLayerNorm as BertLayerNorm 155 | except ImportError: 156 | print("Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex.") 157 | class BertLayerNorm(nn.Module): 158 | def __init__(self, hidden_size, eps=1e-12): 159 | """Construct a layernorm module in the TF style (epsilon inside the square root). 160 | """ 161 | super(BertLayerNorm, self).__init__() 162 | self.weight = nn.Parameter(torch.ones(hidden_size)) 163 | self.bias = nn.Parameter(torch.zeros(hidden_size)) 164 | self.variance_epsilon = eps 165 | 166 | def forward(self, x): 167 | u = x.mean(-1, keepdim=True) 168 | s = (x - u).pow(2).mean(-1, keepdim=True) 169 | x = (x - u) / torch.sqrt(s + self.variance_epsilon) 170 | return self.weight * x + self.bias 171 | 172 | class BertEmbeddings(nn.Module): 173 | """Construct the embeddings from word, position and token_type embeddings. 174 | """ 175 | def __init__(self, config): 176 | super(BertEmbeddings, self).__init__() 177 | self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size) 178 | self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) 179 | self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) 180 | 181 | # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load 182 | # any TensorFlow checkpoint file 183 | self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12) 184 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 185 | 186 | def forward(self, input_ids, token_type_ids=None): 187 | seq_length = input_ids.size(1) 188 | position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) 189 | position_ids = position_ids.unsqueeze(0).expand_as(input_ids) 190 | if token_type_ids is None: 191 | token_type_ids = torch.zeros_like(input_ids) 192 | 193 | words_embeddings = self.word_embeddings(input_ids) 194 | position_embeddings = self.position_embeddings(position_ids) 195 | token_type_embeddings = self.token_type_embeddings(token_type_ids) 196 | 197 | embeddings = words_embeddings + position_embeddings + token_type_embeddings 198 | embeddings = self.LayerNorm(embeddings) 199 | embeddings = self.dropout(embeddings) 200 | return embeddings 201 | 202 | 203 | class BertSelfAttention(nn.Module): 204 | def __init__(self, config): 205 | super(BertSelfAttention, self).__init__() 206 | if config.hidden_size % config.num_attention_heads != 0: 207 | raise ValueError( 208 | "The hidden size (%d) is not a multiple of the number of attention " 209 | "heads (%d)" % (config.hidden_size, config.num_attention_heads)) 210 | self.num_attention_heads = config.num_attention_heads 211 | self.attention_head_size = int(config.hidden_size / config.num_attention_heads) 212 | self.all_head_size = self.num_attention_heads * self.attention_head_size 213 | 214 | self.query = nn.Linear(config.hidden_size, self.all_head_size) 215 | self.key = nn.Linear(config.hidden_size, self.all_head_size) 216 | self.value = nn.Linear(config.hidden_size, self.all_head_size) 217 | 218 | self.dropout = nn.Dropout(config.attention_probs_dropout_prob) 219 | 220 | def transpose_for_scores(self, x): 221 | new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) 222 | x = x.view(*new_x_shape) 223 | return x.permute(0, 2, 1, 3) 224 | 225 | def forward(self, hidden_states, attention_mask): 226 | mixed_query_layer = self.query(hidden_states) 227 | mixed_key_layer = self.key(hidden_states) 228 | mixed_value_layer = self.value(hidden_states) 229 | 230 | query_layer = self.transpose_for_scores(mixed_query_layer) 231 | key_layer = self.transpose_for_scores(mixed_key_layer) 232 | value_layer = self.transpose_for_scores(mixed_value_layer) 233 | 234 | # Take the dot product between "query" and "key" to get the raw attention scores. 235 | attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) 236 | attention_scores = attention_scores / math.sqrt(self.attention_head_size) 237 | # Apply the attention mask is (precomputed for all layers in BertModel forward() function) 238 | attention_scores = attention_scores + attention_mask 239 | 240 | # Normalize the attention scores to probabilities. 241 | attention_probs = nn.Softmax(dim=-1)(attention_scores) 242 | 243 | # This is actually dropping out entire tokens to attend to, which might 244 | # seem a bit unusual, but is taken from the original Transformer paper. 245 | attention_probs = self.dropout(attention_probs) 246 | 247 | context_layer = torch.matmul(attention_probs, value_layer) 248 | context_layer = context_layer.permute(0, 2, 1, 3).contiguous() 249 | new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) 250 | context_layer = context_layer.view(*new_context_layer_shape) 251 | return context_layer 252 | 253 | 254 | class BertSelfOutput(nn.Module): 255 | def __init__(self, config): 256 | super(BertSelfOutput, self).__init__() 257 | self.dense = nn.Linear(config.hidden_size, config.hidden_size) 258 | self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12) 259 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 260 | 261 | def forward(self, hidden_states, input_tensor): 262 | hidden_states = self.dense(hidden_states) 263 | hidden_states = self.dropout(hidden_states) 264 | hidden_states = self.LayerNorm(hidden_states + input_tensor) 265 | return hidden_states 266 | 267 | 268 | class BertAttention(nn.Module): 269 | def __init__(self, config): 270 | super(BertAttention, self).__init__() 271 | self.self = BertSelfAttention(config) 272 | self.output = BertSelfOutput(config) 273 | 274 | def forward(self, input_tensor, attention_mask): 275 | self_output = self.self(input_tensor, attention_mask) 276 | attention_output = self.output(self_output, input_tensor) 277 | return attention_output 278 | 279 | 280 | class BertIntermediate(nn.Module): 281 | def __init__(self, config): 282 | super(BertIntermediate, self).__init__() 283 | self.dense = nn.Linear(config.hidden_size, config.intermediate_size) 284 | self.intermediate_act_fn = ACT2FN[config.hidden_act] \ 285 | if isinstance(config.hidden_act, str) else config.hidden_act 286 | 287 | def forward(self, hidden_states): 288 | hidden_states = self.dense(hidden_states) 289 | hidden_states = self.intermediate_act_fn(hidden_states) 290 | return hidden_states 291 | 292 | 293 | class BertOutput(nn.Module): 294 | def __init__(self, config): 295 | super(BertOutput, self).__init__() 296 | self.dense = nn.Linear(config.intermediate_size, config.hidden_size) 297 | self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12) 298 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 299 | 300 | def forward(self, hidden_states, input_tensor): 301 | hidden_states = self.dense(hidden_states) 302 | hidden_states = self.dropout(hidden_states) 303 | hidden_states = self.LayerNorm(hidden_states + input_tensor) 304 | return hidden_states 305 | 306 | 307 | class BertLayer(nn.Module): 308 | def __init__(self, config): 309 | super(BertLayer, self).__init__() 310 | self.attention = BertAttention(config) 311 | self.intermediate = BertIntermediate(config) 312 | self.output = BertOutput(config) 313 | 314 | def forward(self, hidden_states, attention_mask): 315 | attention_output = self.attention(hidden_states, attention_mask) 316 | intermediate_output = self.intermediate(attention_output) 317 | layer_output = self.output(intermediate_output, attention_output) 318 | return layer_output 319 | 320 | 321 | class BertEncoder(nn.Module): 322 | def __init__(self, config): 323 | super(BertEncoder, self).__init__() 324 | layer = BertLayer(config) 325 | self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.num_hidden_layers)]) 326 | 327 | def forward(self, hidden_states, attention_mask, output_all_encoded_layers=True): 328 | all_encoder_layers = [] 329 | for layer_module in self.layer: 330 | hidden_states = layer_module(hidden_states, attention_mask) 331 | if output_all_encoded_layers: 332 | all_encoder_layers.append(hidden_states) 333 | if not output_all_encoded_layers: 334 | all_encoder_layers.append(hidden_states) 335 | return all_encoder_layers 336 | 337 | 338 | class BertPooler(nn.Module): 339 | def __init__(self, config): 340 | super(BertPooler, self).__init__() 341 | self.dense = nn.Linear(config.hidden_size, config.hidden_size) 342 | self.activation = nn.Tanh() 343 | 344 | def forward(self, hidden_states): 345 | # We "pool" the model by simply taking the hidden state corresponding 346 | # to the first token. 347 | first_token_tensor = hidden_states[:, 0] 348 | pooled_output = self.dense(first_token_tensor) 349 | pooled_output = self.activation(pooled_output) 350 | return pooled_output 351 | 352 | 353 | class BertPredictionHeadTransform(nn.Module): 354 | def __init__(self, config): 355 | super(BertPredictionHeadTransform, self).__init__() 356 | self.dense = nn.Linear(config.hidden_size, config.hidden_size) 357 | self.transform_act_fn = ACT2FN[config.hidden_act] \ 358 | if isinstance(config.hidden_act, str) else config.hidden_act 359 | self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12) 360 | 361 | def forward(self, hidden_states): 362 | hidden_states = self.dense(hidden_states) 363 | hidden_states = self.transform_act_fn(hidden_states) 364 | hidden_states = self.LayerNorm(hidden_states) 365 | return hidden_states 366 | 367 | 368 | class BertLMPredictionHead(nn.Module): 369 | def __init__(self, config, bert_model_embedding_weights): 370 | super(BertLMPredictionHead, self).__init__() 371 | self.transform = BertPredictionHeadTransform(config) 372 | 373 | # The output weights are the same as the input embeddings, but there is 374 | # an output-only bias for each token. 375 | self.decoder = nn.Linear(bert_model_embedding_weights.size(1), 376 | bert_model_embedding_weights.size(0), 377 | bias=False) 378 | self.decoder.weight = bert_model_embedding_weights 379 | self.bias = nn.Parameter(torch.zeros(bert_model_embedding_weights.size(0))) 380 | 381 | def forward(self, hidden_states): 382 | hidden_states = self.transform(hidden_states) 383 | hidden_states = self.decoder(hidden_states) + self.bias 384 | return hidden_states 385 | 386 | 387 | class BertOnlyMLMHead(nn.Module): 388 | def __init__(self, config, bert_model_embedding_weights): 389 | super(BertOnlyMLMHead, self).__init__() 390 | self.predictions = BertLMPredictionHead(config, bert_model_embedding_weights) 391 | 392 | def forward(self, sequence_output): 393 | prediction_scores = self.predictions(sequence_output) 394 | return prediction_scores 395 | 396 | 397 | class BertOnlyNSPHead(nn.Module): 398 | def __init__(self, config): 399 | super(BertOnlyNSPHead, self).__init__() 400 | self.seq_relationship = nn.Linear(config.hidden_size, 2) 401 | 402 | def forward(self, pooled_output): 403 | seq_relationship_score = self.seq_relationship(pooled_output) 404 | return seq_relationship_score 405 | 406 | 407 | class BertPreTrainingHeads(nn.Module): 408 | def __init__(self, config, bert_model_embedding_weights): 409 | super(BertPreTrainingHeads, self).__init__() 410 | self.predictions = BertLMPredictionHead(config, bert_model_embedding_weights) 411 | self.seq_relationship = nn.Linear(config.hidden_size, 2) 412 | 413 | def forward(self, sequence_output, pooled_output): 414 | prediction_scores = self.predictions(sequence_output) 415 | seq_relationship_score = self.seq_relationship(pooled_output) 416 | return prediction_scores, seq_relationship_score 417 | 418 | 419 | class PreTrainedBertModel(nn.Module): 420 | """ An abstract class to handle weights initialization and 421 | a simple interface for dowloading and loading pretrained models. 422 | """ 423 | def __init__(self, config, *inputs, **kwargs): 424 | super(PreTrainedBertModel, self).__init__() 425 | if not isinstance(config, BertConfig): 426 | raise ValueError( 427 | "Parameter config in `{}(config)` should be an instance of class `BertConfig`. " 428 | "To create a model from a Google pretrained model use " 429 | "`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format( 430 | self.__class__.__name__, self.__class__.__name__ 431 | )) 432 | self.config = config 433 | 434 | def init_bert_weights(self, module): 435 | """ Initialize the weights. 436 | """ 437 | if isinstance(module, (nn.Linear, nn.Embedding)): 438 | # Slightly different from the TF version which uses truncated_normal for initialization 439 | # cf https://github.com/pytorch/pytorch/pull/5617 440 | module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) 441 | elif isinstance(module, BertLayerNorm): 442 | module.bias.data.normal_(mean=0.0, std=self.config.initializer_range) 443 | module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) 444 | if isinstance(module, nn.Linear) and module.bias is not None: 445 | module.bias.data.zero_() 446 | 447 | @classmethod 448 | def from_pretrained(cls, pretrained_model_name, state_dict=None, cache_dir=None, *inputs, **kwargs): 449 | """ 450 | Instantiate a PreTrainedBertModel from a pre-trained model file or a pytorch state dict. 451 | Download and cache the pre-trained model file if needed. 452 | 453 | Params: 454 | pretrained_model_name: either: 455 | - a str with the name of a pre-trained model to load selected in the list of: 456 | . `bert-base-uncased` 457 | . `bert-large-uncased` 458 | . `bert-base-cased` 459 | . `bert-base-multilingual` 460 | . `bert-base-chinese` 461 | - a path or url to a pretrained model archive containing: 462 | . `bert_config.json` a configuration file for the model 463 | . `pytorch_model.bin` a PyTorch dump of a BertForPreTraining instance 464 | cache_dir: an optional path to a folder in which the pre-trained models will be cached. 465 | state_dict: an optional state dictionnary (collections.OrderedDict object) to use instead of Google pre-trained models 466 | *inputs, **kwargs: additional input for the specific Bert class 467 | (ex: num_labels for BertForSequenceClassification) 468 | """ 469 | if pretrained_model_name in PRETRAINED_MODEL_ARCHIVE_MAP: 470 | archive_file = PRETRAINED_MODEL_ARCHIVE_MAP[pretrained_model_name] 471 | else: 472 | archive_file = pretrained_model_name 473 | # redirect to the cache, if necessary 474 | try: 475 | resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir) 476 | except FileNotFoundError: 477 | logger.error( 478 | "Model name '{}' was not found in model name list ({}). " 479 | "We assumed '{}' was a path or url but couldn't find any file " 480 | "associated to this path or url.".format( 481 | pretrained_model_name, 482 | ', '.join(PRETRAINED_MODEL_ARCHIVE_MAP.keys()), 483 | archive_file)) 484 | return None 485 | if resolved_archive_file == archive_file: 486 | logger.info("loading archive file {}".format(archive_file)) 487 | else: 488 | logger.info("loading archive file {} from cache at {}".format( 489 | archive_file, resolved_archive_file)) 490 | tempdir = None 491 | if os.path.isdir(resolved_archive_file): 492 | serialization_dir = resolved_archive_file 493 | else: 494 | # Extract archive to temp dir 495 | tempdir = tempfile.mkdtemp() 496 | logger.info("extracting archive file {} to temp dir {}".format( 497 | resolved_archive_file, tempdir)) 498 | with tarfile.open(resolved_archive_file, 'r:gz') as archive: 499 | archive.extractall(tempdir) 500 | serialization_dir = tempdir 501 | # Load config 502 | config_file = os.path.join(serialization_dir, CONFIG_NAME) 503 | config = BertConfig.from_json_file(config_file) 504 | logger.info("Model config {}".format(config)) 505 | # Instantiate model. 506 | model = cls(config, *inputs, **kwargs) 507 | if state_dict is None: 508 | weights_path = os.path.join(serialization_dir, WEIGHTS_NAME) 509 | state_dict = torch.load(weights_path) 510 | 511 | old_keys = [] 512 | new_keys = [] 513 | for key in state_dict.keys(): 514 | new_key = None 515 | if 'gamma' in key: 516 | new_key = key.replace('gamma', 'weight') 517 | if 'beta' in key: 518 | new_key = key.replace('beta', 'bias') 519 | if new_key: 520 | old_keys.append(key) 521 | new_keys.append(new_key) 522 | for old_key, new_key in zip(old_keys, new_keys): 523 | state_dict[new_key] = state_dict.pop(old_key) 524 | 525 | missing_keys = [] 526 | unexpected_keys = [] 527 | error_msgs = [] 528 | # copy state_dict so _load_from_state_dict can modify it 529 | metadata = getattr(state_dict, '_metadata', None) 530 | state_dict = state_dict.copy() 531 | if metadata is not None: 532 | state_dict._metadata = metadata 533 | 534 | def load(module, prefix=''): 535 | local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) 536 | module._load_from_state_dict( 537 | state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs) 538 | for name, child in module._modules.items(): 539 | if child is not None: 540 | load(child, prefix + name + '.') 541 | load(model, prefix='' if hasattr(model, 'bert') else 'bert.') 542 | if len(missing_keys) > 0: 543 | logger.info("Weights of {} not initialized from pretrained model: {}".format( 544 | model.__class__.__name__, missing_keys)) 545 | if len(unexpected_keys) > 0: 546 | logger.info("Weights from pretrained model not used in {}: {}".format( 547 | model.__class__.__name__, unexpected_keys)) 548 | if tempdir: 549 | # Clean up temp dir 550 | shutil.rmtree(tempdir) 551 | return model 552 | 553 | 554 | class BertModel(PreTrainedBertModel): 555 | """BERT model ("Bidirectional Embedding Representations from a Transformer"). 556 | 557 | Params: 558 | config: a BertConfig class instance with the configuration to build a new model 559 | 560 | Inputs: 561 | `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] 562 | with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts 563 | `extract_features.py`, `run_classifier.py` and `run_squad.py`) 564 | `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token 565 | types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to 566 | a `sentence B` token (see BERT paper for more details). 567 | `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices 568 | selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max 569 | input sequence length in the current batch. It's the mask that we typically use for attention when 570 | a batch has varying length sentences. 571 | `output_all_encoded_layers`: boolean which controls the content of the `encoded_layers` output as described below. Default: `True`. 572 | 573 | Outputs: Tuple of (encoded_layers, pooled_output) 574 | `encoded_layers`: controled by `output_all_encoded_layers` argument: 575 | - `output_all_encoded_layers=True`: outputs a list of the full sequences of encoded-hidden-states at the end 576 | of each attention block (i.e. 12 full sequences for BERT-base, 24 for BERT-large), each 577 | encoded-hidden-state is a torch.FloatTensor of size [batch_size, sequence_length, hidden_size], 578 | - `output_all_encoded_layers=False`: outputs only the full sequence of hidden-states corresponding 579 | to the last attention block of shape [batch_size, sequence_length, hidden_size], 580 | `pooled_output`: a torch.FloatTensor of size [batch_size, hidden_size] which is the output of a 581 | classifier pretrained on top of the hidden state associated to the first character of the 582 | input (`CLF`) to train on the Next-Sentence task (see BERT's paper). 583 | 584 | Example usage: 585 | ```python 586 | # Already been converted into WordPiece token ids 587 | input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) 588 | input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) 589 | token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) 590 | 591 | config = modeling.BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, 592 | num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) 593 | 594 | model = modeling.BertModel(config=config) 595 | all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask) 596 | ``` 597 | """ 598 | def __init__(self, config): 599 | super(BertModel, self).__init__(config) 600 | self.embeddings = BertEmbeddings(config) 601 | self.encoder = BertEncoder(config) 602 | self.pooler = BertPooler(config) 603 | self.apply(self.init_bert_weights) 604 | 605 | def forward(self, input_ids, token_type_ids=None, attention_mask=None, output_all_encoded_layers=True): 606 | if attention_mask is None: 607 | attention_mask = torch.ones_like(input_ids) 608 | if token_type_ids is None: 609 | token_type_ids = torch.zeros_like(input_ids) 610 | 611 | # We create a 3D attention mask from a 2D tensor mask. 612 | # Sizes are [batch_size, 1, 1, to_seq_length] 613 | # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] 614 | # this attention mask is more simple than the triangular masking of causal attention 615 | # used in OpenAI GPT, we just need to prepare the broadcast dimension here. 616 | extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) 617 | 618 | # Since attention_mask is 1.0 for positions we want to attend and 0.0 for 619 | # masked positions, this operation will create a tensor which is 0.0 for 620 | # positions we want to attend and -10000.0 for masked positions. 621 | # Since we are adding it to the raw scores before the softmax, this is 622 | # effectively the same as removing these entirely. 623 | extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility 624 | extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 625 | 626 | embedding_output = self.embeddings(input_ids, token_type_ids) 627 | encoded_layers = self.encoder(embedding_output, 628 | extended_attention_mask, 629 | output_all_encoded_layers=output_all_encoded_layers) 630 | sequence_output = encoded_layers[-1] 631 | pooled_output = self.pooler(sequence_output) 632 | if not output_all_encoded_layers: 633 | encoded_layers = encoded_layers[-1] 634 | return encoded_layers, pooled_output 635 | 636 | 637 | class BertForPreTraining(PreTrainedBertModel): 638 | """BERT model with pre-training heads. 639 | This module comprises the BERT model followed by the two pre-training heads: 640 | - the masked language modeling head, and 641 | - the next sentence classification head. 642 | 643 | Params: 644 | config: a BertConfig class instance with the configuration to build a new model. 645 | 646 | Inputs: 647 | `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] 648 | with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts 649 | `extract_features.py`, `run_classifier.py` and `run_squad.py`) 650 | `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token 651 | types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to 652 | a `sentence B` token (see BERT paper for more details). 653 | `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices 654 | selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max 655 | input sequence length in the current batch. It's the mask that we typically use for attention when 656 | a batch has varying length sentences. 657 | `masked_lm_labels`: masked language modeling labels: torch.LongTensor of shape [batch_size, sequence_length] 658 | with indices selected in [-1, 0, ..., vocab_size]. All labels set to -1 are ignored (masked), the loss 659 | is only computed for the labels set in [0, ..., vocab_size] 660 | `next_sentence_label`: next sentence classification loss: torch.LongTensor of shape [batch_size] 661 | with indices selected in [0, 1]. 662 | 0 => next sentence is the continuation, 1 => next sentence is a random sentence. 663 | 664 | Outputs: 665 | if `masked_lm_labels` and `next_sentence_label` are not `None`: 666 | Outputs the total_loss which is the sum of the masked language modeling loss and the next 667 | sentence classification loss. 668 | if `masked_lm_labels` or `next_sentence_label` is `None`: 669 | Outputs a tuple comprising 670 | - the masked language modeling logits of shape [batch_size, sequence_length, vocab_size], and 671 | - the next sentence classification logits of shape [batch_size, 2]. 672 | 673 | Example usage: 674 | ```python 675 | # Already been converted into WordPiece token ids 676 | input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) 677 | input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) 678 | token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) 679 | 680 | config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, 681 | num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) 682 | 683 | model = BertForPreTraining(config) 684 | masked_lm_logits_scores, seq_relationship_logits = model(input_ids, token_type_ids, input_mask) 685 | ``` 686 | """ 687 | def __init__(self, config): 688 | super(BertForPreTraining, self).__init__(config) 689 | self.bert = BertModel(config) 690 | self.cls = BertPreTrainingHeads(config, self.bert.embeddings.word_embeddings.weight) 691 | self.apply(self.init_bert_weights) 692 | 693 | def forward(self, input_ids, token_type_ids=None, attention_mask=None, masked_lm_labels=None, next_sentence_label=None): 694 | sequence_output, pooled_output = self.bert(input_ids, token_type_ids, attention_mask, 695 | output_all_encoded_layers=False) 696 | prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output) 697 | 698 | if masked_lm_labels is not None and next_sentence_label is not None: 699 | loss_fct = CrossEntropyLoss(ignore_index=-1) 700 | masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1)) 701 | next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1)) 702 | total_loss = masked_lm_loss + next_sentence_loss 703 | return total_loss 704 | else: 705 | return prediction_scores, seq_relationship_score 706 | 707 | 708 | class BertForMaskedLM(PreTrainedBertModel): 709 | """BERT model with the masked language modeling head. 710 | This module comprises the BERT model followed by the masked language modeling head. 711 | 712 | Params: 713 | config: a BertConfig class instance with the configuration to build a new model. 714 | 715 | Inputs: 716 | `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] 717 | with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts 718 | `extract_features.py`, `run_classifier.py` and `run_squad.py`) 719 | `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token 720 | types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to 721 | a `sentence B` token (see BERT paper for more details). 722 | `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices 723 | selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max 724 | input sequence length in the current batch. It's the mask that we typically use for attention when 725 | a batch has varying length sentences. 726 | `masked_lm_labels`: masked language modeling labels: torch.LongTensor of shape [batch_size, sequence_length] 727 | with indices selected in [-1, 0, ..., vocab_size]. All labels set to -1 are ignored (masked), the loss 728 | is only computed for the labels set in [0, ..., vocab_size] 729 | 730 | Outputs: 731 | if `masked_lm_labels` is `None`: 732 | Outputs the masked language modeling loss. 733 | if `masked_lm_labels` is `None`: 734 | Outputs the masked language modeling logits of shape [batch_size, sequence_length, vocab_size]. 735 | 736 | Example usage: 737 | ```python 738 | # Already been converted into WordPiece token ids 739 | input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) 740 | input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) 741 | token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) 742 | 743 | config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, 744 | num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) 745 | 746 | model = BertForMaskedLM(config) 747 | masked_lm_logits_scores = model(input_ids, token_type_ids, input_mask) 748 | ``` 749 | """ 750 | def __init__(self, config): 751 | super(BertForMaskedLM, self).__init__(config) 752 | self.bert = BertModel(config) 753 | self.cls = BertOnlyMLMHead(config, self.bert.embeddings.word_embeddings.weight) 754 | self.apply(self.init_bert_weights) 755 | 756 | def forward(self, input_ids, token_type_ids=None, attention_mask=None, masked_lm_labels=None): 757 | sequence_output, _ = self.bert(input_ids, token_type_ids, attention_mask, 758 | output_all_encoded_layers=False) 759 | prediction_scores = self.cls(sequence_output) 760 | 761 | if masked_lm_labels is not None: 762 | loss_fct = CrossEntropyLoss(ignore_index=-1) 763 | masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1)) 764 | return masked_lm_loss 765 | else: 766 | return prediction_scores 767 | 768 | 769 | class BertForNextSentencePrediction(PreTrainedBertModel): 770 | """BERT model with next sentence prediction head. 771 | This module comprises the BERT model followed by the next sentence classification head. 772 | 773 | Params: 774 | config: a BertConfig class instance with the configuration to build a new model. 775 | 776 | Inputs: 777 | `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] 778 | with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts 779 | `extract_features.py`, `run_classifier.py` and `run_squad.py`) 780 | `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token 781 | types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to 782 | a `sentence B` token (see BERT paper for more details). 783 | `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices 784 | selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max 785 | input sequence length in the current batch. It's the mask that we typically use for attention when 786 | a batch has varying length sentences. 787 | `next_sentence_label`: next sentence classification loss: torch.LongTensor of shape [batch_size] 788 | with indices selected in [0, 1]. 789 | 0 => next sentence is the continuation, 1 => next sentence is a random sentence. 790 | 791 | Outputs: 792 | if `next_sentence_label` is not `None`: 793 | Outputs the total_loss which is the sum of the masked language modeling loss and the next 794 | sentence classification loss. 795 | if `next_sentence_label` is `None`: 796 | Outputs the next sentence classification logits of shape [batch_size, 2]. 797 | 798 | Example usage: 799 | ```python 800 | # Already been converted into WordPiece token ids 801 | input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) 802 | input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) 803 | token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) 804 | 805 | config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, 806 | num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) 807 | 808 | model = BertForNextSentencePrediction(config) 809 | seq_relationship_logits = model(input_ids, token_type_ids, input_mask) 810 | ``` 811 | """ 812 | def __init__(self, config): 813 | super(BertForNextSentencePrediction, self).__init__(config) 814 | self.bert = BertModel(config) 815 | self.cls = BertOnlyNSPHead(config) 816 | self.apply(self.init_bert_weights) 817 | 818 | def forward(self, input_ids, token_type_ids=None, attention_mask=None, next_sentence_label=None): 819 | _, pooled_output = self.bert(input_ids, token_type_ids, attention_mask, 820 | output_all_encoded_layers=False) 821 | seq_relationship_score = self.cls( pooled_output) 822 | 823 | if next_sentence_label is not None: 824 | loss_fct = CrossEntropyLoss(ignore_index=-1) 825 | next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1)) 826 | return next_sentence_loss 827 | else: 828 | return seq_relationship_score 829 | 830 | 831 | class BertForSequenceClassification(PreTrainedBertModel): 832 | """BERT model for classification. 833 | This module is composed of the BERT model with a linear layer on top of 834 | the pooled output. 835 | 836 | Params: 837 | `config`: a BertConfig class instance with the configuration to build a new model. 838 | `num_labels`: the number of classes for the classifier. Default = 2. 839 | 840 | Inputs: 841 | `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] 842 | with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts 843 | `extract_features.py`, `run_classifier.py` and `run_squad.py`) 844 | `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token 845 | types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to 846 | a `sentence B` token (see BERT paper for more details). 847 | `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices 848 | selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max 849 | input sequence length in the current batch. It's the mask that we typically use for attention when 850 | a batch has varying length sentences. 851 | `labels`: labels for the classification output: torch.LongTensor of shape [batch_size] 852 | with indices selected in [0, ..., num_labels]. 853 | 854 | Outputs: 855 | if `labels` is not `None`: 856 | Outputs the CrossEntropy classification loss of the output with the labels. 857 | if `labels` is `None`: 858 | Outputs the classification logits of shape [batch_size, num_labels]. 859 | 860 | Example usage: 861 | ```python 862 | # Already been converted into WordPiece token ids 863 | input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) 864 | input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) 865 | token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) 866 | 867 | config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, 868 | num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) 869 | 870 | num_labels = 2 871 | 872 | model = BertForSequenceClassification(config, num_labels) 873 | logits = model(input_ids, token_type_ids, input_mask) 874 | ``` 875 | """ 876 | def __init__(self, config, num_labels=2): 877 | super(BertForSequenceClassification, self).__init__(config) 878 | self.num_labels = num_labels 879 | self.bert = BertModel(config) 880 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 881 | self.classifier = nn.Linear(config.hidden_size, num_labels) 882 | self.apply(self.init_bert_weights) 883 | 884 | def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None): 885 | _, pooled_output = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False) 886 | pooled_output = self.dropout(pooled_output) 887 | logits = self.classifier(pooled_output) 888 | 889 | if labels is not None: 890 | loss_fct = CrossEntropyLoss() 891 | loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) 892 | return loss 893 | else: 894 | return logits 895 | 896 | 897 | class BertForMultipleChoice(PreTrainedBertModel): 898 | """BERT model for multiple choice tasks. 899 | This module is composed of the BERT model with a linear layer on top of 900 | the pooled output. 901 | 902 | Params: 903 | `config`: a BertConfig class instance with the configuration to build a new model. 904 | `num_choices`: the number of classes for the classifier. Default = 2. 905 | 906 | Inputs: 907 | `input_ids`: a torch.LongTensor of shape [batch_size, num_choices, sequence_length] 908 | with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts 909 | `extract_features.py`, `run_classifier.py` and `run_squad.py`) 910 | `token_type_ids`: an optional torch.LongTensor of shape [batch_size, num_choices, sequence_length] 911 | with the token types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` 912 | and type 1 corresponds to a `sentence B` token (see BERT paper for more details). 913 | `attention_mask`: an optional torch.LongTensor of shape [batch_size, num_choices, sequence_length] with indices 914 | selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max 915 | input sequence length in the current batch. It's the mask that we typically use for attention when 916 | a batch has varying length sentences. 917 | `labels`: labels for the classification output: torch.LongTensor of shape [batch_size] 918 | with indices selected in [0, ..., num_choices]. 919 | 920 | Outputs: 921 | if `labels` is not `None`: 922 | Outputs the CrossEntropy classification loss of the output with the labels. 923 | if `labels` is `None`: 924 | Outputs the classification logits of shape [batch_size, num_labels]. 925 | 926 | Example usage: 927 | ```python 928 | # Already been converted into WordPiece token ids 929 | input_ids = torch.LongTensor([[[31, 51, 99], [15, 5, 0]], [[12, 16, 42], [14, 28, 57]]]) 930 | input_mask = torch.LongTensor([[[1, 1, 1], [1, 1, 0]],[[1,1,0], [1, 0, 0]]]) 931 | token_type_ids = torch.LongTensor([[[0, 0, 1], [0, 1, 0]],[[0, 1, 1], [0, 0, 1]]]) 932 | config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, 933 | num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) 934 | 935 | num_choices = 2 936 | 937 | model = BertForMultipleChoice(config, num_choices) 938 | logits = model(input_ids, token_type_ids, input_mask) 939 | ``` 940 | """ 941 | def __init__(self, config, num_choices=2): 942 | super(BertForMultipleChoice, self).__init__(config) 943 | self.num_choices = num_choices 944 | self.bert = BertModel(config) 945 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 946 | self.classifier = nn.Linear(config.hidden_size, 1) 947 | self.apply(self.init_bert_weights) 948 | 949 | def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None): 950 | flat_input_ids = input_ids.view(-1, input_ids.size(-1)) 951 | flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) 952 | flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) 953 | _, pooled_output = self.bert(flat_input_ids, flat_token_type_ids, flat_attention_mask, output_all_encoded_layers=False) 954 | pooled_output = self.dropout(pooled_output) 955 | logits = self.classifier(pooled_output) 956 | reshaped_logits = logits.view(-1, self.num_choices) 957 | 958 | if labels is not None: 959 | loss_fct = CrossEntropyLoss() 960 | loss = loss_fct(reshaped_logits, labels) 961 | return loss 962 | else: 963 | return reshaped_logits 964 | 965 | 966 | class BertForTokenClassification(PreTrainedBertModel): 967 | """BERT model for token-level classification. 968 | This module is composed of the BERT model with a linear layer on top of 969 | the full hidden state of the last layer. 970 | 971 | Params: 972 | `config`: a BertConfig class instance with the configuration to build a new model. 973 | `num_labels`: the number of classes for the classifier. Default = 2. 974 | 975 | Inputs: 976 | `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] 977 | with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts 978 | `extract_features.py`, `run_classifier.py` and `run_squad.py`) 979 | `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token 980 | types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to 981 | a `sentence B` token (see BERT paper for more details). 982 | `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices 983 | selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max 984 | input sequence length in the current batch. It's the mask that we typically use for attention when 985 | a batch has varying length sentences. 986 | `labels`: labels for the classification output: torch.LongTensor of shape [batch_size] 987 | with indices selected in [0, ..., num_labels]. 988 | 989 | Outputs: 990 | if `labels` is not `None`: 991 | Outputs the CrossEntropy classification loss of the output with the labels. 992 | if `labels` is `None`: 993 | Outputs the classification logits of shape [batch_size, sequence_length, num_labels]. 994 | 995 | Example usage: 996 | ```python 997 | # Already been converted into WordPiece token ids 998 | input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) 999 | input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) 1000 | token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) 1001 | 1002 | config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, 1003 | num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) 1004 | 1005 | num_labels = 2 1006 | 1007 | model = BertForTokenClassification(config, num_labels) 1008 | logits = model(input_ids, token_type_ids, input_mask) 1009 | ``` 1010 | """ 1011 | def __init__(self, config, num_labels=2): 1012 | super(BertForTokenClassification, self).__init__(config) 1013 | self.num_labels = num_labels 1014 | self.bert = BertModel(config) 1015 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 1016 | self.classifier = nn.Linear(config.hidden_size, num_labels) 1017 | self.apply(self.init_bert_weights) 1018 | 1019 | def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None): 1020 | sequence_output, _ = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False) 1021 | sequence_output = self.dropout(sequence_output) 1022 | logits = self.classifier(sequence_output) 1023 | 1024 | if labels is not None: 1025 | loss_fct = CrossEntropyLoss() 1026 | loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) 1027 | return loss 1028 | else: 1029 | return logits 1030 | 1031 | 1032 | class BertForQuestionAnswering(PreTrainedBertModel): 1033 | """BERT model for Question Answering (span extraction). 1034 | This module is composed of the BERT model with a linear layer on top of 1035 | the sequence output that computes start_logits and end_logits 1036 | 1037 | Params: 1038 | `config`: either 1039 | - a BertConfig class instance with the configuration to build a new model, or 1040 | - a str with the name of a pre-trained model to load selected in the list of: 1041 | . `bert-base-uncased` 1042 | . `bert-large-uncased` 1043 | . `bert-base-cased` 1044 | . `bert-base-multilingual` 1045 | . `bert-base-chinese` 1046 | The pre-trained model will be downloaded and cached if needed. 1047 | 1048 | Inputs: 1049 | `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] 1050 | with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts 1051 | `extract_features.py`, `run_classifier.py` and `run_squad.py`) 1052 | `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token 1053 | types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to 1054 | a `sentence B` token (see BERT paper for more details). 1055 | `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices 1056 | selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max 1057 | input sequence length in the current batch. It's the mask that we typically use for attention when 1058 | a batch has varying length sentences. 1059 | `start_positions`: position of the first token for the labeled span: torch.LongTensor of shape [batch_size]. 1060 | Positions are clamped to the length of the sequence and position outside of the sequence are not taken 1061 | into account for computing the loss. 1062 | `end_positions`: position of the last token for the labeled span: torch.LongTensor of shape [batch_size]. 1063 | Positions are clamped to the length of the sequence and position outside of the sequence are not taken 1064 | into account for computing the loss. 1065 | 1066 | Outputs: 1067 | if `start_positions` and `end_positions` are not `None`: 1068 | Outputs the total_loss which is the sum of the CrossEntropy loss for the start and end token positions. 1069 | if `start_positions` or `end_positions` is `None`: 1070 | Outputs a tuple of start_logits, end_logits which are the logits respectively for the start and end 1071 | position tokens of shape [batch_size, sequence_length]. 1072 | 1073 | Example usage: 1074 | ```python 1075 | # Already been converted into WordPiece token ids 1076 | input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) 1077 | input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) 1078 | token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) 1079 | 1080 | config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, 1081 | num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) 1082 | 1083 | model = BertForQuestionAnswering(config) 1084 | start_logits, end_logits = model(input_ids, token_type_ids, input_mask) 1085 | ``` 1086 | """ 1087 | def __init__(self, config): 1088 | super(BertForQuestionAnswering, self).__init__(config) 1089 | self.bert = BertModel(config) 1090 | # TODO check with Google if it's normal there is no dropout on the token classifier of SQuAD in the TF version 1091 | # self.dropout = nn.Dropout(config.hidden_dropout_prob) 1092 | self.qa_outputs = nn.Linear(config.hidden_size, 2) 1093 | self.apply(self.init_bert_weights) 1094 | 1095 | def forward(self, input_ids, token_type_ids=None, attention_mask=None, start_positions=None, end_positions=None): 1096 | sequence_output, _ = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False) 1097 | logits = self.qa_outputs(sequence_output) 1098 | start_logits, end_logits = logits.split(1, dim=-1) 1099 | start_logits = start_logits.squeeze(-1) 1100 | end_logits = end_logits.squeeze(-1) 1101 | 1102 | if start_positions is not None and end_positions is not None: 1103 | # If we are on multi-GPU, split add a dimension 1104 | if len(start_positions.size()) > 1: 1105 | start_positions = start_positions.squeeze(-1) 1106 | if len(end_positions.size()) > 1: 1107 | end_positions = end_positions.squeeze(-1) 1108 | # sometimes the start/end positions are outside our model inputs, we ignore these terms 1109 | ignored_index = start_logits.size(1) 1110 | start_positions.clamp_(0, ignored_index) 1111 | end_positions.clamp_(0, ignored_index) 1112 | 1113 | loss_fct = CrossEntropyLoss(ignore_index=ignored_index) 1114 | start_loss = loss_fct(start_logits, start_positions) 1115 | end_loss = loss_fct(end_logits, end_positions) 1116 | total_loss = (start_loss + end_loss) / 2 1117 | return total_loss 1118 | else: 1119 | return start_logits, end_logits 1120 | -------------------------------------------------------------------------------- /pyernie/model/ernie/optimization.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """PyTorch optimization for BERT model.""" 16 | 17 | import math 18 | import torch 19 | from torch.optim import Optimizer 20 | from torch.optim.optimizer import required 21 | from torch.nn.utils import clip_grad_norm_ 22 | 23 | def warmup_cosine(x, warmup=0.002): 24 | if x < warmup: 25 | return x/warmup 26 | return 0.5 * (1.0 + torch.cos(math.pi * x)) 27 | 28 | def warmup_constant(x, warmup=0.002): 29 | if x < warmup: 30 | return x/warmup 31 | return 1.0 32 | 33 | def warmup_linear(x, warmup=0.002): 34 | if x < warmup: 35 | return x/warmup 36 | return 1.0 - x 37 | 38 | SCHEDULES = { 39 | 'warmup_cosine':warmup_cosine, 40 | 'warmup_constant':warmup_constant, 41 | 'warmup_linear':warmup_linear, 42 | } 43 | 44 | 45 | class BertAdam(Optimizer): 46 | """Implements BERT version of Adam algorithm with weight decay fix. 47 | Params: 48 | lr: learning rate 49 | warmup: portion of t_total for the warmup, -1 means no warmup. Default: -1 50 | t_total: total number of training steps for the learning 51 | rate schedule, -1 means constant learning rate. Default: -1 52 | schedule: schedule to use for the warmup (see above). Default: 'warmup_linear' 53 | b1: Adams b1. Default: 0.9 54 | b2: Adams b2. Default: 0.999 55 | e: Adams epsilon. Default: 1e-6 56 | weight_decay: Weight decay. Default: 0.01 57 | max_grad_norm: Maximum norm for the gradients (-1 means no clipping). Default: 1.0 58 | """ 59 | def __init__(self, params, lr=required, warmup=-1, t_total=-1, schedule='warmup_linear', 60 | b1=0.9, b2=0.999, e=1e-6, weight_decay=0.01, 61 | max_grad_norm=1.0): 62 | if lr is not required and lr < 0.0: 63 | raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr)) 64 | if schedule not in SCHEDULES: 65 | raise ValueError("Invalid schedule parameter: {}".format(schedule)) 66 | if not 0.0 <= warmup < 1.0 and not warmup == -1: 67 | raise ValueError("Invalid warmup: {} - should be in [0.0, 1.0[ or -1".format(warmup)) 68 | if not 0.0 <= b1 < 1.0: 69 | raise ValueError("Invalid b1 parameter: {} - should be in [0.0, 1.0[".format(b1)) 70 | if not 0.0 <= b2 < 1.0: 71 | raise ValueError("Invalid b2 parameter: {} - should be in [0.0, 1.0[".format(b2)) 72 | if not e >= 0.0: 73 | raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(e)) 74 | defaults = dict(lr=lr, schedule=schedule, warmup=warmup, t_total=t_total, 75 | b1=b1, b2=b2, e=e, weight_decay=weight_decay, 76 | max_grad_norm=max_grad_norm) 77 | super(BertAdam, self).__init__(params, defaults) 78 | 79 | def get_lr(self): 80 | lr = [] 81 | for group in self.param_groups: 82 | for p in group['params']: 83 | state = self.state[p] 84 | if len(state) == 0: 85 | return [0] 86 | if group['t_total'] != -1: 87 | schedule_fct = SCHEDULES[group['schedule']] 88 | lr_scheduled = group['lr'] * schedule_fct(state['step']/group['t_total'], group['warmup']) 89 | else: 90 | lr_scheduled = group['lr'] 91 | lr.append(lr_scheduled) 92 | return lr 93 | 94 | def step(self, closure=None): 95 | """Performs a single optimization step. 96 | 97 | Arguments: 98 | closure (callable, optional): A closure that reevaluates the model 99 | and returns the loss. 100 | """ 101 | loss = None 102 | if closure is not None: 103 | loss = closure() 104 | 105 | for group in self.param_groups: 106 | for p in group['params']: 107 | if p.grad is None: 108 | continue 109 | grad = p.grad.data 110 | if grad.is_sparse: 111 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') 112 | 113 | state = self.state[p] 114 | 115 | # State initialization 116 | if len(state) == 0: 117 | state['step'] = 0 118 | # Exponential moving average of gradient values 119 | state['next_m'] = torch.zeros_like(p.data) 120 | # Exponential moving average of squared gradient values 121 | state['next_v'] = torch.zeros_like(p.data) 122 | 123 | next_m, next_v = state['next_m'], state['next_v'] 124 | beta1, beta2 = group['b1'], group['b2'] 125 | 126 | # Add grad clipping 127 | if group['max_grad_norm'] > 0: 128 | clip_grad_norm_(p, group['max_grad_norm']) 129 | 130 | # Decay the first and second moment running average coefficient 131 | # In-place operations to update the averages at the same time 132 | next_m.mul_(beta1).add_(1 - beta1, grad) 133 | next_v.mul_(beta2).addcmul_(1 - beta2, grad, grad) 134 | update = next_m / (next_v.sqrt() + group['e']) 135 | 136 | # Just adding the square of the weights to the loss function is *not* 137 | # the correct way of using L2 regularization/weight decay with Adam, 138 | # since that will interact with the m and v parameters in strange ways. 139 | # 140 | # Instead we want to decay the weights in a manner that doesn't interact 141 | # with the m/v parameters. This is equivalent to adding the square 142 | # of the weights to the loss with plain (non-momentum) SGD. 143 | if group['weight_decay'] > 0.0: 144 | update += group['weight_decay'] * p.data 145 | 146 | if group['t_total'] != -1: 147 | schedule_fct = SCHEDULES[group['schedule']] 148 | lr_scheduled = group['lr'] * schedule_fct(state['step']/group['t_total'], group['warmup']) 149 | else: 150 | lr_scheduled = group['lr'] 151 | 152 | update_with_lr = lr_scheduled * update 153 | p.data.add_(-update_with_lr) 154 | 155 | state['step'] += 1 156 | 157 | # step_size = lr_scheduled * math.sqrt(bias_correction2) / bias_correction1 158 | # No bias correction 159 | # bias_correction1 = 1 - beta1 ** state['step'] 160 | # bias_correction2 = 1 - beta2 ** state['step'] 161 | 162 | return loss 163 | -------------------------------------------------------------------------------- /pyernie/model/ernie/tokenization.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Tokenization classes.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import collections 22 | import unicodedata 23 | import os 24 | import logging 25 | 26 | from .file_utils import cached_path 27 | 28 | logger = logging.getLogger(__name__) 29 | 30 | PRETRAINED_VOCAB_ARCHIVE_MAP = { 31 | 'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt", 32 | 'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-vocab.txt", 33 | 'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-vocab.txt", 34 | 'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-vocab.txt", 35 | 'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-vocab.txt", 36 | 'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-vocab.txt", 37 | 'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-vocab.txt", 38 | } 39 | PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = { 40 | 'bert-base-uncased': 512, 41 | 'bert-large-uncased': 512, 42 | 'bert-base-cased': 512, 43 | 'bert-large-cased': 512, 44 | 'bert-base-multilingual-uncased': 512, 45 | 'bert-base-multilingual-cased': 512, 46 | 'bert-base-chinese': 512, 47 | } 48 | VOCAB_NAME = 'vocab.txt' 49 | 50 | 51 | def load_vocab(vocab_file): 52 | """Loads a vocabulary file into a dictionary.""" 53 | vocab = collections.OrderedDict() 54 | index = 0 55 | with open(vocab_file, "r", encoding="utf-8") as reader: 56 | while True: 57 | token = reader.readline() 58 | if not token: 59 | break 60 | token = token.strip() 61 | vocab[token] = index 62 | index += 1 63 | return vocab 64 | 65 | 66 | def whitespace_tokenize(text): 67 | """Runs basic whitespace cleaning and splitting on a peice of text.""" 68 | text = text.strip() 69 | if not text: 70 | return [] 71 | tokens = text.split() 72 | return tokens 73 | 74 | 75 | class BertTokenizer(object): 76 | """Runs end-to-end tokenization: punctuation splitting + wordpiece""" 77 | 78 | def __init__(self, vocab_file, do_lower_case=True, max_len=None): 79 | if not os.path.isfile(vocab_file): 80 | raise ValueError( 81 | "Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained " 82 | "model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`".format(vocab_file)) 83 | self.vocab = load_vocab(vocab_file) 84 | self.ids_to_tokens = collections.OrderedDict( 85 | [(ids, tok) for tok, ids in self.vocab.items()]) 86 | self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case) 87 | self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab) 88 | self.max_len = max_len if max_len is not None else int(1e12) 89 | 90 | def tokenize(self, text): 91 | split_tokens = [] 92 | for token in self.basic_tokenizer.tokenize(text): 93 | for sub_token in self.wordpiece_tokenizer.tokenize(token): 94 | split_tokens.append(sub_token) 95 | return split_tokens 96 | 97 | def convert_tokens_to_ids(self, tokens): 98 | """Converts a sequence of tokens into ids using the vocab.""" 99 | ids = [] 100 | for token in tokens: 101 | ids.append(self.vocab[token]) 102 | if len(ids) > self.max_len: 103 | raise ValueError( 104 | "Token indices sequence length is longer than the specified maximum " 105 | " sequence length for this BERT model ({} > {}). Running this" 106 | " sequence through BERT will result in indexing errors".format(len(ids), self.max_len) 107 | ) 108 | return ids 109 | 110 | def convert_ids_to_tokens(self, ids): 111 | """Converts a sequence of ids in wordpiece tokens using the vocab.""" 112 | tokens = [] 113 | for i in ids: 114 | tokens.append(self.ids_to_tokens[i]) 115 | return tokens 116 | 117 | @classmethod 118 | def from_pretrained(cls, pretrained_model_name, cache_dir=None, *inputs, **kwargs): 119 | """ 120 | Instantiate a PreTrainedBertModel from a pre-trained model file. 121 | Download and cache the pre-trained model file if needed. 122 | """ 123 | if pretrained_model_name in PRETRAINED_VOCAB_ARCHIVE_MAP: 124 | vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name] 125 | else: 126 | vocab_file = pretrained_model_name 127 | if os.path.isdir(vocab_file): 128 | vocab_file = os.path.join(vocab_file, VOCAB_NAME) 129 | # redirect to the cache, if necessary 130 | try: 131 | resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir) 132 | except FileNotFoundError: 133 | logger.error( 134 | "Model name '{}' was not found in model name list ({}). " 135 | "We assumed '{}' was a path or url but couldn't find any file " 136 | "associated to this path or url.".format( 137 | pretrained_model_name, 138 | ', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()), 139 | vocab_file)) 140 | return None 141 | if resolved_vocab_file == vocab_file: 142 | logger.info("loading vocabulary file {}".format(vocab_file)) 143 | else: 144 | logger.info("loading vocabulary file {} from cache at {}".format( 145 | vocab_file, resolved_vocab_file)) 146 | if pretrained_model_name in PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP: 147 | # if we're using a pretrained model, ensure the tokenizer wont index sequences longer 148 | # than the number of positional embeddings 149 | max_len = PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP[pretrained_model_name] 150 | kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len) 151 | # Instantiate tokenizer. 152 | tokenizer = cls(resolved_vocab_file, *inputs, **kwargs) 153 | return tokenizer 154 | 155 | 156 | class BasicTokenizer(object): 157 | """Runs basic tokenization (punctuation splitting, lower casing, etc.).""" 158 | 159 | def __init__(self, do_lower_case=True): 160 | """Constructs a BasicTokenizer. 161 | 162 | Args: 163 | do_lower_case: Whether to lower case the input. 164 | """ 165 | self.do_lower_case = do_lower_case 166 | 167 | def tokenize(self, text): 168 | """Tokenizes a piece of text.""" 169 | text = self._clean_text(text) 170 | # This was added on November 1st, 2018 for the multilingual and Chinese 171 | # models. This is also applied to the English models now, but it doesn't 172 | # matter since the English models were not trained on any Chinese data 173 | # and generally don't have any Chinese data in them (there are Chinese 174 | # characters in the vocabulary because Wikipedia does have some Chinese 175 | # words in the English Wikipedia.). 176 | text = self._tokenize_chinese_chars(text) 177 | orig_tokens = whitespace_tokenize(text) 178 | split_tokens = [] 179 | for token in orig_tokens: 180 | if self.do_lower_case: 181 | token = token.lower() 182 | token = self._run_strip_accents(token) 183 | split_tokens.extend(self._run_split_on_punc(token)) 184 | 185 | output_tokens = whitespace_tokenize(" ".join(split_tokens)) 186 | return output_tokens 187 | 188 | def _run_strip_accents(self, text): 189 | """Strips accents from a piece of text.""" 190 | text = unicodedata.normalize("NFD", text) 191 | output = [] 192 | for char in text: 193 | cat = unicodedata.category(char) 194 | if cat == "Mn": 195 | continue 196 | output.append(char) 197 | return "".join(output) 198 | 199 | def _run_split_on_punc(self, text): 200 | """Splits punctuation on a piece of text.""" 201 | chars = list(text) 202 | i = 0 203 | start_new_word = True 204 | output = [] 205 | while i < len(chars): 206 | char = chars[i] 207 | if _is_punctuation(char): 208 | output.append([char]) 209 | start_new_word = True 210 | else: 211 | if start_new_word: 212 | output.append([]) 213 | start_new_word = False 214 | output[-1].append(char) 215 | i += 1 216 | 217 | return ["".join(x) for x in output] 218 | 219 | def _tokenize_chinese_chars(self, text): 220 | """Adds whitespace around any CJK character.""" 221 | output = [] 222 | for char in text: 223 | cp = ord(char) 224 | if self._is_chinese_char(cp): 225 | output.append(" ") 226 | output.append(char) 227 | output.append(" ") 228 | else: 229 | output.append(char) 230 | return "".join(output) 231 | 232 | def _is_chinese_char(self, cp): 233 | """Checks whether CP is the codepoint of a CJK character.""" 234 | # This defines a "chinese character" as anything in the CJK Unicode block: 235 | # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) 236 | # 237 | # Note that the CJK Unicode block is NOT all Japanese and Korean characters, 238 | # despite its name. The modern Korean Hangul alphabet is a different block, 239 | # as is Japanese Hiragana and Katakana. Those alphabets are used to write 240 | # space-separated words, so they are not treated specially and handled 241 | # like the all of the other languages. 242 | if ((cp >= 0x4E00 and cp <= 0x9FFF) or # 243 | (cp >= 0x3400 and cp <= 0x4DBF) or # 244 | (cp >= 0x20000 and cp <= 0x2A6DF) or # 245 | (cp >= 0x2A700 and cp <= 0x2B73F) or # 246 | (cp >= 0x2B740 and cp <= 0x2B81F) or # 247 | (cp >= 0x2B820 and cp <= 0x2CEAF) or 248 | (cp >= 0xF900 and cp <= 0xFAFF) or # 249 | (cp >= 0x2F800 and cp <= 0x2FA1F)): # 250 | return True 251 | 252 | return False 253 | 254 | def _clean_text(self, text): 255 | """Performs invalid character removal and whitespace cleanup on text.""" 256 | output = [] 257 | for char in text: 258 | cp = ord(char) 259 | if cp == 0 or cp == 0xfffd or _is_control(char): 260 | continue 261 | if _is_whitespace(char): 262 | output.append(" ") 263 | else: 264 | output.append(char) 265 | return "".join(output) 266 | 267 | 268 | class WordpieceTokenizer(object): 269 | """Runs WordPiece tokenization.""" 270 | 271 | def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=100): 272 | self.vocab = vocab 273 | self.unk_token = unk_token 274 | self.max_input_chars_per_word = max_input_chars_per_word 275 | 276 | def tokenize(self, text): 277 | """Tokenizes a piece of text into its word pieces. 278 | 279 | This uses a greedy longest-match-first algorithm to perform tokenization 280 | using the given vocabulary. 281 | 282 | For example: 283 | input = "unaffable" 284 | output = ["un", "##aff", "##able"] 285 | 286 | Args: 287 | text: A single token or whitespace separated tokens. This should have 288 | already been passed through `BasicTokenizer`. 289 | 290 | Returns: 291 | A list of wordpiece tokens. 292 | """ 293 | 294 | output_tokens = [] 295 | for token in whitespace_tokenize(text): 296 | chars = list(token) 297 | if len(chars) > self.max_input_chars_per_word: 298 | output_tokens.append(self.unk_token) 299 | continue 300 | 301 | is_bad = False 302 | start = 0 303 | sub_tokens = [] 304 | while start < len(chars): 305 | end = len(chars) 306 | cur_substr = None 307 | while start < end: 308 | substr = "".join(chars[start:end]) 309 | if start > 0: 310 | substr = "##" + substr 311 | if substr in self.vocab: 312 | cur_substr = substr 313 | break 314 | end -= 1 315 | if cur_substr is None: 316 | is_bad = True 317 | break 318 | sub_tokens.append(cur_substr) 319 | start = end 320 | 321 | if is_bad: 322 | output_tokens.append(self.unk_token) 323 | else: 324 | output_tokens.extend(sub_tokens) 325 | return output_tokens 326 | 327 | 328 | def _is_whitespace(char): 329 | """Checks whether `chars` is a whitespace character.""" 330 | # \t, \n, and \r are technically contorl characters but we treat them 331 | # as whitespace since they are generally considered as such. 332 | if char == " " or char == "\t" or char == "\n" or char == "\r": 333 | return True 334 | cat = unicodedata.category(char) 335 | if cat == "Zs": 336 | return True 337 | return False 338 | 339 | 340 | def _is_control(char): 341 | """Checks whether `chars` is a control character.""" 342 | # These are technically control characters but we count them as whitespace 343 | # characters. 344 | if char == "\t" or char == "\n" or char == "\r": 345 | return False 346 | cat = unicodedata.category(char) 347 | if cat.startswith("C"): 348 | return True 349 | return False 350 | 351 | 352 | def _is_punctuation(char): 353 | """Checks whether `chars` is a punctuation character.""" 354 | cp = ord(char) 355 | # We treat all non-letter/number ASCII as punctuation. 356 | # Characters such as "^", "$", and "`" are not in the Unicode 357 | # Punctuation class but we treat them as punctuation anyways, for 358 | # consistency. 359 | if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or 360 | (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)): 361 | return True 362 | cat = unicodedata.category(char) 363 | if cat.startswith("P"): 364 | return True 365 | return False 366 | -------------------------------------------------------------------------------- /pyernie/model/nn/__init__.py: -------------------------------------------------------------------------------- 1 | #encoding:utf-8 -------------------------------------------------------------------------------- /pyernie/model/nn/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lonePatient/ERNIE-text-classification-pytorch/01d93c594b0d102ea0a8c1d310a162d6fe5a5328/pyernie/model/nn/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /pyernie/model/nn/__pycache__/ernie_fine.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lonePatient/ERNIE-text-classification-pytorch/01d93c594b0d102ea0a8c1d310a162d6fe5a5328/pyernie/model/nn/__pycache__/ernie_fine.cpython-36.pyc -------------------------------------------------------------------------------- /pyernie/model/nn/ernie_fine.py: -------------------------------------------------------------------------------- 1 | #encoding:utf-8 2 | import torch.nn as nn 3 | from ..ernie.modeling import PreTrainedBertModel, BertModel 4 | 5 | class ErnieFine(PreTrainedBertModel): 6 | def __init__(self,bertConfig,num_classes): 7 | super(ErnieFine ,self).__init__(bertConfig) 8 | self.bert = BertModel(bertConfig) # bert模型 9 | # 默认情况下,bert encoder模型所有的参数都是参与训练的, 10 | # 可以通过以下设置为将其设为不训练,只将classifier这一层进行反响传播, 11 | # for p in self.bert.parameters(): 12 | # p.requires_grad = False 13 | self.dropout = nn.Dropout(bertConfig.hidden_dropout_prob) 14 | self.classifier = nn.Linear(in_features=bertConfig.hidden_size, out_features=num_classes) 15 | self.apply(self.init_bert_weights) 16 | 17 | def forward(self, input_ids, token_type_ids, attention_mask, label_ids=None, output_all_encoded_layers=False): 18 | _, pooled_output = self.bert(input_ids, 19 | token_type_ids, 20 | attention_mask, 21 | output_all_encoded_layers=output_all_encoded_layers) 22 | pooled_output = self.dropout(pooled_output) 23 | logits = self.classifier(pooled_output) 24 | return logits 25 | 26 | -------------------------------------------------------------------------------- /pyernie/model/pretrain/__init__.py: -------------------------------------------------------------------------------- 1 | #encoding:utf-8 -------------------------------------------------------------------------------- /pyernie/model/pretrain/ernie_base/__init__.py: -------------------------------------------------------------------------------- 1 | #encoding:utf-8 -------------------------------------------------------------------------------- /pyernie/output/__init__.py: -------------------------------------------------------------------------------- 1 | #encoding:utf-8 -------------------------------------------------------------------------------- /pyernie/output/checkpoints/__init__.py: -------------------------------------------------------------------------------- 1 | #encoding:utf-8 -------------------------------------------------------------------------------- /pyernie/output/embedding/__init__.py: -------------------------------------------------------------------------------- 1 | #encoding:utf-8 -------------------------------------------------------------------------------- /pyernie/output/feature/__init__.py: -------------------------------------------------------------------------------- 1 | #encoding:utf-8 -------------------------------------------------------------------------------- /pyernie/output/figure/__init__.py: -------------------------------------------------------------------------------- 1 | #encoding:utf-8 -------------------------------------------------------------------------------- /pyernie/output/log/__init__.py: -------------------------------------------------------------------------------- 1 | #encoding:utf-8 -------------------------------------------------------------------------------- /pyernie/output/result/__init__.py: -------------------------------------------------------------------------------- 1 | #encoding:utf-8 -------------------------------------------------------------------------------- /pyernie/preprocessing/__init__.py: -------------------------------------------------------------------------------- 1 | #encoding:utf-8 -------------------------------------------------------------------------------- /pyernie/preprocessing/augmentation.py: -------------------------------------------------------------------------------- 1 | #encoding:utf-8 2 | import numpy as np 3 | import random 4 | 5 | class Augmentator(object): 6 | def __init__(self,is_train_mode = True, proba = 0.5): 7 | self.mode = is_train_mode 8 | self.proba = proba 9 | self.augs = [] 10 | self._reset() 11 | 12 | # 总的增强列表 13 | def _reset(self): 14 | self.augs.append(lambda text: self._shuffle(text)) 15 | self.augs.append(lambda text: self._dropout(text,p = 0.5)) 16 | 17 | # 打乱 18 | def _shuffle(self, text): 19 | text = np.random.permutation(text.strip().split()) 20 | return ' '.join(text) 21 | 22 | #随机删除一些 23 | def _dropout(self, text, p=0.5): 24 | # random delete some text 25 | text = text.strip().split() 26 | len_ = len(text) 27 | indexs = np.random.choice(len_, int(len_ * p)) 28 | for i in indexs: 29 | text[i] = '' 30 | return ' '.join(text) 31 | 32 | def __call__(self,text,aug_type): 33 | ''' 34 | 用aug_type区分数据 35 | ''' 36 | # TTA模式 37 | if 0 <= aug_type <= 2: 38 | pass 39 | # 训练模式 40 | if self.mode and random.random() < self.proba: 41 | aug = random.choice(self.augs) 42 | text = aug(text) 43 | return text 44 | -------------------------------------------------------------------------------- /pyernie/preprocessing/preprocessor.py: -------------------------------------------------------------------------------- 1 | #encoding:utf-8 2 | import re 3 | import jieba 4 | 5 | class Preprocessor(object): 6 | def __init__(self,min_len = 2,stopwords_path = None): 7 | self.min_len = min_len 8 | self.stopwords_path = stopwords_path 9 | self.reset() 10 | 11 | # jieba分词 12 | def jieba_cut(self,sentence): 13 | seg_list = jieba.cut(sentence,cut_all=False) 14 | return ' '.join(seg_list) 15 | 16 | # 加载停用词 17 | def reset(self): 18 | if self.stopwords_path: 19 | with open(self.stopwords_path,'r') as fr: 20 | self.stopwords = {} 21 | for line in fr: 22 | word = line.strip(' ').strip('\n') 23 | self.stopwords[word] = 1 24 | 25 | # 去除长度小于min_len的文本 26 | def clean_length(self,sentence): 27 | if len([x for x in sentence]) >= self.min_len: 28 | return sentence 29 | 30 | # 全角转化为半角 31 | def full2half(self,sentence): 32 | ret_str = '' 33 | for i in sentence: 34 | if ord(i) >= 33 + 65248 and ord(i) <= 126 + 65248: 35 | ret_str += chr(ord(i) - 65248) 36 | else: 37 | ret_str += i 38 | return ret_str 39 | 40 | #去除停用词 41 | def remove_stopword(self,sentence): 42 | words = sentence.split() 43 | x = [word for word in words if word not in self.stopwords] 44 | return " ".join(x) 45 | 46 | # 提取中文 47 | def get_china(self,sentence): 48 | zhmodel = re.compile("[\u4e00-\u9fa5]") 49 | words = [x for x in sentence if zhmodel.search(x)] 50 | return ''.join(words) 51 | # 移除数字 52 | def remove_numbers(self,sentence): 53 | words = sentence.split() 54 | x = [re.sub('\d+','',word) for word in words] 55 | return ' '.join([w for w in x if w !='']) 56 | 57 | def remove_whitespace(self,sentence): 58 | x = ''.join([x for x in sentence if x !=' ' or x !='' or x!=' ']) 59 | return x 60 | # 主函数 61 | def __call__(self, sentence): 62 | x = sentence.strip('\n') 63 | x = self.full2half(x) 64 | # x = self.jieba_cut(x) 65 | # if self.stopwords_path: 66 | # x = self.remove_stopword(x) 67 | x = self.remove_whitespace(x) 68 | x = self.get_china(x) 69 | x = self.clean_length(x) 70 | 71 | return x 72 | 73 | -------------------------------------------------------------------------------- /pyernie/test/__init__.py: -------------------------------------------------------------------------------- 1 | #encoding:utf-8 -------------------------------------------------------------------------------- /pyernie/train/__init__.py: -------------------------------------------------------------------------------- 1 | #encoding:utf-8 -------------------------------------------------------------------------------- /pyernie/train/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lonePatient/ERNIE-text-classification-pytorch/01d93c594b0d102ea0a8c1d310a162d6fe5a5328/pyernie/train/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /pyernie/train/__pycache__/losses.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lonePatient/ERNIE-text-classification-pytorch/01d93c594b0d102ea0a8c1d310a162d6fe5a5328/pyernie/train/__pycache__/losses.cpython-36.pyc -------------------------------------------------------------------------------- /pyernie/train/__pycache__/metrics.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lonePatient/ERNIE-text-classification-pytorch/01d93c594b0d102ea0a8c1d310a162d6fe5a5328/pyernie/train/__pycache__/metrics.cpython-36.pyc -------------------------------------------------------------------------------- /pyernie/train/__pycache__/trainer.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lonePatient/ERNIE-text-classification-pytorch/01d93c594b0d102ea0a8c1d310a162d6fe5a5328/pyernie/train/__pycache__/trainer.cpython-36.pyc -------------------------------------------------------------------------------- /pyernie/train/losses.py: -------------------------------------------------------------------------------- 1 | #encoding:utf-8 2 | from torch.nn import CrossEntropyLoss 3 | from torch.nn import BCEWithLogitsLoss 4 | 5 | __call__ = ['CrossEntropy','BCEWithLogLoss'] 6 | 7 | class CrossEntropy(object): 8 | def __init__(self): 9 | self.loss_f = CrossEntropyLoss() 10 | 11 | def __call__(self, output, target): 12 | loss = self.loss_f(input=output, target=target) 13 | return loss 14 | 15 | class BCEWithLogLoss(object): 16 | def __init__(self): 17 | self.loss_fn = BCEWithLogitsLoss() 18 | 19 | def __call__(self,output,target): 20 | loss = self.loss_fn(input = output,target = target) 21 | return loss 22 | -------------------------------------------------------------------------------- /pyernie/train/metrics.py: -------------------------------------------------------------------------------- 1 | #encoding:utf-8 2 | import torch 3 | from tqdm import tqdm 4 | import numpy as np 5 | from collections import Counter 6 | from sklearn.metrics import roc_auc_score 7 | from sklearn.metrics import f1_score, classification_report 8 | 9 | __call__ = ['Accuracy','AUC','F1Score','EntityScore','ClassReport','MultiLabelReport','AccuracyThresh'] 10 | 11 | class Metric: 12 | def __init__(self): 13 | pass 14 | 15 | def __call__(self, outputs, target): 16 | raise NotImplementedError 17 | 18 | def reset(self): 19 | raise NotImplementedError 20 | 21 | def value(self): 22 | raise NotImplementedError 23 | 24 | def name(self): 25 | raise NotImplementedError 26 | 27 | class Accuracy(Metric): 28 | ''' 29 | 计算准确度 30 | 可以使用topK参数设定计算K准确度 31 | Example: 32 | >>> metric = Accuracy(**) 33 | >>> for epoch in range(epochs): 34 | >>> metric.reset() 35 | >>> for batch in batchs: 36 | >>> logits = model() 37 | >>> metric(logits,target) 38 | >>> print(metric.name(),metric.value()) 39 | ''' 40 | def __init__(self,topK): 41 | super(Accuracy,self).__init__() 42 | self.topK = topK 43 | self.reset() 44 | 45 | def __call__(self, logits, target): 46 | _, pred = logits.topk(self.topK, 1, True, True) 47 | pred = pred.t() 48 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 49 | self.correct_k = correct[:self.topK].view(-1).float().sum(0) 50 | self.total = target.size(0) 51 | 52 | def reset(self): 53 | self.correct_k = 0 54 | self.total = 0 55 | 56 | def value(self): 57 | return float(self.correct_k) / self.total 58 | 59 | def name(self): 60 | return 'accuracy' 61 | 62 | 63 | class AccuracyThresh(Metric): 64 | ''' 65 | 计算准确度 66 | 可以使用topK参数设定计算K准确度 67 | Example: 68 | >>> metric = AccuracyThresh(**) 69 | >>> for epoch in range(epochs): 70 | >>> metric.reset() 71 | >>> for batch in batchs: 72 | >>> logits = model() 73 | >>> metric(logits,target) 74 | >>> print(metric.name(),metric.value()) 75 | ''' 76 | def __init__(self,thresh = 0.5): 77 | super(AccuracyThresh,self).__init__() 78 | self.thresh = thresh 79 | self.reset() 80 | 81 | def __call__(self, logits, target): 82 | self.y_pred = logits.sigmoid() 83 | self.y_true = target 84 | 85 | def reset(self): 86 | self.correct_k = 0 87 | self.total = 0 88 | 89 | def value(self): 90 | data_size = self.y_pred.size(0) 91 | acc = np.mean(((self.y_pred>self.thresh)==self.y_true.byte()).float().cpu().numpy(), axis=1).sum() 92 | return acc / data_size 93 | 94 | def name(self): 95 | return 'accuracy' 96 | 97 | 98 | class AUC(Metric): 99 | ''' 100 | AUC score 101 | micro: 102 | Calculate metrics globally by considering each element of the label 103 | indicator matrix as a label. 104 | macro: 105 | Calculate metrics for each label, and find their unweighted 106 | mean. This does not take label imbalance into account. 107 | weighted: 108 | Calculate metrics for each label, and find their average, weighted 109 | by support (the number of true instances for each label). 110 | samples: 111 | Calculate metrics for each instance, and find their average. 112 | Example: 113 | >>> metric = AUC(**) 114 | >>> for epoch in range(epochs): 115 | >>> metric.reset() 116 | >>> for batch in batchs: 117 | >>> logits = model() 118 | >>> metric(logits,target) 119 | >>> print(metric.name(),metric.value()) 120 | ''' 121 | 122 | def __init__(self,task_type = 'binary',average = 'binary'): 123 | super(AUC, self).__init__() 124 | 125 | assert task_type in ['binary','multiclass'] 126 | assert average in ['binary','micro', 'macro', 'samples', 'weighted'] 127 | 128 | self.task_type = task_type 129 | self.average = average 130 | 131 | def __call__(self,logits,target): 132 | ''' 133 | 计算整个结果 134 | ''' 135 | if self.task_type == 'binary': 136 | self.y_prob = logits.sigmoid().data.cpu().numpy() 137 | else: 138 | self.y_prob = logits.softmax(-1).data.cpu().detach().numpy() 139 | self.y_true = target.cpu().numpy() 140 | 141 | def reset(self): 142 | self.y_prob = 0 143 | self.y_true = 0 144 | 145 | def value(self): 146 | ''' 147 | 计算指标得分 148 | ''' 149 | auc = roc_auc_score(y_score=self.y_prob, y_true=self.y_true, average=self.average) 150 | return auc 151 | 152 | def name(self): 153 | return 'auc' 154 | 155 | class F1Score(Metric): 156 | ''' 157 | F1 Score 158 | binary: 159 | Only report results for the class specified by ``pos_label``. 160 | This is applicable only if targets (``y_{true,pred}``) are binary. 161 | micro: 162 | Calculate metrics globally by considering each element of the label 163 | indicator matrix as a label. 164 | macro: 165 | Calculate metrics for each label, and find their unweighted 166 | mean. This does not take label imbalance into account. 167 | weighted: 168 | Calculate metrics for each label, and find their average, weighted 169 | by support (the number of true instances for each label). 170 | samples: 171 | Calculate metrics for each instance, and find their average. 172 | Example: 173 | >>> metric = F1Score(**) 174 | >>> for epoch in range(epochs): 175 | >>> metric.reset() 176 | >>> for batch in batchs: 177 | >>> logits = model() 178 | >>> metric(logits,target) 179 | >>> print(metric.name(),metric.value()) 180 | ''' 181 | def __init__(self,thresh = 0.5, normalizate = True,task_type = 'binary',average = 'binary',search_thresh = False): 182 | super(F1Score).__init__() 183 | assert task_type in ['binary','multiclass'] 184 | assert average in ['binary','micro', 'macro', 'samples', 'weighted'] 185 | 186 | self.thresh = thresh 187 | self.task_type = task_type 188 | self.normalizate = normalizate 189 | self.search_thresh = search_thresh 190 | self.average = average 191 | 192 | def thresh_search(self,y_prob): 193 | ''' 194 | 对于f1评分的指标,一般我们需要对阈值进行调整,一般不会使用默认的0.5值,因此 195 | 这里我们队Thresh进行优化 196 | :return: 197 | ''' 198 | best_threshold = 0 199 | best_score = 0 200 | for threshold in tqdm([i * 0.01 for i in range(100)], disable=True): 201 | self.y_pred = y_prob > threshold 202 | score = self.value() 203 | if score > best_score: 204 | best_threshold = threshold 205 | best_score = score 206 | return best_threshold,best_score 207 | 208 | def __call__(self,logits,target): 209 | ''' 210 | 计算整个结果 211 | :return: 212 | ''' 213 | self.y_true = target.cpu().numpy() 214 | if self.normalizate and self.task_type == 'binary': 215 | y_prob = logits.sigmoid().data.cpu().numpy() 216 | elif self.normalizate and self.task_type == 'multiclass': 217 | y_prob = logits.softmax(-1).data.cpu().detach().numpy() 218 | else: 219 | y_prob = logits.cpu().detach().numpy() 220 | 221 | if self.task_type == 'binary': 222 | if self.thresh and self.search_thresh == False: 223 | self.y_pred = (y_prob > self.thresh ).astype(int) 224 | self.value() 225 | else: 226 | thresh,f1 = self.thresh_search(y_prob = y_prob) 227 | print(f"Best thresh: {thresh:.4f} - F1 Score: {f1:.4f}") 228 | 229 | if self.task_type == 'multiclass': 230 | self.y_pred = np.argmax(y_prob, 1) 231 | 232 | def reset(self): 233 | self.y_pred = 0 234 | self.y_true = 0 235 | 236 | def value(self): 237 | ''' 238 | 计算指标得分 239 | ''' 240 | if self.task_type == 'binary': 241 | f1 = f1_score(y_true=self.y_true, y_pred=self.y_pred, average=self.average) 242 | return f1 243 | if self.task_type == 'multiclass': 244 | f1 = f1_score(y_true=self.y_true, y_pred=self.y_pred, average=self.average) 245 | return f1 246 | 247 | def name(self): 248 | return 'f1' 249 | 250 | class ClassReport(Metric): 251 | ''' 252 | class report 253 | ''' 254 | def __init__(self,target_names = None): 255 | super(ClassReport).__init__() 256 | self.target_names = target_names 257 | 258 | def reset(self): 259 | self.y_pred = 0 260 | self.y_true = 0 261 | 262 | def value(self): 263 | ''' 264 | 计算指标得分 265 | ''' 266 | score = classification_report(y_true = self.y_true, y_pred = self.y_pred, target_names=self.target_names) 267 | print(f"\n\n classification report: {score}") 268 | 269 | def __call__(self,logits,target): 270 | _, y_pred = torch.max(logits.data, 1) 271 | self.y_pred = y_pred.cpu().numpy() 272 | self.y_true = target.cpu().numpy() 273 | 274 | def name(self): 275 | return "class_report" 276 | 277 | class MultiLabelReport(Metric): 278 | ''' 279 | multi label report 280 | ''' 281 | def __init__(self,id2label = None): 282 | super(MultiLabelReport).__init__() 283 | self.id2label = id2label 284 | 285 | def reset(self): 286 | self.y_prob = 0 287 | self.y_true = 0 288 | 289 | def __call__(self,logits,target): 290 | 291 | self.y_prob = logits.sigmoid().data.cpu().detach().numpy() 292 | self.y_true = target.cpu().numpy() 293 | 294 | def value(self): 295 | ''' 296 | 计算指标得分 297 | ''' 298 | for i, label in self.id2label.items(): 299 | auc = roc_auc_score(y_score=self.y_prob[:, i], y_true=self.y_true[:, i]) 300 | print(f"label:{label} - auc: {auc:.4f}") 301 | 302 | def name(self): 303 | return "multilabel_report" 304 | -------------------------------------------------------------------------------- /pyernie/train/train_utils.py: -------------------------------------------------------------------------------- 1 | #encoding:utf-8 2 | import os 3 | import torch 4 | # gpu 5 | def prepare_device(n_gpu_use,logger): 6 | """ 7 | setup GPU device if available, move model into configured device 8 | # 如果n_gpu_use为数字,则使用range生成list 9 | # 如果输入的是一个list,则默认使用list[0]作为controller 10 | """ 11 | if isinstance(n_gpu_use,int): 12 | n_gpu_use = range(n_gpu_use) 13 | n_gpu = torch.cuda.device_count() 14 | if len(n_gpu_use) > 0 and n_gpu == 0: 15 | logger.warning("Warning: There\'s no GPU available on this machine, training will be performed on CPU.") 16 | n_gpu_use = range(0) 17 | if len(n_gpu_use) > n_gpu: 18 | msg = "Warning: The number of GPU\'s configured to use is {}, but only {} are available on this machine.".format(n_gpu_use, n_gpu) 19 | logger.warning(msg) 20 | n_gpu_use = range(n_gpu) 21 | device = torch.device('cuda:%d'%n_gpu_use[0] if len(n_gpu_use) > 0 else 'cpu') 22 | list_ids = n_gpu_use 23 | return device, list_ids 24 | 25 | # 加载模型 26 | def restore_checkpoint(resume_path,model = None,optimizer = None): 27 | checkpoint = torch.load(resume_path) 28 | best = checkpoint['best'] 29 | start_epoch = checkpoint['epoch'] + 1 30 | if model: 31 | model.load_state_dict(checkpoint['state_dict']) 32 | if optimizer: 33 | optimizer.load_state_dict(checkpoint['optimizer']) 34 | return [model,optimizer,best,start_epoch] 35 | 36 | # 判断环境 cpu还是gpu 37 | def model_device(n_gpu,model,logger): 38 | device, device_ids = prepare_device(n_gpu,logger) 39 | if len(device_ids) > 1: 40 | logger.info("current {} GPUs".format(len(device_ids))) 41 | model = torch.nn.DataParallel(model, device_ids=device_ids) 42 | if len(device_ids) == 1: 43 | os.environ['CUDA_VISIBLE_DEVICES'] = str(device_ids[0]) 44 | model = model.to(device) 45 | return model,device 46 | 47 | -------------------------------------------------------------------------------- /pyernie/train/trainer.py: -------------------------------------------------------------------------------- 1 | #encoding:utf-8 2 | import time 3 | import torch 4 | from ..callback.progressbar import ProgressBar 5 | from ..utils.utils import restore_checkpoint,model_device 6 | from ..utils.utils import summary 7 | # 训练包装器 8 | class Trainer(object): 9 | def __init__(self,train_configs): 10 | 11 | self.start_epoch = 1 12 | self.global_step = 0 13 | self.n_gpu = train_configs['n_gpu'] 14 | self.model = train_configs['model'] 15 | self.epochs = train_configs['epochs'] 16 | self.logger = train_configs['logger'] 17 | self.verbose = train_configs['verbose'] 18 | self.criterion = train_configs['criterion'] 19 | self.optimizer = train_configs['optimizer'] 20 | self.lr_scheduler = train_configs['lr_scheduler'] 21 | self.early_stopping = train_configs['early_stopping'] 22 | self.epoch_metrics = train_configs['epoch_metrics'] 23 | self.batch_metrics = train_configs['batch_metrics'] 24 | self.model_checkpoint = train_configs['model_checkpoint'] 25 | self.training_monitor = train_configs['training_monitor'] 26 | self.gradient_accumulation_steps = train_configs['gradient_accumulation_steps'] 27 | 28 | self.model, self.device = model_device(n_gpu = self.n_gpu, model=self.model, logger=self.logger) 29 | # 重载模型,进行训练 30 | if train_configs['resume']: 31 | self.logger.info(f"\nLoading checkpoint: {train_configs['resume']}") 32 | resume_list = restore_checkpoint(resume_path =train_configs['resume'],model = self.model,optimizer = self.optimizer) 33 | best = resume_list[2] 34 | self.model = resume_list[0] 35 | self.optimizer = resume_list[1] 36 | self.start_epoch = resume_list[3] 37 | if self.model_checkpoint: 38 | self.model_checkpoint.best = best 39 | self.logger.info(f"\nCheckpoint '{train_configs['resume']}' and epoch {self.start_epoch} loaded") 40 | 41 | def epoch_reset(self): 42 | self.outputs = [] 43 | self.targets = [] 44 | self.result = {} 45 | for metric in self.epoch_metrics: 46 | metric.reset() 47 | 48 | def batch_reset(self): 49 | self.info = {} 50 | for metric in self.batch_metrics: 51 | metric.reset() 52 | 53 | 54 | def _save_info(self,epoch,valid_loss): 55 | ''' 56 | 保存模型信息 57 | ''' 58 | state = { 59 | 'epoch': epoch, 60 | 'arch': self.model_checkpoint.arch, 61 | 'state_dict': self.model.state_dict(), 62 | 'optimizer': self.optimizer.state_dict(), 63 | 'valid_loss': round(valid_loss,4) 64 | } 65 | return state 66 | 67 | def _valid_epoch(self,data): 68 | ''' 69 | valid数据集评估 70 | ''' 71 | self.epoch_reset() 72 | self.model.eval() 73 | with torch.no_grad(): 74 | for step, (input_ids, input_mask, segment_ids, label_ids) in enumerate(data): 75 | input_ids = input_ids.to(self.device) 76 | input_mask = input_mask.to(self.device) 77 | segment_ids = segment_ids.to(self.device) 78 | label = label_ids.to(self.device) 79 | logits = self.model(input_ids, segment_ids,input_mask) 80 | self.outputs.append(logits.cpu().detach()) 81 | self.targets.append(label.cpu().detach()) 82 | 83 | self.outputs = torch.cat(self.outputs, dim = 0).cpu().detach() 84 | self.targets = torch.cat(self.targets, dim = 0).cpu().detach() 85 | loss = self.criterion(target = self.targets, output=self.outputs) 86 | self.result['valid_loss'] = loss.item() 87 | print("\n--------------------------valid result ------------------------------") 88 | if self.epoch_metrics: 89 | for metric in self.epoch_metrics: 90 | metric(logits=self.outputs, target=self.targets) 91 | value = metric.value() 92 | if value: 93 | self.result[f'valid_{metric.name()}'] = value 94 | if len(self.n_gpu) > 0: 95 | torch.cuda.empty_cache() 96 | return self.result 97 | 98 | def _train_epoch(self,data): 99 | ''' 100 | epoch训练 101 | :param data: 102 | :return: 103 | ''' 104 | self.epoch_reset() 105 | self.model.train() 106 | for step, (input_ids, input_mask, segment_ids, label_ids) in enumerate(data): 107 | start = time.time() 108 | self.batch_reset() 109 | input_ids = input_ids.to(self.device) 110 | input_mask = input_mask.to(self.device) 111 | segment_ids = segment_ids.to(self.device) 112 | label = label_ids.to(self.device) 113 | logits = self.model(input_ids, segment_ids,input_mask) 114 | # 计算batch loss 115 | loss = self.criterion(output=logits,target=label) 116 | if len(self.n_gpu) >= 2: 117 | loss = loss.mean() 118 | # 如果梯度更新累加step>1,则也需要进行mean操作 119 | if self.gradient_accumulation_steps > 1: 120 | loss = loss / self.gradient_accumulation_steps 121 | loss.backward() 122 | # 学习率更新方式 123 | if (step + 1) % self.gradient_accumulation_steps == 0: 124 | self.lr_scheduler.batch_step(training_step = self.global_step) 125 | self.optimizer.step() 126 | self.optimizer.zero_grad() 127 | self.global_step += 1 128 | 129 | if self.batch_metrics: 130 | for metric in self.batch_metrics: 131 | metric(logits = logits,target = label) 132 | self.info[metric.name()] = metric.value() 133 | 134 | self.info['loss'] = loss.item() 135 | if self.verbose >= 1: 136 | self.progressbar.batch_step(batch_idx= step,info = self.info,use_time=time.time() - start) 137 | # 为了降低显存使用量 138 | self.outputs.append(logits.cpu().detach()) 139 | self.targets.append(label.cpu().detach()) 140 | 141 | print("\n------------------------- train result ------------------------------") 142 | # epoch metric 143 | self.outputs = torch.cat(self.outputs, dim =0).cpu().detach() 144 | self.targets = torch.cat(self.targets, dim =0).cpu().detach() 145 | loss = self.criterion(target=self.targets, output=self.outputs) 146 | self.result['loss'] = loss.item() 147 | 148 | if self.epoch_metrics: 149 | for metric in self.epoch_metrics: 150 | metric(logits=self.outputs, target=self.targets) 151 | value = metric.value() 152 | if value: 153 | self.result[f'{metric.name()}'] = value 154 | if len(self.n_gpu) > 0: 155 | torch.cuda.empty_cache() 156 | return self.result 157 | 158 | def train(self,train_data,valid_data): 159 | self.batch_num = len(train_data) 160 | self.progressbar = ProgressBar(n_batch=self.batch_num) 161 | 162 | print("model summary info: ") 163 | for step, (input_ids, input_mask, segment_ids, label_ids) in enumerate(train_data): 164 | input_ids = input_ids.to(self.device) 165 | input_mask = input_mask.to(self.device) 166 | segment_ids = segment_ids.to(self.device) 167 | summary(self.model,*(input_ids, segment_ids,input_mask),show_input=True) 168 | break 169 | # *************************************************************** 170 | for epoch in range(self.start_epoch,self.start_epoch+self.epochs): 171 | print(f"--------------------Epoch {epoch}/{self.epochs}------------------------") 172 | train_log = self._train_epoch(train_data) 173 | valid_log = self._valid_epoch(valid_data) 174 | 175 | logs = dict(train_log,**valid_log) 176 | show_info = f'\nEpoch: {epoch} - ' + "-".join([f' {key}: {value:.4f} ' for key,value in logs.items()]) 177 | self.logger.info(show_info) 178 | print("-----------------------------------------------------------------------") 179 | # 保存训练过程中模型指标变化 180 | if self.training_monitor: 181 | self.training_monitor.epoch_step(logs) 182 | 183 | # save model 184 | if self.model_checkpoint: 185 | state = self._save_info(epoch,valid_loss = logs['valid_loss']) 186 | self.model_checkpoint.epoch_step(current=logs[self.model_checkpoint.monitor],state = state) 187 | 188 | # early_stopping 189 | if self.early_stopping: 190 | self.early_stopping.epoch_step(epoch=epoch, current=logs[self.early_stopping.monitor]) 191 | if self.early_stopping.stop_training: 192 | break 193 | 194 | 195 | -------------------------------------------------------------------------------- /pyernie/utils/__init__.py: -------------------------------------------------------------------------------- 1 | #encoding:utf-8 -------------------------------------------------------------------------------- /pyernie/utils/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lonePatient/ERNIE-text-classification-pytorch/01d93c594b0d102ea0a8c1d310a162d6fe5a5328/pyernie/utils/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /pyernie/utils/__pycache__/logginger.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lonePatient/ERNIE-text-classification-pytorch/01d93c594b0d102ea0a8c1d310a162d6fe5a5328/pyernie/utils/__pycache__/logginger.cpython-36.pyc -------------------------------------------------------------------------------- /pyernie/utils/__pycache__/utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lonePatient/ERNIE-text-classification-pytorch/01d93c594b0d102ea0a8c1d310a162d6fe5a5328/pyernie/utils/__pycache__/utils.cpython-36.pyc -------------------------------------------------------------------------------- /pyernie/utils/logginger.py: -------------------------------------------------------------------------------- 1 | #encoding:utf-8 2 | import os 3 | import logging 4 | from pathlib import Path 5 | from logging import Logger 6 | from logging.handlers import TimedRotatingFileHandler 7 | 8 | ''' 9 | 日志模块 10 | 1. 同时将日志打印到屏幕跟文件中 11 | 2. 默认值保留近30天日志文件 12 | ''' 13 | 14 | def init_logger(log_name,log_dir): 15 | if not isinstance(log_dir,Path): 16 | log_dir = Path(log_dir) 17 | if not log_dir.exists(): 18 | log_dir.mkdir(exist_ok=True) 19 | if log_name not in Logger.manager.loggerDict: 20 | logger = logging.getLogger(log_name) 21 | logger.setLevel(logging.DEBUG) 22 | handler = TimedRotatingFileHandler(filename=str(log_dir / f"{log_name}.log"),when='D',backupCount = 30) 23 | datefmt = '%Y-%m-%d %H:%M:%S' 24 | format_str = '[%(asctime)s]: %(name)s %(filename)s[line:%(lineno)s] %(levelname)s %(message)s' 25 | formatter = logging.Formatter(format_str,datefmt) 26 | handler.setFormatter(formatter) 27 | handler.setLevel(logging.INFO) 28 | logger.addHandler(handler) 29 | console= logging.StreamHandler() 30 | console.setLevel(logging.INFO) 31 | console.setFormatter(formatter) 32 | logger.addHandler(console) 33 | 34 | handler = TimedRotatingFileHandler(filename=str(log_dir / "ERROR.log"),when='D',backupCount= 30) 35 | datefmt = '%Y-%m-%d %H:%M:%S' 36 | format_str = '[%(asctime)s]: %(name)s %(filename)s[line:%(lineno)s] %(levelname)s %(message)s' 37 | formatter = logging.Formatter(format_str,datefmt) 38 | handler.setFormatter(formatter) 39 | handler.setLevel(logging.ERROR) 40 | logger.addHandler(handler) 41 | logger = logging.getLogger(log_name) 42 | return logger 43 | -------------------------------------------------------------------------------- /pyernie/utils/utils.py: -------------------------------------------------------------------------------- 1 | #encoding:utf-8 2 | import os 3 | import random 4 | import json 5 | import pickle 6 | import torch 7 | import numpy as np 8 | from tqdm import tqdm 9 | from pathlib import Path 10 | import torch.nn as nn 11 | from collections import OrderedDict 12 | 13 | def prepare_device(n_gpu_use,logger): 14 | """ 15 | setup GPU device if available, move model into configured device 16 | # 如果n_gpu_use为数字,则使用range生成list 17 | # 如果输入的是一个list,则默认使用list[0]作为controller 18 | """ 19 | if isinstance(n_gpu_use,int): 20 | n_gpu_use = range(n_gpu_use) 21 | n_gpu = torch.cuda.device_count() 22 | if len(n_gpu_use) > 0 and n_gpu == 0: 23 | logger.warning("Warning: There\'s no GPU available on this machine, training will be performed on CPU.") 24 | n_gpu_use = range(0) 25 | if len(n_gpu_use) > n_gpu: 26 | msg = "Warning: The number of GPU\'s configured to use is {}, but only {} are available on this machine.".format(n_gpu_use, n_gpu) 27 | logger.warning(msg) 28 | n_gpu_use = range(n_gpu) 29 | device = torch.device('cuda:%d'%n_gpu_use[0] if len(n_gpu_use) > 0 else 'cpu') 30 | list_ids = n_gpu_use 31 | return device, list_ids 32 | 33 | def model_device(n_gpu,model,logger): 34 | ''' 35 | 判断环境 cpu还是gpu 36 | :param n_gpu: 37 | :param model: 38 | :param logger: 39 | :return: 40 | ''' 41 | device, device_ids = prepare_device(n_gpu,logger) 42 | if len(device_ids) > 1: 43 | logger.info("current {} GPUs".format(len(device_ids))) 44 | model = torch.nn.DataParallel(model, device_ids=device_ids) 45 | if len(device_ids) == 1: 46 | os.environ['CUDA_VISIBLE_DEVICES'] = str(device_ids[0]) 47 | model = model.to(device) 48 | return model,device 49 | 50 | def restore_checkpoint(resume_path,model = None,optimizer = None): 51 | ''' 52 | 加载模型 53 | :param resume_path: 54 | :param model: 55 | :param optimizer: 56 | :return: 57 | 注意: 如果是加载Bert模型的话,需要调整,不能使用该模式 58 | 可以使用模块自带的Bert_model.from_pretrained(state_dict = your save state_dict) 59 | ''' 60 | if isinstance(resume_path,Path): 61 | resume_path = str(resume_path) 62 | checkpoint = torch.load(resume_path) 63 | best = checkpoint['best'] 64 | start_epoch = checkpoint['epoch'] + 1 65 | if model: 66 | model.load_state_dict(checkpoint['state_dict']) 67 | if optimizer: 68 | optimizer.load_state_dict(checkpoint['optimizer']) 69 | return [model,optimizer,best,start_epoch] 70 | 71 | 72 | def load_bert(model_path,model = None,optimizer = None): 73 | ''' 74 | 加载模型 75 | :param resume_path: 76 | :param model: 77 | :param optimizer: 78 | :return: 79 | ''' 80 | if isinstance(model_path,Path): 81 | model_path = str(model_path) 82 | checkpoint = torch.load(model_path) 83 | state_dict = checkpoint['state_dict'] 84 | # new_state_dict = {} 85 | # for key,value in state_dict.items(): 86 | # if "module" in key: 87 | # new_state_dict[key.replace("module.","")] = value 88 | # else: 89 | # new_state_dict[key] = value 90 | best = checkpoint['best'] 91 | start_epoch = checkpoint['epoch'] + 1 92 | if model: 93 | model.load_state_dict(state_dict) 94 | if optimizer: 95 | optimizer.load_state_dict(checkpoint['optimizer']) 96 | return [model,optimizer,best,start_epoch] 97 | 98 | def seed_everything(seed = 1029,device='cpu'): 99 | ''' 100 | 设置seed环境 101 | :param seed: 102 | :param device: 103 | :return: 104 | ''' 105 | random.seed(seed) 106 | os.environ['PYTHONHASHSEED'] = str(seed) 107 | np.random.seed(seed) 108 | torch.manual_seed(seed) 109 | if 'cuda' in device: 110 | torch.cuda.manual_seed(seed) 111 | torch.cuda.manual_seed_all(seed) 112 | torch.backends.cudnn.deterministic = True 113 | 114 | def collate_fn(batch): 115 | ''' 116 | batch的数据处理 117 | :param batch: 118 | :return: 119 | ''' 120 | r"""Puts each data field into a tensor with outer dimension batch size""" 121 | transposed = zip(*batch) 122 | lbd = lambda batch:torch.cat([torch.from_numpy(b).long() for b in batch]) 123 | return [lbd(samples) for samples in transposed] 124 | 125 | class AverageMeter(object): 126 | ''' 127 | computes and stores the average and current value 128 | ''' 129 | def __init__(self): 130 | self.reset() 131 | def reset(self): 132 | self.val = 0 133 | self.avg = 0 134 | self.sum = 0 135 | self.count = 0 136 | 137 | def update(self,val,n = 1): 138 | self.val = val 139 | self.sum += val * n 140 | self.count += n 141 | self.avg = self.sum / self.count 142 | 143 | def summary(model, *inputs, batch_size=-1, show_input=True): 144 | 145 | def register_hook(module): 146 | def hook(module, input, output=None): 147 | class_name = str(module.__class__).split(".")[-1].split("'")[0] 148 | module_idx = len(summary) 149 | 150 | m_key = f"{class_name}-{module_idx + 1}" 151 | summary[m_key] = OrderedDict() 152 | summary[m_key]["input_shape"] = list(input[0].size()) 153 | summary[m_key]["input_shape"][0] = batch_size 154 | 155 | if show_input is False and output is not None: 156 | if isinstance(output, (list, tuple)): 157 | for out in output: 158 | if isinstance(out, torch.Tensor): 159 | summary[m_key]["output_shape"] = [ 160 | [-1] + list(out.size())[1:] 161 | ][0] 162 | else: 163 | summary[m_key]["output_shape"] = [ 164 | [-1] + list(out[0].size())[1:] 165 | ][0] 166 | else: 167 | summary[m_key]["output_shape"] = list(output.size()) 168 | summary[m_key]["output_shape"][0] = batch_size 169 | 170 | params = 0 171 | if hasattr(module, "weight") and hasattr(module.weight, "size"): 172 | params += torch.prod(torch.LongTensor(list(module.weight.size()))) 173 | summary[m_key]["trainable"] = module.weight.requires_grad 174 | if hasattr(module, "bias") and hasattr(module.bias, "size"): 175 | params += torch.prod(torch.LongTensor(list(module.bias.size()))) 176 | summary[m_key]["nb_params"] = params 177 | 178 | if (not isinstance(module, nn.Sequential) and not isinstance(module, nn.ModuleList) and not (module == model)): 179 | if show_input is True: 180 | hooks.append(module.register_forward_pre_hook(hook)) 181 | else: 182 | hooks.append(module.register_forward_hook(hook)) 183 | 184 | # create properties 185 | summary = OrderedDict() 186 | hooks = [] 187 | 188 | # register hook 189 | model.apply(register_hook) 190 | model(*inputs) 191 | 192 | # remove these hooks 193 | for h in hooks: 194 | h.remove() 195 | 196 | print("-----------------------------------------------------------------------") 197 | if show_input is True: 198 | line_new = f"{'Layer (type)':>25} {'Input Shape':>25} {'Param #':>15}" 199 | else: 200 | line_new = f"{'Layer (type)':>25} {'Output Shape':>25} {'Param #':>15}" 201 | print(line_new) 202 | print("=======================================================================") 203 | 204 | total_params = 0 205 | total_output = 0 206 | trainable_params = 0 207 | for layer in summary: 208 | # input_shape, output_shape, trainable, nb_params 209 | if show_input is True: 210 | line_new = "{:>25} {:>25} {:>15}".format( 211 | layer, 212 | str(summary[layer]["input_shape"]), 213 | "{0:,}".format(summary[layer]["nb_params"]), 214 | ) 215 | else: 216 | line_new = "{:>25} {:>25} {:>15}".format( 217 | layer, 218 | str(summary[layer]["output_shape"]), 219 | "{0:,}".format(summary[layer]["nb_params"]), 220 | ) 221 | 222 | total_params += summary[layer]["nb_params"] 223 | if show_input is True: 224 | total_output += np.prod(summary[layer]["input_shape"]) 225 | else: 226 | total_output += np.prod(summary[layer]["output_shape"]) 227 | if "trainable" in summary[layer]: 228 | if summary[layer]["trainable"] == True: 229 | trainable_params += summary[layer]["nb_params"] 230 | 231 | print(line_new) 232 | 233 | print("=======================================================================") 234 | print(f"Total params: {total_params:0,}") 235 | print(f"Trainable params: {trainable_params:0,}") 236 | print(f"Non-trainable params: {(total_params - trainable_params):0,}") 237 | print("-----------------------------------------------------------------------") 238 | 239 | def ensure_dir(path): 240 | if not os.path.exists(path): 241 | os.makedirs(path) 242 | 243 | def json_write(data,filename): 244 | with open(filename,'w') as f: 245 | json.dump(data,f) 246 | 247 | def json_read(filename): 248 | with open(filename,'r') as f: 249 | return json.load(f) 250 | 251 | def pkl_read(filename): 252 | with open(filename,'rb') as f: 253 | return pickle.load(f) 254 | 255 | def pkl_write(filename,data): 256 | with open(filename, 'wb') as f: 257 | pickle.dump(data, f) 258 | 259 | def text_write(filename,data): 260 | with open(filename,'w') as fw: 261 | for sentence,target in tqdm(data,desc = 'write data to disk'): 262 | target = [str(x) for x in target] 263 | line = '\t'.join([sentence,",".join(target)]) 264 | fw.write(line +'\n') --------------------------------------------------------------------------------