├── readme.md ├── main.py ├── data_loader.py ├── .gitignore └── train.py /readme.md: -------------------------------------------------------------------------------- 1 | 2 | R-Drop的torch版本 3 | * 论文: https://arxiv.org/pdf/2106.14448.pdf 4 | * keras(苏神):https://github.com/bojone/r-drop 5 | * 官方:https://github.com/dropreg/R-Drop 6 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | # author: sunshine 2 | # datetime:2021/7/2 下午2:06 3 | from argparse import Namespace 4 | from train import Trainer 5 | from data_loader import load_data, NewsDataset 6 | from transformers import BertTokenizer 7 | import json 8 | 9 | 10 | def get_args(): 11 | params = dict( 12 | max_len=128, 13 | batch_size=4, 14 | drop=0.3, 15 | epoch_num=10, 16 | learning_rate=2e-5, 17 | warmup_proportion=0.1, 18 | data_path='/home/sunshine/datasets/tnews/', 19 | output='output', 20 | bert_path='/home/sunshine/pre_models/pytorch/bert-base-chinese/', 21 | train_mode='train' 22 | ) 23 | return Namespace(**params) 24 | 25 | 26 | def build_dataset(args, tokenizer): 27 | """ 28 | 数据处理 29 | :return: 30 | """ 31 | labels = [ 32 | "100", "101", "102", "103", "104", "106", "107", "108", "109", "110", "112", 33 | "113", "114", "115", "116" 34 | ] 35 | 36 | train_data = load_data(args.data_path + '/train.json', labels) 37 | valid_data = load_data(args.data_path + '/dev.json', labels) 38 | print(len(train_data)) 39 | train_loader = NewsDataset(train_data, tokenizer).get_data_loader(batch_size=args.batch_size, shuffle=True) 40 | valid_loader = NewsDataset(valid_data, tokenizer).get_data_loader(batch_size=args.batch_size, shuffle=False) 41 | print(len(train_loader)) 42 | return [train_loader, valid_loader], labels 43 | 44 | 45 | def main(): 46 | # 准备参数 47 | args = get_args() 48 | 49 | tokenizer = BertTokenizer.from_pretrained(args.bert_path) 50 | 51 | # 处理数据 52 | data_loader, labels = build_dataset(args, tokenizer) 53 | 54 | # 构建trainer 55 | 56 | trainer = Trainer( 57 | args=args, 58 | data_loaders=data_loader, 59 | tokenizer=tokenizer, 60 | num_labels=len(labels) 61 | ) 62 | 63 | trainer.train(args) 64 | 65 | 66 | if __name__ == '__main__': 67 | main() 68 | # args = get_args() 69 | # tokenizer = BertTokenizer.from_pretrained(args.bert_path) 70 | # a = tokenizer(['我是中国人', '张三十上班张三是水电费是否'], padding='longest', max_length=30, truncation='longest_first') 71 | # print(a) -------------------------------------------------------------------------------- /data_loader.py: -------------------------------------------------------------------------------- 1 | # author: sunshine 2 | # datetime:2021/7/2 下午2:06 3 | 4 | # 数据处理器 5 | 6 | from torch.utils.data import DataLoader, Dataset 7 | from functools import partial 8 | import torch 9 | import json 10 | 11 | 12 | def load_data(path, labels): 13 | """ 14 | 样本格式: (text, id) 15 | """ 16 | D = [] 17 | with open(path, 'r', encoding='utf-8') as f: 18 | for l in f: 19 | l = json.loads(l) 20 | text, label = l['sentence'], l['label'] 21 | D.append((text, labels.index(label))) 22 | return D 23 | 24 | 25 | class NewsDataset(Dataset): 26 | def __init__(self, data, tokenizer, max_len=128): 27 | self.data = data 28 | self.tokenizer = tokenizer 29 | self.max_len = max_len 30 | 31 | def __getitem__(self, item): 32 | return self.data[item] 33 | 34 | def __len__(self): 35 | return len(self.data) 36 | 37 | def create_collate_fn(self): 38 | def collate(examples): 39 | inputs = self.tokenizer([e[0] for e in examples], padding='longest', max_length=self.max_len, 40 | truncation='longest_first') 41 | 42 | input_ids = sum([[item, item] for item in inputs['input_ids']], []) 43 | attention_mask = sum([[item, item] for item in inputs['attention_mask']], []) 44 | token_type_ids = sum([[item, item] for item in inputs['token_type_ids']], []) 45 | 46 | # input_ids = torch.tensor(inputs['input_ids'], dtype=torch.long).repeat(2, 1) 47 | # attention_mask = torch.tensor(inputs['attention_mask'], dtype=torch.long).repeat(2, 1) 48 | # token_type_ids = torch.tensor(inputs['token_type_ids'], dtype=torch.long).repeat(2, 1) 49 | 50 | input_ids = torch.tensor(input_ids, dtype=torch.long) 51 | attention_mask = torch.tensor(attention_mask, dtype=torch.long) 52 | token_type_ids = torch.tensor(token_type_ids, dtype=torch.long) 53 | 54 | label = sum([[item, item] for item in [e[1] for e in examples]], []) 55 | label = torch.tensor(label, dtype=torch.long) 56 | 57 | return input_ids, attention_mask, token_type_ids, label 58 | 59 | return partial(collate) 60 | 61 | def get_data_loader(self, batch_size, shuffle=True, num_workers=0): 62 | return DataLoader(self, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, 63 | collate_fn=self.create_collate_fn()) 64 | 65 | 66 | if __name__ == '__main__': 67 | a = [[1, 2, 3], [4, 5, 6]] 68 | 69 | b = [[i, i] for i in a] 70 | print(b) 71 | print(sum(b, [])) 72 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | ### Python template 2 | # Byte-compiled / optimized / DLL files 3 | __pycache__/ 4 | *.py[cod] 5 | *$py.class 6 | 7 | # C extensions 8 | *.so 9 | 10 | # Distribution / packaging 11 | .Python 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | cover/ 54 | 55 | # Translations 56 | *.mo 57 | *.pot 58 | 59 | # Django stuff: 60 | *.log 61 | local_settings.py 62 | db.sqlite3 63 | db.sqlite3-journal 64 | 65 | # Flask stuff: 66 | instance/ 67 | .webassets-cache 68 | 69 | # Scrapy stuff: 70 | .scrapy 71 | 72 | # Sphinx documentation 73 | docs/_build/ 74 | 75 | # PyBuilder 76 | .pybuilder/ 77 | target/ 78 | 79 | # Jupyter Notebook 80 | .ipynb_checkpoints 81 | 82 | # IPython 83 | profile_default/ 84 | ipython_config.py 85 | 86 | # pyenv 87 | # For a library or package, you might want to ignore these files since the code is 88 | # intended to run in multiple environments; otherwise, check them in: 89 | # .python-version 90 | 91 | # pipenv 92 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 93 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 94 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 95 | # install all needed dependencies. 96 | #Pipfile.lock 97 | 98 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 99 | __pypackages__/ 100 | 101 | # Celery stuff 102 | celerybeat-schedule 103 | celerybeat.pid 104 | 105 | # SageMath parsed files 106 | *.sage.py 107 | 108 | # Environments 109 | .env 110 | .venv 111 | env/ 112 | venv/ 113 | ENV/ 114 | env.bak/ 115 | venv.bak/ 116 | 117 | # Spyder project settings 118 | .spyderproject 119 | .spyproject 120 | 121 | # Rope project settings 122 | .ropeproject 123 | 124 | # mkdocs documentation 125 | /site 126 | 127 | # mypy 128 | .mypy_cache/ 129 | .dmypy.json 130 | dmypy.json 131 | 132 | # Pyre type checker 133 | .pyre/ 134 | 135 | # pytype static type analyzer 136 | .pytype/ 137 | 138 | # Cython debug symbols 139 | cython_debug/ 140 | 141 | RDrop_sup_keras.py 142 | RDrop_unsup_keras.py 143 | t.py 144 | .idea/ -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # author: sunshine 2 | # datetime:2021/7/2 下午2:06 3 | 4 | 5 | import torch 6 | import torch.nn as nn 7 | from transformers import BertModel 8 | from transformers import AdamW, get_linear_schedule_with_warmup 9 | from tqdm import tqdm 10 | import torch.nn.functional as F 11 | import logging 12 | 13 | logger = logging.getLogger(__name__) 14 | 15 | 16 | class NewsClassifier(nn.Module): 17 | def __init__(self, args, num_label): 18 | super(NewsClassifier, self).__init__() 19 | 20 | self.bert = BertModel.from_pretrained(args.bert_path) 21 | self.fc = torch.nn.Linear(768, num_label) 22 | 23 | def forward(self, input_ids, attention_mask, token_type_ids): 24 | x = self.bert(input_ids, attention_mask, token_type_ids) 25 | x = x[0][:, 0, :] # 取cls向量 26 | x = self.fc(x) 27 | return x 28 | 29 | 30 | class Trainer(object): 31 | 32 | def __init__(self, args, data_loaders, tokenizer, num_labels): 33 | 34 | self.args = args 35 | self.num_labels = num_labels 36 | 37 | self.tokenizer = tokenizer 38 | self.device = torch.device("cuda:{}".format(args.device_id) if torch.cuda.is_available() else "cpu") 39 | 40 | self.model = NewsClassifier(args, num_labels) 41 | 42 | self.model.to(self.device) 43 | if args.train_mode == "eval": 44 | self.resume() 45 | 46 | self.train_dataloader, self.dev_dataloader = data_loaders 47 | 48 | # 设置优化器,优化策略 49 | train_steps = (len(self.train_dataloader) / args.batch_size) * args.epoch_num 50 | self.optimizer, self.schedule = self.set_optimizer(args=args, 51 | model=self.model, 52 | train_steps=train_steps) 53 | 54 | self.ce = torch.nn.CrossEntropyLoss() 55 | self.kld = torch.nn.KLDivLoss(reduction="none") 56 | 57 | def loss_fnc(self, y_pred, y_true, alpha=4): 58 | """配合R-Drop的交叉熵损失 59 | """ 60 | 61 | loss1 = self.ce(y_pred, y_true) 62 | loss2 = self.kld(torch.log_softmax(y_pred[::2], dim=1), y_pred[1::2].softmax(dim=-1)) + \ 63 | self.kld(torch.log_softmax(y_pred[1::2], dim=1), y_pred[::2].softmax(dim=-1)) 64 | 65 | return loss1 + torch.mean(loss2) / 4 * alpha 66 | 67 | def set_optimizer(self, args, model, train_steps=None): 68 | no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] 69 | optimizer_grouped_parameters = [ 70 | {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 71 | 'weight_decay': 0.01}, 72 | {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 73 | 'weight_decay': 0.0} 74 | ] 75 | optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate) 76 | 77 | # optimizer, num_warmup_steps, num_training_steps 78 | schedule = get_linear_schedule_with_warmup( 79 | optimizer=optimizer, 80 | num_warmup_steps=0, 81 | num_training_steps=train_steps 82 | ) 83 | return optimizer, schedule 84 | 85 | def train(self, args): 86 | 87 | best_f1 = 0.0 88 | self.model.train() 89 | step_gap = 10 90 | step_eval = 500 91 | for epoch in range(int(args.epoch_num)): 92 | 93 | for step, batch in tqdm(enumerate(self.train_dataloader)): 94 | 95 | loss = self.forward(batch, is_eval=False) 96 | if step % step_gap == 0: 97 | print(u"step {} / {} of epoch {}, train/loss: {}".format(step, 98 | len(self.train_dataloader) / args.batch_size, 99 | epoch, loss.item())) 100 | 101 | if step % step_eval == 0: 102 | 103 | acc = self.evaluate(self.dev_dataloader) 104 | print("acc: {}".format(acc)) 105 | if acc >= best_f1: 106 | best_f1 = acc 107 | 108 | # 保存模型 109 | self.save() 110 | 111 | def forward(self, batch, is_eval=False): 112 | batch = tuple(t.to(self.device) for t in batch) 113 | if not is_eval: 114 | input_ids, attention_mask, token_type_ids, label = batch 115 | self.optimizer.zero_grad() 116 | span_logits = self.model(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids) 117 | 118 | loss = self.loss_fnc(y_pred=span_logits, y_true=label) 119 | 120 | loss.backward() 121 | torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.args.warmup_proportion) 122 | # loss = loss.item() 123 | self.optimizer.step() 124 | self.schedule.step() 125 | 126 | return loss 127 | else: 128 | input_ids, attention_mask, token_type_ids, label = batch 129 | span_logits = self.model(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids) 130 | 131 | y_pred = torch.argmax(span_logits, dim=-1) 132 | y_true = label 133 | tmp_total = len(y_true) 134 | tmp_right = (y_true == y_pred).sum().cpu().numpy() 135 | return tmp_total, tmp_right 136 | 137 | def resume(self): 138 | resume_model_file = self.args.output + "/pytorch_model.bin" 139 | logging.info("=> loading checkpoint '{}'".format(resume_model_file)) 140 | checkpoint = torch.load(resume_model_file, map_location='cpu') 141 | self.model.load_state_dict(checkpoint) 142 | 143 | def save(self): 144 | logger.info("** ** * Saving fine-tuned model ** ** * ") 145 | model_to_save = self.model.module if hasattr(self.model, 146 | 'module') else self.model # Only save the model it-self 147 | output_model_file = self.args.output + "/pytorch_model.bin" 148 | torch.save(model_to_save.state_dict(), str(output_model_file)) 149 | 150 | def evaluate(self, dataloader): 151 | """验证 152 | """ 153 | self.model.eval() 154 | total, right = 0.0, 0.0 155 | with torch.no_grad(): 156 | for batch in dataloader: 157 | tmp_total, tmp_right = self.forward(batch=batch, is_eval=True) 158 | total += tmp_total 159 | tmp_right += tmp_right 160 | self.model.train() 161 | return right / total 162 | 163 | 164 | if __name__ == '__main__': 165 | KL_criterion = torch.nn.KLDivLoss(size_average=False) 166 | a = torch.tensor([0.2, 0.1, 0.3, 0.4]) 167 | b = torch.tensor([0.1, 0.2, 0.3, 0.4]) 168 | 169 | loss = F.kl_div(a.log(), b) 170 | print(loss) 171 | --------------------------------------------------------------------------------