├── README.md ├── data └── NewsTitle │ ├── class.txt │ ├── dev.txt │ ├── test.txt │ └── train.txt ├── main.py ├── models └── BERT.py ├── predict.py ├── pretrained └── bert-base-chinese │ └── occupied.txt ├── requirements.txt ├── test.py ├── train.py └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | ## PyTorch BERT 文本分类 2 | 3 | ### 预训练模型下载路径 4 | 5 | ```text 6 | https://huggingface.co/bert-base-chinese/tree/main 7 | ``` 8 | 下载config.json, pytorch_model.bin, vocab.txt, 存放在pretrained/bert-base-chinese/文件夹中 9 | ```text 10 | pretrained 11 | │ 12 | └─bert-base-chinese 13 | config.json 14 | pytorch_model.bin 15 | vocab.txt 16 | 17 | ``` 18 | 19 | **环境** 20 | ```text 21 | torch==1.10.1+cu113 22 | transformers==4.15.0 23 | numpy==1.22.3 24 | sklearn==0.0 25 | scikit-learn==1.0.2 26 | tqdm==4.62.3 27 | ``` 28 | 29 | **使用方法** 30 | 31 | 1. 安装环境 32 | ```shell 33 | pip install requirements.txt 34 | ``` 35 | 2. 运行代码 36 | ```shell 37 | python main.py 38 | ``` 39 | ### 欢迎star -------------------------------------------------------------------------------- /data/NewsTitle/class.txt: -------------------------------------------------------------------------------- 1 | finance 2 | realty 3 | stocks 4 | education 5 | science 6 | society 7 | politics 8 | sports 9 | game 10 | entertainment -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @File : main.py 4 | # @Author: LauTrueYes 5 | # @Date : 2020/12/27 6 | import time 7 | import torch 8 | import numpy as np 9 | import utils 10 | import argparse 11 | from train import train 12 | from test import test 13 | from predict import predict 14 | from importlib import import_module 15 | 16 | parser = argparse.ArgumentParser(description='TextClassification') 17 | parser.add_argument('--model', type=str, default='BERT', help='BERT') #在defaule中修改所需的模型 18 | args = parser.parse_args() 19 | 20 | 21 | if __name__ == '__main__': 22 | dataset = './data/NewsTitle' #数据集地址 23 | model_name = args.model 24 | lib = import_module('models.'+model_name) 25 | config = lib.Config(dataset) 26 | model = lib.Model(config).to(config.device) 27 | 28 | np.random.seed(1) 29 | torch.manual_seed(1) 30 | torch.cuda.manual_seed_all(4) 31 | torch.backends.cudnn.deterministic = True #保证每次运行结果一样 32 | 33 | start_time = time.time() 34 | print('加载数据集') 35 | train_data, dev_data, test_data = utils.build_dataset(config) 36 | train_loader = utils.build_data_loader(train_data, config) 37 | dev_loader = utils.build_data_loader(dev_data, config) 38 | test_loader = utils.build_data_loader(test_data, config) 39 | 40 | time_dif = utils.get_time_dif(start_time) 41 | print("模型开始之前,准备数据时间:",time_dif) 42 | 43 | train(config, model, train_loader, dev_loader) 44 | predict(config, model, test_loader) 45 | test(config, model, test_loader) 46 | -------------------------------------------------------------------------------- /models/BERT.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @File : BERT.py 4 | # @Author: LauTrueYes 5 | # @Date : 2020/12/27 6 | import torch 7 | import torch.nn as nn 8 | from torch.nn import CrossEntropyLoss 9 | from transformers import BertModel, BertTokenizer 10 | 11 | class Config(object): 12 | """ 13 | 配置参数 14 | """ 15 | def __init__(self, dataset): 16 | self.model_name = 'BERT' #模型名称 17 | 18 | self.train_path = dataset + '/train.txt' #训练集 19 | self.test_path = dataset + '/test.txt' #测试集 20 | self.dev_path = dataset + '/dev.txt' #验证集 21 | self.predict_path = dataset + '/' + self.model_name +'_predict.txt' #预测数据 22 | 23 | self.class_list = [x.strip() for x in open(dataset + '/class.txt', encoding='utf-8').readlines()] #类别 24 | self.class2id = {cls:id for id, cls in enumerate(self.class_list)} 25 | self.id2class = {j:i for i, j in self.class2id.items()} 26 | self.save_path = dataset + '/saved_data/' + self.model_name + '.ckpt' #模型训练结果 27 | self.device = torch.device( 'cuda' if torch.cuda.is_available() else 'cpu') #设备配置 28 | 29 | self.require_improvement = 1000 #若超过1000batch效果还没有提升,提前结束训练 30 | 31 | self.num_classes = len(self.class_list) #类别数量 32 | self.num_epochs = 1 #轮次数 33 | self.batch_size = 64 #batch_size,一次传入128个pad_size 34 | self.pad_size = 60 #每句话处理长度(短填,长切) 35 | self.learning_rate = 1e-5 #学习率 36 | self.bert_path = './pretrained/bert-base-chinese' #bert预训练位置 37 | self.tokenizer = BertTokenizer.from_pretrained(self.bert_path) #bert切词器 38 | self.hidden_size = 768 #bert隐藏层个数,在bert_config.json中有设定,不能随意改 39 | self.hidden_dropout_prob = 0.1 40 | 41 | 42 | class Model(nn.Module): 43 | def __init__(self, config): 44 | super(Model, self).__init__() #继承父类方法 45 | self.num_classes = config.num_classes 46 | self.bert = BertModel.from_pretrained(config.bert_path) #加载预训练模型 47 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 48 | for param in self.bert.parameters(): #加载bert所有参数 49 | param.requires_grad = True #需要梯度,需要微调,一般都设定为True 50 | # 以上为原生BERT 51 | 52 | self.classifier = nn.Linear(config.hidden_size, config.num_classes) 53 | 54 | 55 | def forward(self, input): 56 | #x是输入数据其中有:[ids, seq_len, mask] 57 | input_ids, attention_mask, labels = input[0], input[1], input[2] 58 | outputs = self.bert(input_ids=input_ids) #shape[batch_size, hidden_size] 59 | sequence_output = outputs[1] 60 | sequence_output = self.dropout(sequence_output) 61 | #不需要encoded_layers,只需要pooled_output返回得到pooled 62 | logits = self.classifier(sequence_output) #shape[batch_size, num_classes] 63 | loss = None 64 | if labels is not None: 65 | loss_fct = CrossEntropyLoss() 66 | loss = loss_fct(logits.view(-1, self.num_classes), labels.view(-1)) 67 | if loss is not None: 68 | return loss, logits.argmax(dim=-1) 69 | else: 70 | return logits.argmax(dim=-1) 71 | -------------------------------------------------------------------------------- /predict.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @File : predict.py 4 | # @Author: LauTrueYes 5 | # @Date : 2020/12/27 6 | 7 | import torch 8 | 9 | def predict(config, model, test_iter): 10 | """ 11 | 12 | :param config: 13 | :param model: 14 | :param test_iter: 15 | :return: 16 | """ 17 | model.load_state_dict(torch.load(config.save_path)) 18 | model.eval() 19 | predict_labels = [] 20 | with torch.no_grad(): #不需要梯度 21 | for i, (input_ids, attention_mask, labels) in enumerate(test_iter): 22 | input_ids, attention_mask, labels = input_ids.to(config.device), attention_mask.to(config.device), labels.to(config.device) 23 | input = (input_ids, attention_mask, None) 24 | label_predict = model(input) 25 | 26 | predict = [config.id2class[i] for i in label_predict.tolist()] 27 | predict_labels.append(predict) 28 | with open(config.predict_path, 'a', encoding='utf-8') as p: 29 | with open(config.test_path, 'r', encoding='utf-8') as t: 30 | i, j = 0, 0 31 | for line in t: 32 | line = line.strip() 33 | if not line: 34 | continue 35 | content, label = line.split('\t') 36 | predict_label = predict_labels[i][j] 37 | predict_data = str(content) + '\t' + predict_label + '\n' 38 | j += 1 39 | if j == config.batch_size: 40 | i += 1 41 | j = 0 42 | p.write(predict_data) 43 | -------------------------------------------------------------------------------- /pretrained/bert-base-chinese/occupied.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FreeRotate/TextClassificationBERT/5f37f5624e4c70bc90f4a6e1ca9f3220a54cb334/pretrained/bert-base-chinese/occupied.txt -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.10.1+cu113 2 | transformers==4.15.0 3 | numpy==1.22.3 4 | sklearn==0.0 5 | scikit-learn==1.0.2 6 | tqdm==4.62.3 -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @File : test.py 4 | # @Author: LauTrueYes 5 | # @Date : 2020/12/27 6 | import torch 7 | from train import evaluate 8 | 9 | def test(config, model, test_loader): 10 | """ 11 | 模型测试 12 | :param config: 13 | :param model: 14 | :param test_loader: 15 | :return: 16 | """ 17 | model.load_state_dict(torch.load(config.save_path)) 18 | model.eval() 19 | test_loss, test_acc, test_f1, test_report, test_confusion = evaluate(config, model, test_loader) 20 | msg = "Dev Loss:{}--------Dev Acc:{}--------Dev F1:{}" 21 | print(msg.format(test_loss, test_acc, test_f1)) 22 | print("Dev Report") 23 | print(test_report) 24 | print("Dev Confusion") 25 | print(test_confusion) -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @File : train.py 4 | # @Author: LauTrueYes 5 | # @Date : 2020/12/27 6 | import torch 7 | import numpy as np 8 | from sklearn import metrics 9 | from torch.optim import AdamW 10 | 11 | def train(config, model, train_loader, dev_loader): 12 | 13 | dev_best_f1 = float('-inf') 14 | avg_loss = [] 15 | param_optimizer = list(model.named_parameters()) #拿到所有model中的参数 16 | no_decay = ['bias','LayerNorm.bias', 'LayerNorm.weight'] #不需要衰减的参数 17 | optimizer_grouped_parameters = [ 18 | {'params':[p for n,p in param_optimizer if not any( nd in n for nd in no_decay) ], 'weight_decay':0.01 }, 19 | {'params':[p for n,p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay':0.0} 20 | ] 21 | 22 | optimizer = AdamW(params = optimizer_grouped_parameters, lr = config.learning_rate) 23 | 24 | for epoch in range(config.num_epochs): 25 | train_right, train_total = 0, 0 26 | model.train() 27 | model.to(config.device) 28 | print('Epoch:{}/{}'.format(epoch+1, config.num_epochs)) 29 | for batch_idx,(input_ids, attention_mask, label_ids) in enumerate(train_loader): 30 | input_ids, attention_mask, label_ids = input_ids.to(config.device), attention_mask.to(config.device), label_ids.to(config.device) 31 | input = (input_ids, attention_mask, label_ids) 32 | loss, predicts = model(input) 33 | 34 | avg_loss.append(loss.data.item()) 35 | 36 | batch_right = (predicts == label_ids).sum().item() 37 | train_right += batch_right 38 | train_total += len(predicts) 39 | 40 | 41 | optimizer.zero_grad() 42 | loss.backward() 43 | optimizer.step() 44 | 45 | if batch_idx % 10 == 0: 46 | print("Epoch:{}--------Iter:{}--------train_loss:{:.3f}--------train_acc:{:.3f}".format(epoch + 1, 47 | batch_idx + 1, 48 | np.array(avg_loss).mean(), 49 | train_right/train_total)) 50 | dev_loss, dev_acc, dev_f1, dev_report, dev_confusion = evaluate(config, model, dev_loader) 51 | msg = "Dev Loss:{}--------Dev Acc:{}--------Dev F1:{}" 52 | print(msg.format(dev_loss, dev_acc, dev_f1)) 53 | print("Dev Report") 54 | print(dev_report) 55 | print("Dev Confusion") 56 | print(dev_confusion) 57 | 58 | if dev_best_f1 < dev_f1: 59 | dev_best_f1 = dev_f1 60 | torch.save(model.state_dict(), config.save_path) 61 | print("***************************** Save Model *****************************") 62 | 63 | 64 | def evaluate(config, model, dev_loader): 65 | 66 | loss_all = np.array([], dtype=float) 67 | predict_all = np.array([], dtype=int) 68 | label_all = np.array([], dtype=int) 69 | with torch.no_grad(): #不需要梯度 70 | model.eval() # 开启评估模式 71 | for i, (input_ids, attention_mask, label_ids) in enumerate(dev_loader): 72 | input_ids, attention_mask, label_ids = input_ids.to(config.device), attention_mask.to(config.device), label_ids.to(config.device) 73 | input = (input_ids, attention_mask, label_ids) 74 | loss, label_predict = model(input) 75 | 76 | loss_all = np.append(loss_all, loss.data.item()) 77 | predict_all = np.append(predict_all, label_predict.data.cpu().numpy()) 78 | label_all = np.append(label_all, label_ids.data.cpu().numpy()) 79 | acc = metrics.accuracy_score(label_all, predict_all) 80 | f1 = metrics.f1_score(label_all, predict_all, average='macro') 81 | report = metrics.classification_report(label_all, predict_all, target_names=config.class_list, digits=3) 82 | confusion = metrics.confusion_matrix(label_all, predict_all) 83 | 84 | return loss.mean(), acc, f1, report, confusion 85 | 86 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @File : utils.py 4 | # @Author: LauTrueYes 5 | # @Date : 2020/12/27 6 | from tqdm import tqdm 7 | import torch 8 | import time 9 | from datetime import timedelta 10 | from torch.utils.data import TensorDataset, DataLoader 11 | 12 | PAD, CLS = '[PAD]', '[CLS]' 13 | 14 | def load_dataset(file_path, config): 15 | """ 16 | 返回结果4个list:ids, label, ids_len, mask 17 | :param file_path: 18 | :param seq_len: 19 | :return: 20 | """ 21 | contents = [] 22 | with open(file_path, 'r', encoding='utf-8') as f: 23 | for line in tqdm(f): 24 | line = line.strip() 25 | if not line: 26 | continue 27 | content, original_label = line.split('\t') 28 | # label = config.class2id[original_label] 29 | label = int(original_label) 30 | token = config.tokenizer.tokenize(content) 31 | token = [CLS] + token 32 | seq_len = len(token) 33 | mask = [] 34 | token_ids = config.tokenizer.convert_tokens_to_ids(token) 35 | 36 | pad_size = config.pad_size 37 | if pad_size: 38 | if len(token) < pad_size: 39 | mask = [1] * len(token_ids) + [0] * (pad_size - len(token)) 40 | token_ids = token_ids + ([0] * (pad_size-len(token))) 41 | else: 42 | mask = [1] * pad_size 43 | token_ids = token_ids[:pad_size] 44 | seq_len = pad_size 45 | contents.append((token_ids, mask, int(label))) 46 | return contents 47 | 48 | def build_dataset(config): 49 | """ 50 | 返回值train,dev,test 51 | 4个list:ids, label, ids_len, mask 52 | :param config: 53 | :return: 54 | """ 55 | train = load_dataset(config.train_path, config) 56 | dev = load_dataset(config.dev_path, config) 57 | test = load_dataset(config.test_path, config) 58 | return train, dev, test 59 | 60 | 61 | 62 | def build_data_loader(dataset, config): 63 | token_ids = [i[0] for i in dataset] 64 | mask= [i[1] for i in dataset] 65 | label_ids = [i[2]for i in dataset] 66 | iter_set = TensorDataset(torch.LongTensor(token_ids).to(config.device), 67 | torch.LongTensor(mask).to(config.device), 68 | torch.LongTensor(label_ids).to(config.device)) 69 | iter = DataLoader(iter_set, batch_size=config.batch_size, shuffle=False) 70 | return iter 71 | 72 | def get_time_dif(start_time): 73 | """ 74 | 获取已使用的时间 75 | :param start_time: 76 | :return: 77 | """ 78 | end_time = time.time() 79 | time_dif = end_time - start_time 80 | return timedelta(seconds=int(round(time_dif))) 81 | --------------------------------------------------------------------------------