├── requirements.txt ├── utils └── radom_seed.py ├── datasets ├── sst_dataset.py ├── snli_dataset.py └── collate_functions.py ├── README.md ├── explain ├── model.py └── trainer.py ├── check └── trainer.py └── LICENSE /requirements.txt: -------------------------------------------------------------------------------- 1 | pytorch-lightning == 1.0.6 2 | transformers == 3.4.0 3 | torch == 1.6.0 4 | numpy >= 1.19.2 -------------------------------------------------------------------------------- /utils/radom_seed.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: Yuxian Meng 4 | @contact: yuxian_meng@shannonai.com 5 | 6 | @version: 1.0 7 | @file: radom_seed 8 | @time: 2020/7/9 15:53 9 | """ 10 | 11 | import numpy as np 12 | import torch 13 | 14 | 15 | def set_random_seed(seed: int): 16 | """set seeds for reproducibility""" 17 | np.random.seed(seed) 18 | torch.manual_seed(seed) 19 | torch.backends.cudnn.deterministic = True 20 | torch.backends.cudnn.benchmark = False 21 | 22 | 23 | if __name__ == '__main__': 24 | # without this line, x would be different in every execution. 25 | set_random_seed(0) 26 | 27 | x = np.random.random() 28 | print(x) 29 | -------------------------------------------------------------------------------- /datasets/sst_dataset.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | """ 4 | @file : sst_dataset.py 5 | @author: zijun 6 | @contact : zijun_sun@shannonai.com 7 | @date : 2020/11/17 11:45 8 | @version: 1.0 9 | @desc : sst5 and imdb task use the same dataset 10 | """ 11 | import os 12 | from functools import partial 13 | 14 | import torch 15 | from transformers import RobertaTokenizer 16 | from torch.utils.data import Dataset, DataLoader 17 | 18 | from datasets.collate_functions import collate_to_max_length 19 | 20 | 21 | class SSTDataset(Dataset): 22 | 23 | def __init__(self, directory, prefix, bert_path, max_length: int = 512): 24 | super().__init__() 25 | self.max_length = max_length 26 | with open(os.path.join(directory, prefix + '.txt'), 'r', encoding='utf8') as f: 27 | lines = f.readlines() 28 | self.lines = lines 29 | self.tokenizer = RobertaTokenizer.from_pretrained(bert_path) 30 | 31 | def __len__(self): 32 | return len(self.lines) 33 | 34 | def __getitem__(self, idx): 35 | line = self.lines[idx] 36 | label, sentence = line.split('\t', 1) 37 | # delete . 38 | sentence = sentence.strip() 39 | if sentence.endswith("."): 40 | sentence = sentence[:-1] 41 | input_ids = self.tokenizer.encode(sentence, add_special_tokens=False) 42 | if len(input_ids) > self.max_length - 2: 43 | input_ids = input_ids[:self.max_length - 2] 44 | # convert list to tensor 45 | length = torch.LongTensor([len(input_ids) + 2]) 46 | input_ids = torch.LongTensor([0] + input_ids + [2]) 47 | label = torch.LongTensor([int(label)]) 48 | return input_ids, label, length 49 | 50 | 51 | def unit_test(): 52 | root_path = "/data/nfsdata2/sunzijun/sstc/imdb_data" 53 | bert_path = "/data/nfsdata2/sunzijun/loop/roberta-base" 54 | prefix = "train" 55 | dataset = SSTDataset(directory=root_path, prefix=prefix, bert_path=bert_path) 56 | 57 | dataloader = DataLoader( 58 | dataset=dataset, 59 | batch_size=10, 60 | num_workers=0, 61 | shuffle=False, 62 | collate_fn=partial(collate_to_max_length, fill_values=[1, 0, 0]) 63 | ) 64 | for input_ids, label, length, start_index, end_index, span_mask in dataloader: 65 | print(input_ids.shape) 66 | print(start_index.shape) 67 | print(end_index.shape) 68 | print(span_mask.shape) 69 | print(label.view(-1).shape) 70 | print() 71 | 72 | 73 | if __name__ == '__main__': 74 | unit_test() 75 | -------------------------------------------------------------------------------- /datasets/snli_dataset.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | """ 4 | @file : snli_dataset.py 5 | @author: zijun 6 | @contact : zijun_sun@shannonai.com 7 | @date : 2020/11/26 14:16 8 | @version: 1.0 9 | @desc : 10 | """ 11 | 12 | import json 13 | import os 14 | from functools import partial 15 | 16 | import torch 17 | from torch.utils.data import Dataset, DataLoader 18 | from transformers import RobertaTokenizer 19 | 20 | from datasets.collate_functions import collate_to_max_length 21 | 22 | 23 | class SNLIDataset(Dataset): 24 | 25 | def __init__(self, directory, prefix, bert_path, max_length: int = 512): 26 | super().__init__() 27 | self.max_length = max_length 28 | label_map = {"contradiction": 0, 'neutral': 1, "entailment": 2} 29 | with open(os.path.join(directory, 'snli_1.0_' + prefix + '.jsonl'), 'r', encoding='utf8') as f: 30 | lines = f.readlines() 31 | self.result = [] 32 | for line in lines: 33 | line_json = json.loads(line) 34 | if line_json['gold_label'] not in label_map: 35 | # print(line_json['gold_label']) 36 | continue 37 | self.result.append((line_json['sentence1'], line_json['sentence2'], label_map[line_json['gold_label']])) 38 | self.tokenizer = RobertaTokenizer.from_pretrained(bert_path) 39 | 40 | def __len__(self): 41 | return len(self.result) 42 | 43 | def __getitem__(self, idx): 44 | sentence_1, sentence_2, label = self.result[idx] 45 | # remove . 46 | if sentence_1.endswith("."): 47 | sentence_1 = sentence_1[:-1] 48 | if sentence_2.endswith("."): 49 | sentence_2 = sentence_2[:-1] 50 | sentence_1_input_ids = self.tokenizer.encode(sentence_1, add_special_tokens=False) 51 | sentence_2_input_ids = self.tokenizer.encode(sentence_2, add_special_tokens=False) 52 | input_ids = sentence_1_input_ids + [2] + sentence_2_input_ids 53 | if len(input_ids) > self.max_length - 2: 54 | input_ids = input_ids[:self.max_length - 2] 55 | # convert list to tensor 56 | length = torch.LongTensor([len(input_ids) + 2]) 57 | input_ids = torch.LongTensor([0] + input_ids + [2]) 58 | label = torch.LongTensor([label]) 59 | return input_ids, label, length 60 | 61 | 62 | def unit_test(): 63 | root_path = "/data/nfsdata2/sunzijun/explain/snli_1.0" 64 | bert_path = "/data/nfsdata2/sunzijun/loop/roberta-base" 65 | prefix = "dev" 66 | dataset = SNLIDataset(directory=root_path, prefix=prefix, bert_path=bert_path) 67 | 68 | dataloader = DataLoader( 69 | dataset=dataset, 70 | batch_size=10, 71 | num_workers=0, 72 | shuffle=False, 73 | collate_fn=partial(collate_to_max_length, fill_values=[1, 0, 0]) 74 | ) 75 | for input_ids, label, length, start_index, end_index, span_mask in dataloader: 76 | print(input_ids.shape) 77 | print(start_index.shape) 78 | print(end_index.shape) 79 | print(span_mask.shape) 80 | print(label.view(-1).shape) 81 | print() 82 | 83 | 84 | if __name__ == '__main__': 85 | unit_test() 86 | -------------------------------------------------------------------------------- /datasets/collate_functions.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: Yuxian Meng 4 | @contact: yuxian_meng@shannonai.com 5 | 6 | @version: 1.0 7 | @file: collate_functions 8 | @time: 2020/6/17 19:18 9 | 10 | collate functions 11 | """ 12 | 13 | from typing import List 14 | 15 | import numpy as np 16 | import torch 17 | 18 | 19 | def collate_to_max_length(batch: List[List[torch.Tensor]], max_len: int = None, fill_values: List[float] = None) -> \ 20 | List[torch.Tensor]: 21 | """ 22 | pad to maximum length of this batch 23 | Args: 24 | batch: a batch of samples, each contains a list of field data(Tensor), which shape is [seq_length] 25 | max_len: specify max length 26 | fill_values: specify filled values of each field 27 | Returns: 28 | output: list of field batched data, which shape is [batch, max_length] 29 | """ 30 | # [batch, num_fields] 31 | lengths = np.array([[len(field_data) for field_data in sample] for sample in batch]) 32 | batch_size, num_fields = lengths.shape 33 | fill_values = fill_values or [0.0] * num_fields 34 | # [num_fields] 35 | max_lengths = lengths.max(axis=0) 36 | if max_len: 37 | assert max_lengths.max() <= max_len 38 | max_lengths = np.ones_like(max_lengths) * max_len 39 | 40 | output = [torch.full([batch_size, max_lengths[field_idx]], 41 | fill_value=fill_values[field_idx], 42 | dtype=batch[0][field_idx].dtype) 43 | for field_idx in range(num_fields)] 44 | for sample_idx in range(batch_size): 45 | for field_idx in range(num_fields): 46 | # seq_length 47 | data = batch[sample_idx][field_idx] 48 | output[field_idx][sample_idx][: data.shape[0]] = data 49 | # generate span_index and span_mask 50 | max_sentence_length = max_lengths[0] 51 | start_indexs = [] 52 | end_indexs = [] 53 | for i in range(1, max_sentence_length - 1): 54 | for j in range(i, max_sentence_length - 1): 55 | # # span大小为10 56 | # if j - i > 10: 57 | # continue 58 | start_indexs.append(i) 59 | end_indexs.append(j) 60 | # generate span mask 61 | span_masks = [] 62 | for input_ids, label, length in batch: 63 | span_mask = [] 64 | middle_index = input_ids.tolist().index(2) 65 | for start_index, end_index in zip(start_indexs, end_indexs): 66 | if 1 <= start_index <= length.item() - 2 and 1 <= end_index <= length.item() - 2 and ( 67 | start_index > middle_index or end_index < middle_index): 68 | span_mask.append(0) 69 | else: 70 | span_mask.append(1e6) 71 | span_masks.append(span_mask) 72 | # add to output 73 | output.append(torch.LongTensor(start_indexs)) 74 | output.append(torch.LongTensor(end_indexs)) 75 | output.append(torch.LongTensor(span_masks)) 76 | return output # (input_ids, labels, length, start_indexs, end_indexs, span_masks) 77 | 78 | 79 | def unit_test(): 80 | input_id_1 = torch.LongTensor([0, 3, 2, 5, 6, 2]) 81 | input_id_2 = torch.LongTensor([0, 3, 2, 4, 2]) 82 | input_id_3 = torch.LongTensor([0, 3, 2]) 83 | batch = [(input_id_1, torch.LongTensor([1]), torch.LongTensor([6])), 84 | (input_id_2, torch.LongTensor([1]), torch.LongTensor([5])), 85 | (input_id_3, torch.LongTensor([1]), torch.LongTensor([3]))] 86 | 87 | output = collate_to_max_length(batch=batch, fill_values=[1, 0, 0]) 88 | print(output) 89 | 90 | 91 | if __name__ == '__main__': 92 | unit_test() 93 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # self-explaining-NLP 2 | Code, models and Datasets for[《Self-Explaining Structures Improve NLP Models》](http://arxiv.org/abs/2012.01786). 3 | 4 | ## installation 5 | `pip install -r requirements.txt` 6 | 7 | ## Prepare Datasets and Models 8 | - Download the SST-5 dataset, the official corpus can be found [HERE](https://nlp.stanford.edu/sentiment/index.html). 9 | We provide processed raw text which you can download [HERE](https://drive.google.com/drive/folders/1TYR-yRw3NXqfXnMSvFDxGTdf1urGfrPY?usp=sharing). 10 | Save the processed raw text dataset at `[SST_PATA_PATH]`. 11 | - Download the SNLI dataset, the official corpus can be found [HERE](https://nlp.stanford.edu/projects/snli/). 12 | Save the SNLI dataset at `[SNLI_PATA_PATH]`. 13 | - Download the vanilla RoBERTa-base model released by HuggingFace. Save the model at `[ROBERTA_BASE_PATH]`, 14 | it can be found [HERE](https://huggingface.co/roberta-base) 15 | - Download the model checkpoints we trained for different tasks. You can use our checkpoint for evaluation. 16 | the checkpoints can be download [HERE](https://drive.google.com/drive/folders/1RV5OJSzN_7p-YkjkmAhq2vzhouZEtzSS?usp=sharing) 17 | 18 | ## Reproduce paper results step by step 19 | In this paper, we utilize self-explaining structures in different NLP tasks. This repo contains all train 20 | and evaluate codes, but here, we only provide commands for SST-5 task as an example. 21 | For other tasks, you can reproduce the results simply by modifying the commands. 22 | 23 | ### 1.Train the self-explaining model 24 | SST-5 is a task with five classes, so we should modify the Roberta-base config file. 25 | Open `[ROBERTA_BASE_PATH]\config.json` and set `num_labels=5`. Then run the following commands. 26 | ```bash 27 | cd explain 28 | python trainer.py \ 29 | --bert_path [ROBERTA_BASE_PATH] \ 30 | --data_dir [SST_PATA_PATH] \ 31 | --task sst5 \ 32 | --save_path [SELF_EXPLAINING_MODEL_CHECKPOINTS] \ 33 | --gpus=0,1,2,3 \ 34 | --precision 16 \ 35 | --lr=2e-5 \ 36 | --batch_size=10 \ 37 | --lamb=1.0 \ 38 | --workers=4 \ 39 | --max_epoch=20 40 | ``` 41 | After training, the checkpoints and training log will be saved at `[SELF_EXPLAINING_MODEL_CHECKPOINTS]`. 42 | ### 2.Evaluate the self-explaining model 43 | Run the following evaluation command to get the performance on test dataset. 44 | You can use the checkpoint you trained or just download our checkpoint to evaluate test dataset. 45 | After evaluation, you will get two output file at `[SPAN_SAVE_PATH]`: `output.txt` and `test.txt`. 46 | `output.txt` records visual extract spans and prediction results. 47 | `text.txt` only records top-ranked span as span-base test data for next stage. 48 | ```bash 49 | cd explain 50 | python trainer.py \ 51 | --bert_path [ROBERTA_BASE_PATH] \ 52 | --data_dir [SST_PATA_PATH] \ 53 | --task sst5 \ 54 | --checkpoint_path [SELF_EXPLAINING_MODEL_CHECKPOINTS]/***.ckpt \ 55 | --save_path [SPAN_SAVE_PATH] \ 56 | --gpus=0, \ 57 | --mode eval 58 | ``` 59 | 60 | ### 3.Check the extracted span 61 | In previous stage, we got span-based test data. You can use the same method to get span-based train data. 62 | To check the extracted span, we set four experiments which are full-full mode, full-span mode, span-full 63 | mode and span-span mode. For example, full-span mode means we use origin SST-5 train data as train data, 64 | and use span-based test data as test data. 65 | You should save the origin SST-5 train data and span-base test data at `[FULL_SPAN_PATH]` 66 | ```bash 67 | scp [SST_PATA_PATH]/train.txt [FULL_SPAN_PATH] 68 | scp [SPAN_SAVE_PATH]/test/txt [FULL_SPAN_PATH] 69 | cd check 70 | python trainer.py \ 71 | --bert_path [ROBERTA_BASE_PATH] \ 72 | --data_dir [FULL_SPAN_PATH] \ 73 | --task sst5 \ 74 | --save_path [CHECK_MODEL_CHECKPOINTS] \ 75 | --gpus=0,1,2,3 \ 76 | --precision 16 \ 77 | --lr=2e-5 \ 78 | --batch_size=10 \ 79 | --workers=4 \ 80 | --max_epoch=20 81 | ``` 82 | -------------------------------------------------------------------------------- /explain/model.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | """ 4 | @file : model.py 5 | @author: zijun 6 | @contact : zijun_sun@shannonai.com 7 | @date : 2020/11/17 14:57 8 | @version: 1.0 9 | @desc : 10 | """ 11 | import torch 12 | from torch import nn 13 | from transformers.modeling_roberta import RobertaModel, RobertaConfig 14 | 15 | from datasets.collate_functions import collate_to_max_length 16 | 17 | 18 | class ExplainableModel(nn.Module): 19 | def __init__(self, bert_dir): 20 | super().__init__() 21 | self.bert_config = RobertaConfig.from_pretrained(bert_dir, output_hidden_states=False) 22 | self.intermediate = RobertaModel.from_pretrained(bert_dir) 23 | self.span_info_collect = SICModel(self.bert_config.hidden_size) 24 | self.interpretation = InterpretationModel(self.bert_config.hidden_size) 25 | self.output = nn.Linear(self.bert_config.hidden_size, self.bert_config.num_labels) 26 | 27 | def forward(self, input_ids, start_indexs, end_indexs, span_masks): 28 | # generate mask 29 | attention_mask = (input_ids != 1).long() 30 | # intermediate layer 31 | hidden_states, first_token = self.intermediate(input_ids, attention_mask=attention_mask) # output.shape = (bs, length, hidden_size) 32 | # span info collecting layer(SIC) 33 | h_ij = self.span_info_collect(hidden_states, start_indexs, end_indexs) 34 | # interpretation layer 35 | H, a_ij = self.interpretation(h_ij, span_masks) 36 | # output layer 37 | out = self.output(H) 38 | return out, a_ij 39 | 40 | 41 | class SICModel(nn.Module): 42 | def __init__(self, hidden_size): 43 | super().__init__() 44 | self.hidden_size = hidden_size 45 | 46 | self.W_1 = nn.Linear(hidden_size, hidden_size) 47 | self.W_2 = nn.Linear(hidden_size, hidden_size) 48 | self.W_3 = nn.Linear(hidden_size, hidden_size) 49 | self.W_4 = nn.Linear(hidden_size, hidden_size) 50 | 51 | def forward(self, hidden_states, start_indexs, end_indexs): 52 | W1_h = self.W_1(hidden_states) # (bs, length, hidden_size) 53 | W2_h = self.W_2(hidden_states) 54 | W3_h = self.W_3(hidden_states) 55 | W4_h = self.W_4(hidden_states) 56 | 57 | W1_hi_emb = torch.index_select(W1_h, 1, start_indexs) # (bs, span_num, hidden_size) 58 | W2_hj_emb = torch.index_select(W2_h, 1, end_indexs) 59 | W3_hi_start_emb = torch.index_select(W3_h, 1, start_indexs) 60 | W3_hi_end_emb = torch.index_select(W3_h, 1, end_indexs) 61 | W4_hj_start_emb = torch.index_select(W4_h, 1, start_indexs) 62 | W4_hj_end_emb = torch.index_select(W4_h, 1, end_indexs) 63 | 64 | # [w1*hi, w2*hj, w3(hi-hj), w4(hi⊗hj)] 65 | span = W1_hi_emb + W2_hj_emb + (W3_hi_start_emb - W3_hi_end_emb) + torch.mul(W4_hj_start_emb, W4_hj_end_emb) 66 | h_ij = torch.tanh(span) 67 | return h_ij 68 | 69 | 70 | class InterpretationModel(nn.Module): 71 | def __init__(self, hidden_size): 72 | super().__init__() 73 | self.h_t = nn.Linear(hidden_size, 1) 74 | 75 | def forward(self, h_ij, span_masks): 76 | o_ij = self.h_t(h_ij).squeeze(-1) # (ba, span_num) 77 | # mask illegal span 78 | o_ij = o_ij - span_masks 79 | # normalize all a_ij, a_ij sum = 1 80 | a_ij = nn.functional.softmax(o_ij, dim=1) 81 | # weight average span representation to get H 82 | H = (a_ij.unsqueeze(-1) * h_ij).sum(dim=1) # (bs, hidden_size) 83 | return H, a_ij 84 | 85 | 86 | def main(): 87 | # data 88 | input_id_1 = torch.LongTensor([0, 4, 5, 6, 7, 2]) 89 | input_id_2 = torch.LongTensor([0, 4, 5, 2]) 90 | input_id_3 = torch.LongTensor([0, 4, 2]) 91 | batch = [(input_id_1, torch.LongTensor([1]), torch.LongTensor([6])), 92 | (input_id_2, torch.LongTensor([1]), torch.LongTensor([4])), 93 | (input_id_3, torch.LongTensor([1]), torch.LongTensor([3]))] 94 | 95 | output = collate_to_max_length(batch=batch, fill_values=[1, 0, 0]) 96 | input_ids, labels, length, start_indexs, end_indexs, span_masks = output 97 | 98 | # model 99 | bert_path = "/data/nfsdata2/sunzijun/loop/roberta-base" 100 | model = ExplainableModel(bert_path) 101 | print(model) 102 | 103 | output = model(input_ids, start_indexs, end_indexs, span_masks) 104 | print(output) 105 | 106 | 107 | if __name__ == '__main__': 108 | main() 109 | -------------------------------------------------------------------------------- /check/trainer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | """ 4 | @file : trainer.py 5 | @author: zijun 6 | @contact : zijun_sun@shannonai.com 7 | @date : 2020/11/16 21:55 8 | @version: 1.0 9 | @desc : 10 | """ 11 | 12 | import argparse 13 | import json 14 | import os 15 | from functools import partial 16 | 17 | import pytorch_lightning as pl 18 | import torch 19 | from pytorch_lightning import Trainer 20 | from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint 21 | from pytorch_lightning.loggers import TensorBoardLogger 22 | from torch.nn import functional as F 23 | from torch.nn.modules import CrossEntropyLoss 24 | from torch.utils.data.dataloader import DataLoader 25 | from transformers import AdamW, get_linear_schedule_with_warmup, RobertaTokenizer 26 | from transformers.modeling_roberta import RobertaForSequenceClassification 27 | 28 | from datasets.collate_functions import collate_to_max_length 29 | from datasets.sst_dataset import SSTDataset 30 | from datasets.snli_dataset import SNLIDataset 31 | from utils.radom_seed import set_random_seed 32 | 33 | set_random_seed(0) 34 | 35 | 36 | class CheckExplainNLP(pl.LightningModule): 37 | 38 | def __init__( 39 | self, 40 | args: argparse.Namespace 41 | ): 42 | """Initialize a model, tokenizer and config.""" 43 | super().__init__() 44 | self.args = args 45 | if isinstance(args, argparse.Namespace): 46 | self.save_hyperparameters(args) 47 | self.bert_dir = args.bert_path 48 | self.model = RobertaForSequenceClassification.from_pretrained(self.bert_dir) 49 | self.tokenizer = RobertaTokenizer.from_pretrained(self.bert_dir) 50 | self.loss_fn = CrossEntropyLoss() 51 | self.train_acc = pl.metrics.Accuracy() 52 | self.valid_acc = pl.metrics.Accuracy() 53 | 54 | def configure_optimizers(self): 55 | """Prepare optimizer and schedule (linear warmup and decay)""" 56 | model = self.model 57 | no_decay = ["bias", "LayerNorm.weight"] 58 | optimizer_grouped_parameters = [ 59 | { 60 | "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 61 | "weight_decay": self.args.weight_decay, 62 | }, 63 | { 64 | "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 65 | "weight_decay": 0.0, 66 | }, 67 | ] 68 | optimizer = AdamW(optimizer_grouped_parameters, 69 | betas=(0.9, 0.98), # according to RoBERTa paper 70 | lr=self.args.lr, 71 | eps=self.args.adam_epsilon) 72 | t_total = len(self.train_dataloader()) // self.args.accumulate_grad_batches * self.args.max_epochs 73 | scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=self.args.warmup_steps, 74 | num_training_steps=t_total) 75 | 76 | return [optimizer], [{"scheduler": scheduler, "interval": "step"}] 77 | 78 | def forward(self, input_ids): 79 | attention_mask = (input_ids != 1).long() 80 | return self.model(input_ids, attention_mask=attention_mask) 81 | 82 | def compute_loss_and_acc(self, batch, mode='train'): 83 | input_ids, labels, length, start_indexs, end_indexs, span_masks = batch 84 | y = labels.view(-1) 85 | y_hat = self.forward(input_ids)[0] 86 | # compute loss 87 | loss = self.loss_fn(y_hat, y) 88 | # compute acc 89 | predict_scores = F.softmax(y_hat, dim=1) 90 | predict_labels = torch.argmax(predict_scores, dim=-1) 91 | if mode == 'train': 92 | acc = self.train_acc(predict_labels, y) 93 | else: 94 | acc = self.valid_acc(predict_labels, y) 95 | return loss, acc 96 | 97 | def train_dataloader(self) -> DataLoader: 98 | return self.get_dataloader("train") 99 | 100 | def training_step(self, batch, batch_idx): 101 | loss, acc = self.compute_loss_and_acc(batch) 102 | self.log('lr', self.trainer.optimizers[0].param_groups[0]['lr']) 103 | self.log('train_acc', acc, on_step=True, on_epoch=False) 104 | self.log('train_loss', loss) 105 | return loss 106 | 107 | def val_dataloader(self): 108 | return self.get_dataloader("test") 109 | 110 | def validation_step(self, batch, batch_idx): 111 | loss, acc = self.compute_loss_and_acc(batch, mode='dev') 112 | self.log('valid_acc', acc, on_step=False, on_epoch=True) 113 | self.log('valid_loss', loss) 114 | return loss 115 | 116 | def validation_epoch_end(self, outs): 117 | # log epoch metric 118 | self.valid_acc.compute() 119 | self.log('valid_acc_end', self.valid_acc.compute()) 120 | 121 | def get_dataloader(self, prefix="train") -> DataLoader: 122 | """get training dataloader""" 123 | if self.args.task == 'snli': 124 | dataset = SNLIDataset(directory=self.args.data_dir, prefix=prefix, 125 | bert_path=self.bert_dir, 126 | max_length=self.args.max_length) 127 | else: 128 | dataset = SSTDataset(directory=self.args.data_dir, prefix=prefix, 129 | bert_path=self.bert_dir, 130 | max_length=self.args.max_length) 131 | dataloader = DataLoader( 132 | dataset=dataset, 133 | batch_size=self.args.batch_size, 134 | num_workers=self.args.workers, 135 | collate_fn=partial(collate_to_max_length, fill_values=[1, 0, 0]), 136 | drop_last=False 137 | ) 138 | return dataloader 139 | 140 | 141 | def get_parser(): 142 | parser = argparse.ArgumentParser(description="Training") 143 | parser.add_argument("--bert_path", required=True, type=str, help="bert config file") 144 | parser.add_argument("--batch_size", type=int, default=10, help="batch size") 145 | parser.add_argument("--lr", type=float, default=2e-5, help="learning rate") 146 | parser.add_argument("--workers", type=int, default=0, help="num workers for dataloader") 147 | parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight decay if we apply some.") 148 | parser.add_argument("--adam_epsilon", default=1e-9, type=float, help="Epsilon for Adam optimizer.") 149 | parser.add_argument("--warmup_steps", default=0, type=int, help="warmup steps") 150 | parser.add_argument("--use_memory", action="store_true", help="load dataset to memory to accelerate.") 151 | parser.add_argument("--max_length", default=512, type=int, help="max length of dataset") 152 | parser.add_argument("--data_dir", required=True, type=str, help="train data path") 153 | parser.add_argument("--save_path", required=True, type=str, help="path to save checkpoints") 154 | parser.add_argument("--save_topk", default=5, type=int, help="save topk checkpoint") 155 | parser.add_argument("--checkpoint_path", type=str, help="checkpoint path on test step") 156 | parser.add_argument("--span_topk", type=int, default=5, help="save topk spans on test step") 157 | parser.add_argument("--lamb", default=1.0, type=float, help="regularizer lambda") 158 | parser.add_argument("--task", default='sst5', type=str, help="nlp tasks") 159 | 160 | return parser 161 | 162 | 163 | def main(): 164 | """main""" 165 | parser = get_parser() 166 | parser = Trainer.add_argparse_args(parser) 167 | args = parser.parse_args() 168 | # 如果save path不存在,则创建 169 | if not os.path.exists(args.save_path): 170 | os.mkdir(args.save_path) 171 | 172 | model = CheckExplainNLP(args) 173 | 174 | checkpoint_callback = ModelCheckpoint( 175 | filepath=os.path.join(args.save_path, '{epoch}-{valid_loss:.4f}-{valid_acc_end:.4f}'), 176 | save_top_k=args.save_topk, 177 | save_last=True, 178 | monitor="valid_acc_end", 179 | mode="max", 180 | ) 181 | logger = TensorBoardLogger( 182 | save_dir=args.save_path, 183 | name='log' 184 | ) 185 | 186 | # save args 187 | with open(os.path.join(args.save_path, "args.json"), 'w') as f: 188 | args_dict = args.__dict__ 189 | del args_dict['tpu_cores'] 190 | json.dump(args_dict, f, indent=4) 191 | 192 | trainer = Trainer.from_argparse_args(args, 193 | checkpoint_callback=checkpoint_callback, 194 | distributed_backend="ddp", 195 | logger=logger) 196 | 197 | trainer.fit(model) 198 | 199 | 200 | if __name__ == '__main__': 201 | from multiprocessing import freeze_support 202 | 203 | freeze_support() 204 | main() 205 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. -------------------------------------------------------------------------------- /explain/trainer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | """ 4 | @file : trainer.py 5 | @author: zijun 6 | @contact : zijun_sun@shannonai.com 7 | @date : 2020/11/16 21:55 8 | @version: 1.0 9 | @desc : 10 | """ 11 | 12 | import argparse 13 | import json 14 | import os 15 | from functools import partial 16 | 17 | import pytorch_lightning as pl 18 | import torch 19 | from pytorch_lightning import Trainer 20 | from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint 21 | from pytorch_lightning.loggers import TensorBoardLogger 22 | from torch.nn import functional as F 23 | from torch.nn.modules import CrossEntropyLoss 24 | from torch.utils.data.dataloader import DataLoader 25 | from transformers import AdamW, get_linear_schedule_with_warmup, RobertaTokenizer 26 | 27 | from datasets.collate_functions import collate_to_max_length 28 | from datasets.sst_dataset import SSTDataset 29 | from datasets.snli_dataset import SNLIDataset 30 | from explain.model import ExplainableModel 31 | from utils.radom_seed import set_random_seed 32 | 33 | set_random_seed(0) 34 | 35 | 36 | class ExplainNLP(pl.LightningModule): 37 | 38 | def __init__( 39 | self, 40 | args: argparse.Namespace 41 | ): 42 | """Initialize a model, tokenizer and config.""" 43 | super().__init__() 44 | self.args = args 45 | if isinstance(args, argparse.Namespace): 46 | self.save_hyperparameters(args) 47 | self.bert_dir = args.bert_path 48 | self.model = ExplainableModel(self.bert_dir) 49 | self.tokenizer = RobertaTokenizer.from_pretrained(self.bert_dir) 50 | self.loss_fn = CrossEntropyLoss() 51 | self.train_acc = pl.metrics.Accuracy() 52 | self.valid_acc = pl.metrics.Accuracy() 53 | self.output = [] 54 | self.check_data = [] 55 | 56 | def configure_optimizers(self): 57 | """Prepare optimizer and schedule (linear warmup and decay)""" 58 | model = self.model 59 | no_decay = ["bias", "LayerNorm.weight"] 60 | optimizer_grouped_parameters = [ 61 | { 62 | "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 63 | "weight_decay": self.args.weight_decay, 64 | }, 65 | { 66 | "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 67 | "weight_decay": 0.0, 68 | }, 69 | ] 70 | optimizer = AdamW(optimizer_grouped_parameters, 71 | betas=(0.9, 0.98), # according to RoBERTa paper 72 | lr=self.args.lr, 73 | eps=self.args.adam_epsilon) 74 | t_total = len(self.train_dataloader()) // self.args.accumulate_grad_batches * self.args.max_epochs 75 | scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=self.args.warmup_steps, 76 | num_training_steps=t_total) 77 | 78 | return [optimizer], [{"scheduler": scheduler, "interval": "step"}] 79 | 80 | def forward(self, input_ids, start_indexs, end_indexs, span_masks): 81 | return self.model(input_ids, start_indexs, end_indexs, span_masks) 82 | 83 | def compute_loss_and_acc(self, batch, mode='train'): 84 | input_ids, labels, length, start_indexs, end_indexs, span_masks = batch 85 | y = labels.view(-1) 86 | y_hat, a_ij = self.forward(input_ids, start_indexs, end_indexs, span_masks) 87 | # compute loss 88 | ce_loss = self.loss_fn(y_hat, y) 89 | reg_loss = self.args.lamb * a_ij.pow(2).sum(dim=1).mean() 90 | loss = ce_loss - reg_loss 91 | # compute acc 92 | predict_scores = F.softmax(y_hat, dim=1) 93 | predict_labels = torch.argmax(predict_scores, dim=-1) 94 | if mode == 'train': 95 | acc = self.train_acc(predict_labels, y) 96 | else: 97 | acc = self.valid_acc(predict_labels, y) 98 | # if test, save extract spans 99 | if mode == 'test': 100 | values, indices = torch.topk(a_ij, self.args.span_topk) 101 | values = values.tolist() 102 | indices = indices.tolist() 103 | for i in range(len(values)): 104 | input_ids_list = input_ids[i].tolist() 105 | origin_sentence = self.tokenizer.decode(input_ids_list, skip_special_tokens=True) 106 | self.output.append( 107 | str(labels[i].item()) + '<->' + str(predict_labels[i].item()) + '<->' + origin_sentence + '\n') 108 | # print() 109 | for j, span_idx in enumerate(indices[i]): 110 | score = values[i][j] 111 | start_index = start_indexs[span_idx] 112 | end_index = end_indexs[span_idx] 113 | pre = self.tokenizer.decode(input_ids_list[:start_index], skip_special_tokens=True) 114 | high_light = self.tokenizer.decode(input_ids_list[start_index:end_index + 1], 115 | skip_special_tokens=True) 116 | post = self.tokenizer.decode(input_ids_list[end_index + 1:], skip_special_tokens=True) 117 | span_sentence = pre + '【' + high_light + '】' + post 118 | self.output.append(format('%.4f' % score) + "->" + span_sentence + '\n') 119 | # print(format('%.4f' % score), "->", span_sentence) 120 | if j == 0: 121 | # generate data for check progress 122 | self.check_data.append(str(labels[i].item()) + '\t' + high_light + '\n') 123 | self.output.append('\n') 124 | # print('='*30) 125 | 126 | return loss, acc 127 | 128 | def validation_epoch_end(self, outs): 129 | # log epoch metric 130 | self.valid_acc.compute() 131 | self.log('valid_acc_end', self.valid_acc.compute()) 132 | 133 | def train_dataloader(self) -> DataLoader: 134 | return self.get_dataloader("train") 135 | 136 | def training_step(self, batch, batch_idx): 137 | loss, acc = self.compute_loss_and_acc(batch) 138 | self.log('lr', self.trainer.optimizers[0].param_groups[0]['lr']) 139 | self.log('train_acc', acc, on_step=True, on_epoch=False) 140 | self.log('train_loss', loss) 141 | return loss 142 | 143 | def val_dataloader(self): 144 | return self.get_dataloader("dev") 145 | 146 | def validation_step(self, batch, batch_idx): 147 | loss, acc = self.compute_loss_and_acc(batch, mode='dev') 148 | self.log('valid_acc', acc, on_step=False, on_epoch=True) 149 | self.log('valid_loss', loss) 150 | return loss 151 | 152 | def get_dataloader(self, prefix="train") -> DataLoader: 153 | """get training dataloader""" 154 | if self.args.task == 'snli': 155 | dataset = SNLIDataset(directory=self.args.data_dir, prefix=prefix, 156 | bert_path=self.bert_dir, 157 | max_length=self.args.max_length) 158 | else: 159 | dataset = SSTDataset(directory=self.args.data_dir, prefix=prefix, 160 | bert_path=self.bert_dir, 161 | max_length=self.args.max_length) 162 | dataloader = DataLoader( 163 | dataset=dataset, 164 | batch_size=self.args.batch_size, 165 | num_workers=self.args.workers, 166 | collate_fn=partial(collate_to_max_length, fill_values=[1, 0, 0]), 167 | drop_last=False 168 | ) 169 | return dataloader 170 | 171 | def test_dataloader(self): 172 | return self.get_dataloader("test") 173 | 174 | def test_step(self, batch, batch_idx): 175 | loss, acc = self.compute_loss_and_acc(batch, mode='test') 176 | return {'test_loss': loss, "test_acc": acc} 177 | 178 | def test_epoch_end(self, outputs): 179 | with open(os.path.join(self.args.save_path, 'output.txt'), 'w', encoding='utf8') as f: 180 | f.writelines(self.output) 181 | with open(os.path.join(self.args.save_path, 'test.txt'), 'w', encoding='utf8') as f: 182 | f.writelines(self.check_data) 183 | avg_loss = torch.stack([x['test_loss'] for x in outputs]).mean() 184 | avg_acc = torch.stack([x['test_acc'] for x in outputs]).mean() 185 | tensorboard_logs = {'test_loss': avg_loss, 'test_acc': avg_acc} 186 | print(avg_loss, avg_acc) 187 | return {'val_loss': avg_loss, 'log': tensorboard_logs} 188 | 189 | 190 | def get_parser(): 191 | parser = argparse.ArgumentParser(description="Training") 192 | parser.add_argument("--bert_path", required=True, type=str, help="bert config file") 193 | parser.add_argument("--batch_size", type=int, default=10, help="batch size") 194 | parser.add_argument("--lr", type=float, default=2e-5, help="learning rate") 195 | parser.add_argument("--workers", type=int, default=0, help="num workers for dataloader") 196 | parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight decay if we apply some.") 197 | parser.add_argument("--adam_epsilon", default=1e-9, type=float, help="Epsilon for Adam optimizer.") 198 | parser.add_argument("--warmup_steps", default=0, type=int, help="warmup steps") 199 | parser.add_argument("--use_memory", action="store_true", help="load dataset to memory to accelerate.") 200 | parser.add_argument("--max_length", default=512, type=int, help="max length of dataset") 201 | parser.add_argument("--data_dir", required=True, type=str, help="train data path") 202 | parser.add_argument("--save_path", required=True, type=str, help="path to save checkpoints") 203 | parser.add_argument("--save_topk", default=5, type=int, help="save topk checkpoint") 204 | parser.add_argument("--checkpoint_path", type=str, help="checkpoint path on test step") 205 | parser.add_argument("--span_topk", type=int, default=5, help="save topk spans on test step") 206 | parser.add_argument("--lamb", default=1.0, type=float, help="regularizer lambda") 207 | parser.add_argument("--task", default='sst5', type=str, help="nlp tasks") 208 | parser.add_argument("--mode", default='train', type=str, help="either train or eval") 209 | 210 | return parser 211 | 212 | 213 | def train(args): 214 | # if save path does not exits, create it 215 | if not os.path.exists(args.save_path): 216 | os.mkdir(args.save_path) 217 | 218 | model = ExplainNLP(args) 219 | 220 | checkpoint_callback = ModelCheckpoint( 221 | filepath=os.path.join(args.save_path, '{epoch}-{valid_loss:.4f}-{valid_acc_end:.4f}'), 222 | save_top_k=args.save_topk, 223 | save_last=True, 224 | monitor="valid_acc_end", 225 | mode="max", 226 | ) 227 | logger = TensorBoardLogger( 228 | save_dir=args.save_path, 229 | name='log' 230 | ) 231 | 232 | # save args 233 | with open(os.path.join(args.save_path, "args.json"), 'w') as f: 234 | args_dict = args.__dict__ 235 | del args_dict['tpu_cores'] 236 | json.dump(args_dict, f, indent=4) 237 | 238 | trainer = Trainer.from_argparse_args(args, 239 | checkpoint_callback=checkpoint_callback, 240 | distributed_backend="ddp", 241 | logger=logger) 242 | trainer.fit(model) 243 | 244 | 245 | def evaluate(args): 246 | model = ExplainNLP(args) 247 | checkpoint = torch.load(args.checkpoint_path, map_location=torch.device('cpu')) 248 | model.load_state_dict(checkpoint['state_dict']) 249 | trainer = Trainer.from_argparse_args(args, distributed_backend="ddp") 250 | trainer.test(model) 251 | 252 | 253 | def main(): 254 | parser = get_parser() 255 | parser = Trainer.add_argparse_args(parser) 256 | args = parser.parse_args() 257 | if args.mode == 'train': 258 | train(args) 259 | elif args.mode == 'eval': 260 | evaluate(args) 261 | else: 262 | raise Exception("unexpected mode!!!") 263 | 264 | 265 | if __name__ == '__main__': 266 | from multiprocessing import freeze_support 267 | 268 | freeze_support() 269 | main() 270 | --------------------------------------------------------------------------------