├── .gitignore ├── README.txt ├── binary_classify.py ├── dianping_train_test.xls ├── financial_data.xlsx └── requirements.txt /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | -------------------------------------------------------------------------------- /README.txt: -------------------------------------------------------------------------------- 1 | ##BERT-二分类 2 | ### 1.库 3 | > 本示例基于pytorch开发的pytorch-pretrained-bert库,运行需要torch支持,请前往官网确定如何安装。 4 | -------------------------------------------------------------------------------- /binary_classify.py: -------------------------------------------------------------------------------- 1 | import time 2 | import math 3 | import pandas as pd 4 | import numpy as np 5 | import torch 6 | from torch import nn 7 | from tqdm import tqdm 8 | from sklearn.metrics import classification_report 9 | from concurrent.futures import ThreadPoolExecutor 10 | from torch.utils.data import TensorDataset, DataLoader 11 | from pytorch_pretrained_bert import BertTokenizer, BertModel 12 | from pytorch_pretrained_bert.optimization import BertAdam 13 | 14 | 15 | class ClassifyModel(nn.Module): 16 | def __init__(self, pretrained_model_name_or_path, num_labels, is_lock=False): 17 | super(ClassifyModel, self).__init__() 18 | self.bert = BertModel.from_pretrained(pretrained_model_name_or_path) 19 | config = self.bert.config 20 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 21 | self.classifier = nn.Linear(768, num_labels) 22 | if is_lock: 23 | # 加载并冻结bert模型参数 24 | for name, param in self.bert.named_parameters(): 25 | if name.startswith('pooler'): 26 | continue 27 | else: 28 | param.requires_grad_(False) 29 | 30 | def forward(self, input_ids, token_type_ids=None, attention_mask=None): 31 | _, pooled = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False) 32 | pooled = self.dropout(pooled) 33 | logits = self.classifier(pooled) 34 | # 测试1,直接 35 | return logits 36 | 37 | 38 | class DataProcessForSingleSentence(object): 39 | def __init__(self, bert_tokenizer, max_workers=10): 40 | """ 41 | :param bert_tokenizer: 分词器 42 | :param max_workers: 包含列名comment和sentiment的data frame 43 | """ 44 | self.bert_tokenizer = bert_tokenizer 45 | self.pool = ThreadPoolExecutor(max_workers=max_workers) 46 | 47 | def get_input(self, dataset, max_seq_len=30): 48 | sentences = dataset.iloc[:, 1].tolist() 49 | labels = dataset.iloc[:, 2].tolist() 50 | # 切词 51 | token_seq = list(self.pool.map(self.bert_tokenizer.tokenize, sentences)) 52 | # 获取定长序列及其mask 53 | result = list(self.pool.map(self.trunate_and_pad, token_seq, 54 | [max_seq_len] * len(token_seq))) 55 | seqs = [i[0] for i in result] 56 | seq_masks = [i[1] for i in result] 57 | seq_segments = [i[2] for i in result] 58 | 59 | t_seqs = torch.tensor(seqs, dtype=torch.long) 60 | t_seq_masks = torch.tensor(seq_masks, dtype=torch.long) 61 | t_seq_segments = torch.tensor(seq_segments, dtype=torch.long) 62 | t_labels = torch.tensor(labels, dtype=torch.long) 63 | 64 | return TensorDataset(t_seqs, t_seq_masks, t_seq_segments, t_labels) 65 | 66 | def trunate_and_pad(self, seq, max_seq_len): 67 | # 对超长序列进行截断 68 | if len(seq) > (max_seq_len - 2): 69 | seq = seq[0: (max_seq_len - 2)] 70 | # 添加特殊字符 71 | seq = ['[CLS]'] + seq + ['[SEP]'] 72 | # id化 73 | seq = self.bert_tokenizer.convert_tokens_to_ids(seq) 74 | # 根据max_seq_len与seq的长度产生填充序列 75 | padding = [0] * (max_seq_len - len(seq)) 76 | # 创建seq_mask 77 | seq_mask = [1] * len(seq) + padding 78 | # 创建seq_segment 79 | seq_segment = [0] * len(seq) + padding 80 | # 对seq拼接填充序列 81 | seq += padding 82 | assert len(seq) == max_seq_len 83 | assert len(seq_mask) == max_seq_len 84 | assert len(seq_segment) == max_seq_len 85 | return seq, seq_mask, seq_segment 86 | 87 | 88 | def load_data(filepath, pretrained_model_name_or_path, max_seq_len, batch_size): 89 | """ 90 | 加载excel文件,有train和test 的sheet 91 | :param filepath: 文件路径 92 | :param pretrained_model_name_or_path: 使用什么样的bert模型 93 | :param max_seq_len: bert最大尺寸,不能超过512 94 | :param batch_size: 小批量训练的数据 95 | :return: 返回训练和测试数据迭代器 DataLoader形式 96 | """ 97 | io = pd.io.excel.ExcelFile(filepath) 98 | raw_train_data = pd.read_excel(io, sheet_name='train') 99 | raw_test_data = pd.read_excel(io, sheet_name='test') 100 | io.close() 101 | # 分词工具 102 | bert_tokenizer = BertTokenizer.from_pretrained(pretrained_model_name_or_path, do_lower_case=True) 103 | processor = DataProcessForSingleSentence(bert_tokenizer=bert_tokenizer) 104 | # 产生输入句 数据 105 | train_data = processor.get_input(raw_train_data, max_seq_len) 106 | test_data = processor.get_input(raw_test_data, max_seq_len) 107 | 108 | train_iter = DataLoader(dataset=train_data, batch_size=batch_size, shuffle=True) 109 | test_iter = DataLoader(dataset=test_data, batch_size=batch_size, shuffle=True) 110 | # 训练和测试的样本数量 111 | total_train_batch = math.ceil(len(raw_train_data) / batch_size) 112 | total_test_batch = math.ceil(len(raw_test_data) / batch_size) 113 | return train_iter, test_iter, total_train_batch, total_test_batch 114 | 115 | 116 | def evaluate_accuracy(data_iter, net, device, batch_count): 117 | # 记录预测标签和真实标签 118 | prediction_labels, true_labels = [], [] 119 | with torch.no_grad(): 120 | for batch_data in tqdm(data_iter, desc='eval', total=batch_count): 121 | batch_data = tuple(t.to(device) for t in batch_data) 122 | # 获取给定的输出和模型给的输出 123 | labels = batch_data[-1] 124 | output = net(*batch_data[:-1]) 125 | predictions = output.softmax(dim=1).argmax(dim=1) 126 | prediction_labels.append(predictions.detach().cpu().numpy()) 127 | true_labels.append(labels.detach().cpu().numpy()) 128 | 129 | return classification_report(np.concatenate(true_labels), np.concatenate(prediction_labels)) 130 | 131 | 132 | if __name__ == '__main__': 133 | batch_size, max_seq_len = 32, 200 134 | train_iter, test_iter, train_batch_count, test_batch_count = load_data('dianping_train_test.xls', 'bert-base-chinese', max_seq_len, batch_size) 135 | # 加载模型 136 | # model = BertForSequenceClassification.from_pretrained('bert-base-chinese', num_labels=2) 137 | model = ClassifyModel('bert-base-chinese', num_labels=2, is_lock=True) 138 | print(model) 139 | 140 | optimizer = BertAdam(model.parameters(), lr=5e-05) 141 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 142 | model.to(device) 143 | loss_func = nn.CrossEntropyLoss() 144 | 145 | for epoch in range(4): 146 | start = time.time() 147 | model.train() 148 | # loss和精确度 149 | train_loss_sum, train_acc_sum, n = 0.0, 0.0, 0 150 | for step, batch_data in tqdm(enumerate(train_iter), desc='train epoch:{}/{}'.format(epoch + 1, 4) 151 | , total=train_batch_count): 152 | batch_data = tuple(t.to(device) for t in batch_data) 153 | batch_seqs, batch_seq_masks, batch_seq_segments, batch_labels = batch_data 154 | 155 | logits = model(batch_seqs, batch_seq_masks, batch_seq_segments) 156 | loss = loss_func(logits, batch_labels) 157 | loss.backward() 158 | train_loss_sum += loss.item() 159 | logits = logits.softmax(dim=1) 160 | train_acc_sum += (logits.argmax(dim=1) == batch_labels).sum().item() 161 | n += batch_labels.shape[0] 162 | optimizer.step() 163 | optimizer.zero_grad() 164 | # 每一代都判断 165 | model.eval() 166 | 167 | result = evaluate_accuracy(test_iter, model, device,test_batch_count) 168 | print('epoch %d, loss %.4f, train acc %.3f, time: %.3f' % 169 | (epoch + 1, train_loss_sum / n, train_acc_sum / n, (time.time() - start))) 170 | print(result) 171 | 172 | torch.save(model, 'fine_tuned_chinese_bert.bin') 173 | -------------------------------------------------------------------------------- /dianping_train_test.xls: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sky94520/binary-classification/55696c3dd105688e4adbed16328003b88ffe2d39/dianping_train_test.xls -------------------------------------------------------------------------------- /financial_data.xlsx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sky94520/binary-classification/55696c3dd105688e4adbed16328003b88ffe2d39/financial_data.xlsx -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy==1.18.2 2 | pandas==0.24.2 3 | pytorch-pretrained-bert==0.6.2 4 | sklearn==0.24 5 | torch==1.4.0 6 | torchvision==0.5.0 7 | tqdm==4.43.0 8 | --------------------------------------------------------------------------------