├── src ├── others │ ├── __init__.py │ ├── logging.py │ ├── utils.py │ ├── recadam.py │ ├── optimizer.py │ └── my_pyrouge.py ├── inference.py ├── cal_rouge.py ├── preprocessing.py ├── sdpt_pretraining.py ├── run.py ├── trainer.py ├── dapt_pretraining.py └── tapt_pretraining.py ├── scripts ├── finetuning.sh ├── dapt_pretraining.sh ├── tapt_pretraining.sh └── sdpt_pretraining.sh ├── requirements.txt ├── image ├── HKUST.jpg └── pytorch-logo-dark.png ├── .gitignore ├── README.md └── LICENSE /src/others/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /scripts/finetuning.sh: -------------------------------------------------------------------------------- 1 | python ./src/run.py -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | tqdm 3 | transformers>=3.0.2 4 | nltk -------------------------------------------------------------------------------- /image/HKUST.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TysonYu/AdaptSum/HEAD/image/HKUST.jpg -------------------------------------------------------------------------------- /image/pytorch-logo-dark.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TysonYu/AdaptSum/HEAD/image/pytorch-logo-dark.png -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | dataset 3 | savings 4 | .vscode 5 | processed_datasets/ 6 | *.pt 7 | logs 8 | pyrouge 9 | save -------------------------------------------------------------------------------- /scripts/dapt_pretraining.sh: -------------------------------------------------------------------------------- 1 | python ./src/dapt_pretraining.py -path=./dataset/debate/debateorg.txt \ 2 | -dm=debate \ 3 | -visible_gpu=0 \ 4 | -save_interval=10000 \ 5 | # -recadam \ 6 | # -logging_Euclid_dist -------------------------------------------------------------------------------- /scripts/tapt_pretraining.sh: -------------------------------------------------------------------------------- 1 | python ./src/tapt_pretraining.py -path=./dataset/debate/TAPT-data/train.source \ 2 | -dm=debate \ 3 | -visible_gpu=0 \ 4 | -save_interval=100 \ 5 | # -recadam \ 6 | # -logging_Euclid_dist -------------------------------------------------------------------------------- /scripts/sdpt_pretraining.sh: -------------------------------------------------------------------------------- 1 | python ./src/sdpt_pretraining.py -data_name=SDPT-cnn_dm \ 2 | -visible_gpu=0 \ 3 | -saving_path=SDPT_save \ 4 | -start_to_save_iter=0 \ 5 | -save_interval=10000 \ 6 | # -recadam \ 7 | # -logging_Euclid_dist \ -------------------------------------------------------------------------------- /src/others/logging.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import logging 4 | 5 | logger = logging.getLogger() 6 | 7 | 8 | def init_logger(log_file=None, log_file_level=logging.NOTSET): 9 | log_format = logging.Formatter("[%(asctime)s %(levelname)s] %(message)s") 10 | logger = logging.getLogger() 11 | logger.setLevel(logging.INFO) 12 | 13 | console_handler = logging.StreamHandler() 14 | console_handler.setFormatter(log_format) 15 | logger.handlers = [console_handler] 16 | 17 | if log_file and log_file != '': 18 | file_handler = logging.FileHandler(log_file) 19 | file_handler.setLevel(log_file_level) 20 | file_handler.setFormatter(log_format) 21 | logger.addHandler(file_handler) 22 | 23 | return logger 24 | 25 | 26 | -------------------------------------------------------------------------------- /src/others/utils.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import pickle 4 | import numpy as np 5 | import torch.nn as nn 6 | import torch 7 | import math 8 | import random 9 | NEAR_INF = 1e20 10 | NEAR_INF_FP16 = 65504 11 | 12 | def save(toBeSaved, filename, mode='wb'): 13 | ''' 14 | save data to pickle file 15 | ''' 16 | dirname = os.path.dirname(filename) 17 | if not os.path.exists(dirname): 18 | os.makedirs(dirname) 19 | file = open(filename, mode) 20 | pickle.dump(toBeSaved, file, protocol=4) # protocol 4 allows large size object, it's the default since python 3.8 21 | file.close() 22 | 23 | def load(filename, mode='rb'): 24 | ''' 25 | load pickle file 26 | ''' 27 | file = open(filename, mode) 28 | loaded = pickle.load(file) 29 | file.close() 30 | return loaded 31 | 32 | def pad_sents(sents, pad_token=0, max_len=512): 33 | ''' 34 | pad input to max length 35 | ''' 36 | sents_padded = [] 37 | lens = get_lens(sents) 38 | max_len = min(max(lens), max_len) 39 | sents_padded = [] 40 | new_len = [] 41 | for i, l in enumerate(lens): 42 | if l > max_len: 43 | l = max_len 44 | new_len.append(l) 45 | sents_padded.append(sents[i][:l] + [pad_token] * (max_len - l)) 46 | return sents_padded, new_len 47 | 48 | def get_mask(sents, unmask_idx=1, mask_idx=0, max_len=512): 49 | ''' 50 | make mask for padded input 51 | ''' 52 | lens = get_lens(sents) 53 | max_len = min(max(lens), max_len) 54 | mask = [] 55 | for l in lens: 56 | if l > max_len: 57 | l = max_len 58 | mask.append([unmask_idx] * l + [mask_idx] * (max_len - l)) 59 | return mask 60 | 61 | def get_lens(sents): 62 | return [len(sent) for sent in sents] 63 | 64 | def get_max_len(sents): 65 | max_len = max([len(sent) for sent in sents]) 66 | return max_len 67 | 68 | def fix_random_seed(random_seed): 69 | random.seed(random_seed) 70 | np.random.seed(random_seed) 71 | torch.manual_seed(random_seed) 72 | torch.cuda.manual_seed(random_seed) 73 | torch.backends.cudnn.deterministic = True 74 | 75 | # ----- for model ----------------------------------------------------------- 76 | 77 | def count_parameters(model): 78 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 79 | 80 | def initialize_weights(m): 81 | if hasattr(m, 'weight') and m.weight.dim() > 1: 82 | nn.init.xavier_uniform_(m.weight.data) 83 | 84 | def tile(x, count, dim=0): 85 | """ 86 | Tiles x on dimension dim count times. 87 | """ 88 | perm = list(range(len(x.size()))) 89 | if dim != 0: 90 | perm[0], perm[dim] = perm[dim], perm[0] 91 | x = x.permute(perm).contiguous() 92 | out_size = list(x.size()) 93 | out_size[0] *= count 94 | batch = x.size(0) 95 | x = x.view(batch, -1) \ 96 | .transpose(0, 1) \ 97 | .repeat(count, 1) \ 98 | .transpose(0, 1) \ 99 | .contiguous() \ 100 | .view(*out_size) 101 | if dim != 0: 102 | x = x.permute(perm).contiguous() 103 | return x 104 | 105 | def neginf(dtype: torch.dtype) -> float: 106 | """ 107 | Return a representable finite number near -inf for a dtype. 108 | """ 109 | if dtype is torch.float16: 110 | return -NEAR_INF_FP16 111 | else: 112 | return -NEAR_INF 113 | 114 | def batch_generator(dataloader): 115 | while True: 116 | for data_sample in dataloader: 117 | yield data_sample 118 | -------------------------------------------------------------------------------- /src/inference.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import torch 4 | from tqdm import tqdm 5 | from transformers import BartTokenizer 6 | from transformers import BartForConditionalGeneration 7 | from others.logging import init_logger, logger 8 | from others.utils import load, count_parameters, initialize_weights, fix_random_seed 9 | from preprocessing import BartDataset, DataReader 10 | from others.optimizer import build_optim 11 | from trainer import train 12 | 13 | def tokenize(data, tokenizer): 14 | tokenized_text = [tokenizer.encode(i) for i in data] 15 | return tokenized_text 16 | 17 | if __name__ == '__main__': 18 | # for training 19 | parser = argparse.ArgumentParser() 20 | parser.add_argument('-visible_gpu', default='1', type=str) 21 | parser.add_argument('-log_file', default='./logs/inference/', type=str) 22 | parser.add_argument('-train_from', default='', type=str) 23 | parser.add_argument('-random_seed', type=int, default=199744) 24 | parser.add_argument('-lr', default=0.001, type=float) 25 | parser.add_argument('-max_grad_norm', default=0, type=float) 26 | parser.add_argument('-epoch', type=int, default=50) 27 | parser.add_argument('-saving_path', default='./save/', type=str) 28 | parser.add_argument('-data_name', default='debate', type=str) 29 | # for learning, optimizer 30 | parser.add_argument('-optim', default='adam', type=str) 31 | parser.add_argument('-beta1', default=0.9, type=float) 32 | parser.add_argument('-beta2', default=0.998, type=float) 33 | parser.add_argument('-warmup_steps', default=1000, type=int) 34 | parser.add_argument('-decay_method', default='noam', type=str) 35 | parser.add_argument('-enc_hidden_size', default=768, type=int) 36 | parser.add_argument('-clip', default=1.0, type=float) 37 | parser.add_argument('-accumulation_steps', default=10, type=int) 38 | 39 | args = parser.parse_args() 40 | 41 | # initial logger 42 | init_logger(args.log_file+args.data_name+'.log') 43 | logger.info(args) 44 | 45 | # set gpu 46 | os.environ["CUDA_VISIBLE_DEVICES"] = args.visible_gpu 47 | 48 | # set random seed 49 | fix_random_seed(args.random_seed) 50 | 51 | # loading data 52 | # it's faster to load data from pre_build data 53 | logger.info('starting to read dataloader') 54 | data_file_name = './dataset/' + args.data_name + '/testloader.pt' 55 | train_loader = load(data_file_name) 56 | 57 | # initial tokenizer 58 | tokenizer = BartTokenizer.from_pretrained('facebook/bart-base') 59 | 60 | # initial model 61 | logger.info('starting to build model') 62 | if args.train_from != '': 63 | logger.info("train from : {}".format(args.train_from)) 64 | # checkpoint = torch.load(args.train_from, map_location='cpu') 65 | # model.load_state_dict(checkpoint['model']) 66 | model = torch.load(args.train_from) 67 | else: 68 | model = BartForConditionalGeneration.from_pretrained('facebook/bart-base') 69 | model.cuda() 70 | model.eval() 71 | 72 | # inference and save model 73 | save_file_name = './dataset/' + args.data_name + '/summaries' 74 | summaries = open(save_file_name, 'w') 75 | outputs = [] 76 | for src_ids, decoder_ids, mask, label_ids in tqdm(train_loader): 77 | src_ids = src_ids.cuda() 78 | summary_ids = model.generate(src_ids, num_beams=4, max_length=256, early_stopping=True) 79 | output = [tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in summary_ids] 80 | outputs += output 81 | outputs = [i+'\n' for i in outputs] 82 | print(outputs[:4]) 83 | summaries.writelines(outputs) 84 | summaries.close() 85 | 86 | -------------------------------------------------------------------------------- /src/cal_rouge.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import time 4 | from multiprocessing import Pool 5 | 6 | import shutil 7 | import sys 8 | import codecs 9 | # import pyrouge 10 | from others.my_pyrouge import Rouge155 11 | 12 | def process(data): 13 | candidates, references, pool_id = data 14 | cnt = len(candidates) 15 | current_time = time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime()) 16 | tmp_dir = "rouge-tmp-{}-{}".format(current_time,pool_id) 17 | if not os.path.isdir(tmp_dir): 18 | os.mkdir(tmp_dir) 19 | os.mkdir(tmp_dir + "/candidate") 20 | os.mkdir(tmp_dir + "/reference") 21 | try: 22 | for i in range(cnt): 23 | if len(references[i]) < 1: 24 | continue 25 | with open(tmp_dir + "/candidate/cand.{}.txt".format(i), "w", 26 | encoding="utf-8") as f: 27 | f.write(candidates[i]) 28 | with open(tmp_dir + "/reference/ref.{}.txt".format(i), "w", 29 | encoding="utf-8") as f: 30 | f.write(references[i]) 31 | r = Rouge155() 32 | r.model_dir = tmp_dir + "/reference/" 33 | r.system_dir = tmp_dir + "/candidate/" 34 | r.model_filename_pattern = 'ref.#ID#.txt' 35 | r.system_filename_pattern = r'cand.(\d+).txt' 36 | rouge_results = r.convert_and_evaluate() 37 | # print(rouge_results) 38 | results_dict = r.output_to_dict(rouge_results) 39 | finally: 40 | pass 41 | if os.path.isdir(tmp_dir): 42 | shutil.rmtree(tmp_dir) 43 | return results_dict 44 | 45 | def chunks(l, n): 46 | """Yield successive n-sized chunks from l.""" 47 | for i in range(0, len(l), n): 48 | yield l[i:i + n] 49 | 50 | def test_rouge(cand, ref,num_processes): 51 | """Calculate ROUGE scores of sequences passed as an iterator 52 | e.g. a list of str, an open file, StringIO or even sys.stdin 53 | """ 54 | candidates = [line.strip() for line in cand] 55 | references = [line.strip() for line in ref] 56 | 57 | # print(len(candidates)) 58 | # print(len(references)) 59 | assert len(candidates) == len(references) 60 | candidates_chunks = list(chunks(candidates, int(len(candidates)/num_processes))) 61 | references_chunks = list(chunks(references, int(len(references)/num_processes))) 62 | n_pool = len(candidates_chunks) 63 | arg_lst = [] 64 | for i in range(n_pool): 65 | arg_lst.append((candidates_chunks[i],references_chunks[i],i)) 66 | pool = Pool(n_pool) 67 | results = pool.map(process,arg_lst) 68 | final_results = {} 69 | for i,r in enumerate(results): 70 | for k in r: 71 | if(k not in final_results): 72 | final_results[k] = r[k]*len(candidates_chunks[i]) 73 | else: 74 | final_results[k] += r[k] * len(candidates_chunks[i]) 75 | for k in final_results: 76 | final_results[k] = final_results[k]/len(candidates) 77 | return final_results 78 | 79 | def rouge_results_to_str(results_dict): 80 | return ">> ROUGE-F(1/2/3/l): {:.2f}/{:.2f}/{:.2f}\nROUGE-R(1/2/3/l): {:.2f}/{:.2f}/{:.2f}\nROUGE-P(1/2/3/l): {:.2f}/{:.2f}/{:.2f}\n".format( 81 | results_dict["rouge_1_f_score"] * 100, 82 | results_dict["rouge_2_f_score"] * 100, 83 | # results_dict["rouge_3_f_score"] * 100, 84 | results_dict["rouge_l_f_score"] * 100, 85 | 86 | results_dict["rouge_1_recall"] * 100, 87 | results_dict["rouge_2_recall"] * 100, 88 | # results_dict["rouge_3_f_score"] * 100, 89 | results_dict["rouge_l_recall"] * 100, 90 | 91 | results_dict["rouge_1_precision"] * 100, 92 | results_dict["rouge_2_precision"] * 100, 93 | # results_dict["rouge_3_f_score"] * 100, 94 | results_dict["rouge_l_precision"] * 100 95 | ) 96 | 97 | 98 | if __name__ == "__main__": 99 | # init_logger('test_rouge.log') 100 | parser = argparse.ArgumentParser() 101 | parser.add_argument('-c', type=str, default="../datasets/debate/test.source", 102 | help='candidate file') 103 | parser.add_argument('-r', type=str, default="../datasets/debate/test.target", 104 | help='reference file') 105 | parser.add_argument('-p', type=int, default=1, 106 | help='number of processes') 107 | args = parser.parse_args() 108 | print(args.c) 109 | print(args.r) 110 | print(args.p) 111 | 112 | # read file 113 | c_data = open(args.c) 114 | candidates = c_data.readlines() 115 | candidates = [line.strip('\n') for line in candidates] 116 | c_data = open(args.r) 117 | references = c_data.readlines() 118 | references = [line.strip('\n') for line in references] 119 | 120 | # calculate rouge 121 | results_dict = test_rouge(candidates, references, args.p) 122 | print(time.strftime('%H:%M:%S', time.localtime())) 123 | print(rouge_results_to_str(results_dict)) 124 | -------------------------------------------------------------------------------- /src/preprocessing.py: -------------------------------------------------------------------------------- 1 | from tqdm import tqdm 2 | import torch 3 | from transformers import BartTokenizer 4 | from torch.utils.data import Dataset, DataLoader 5 | import random 6 | import numpy as np 7 | import os 8 | 9 | from others.utils import pad_sents, save, get_mask, fix_random_seed 10 | 11 | class BartDataset(Dataset): 12 | ''' 13 | Attributes: 14 | src: it's a list, each line is a sample for source text. 15 | tgt: it's a list, each line is a sample for target text. 16 | src_ids: it's a list, each line is a sample for source index after tokenized. 17 | tgt_ids: it's a list, each line is a sample for target index after tokenized. 18 | ''' 19 | def __init__(self, tokenizer, multi_news_reader, args): 20 | self.tokenizer = tokenizer 21 | self.multi_news_reader = multi_news_reader 22 | self.src = multi_news_reader.data_src 23 | self.tgt = multi_news_reader.data_tgt 24 | self.src = [i.strip('\n') for i in self.src] 25 | self.tgt = [i.strip('\n') for i in self.tgt] 26 | self.src_ids = self.tokenize(self.src) 27 | self.tgt_ids = self.tokenize(self.tgt) 28 | 29 | def __len__(self): 30 | return len(self.src) 31 | 32 | def __getitem__(self, idx): 33 | return self.src_ids[idx], self.tgt_ids[idx] 34 | 35 | def tokenize(self, data): 36 | tokenized_text = [self.tokenizer.encode(i, add_special_tokens=False) for i in tqdm(data)] 37 | return tokenized_text 38 | 39 | def collate_fn(self, data): 40 | # rebuld the raw text and truncate to max length 41 | max_input_len = 1024 42 | max_output_len = 256 43 | raw_src = [pair[0] for pair in data] 44 | raw_tgt = [pair[1] for pair in data] 45 | raw_src = [i[:max_input_len-1] for i in raw_src] 46 | raw_tgt = [i[:max_output_len-1] for i in raw_tgt] 47 | src = [] 48 | tgt = [] 49 | # remove blank data 50 | for i in range(len(raw_src)): 51 | if (raw_src[i] != []) and (raw_tgt[i] != []): 52 | src.append(raw_src[i]) 53 | tgt.append(raw_tgt[i]) 54 | # make input mask 55 | mask = torch.tensor(get_mask(src, max_len=max_input_len)) 56 | # make input ids 57 | src_ids = torch.tensor(pad_sents(src, 1, max_len=max_input_len)[0]) 58 | # make output ids 59 | decoder_ids = [[0]+i for i in tgt] 60 | # make output labels 61 | label_ids = [i+[2] for i in tgt] 62 | decoder_ids = torch.tensor(pad_sents(decoder_ids, 1, max_len=max_output_len)[0]) 63 | label_ids = torch.tensor(pad_sents(label_ids, -100, max_len=max_output_len)[0]) 64 | 65 | return src_ids, decoder_ids, mask, label_ids 66 | 67 | 68 | class DataReader(object): 69 | ''' 70 | Attributes: 71 | data_src: source text 72 | data_tgt: target text 73 | ''' 74 | def __init__ (self, args): 75 | self.args = args 76 | self.raw_data = self.load_multinews_data(self.args) 77 | self.data_src = self.raw_data[0] 78 | self.data_tgt = self.raw_data[1] 79 | 80 | def file_reader(self, file_path): 81 | file = open(file_path, 'r') 82 | lines = file.readlines() 83 | return lines 84 | 85 | def load_multinews_data(self, args): 86 | train_src_path = args.data_path + args.data_name + '/' + args.mode + '.source' 87 | train_tgt_path = args.data_path + args.data_name + '/' + args.mode + '.target' 88 | train_src_lines = self.file_reader(train_src_path) 89 | train_tgt_lines = self.file_reader(train_tgt_path) 90 | return (train_src_lines, train_tgt_lines) 91 | 92 | def data_builder(args): 93 | save_path = args.data_path + args.data_name + '/' + args.mode + 'loader' + '.pt' 94 | data_reader = DataReader(args) 95 | tokenizer = BartTokenizer.from_pretrained('facebook/bart-base') 96 | train_set = BartDataset(tokenizer, data_reader, args) 97 | if args.mode == 'train': 98 | data_loader = DataLoader(dataset=train_set, 99 | batch_size=args.batch_size, 100 | shuffle=True, 101 | collate_fn=train_set.collate_fn) 102 | else: 103 | data_loader = DataLoader(dataset=train_set, 104 | batch_size=args.batch_size, 105 | shuffle=False, 106 | collate_fn=train_set.collate_fn) 107 | save(data_loader, save_path) 108 | 109 | if __name__ == '__main__': 110 | import argparse 111 | parser = argparse.ArgumentParser() 112 | parser.add_argument('-data_path', default='dataset/', type=str) 113 | parser.add_argument('-data_name', default='debate', type=str) 114 | parser.add_argument('-mode', default='train', type=str) 115 | parser.add_argument('-batch_size', default=4, type=int) 116 | parser.add_argument('-random_seed', type=int, default=0) 117 | args = parser.parse_args() 118 | 119 | # set random seed 120 | fix_random_seed(args.random_seed) 121 | data_builder(args) 122 | 123 | -------------------------------------------------------------------------------- /src/others/recadam.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """RecAdam optimizer""" 16 | 17 | import logging 18 | import math 19 | import numpy as np 20 | 21 | import torch 22 | from torch.optim import Optimizer 23 | 24 | 25 | logger = logging.getLogger(__name__) 26 | 27 | 28 | def anneal_function(function, step, k, t0, weight): 29 | if function == 'sigmoid': 30 | return float(1 / (1 + np.exp(-k * (step - t0)))) * weight 31 | elif function == 'linear': 32 | return min(1, step / t0) * weight 33 | elif function == 'constant': 34 | return weight 35 | else: 36 | ValueError 37 | 38 | 39 | class RecAdam(Optimizer): 40 | """ Implementation of RecAdam optimizer, a variant of Adam optimizer. 41 | Parameters: 42 | lr (float): learning rate. Default 1e-3. 43 | betas (tuple of 2 floats): Adams beta parameters (b1, b2). Default: (0.9, 0.999) 44 | eps (float): Adams epsilon. Default: 1e-6 45 | weight_decay (float): Weight decay. Default: 0.0 46 | correct_bias (bool): can be set to False to avoid correcting bias in Adam (e.g. like in Bert TF repository). Default True. 47 | anneal_fun (str): a hyperparam for the anneal function, decide the function of the curve. Default 'sigmoid'. 48 | anneal_k (float): a hyperparam for the anneal function, decide the slop of the curve. Choice: [0.05, 0.1, 0.2, 0.5, 1] 49 | anneal_t0 (float): a hyperparam for the anneal function, decide the middle point of the curve. Choice: [100, 250, 500, 1000] 50 | anneal_w (float): a hyperparam for the anneal function, decide the scale of the curve. Default 1.0. 51 | pretrain_cof (float): the coefficient of the quadratic penalty. Default 5000.0. 52 | pretrain_params (list of tensors): the corresponding group of params in the pretrained model. 53 | """ 54 | 55 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-6, weight_decay=0.0, correct_bias=True, anneal_fun='sigmoid', anneal_k=0, anneal_t0=0, anneal_w=1.0, pretrain_cof=5000.0, pretrain_params=None): 56 | if lr < 0.0: 57 | raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr)) 58 | if not 0.0 <= betas[0] < 1.0: 59 | raise ValueError("Invalid beta parameter: {} - should be in [0.0, 1.0[".format(betas[0])) 60 | if not 0.0 <= betas[1] < 1.0: 61 | raise ValueError("Invalid beta parameter: {} - should be in [0.0, 1.0[".format(betas[1])) 62 | if not 0.0 <= eps: 63 | raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(eps)) 64 | print("anneal_k: {:4f}, anneal_t0: {:4f}".format(anneal_k, anneal_t0)) 65 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, correct_bias=correct_bias, anneal_fun=anneal_fun, anneal_k=anneal_k, anneal_t0=anneal_t0, anneal_w=anneal_w, pretrain_cof=pretrain_cof, pretrain_params=pretrain_params) 66 | super().__init__(params, defaults) 67 | 68 | def step(self, closure=None): 69 | """Performs a single optimization step. 70 | Arguments: 71 | closure (callable, optional): A closure that reevaluates the model 72 | and returns the loss. 73 | """ 74 | loss = None 75 | if closure is not None: 76 | loss = closure() 77 | for group in self.param_groups: 78 | for p, pp in zip(group["params"], group["pretrain_params"]): 79 | if p.grad is None: 80 | continue 81 | grad = p.grad.data 82 | if grad.is_sparse: 83 | raise RuntimeError("Adam does not support sparse gradients, please consider SparseAdam instead") 84 | 85 | state = self.state[p] 86 | 87 | # State initialization 88 | if len(state) == 0: 89 | state["step"] = 0 90 | # Exponential moving average of gradient values 91 | state["exp_avg"] = torch.zeros_like(p.data) 92 | # Exponential moving average of squared gradient values 93 | state["exp_avg_sq"] = torch.zeros_like(p.data) 94 | 95 | exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] 96 | beta1, beta2 = group["betas"] 97 | 98 | state["step"] += 1 99 | 100 | # Decay the first and second moment running average coefficient 101 | # In-place operations to update the averages at the same time 102 | exp_avg.mul_(beta1).add_(1.0 - beta1, grad) 103 | exp_avg_sq.mul_(beta2).addcmul_(1.0 - beta2, grad, grad) 104 | denom = exp_avg_sq.sqrt().add_(group["eps"]) 105 | 106 | step_size = group["lr"] 107 | if group["correct_bias"]: 108 | bias_correction1 = 1.0 - beta1 ** state["step"] 109 | bias_correction2 = 1.0 - beta2 ** state["step"] 110 | step_size = step_size * math.sqrt(bias_correction2) / bias_correction1 111 | 112 | # With RecAdam method, the optimization objective is 113 | # Loss = lambda(t)*Loss_T + (1-lambda(t))*Loss_S 114 | # Loss = lambda(t)*Loss_T + (1-lambda(t))*\gamma/2*\sum((\theta_i-\theta_i^*)^2) 115 | if group['anneal_w'] > 0.0: 116 | # print("anneal_fun: {}, anneal_w: {:4f}, anneal_k: {:4f}, anneal_t0: {:4f}".format(group['anneal_fun'], group['anneal_w'], group['anneal_k'], group['anneal_t0'])) 117 | # We calculate the lambda as the annealing function 118 | anneal_lambda = anneal_function(group['anneal_fun'], state["step"], group['anneal_k'], group['anneal_t0'], group['anneal_w']) 119 | assert anneal_lambda <= group['anneal_w'] 120 | # print(anneal_lambda, step_size) 121 | # The loss of the target task is multiplied by lambda(t) 122 | p.data.addcdiv_(-step_size * anneal_lambda, exp_avg, denom) 123 | # Add the quadratic penalty to simulate the pretraining tasks 124 | p.data.add_(-group["lr"] * (group['anneal_w'] - anneal_lambda) * group["pretrain_cof"], p.data - pp.data) 125 | else: 126 | p.data.addcdiv_(-step_size, exp_avg, denom) 127 | 128 | # Just adding the square of the weights to the loss function is *not* 129 | # the correct way of using L2 regularization/weight decay with Adam, 130 | # since that will interact with the m and v parameters in strange ways. 131 | # 132 | # Instead we want to decay the weights in a manner that doesn't interact 133 | # with the m/v parameters. This is equivalent to adding the square 134 | # of the weights to the loss with plain (non-momentum) SGD. 135 | # Add weight decay at the end (fixed version) 136 | if group["weight_decay"] > 0.0: 137 | p.data.add_(-group["lr"] * group["weight_decay"], p.data) 138 | 139 | return loss -------------------------------------------------------------------------------- /src/sdpt_pretraining.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import torch 4 | from transformers import BartForConditionalGeneration, get_linear_schedule_with_warmup 5 | import numpy as np 6 | from preprocessing import BartDataset, DataReader 7 | from others.logging import init_logger, logger 8 | from others.utils import load, fix_random_seed 9 | from others.optimizer import build_optim 10 | 11 | 12 | 13 | def load_dataloader(args): 14 | train_file_name = './dataset/' + args.data_name + '/trainloader.pt' 15 | train_loader = load(train_file_name) 16 | logger.info('train loader has {} samples'.format(len(train_loader.dataset))) 17 | return train_loader 18 | 19 | def train(model, training_data, optimizer, checkpoint, args, pretrained_model): 20 | ''' Start training ''' 21 | if args.logging_Euclid_dist: 22 | t_total = len(training_data) // args.accumulation_steps * 10 23 | scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total) 24 | logger.info('Start training') 25 | iteration = 0 26 | if args.break_point_continue: 27 | iteration = checkpoint['iteration'] 28 | total_loss = 0 29 | F1 = 0 30 | for epoch_i in range(args.epoch): 31 | logger.info('[ Epoch : {}]'.format(epoch_i)) 32 | dist_sum, dist_num = 0.0, 0 33 | # training part 34 | model.train() 35 | for src_ids, decoder_ids, mask, label_ids in training_data: 36 | iteration += 1 37 | src_ids = src_ids.cuda() 38 | decoder_ids = decoder_ids.cuda() 39 | mask = mask.cuda() 40 | label_ids = label_ids.cuda() 41 | # forward 42 | # optimizer.optimizer.zero_grad() 43 | loss = model(input_ids=src_ids, attention_mask=mask, decoder_input_ids=decoder_ids, labels=label_ids)[0] 44 | total_loss += loss.item() 45 | loss = loss / args.accumulation_steps 46 | # backward 47 | loss.backward() 48 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip) 49 | # loss accumulation 50 | if (iteration+1) % args.accumulation_steps == 0: 51 | optimizer.step() 52 | if args.recadam: 53 | scheduler.step() 54 | model.zero_grad() 55 | 56 | if args.logging_Euclid_dist: 57 | dist = torch.sum(torch.abs(torch.cat( 58 | [p.view(-1) for n, p in model.named_parameters()]) - torch.cat( 59 | [p.view(-1) for n, p in pretrained_model.named_parameters()])) ** 2).item() 60 | 61 | dist_sum += dist 62 | dist_num += 1 63 | # write to log file 64 | if iteration % 20 == 0: 65 | if args.logging_Euclid_dist: 66 | logger.info("iteration: {} loss_per_word: {:4f} Euclid dist: {:.6f}".format(iteration, total_loss/20, dist_sum / dist_num)) 67 | else: 68 | logger.info("iteration: {} loss_per_word: {:4f} learning rate: {:4f} ".format(iteration, total_loss/20, optimizer.learning_rate)) 69 | total_loss = 0 70 | # save model 71 | if iteration % args.save_interval == 0 and iteration > args.start_to_save_iter: 72 | print('=====Saving checkpoint=====') 73 | model_name = args.saving_path + "/{}_{}.chkpt".format(args.data_name, iteration) 74 | torch.save(model, model_name) 75 | else: 76 | pass 77 | 78 | if __name__ == '__main__': 79 | # for training 80 | parser = argparse.ArgumentParser() 81 | parser.add_argument('-visible_gpu', default='1', type=str) 82 | parser.add_argument('-log_file', default='./logs/', type=str) 83 | parser.add_argument('-train_from', default='', type=str) 84 | parser.add_argument('-random_seed', type=int, default=0) 85 | parser.add_argument('-lr', default=0.05, type=float) 86 | parser.add_argument('-max_grad_norm', default=0, type=float) 87 | parser.add_argument('-epoch', type=int, default=50) 88 | parser.add_argument('-max_iter', type=int, default=800000) 89 | parser.add_argument('-saving_path', default='./save/', type=str) 90 | parser.add_argument('-data_name', default='debate', type=str) 91 | parser.add_argument('-minor_data', action='store_true') 92 | parser.add_argument('-pre_trained_lm', default='', type=str) 93 | parser.add_argument('-pre_trained_src', action='store_true') 94 | parser.add_argument('-break_point_continue', action='store_true') 95 | parser.add_argument('-corpus_path', type=str, default="", help="target domain corpus path") 96 | parser.add_argument('-mask_prob', type=float, default=0.15, help="mask probability") 97 | # for learning, optimizer 98 | parser.add_argument('-mtl', action='store_true', help='multitask learning') 99 | parser.add_argument('-optim', default='adam', type=str) 100 | parser.add_argument('-beta1', default=0.9, type=float) 101 | parser.add_argument('-beta2', default=0.998, type=float) 102 | parser.add_argument('-warmup_steps', default=1000, type=int) 103 | parser.add_argument('-decay_method', default='noam', type=str) 104 | parser.add_argument('-enc_hidden_size', default=768, type=int) 105 | parser.add_argument('-clip', default=1.0, type=float) 106 | parser.add_argument('-accumulation_steps', default=10, type=int) 107 | parser.add_argument('-bsz', default=4, type=int, help='batch size') 108 | # for evaluation 109 | parser.add_argument('-process_num', default=4, type=int) 110 | parser.add_argument('-start_to_save_iter', default=3000, type=int) 111 | parser.add_argument('-save_interval', default=10000, type=int) 112 | # using RecAdam 113 | parser.add_argument("-adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.") 114 | parser.add_argument('-recadam', default=False, action='store_true') 115 | parser.add_argument("-weight_decay", default=0.0, type=float, help="Weight decay if we apply some.") 116 | parser.add_argument("-anneal_w", type=float, default=1.0, help="Weight for the annealing function in RecAdam. Default 1.0.") 117 | parser.add_argument("-anneal_fun", type=str, default='sigmoid', choices=["sigmoid", "linear", 'constant'], help="the type of annealing function in RecAdam. Default sigmoid") 118 | parser.add_argument("-anneal_t0", type=int, default=1000, help="t0 for the annealing function in RecAdam.") 119 | parser.add_argument("-anneal_k", type=float, default=0.1, help="k for the annealing function in RecAdam.") 120 | parser.add_argument("-pretrain_cof", type=float, default=5000.0, help="Coefficient of the quadratic penalty in RecAdam. Default 5000.0.") 121 | parser.add_argument("-logging_Euclid_dist", action="store_true", help="Whether to log the Euclidean distance between the pretrained model and fine-tuning model") 122 | parser.add_argument("-max_steps", default=-1, type=int, help="If > 0: set total number of training steps to perform. Override num_train_epochs.") 123 | parser.add_argument("-model_type", type=str, default="layers") 124 | args = parser.parse_args() 125 | 126 | # initial logger 127 | if not os.path.exists(args.log_file + args.data_name): 128 | os.makedirs(args.log_file + args.data_name) 129 | init_logger(args.log_file + args.data_name + '/DAPT_pretraining.log') 130 | logger.info(args) 131 | # set gpu 132 | os.environ["CUDA_VISIBLE_DEVICES"] = args.visible_gpu 133 | # set random seed 134 | fix_random_seed(args.random_seed) 135 | 136 | # loading data 137 | # it's faster to load data from pre_build data 138 | logger.info('starting to read dataloader') 139 | train_loader = load_dataloader(args) 140 | 141 | # initial model and optimizer 142 | logger.info('starting to build model') 143 | model = BartForConditionalGeneration.from_pretrained('facebook/bart-base') 144 | model.cuda() 145 | checkpoint = None 146 | optim = build_optim(args, model, None, model) 147 | pretrained_model = BartForConditionalGeneration.from_pretrained('facebook/bart-base') 148 | if args.recadam: 149 | pretrained_model.cuda() 150 | optim = build_optim(args, model, None, pretrained_model) 151 | 152 | # training 153 | train(model, train_loader, optim, checkpoint, args, pretrained_model) 154 | -------------------------------------------------------------------------------- /src/run.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import torch 4 | from torch.utils.data import DataLoader 5 | from transformers import BartForConditionalGeneration 6 | from others.logging import init_logger, logger 7 | from others.utils import load, count_parameters, initialize_weights, batch_generator, fix_random_seed 8 | from preprocessing import BartDataset, DataReader 9 | from others.optimizer import build_optim 10 | from trainer import train, multitask_train 11 | from dapt_pretraining import CorpusDataset 12 | import random 13 | import numpy as np 14 | 15 | def make_log_file_name(args): 16 | if args.pre_trained_lm != '': 17 | log_file_name = args.log_file + args.data_name + '/train_' + args.pre_trained_lm.split('/')[-1][:-3] + '_' + args.percentage + '%_pretrain_lm.log' 18 | elif args.pre_trained_src: 19 | log_file_name = args.log_file + args.data_name + '/train_' + args.train_from.split('/')[-1][:-6] + '_' + args.percentage + '_' + '%_pretrain_src.log' 20 | else: 21 | log_file_name = args.log_file + args.data_name + '/train_' + args.percentage + '%.log' 22 | return log_file_name 23 | 24 | def load_dataloader(args): 25 | train_file_name = './dataset/' + args.data_name + '/trainloader.pt' 26 | train_loader = load(train_file_name) 27 | valid_file_name = './dataset/' + args.data_name + '/validloader.pt' 28 | valid_loader = load(valid_file_name) 29 | logger.info('train loader has {} samples'.format(len(train_loader.dataset))) 30 | return train_loader, valid_loader 31 | 32 | def load_dataloader_for_domain_corpus(args): 33 | corpus_dataset = CorpusDataset(args.corpus_path, denoising_flag=True) 34 | dataloader = DataLoader(dataset=corpus_dataset, batch_size=args.bsz, shuffle=True) 35 | data_generator = batch_generator(dataloader) 36 | return data_generator 37 | 38 | def load_model(args): 39 | model = BartForConditionalGeneration.from_pretrained('facebook/bart-base') 40 | if args.pre_trained_lm != '': 41 | model = torch.load(args.pre_trained_lm, map_location='cpu') 42 | # load from saved model 43 | if args.train_from != '': 44 | logger.info("train from : {}".format(args.train_from)) 45 | if "mtl_pre_trained_lm" in args.train_from: 46 | checkpoint = torch.load(args.train_from, map_location='cpu') 47 | model.load_state_dict(checkpoint['model_lm']) 48 | elif "xsum" in args.train_from: 49 | checkpoint = None 50 | model = BartForConditionalGeneration.from_pretrained('VictorSanh/bart-base-finetuned-xsum') 51 | else: 52 | checkpoint = torch.load(args.train_from, map_location='cpu') 53 | model.load_state_dict(checkpoint['model']) 54 | if args.train_from == '': 55 | checkpoint = None 56 | if args.mtl: 57 | model_lm = BartForConditionalGeneration.from_pretrained('facebook/bart-base') 58 | if args.pre_trained_lm != '': 59 | model_lm = torch.load(args.pre_trained_lm, map_location='cpu') 60 | model_cnn = BartForConditionalGeneration.from_pretrained('facebook/bart-base') 61 | # shared part 62 | model_cnn.model.shared = model_lm.model.shared 63 | # encoder part 64 | model_cnn.model.encoder = model_lm.model.encoder 65 | 66 | print("dont share decoder!") 67 | model = None 68 | return model_lm, model_cnn, checkpoint 69 | return model, checkpoint 70 | 71 | if __name__ == '__main__': 72 | # for training 73 | parser = argparse.ArgumentParser() 74 | parser.add_argument('-visible_gpu', default='1', type=str) 75 | parser.add_argument('-log_file', default='./logs/', type=str) 76 | parser.add_argument('-train_from', default='', type=str) 77 | parser.add_argument('-random_seed', type=int, default=0) 78 | parser.add_argument('-lr', default=0.05, type=float) 79 | parser.add_argument('-max_grad_norm', default=0, type=float) 80 | parser.add_argument('-epoch', type=int, default=50) 81 | parser.add_argument('-max_iter', type=int, default=800000) 82 | parser.add_argument('-saving_path', default='./save/', type=str) 83 | parser.add_argument('-data_name', default='debate', type=str) 84 | parser.add_argument('-pre_trained_lm', default='', type=str) 85 | parser.add_argument('-pre_trained_src', action='store_true') 86 | parser.add_argument('-break_point_continue', action='store_true') 87 | parser.add_argument('-percentage', default='100', type=str) 88 | parser.add_argument('-corpus_path', type=str, default="", help="target domain corpus path") 89 | parser.add_argument('-mask_prob', type=float, default=0.15, help="mask probability") 90 | # for learning, optimizer 91 | parser.add_argument('-mtl', action='store_true', help='multitask learning') 92 | parser.add_argument('-optim', default='adam', type=str) 93 | parser.add_argument('-beta1', default=0.9, type=float) 94 | parser.add_argument('-beta2', default=0.998, type=float) 95 | parser.add_argument('-warmup_steps', default=1000, type=int) 96 | parser.add_argument('-decay_method', default='noam', type=str) 97 | parser.add_argument('-enc_hidden_size', default=768, type=int) 98 | parser.add_argument('-clip', default=1.0, type=float) 99 | parser.add_argument('-accumulation_steps', default=10, type=int) 100 | parser.add_argument('-bsz', default=4, type=int, help='batch size') 101 | # for evaluation 102 | parser.add_argument('-process_num', default=4, type=int) 103 | parser.add_argument('-save_interval', default=100, type=int) 104 | parser.add_argument('-start_to_save_iter', default=100, type=int) 105 | # using RecAdam 106 | parser.add_argument("-adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.") 107 | parser.add_argument('-recadam', default=False, action='store_true') 108 | parser.add_argument("-weight_decay", default=0.0, type=float, help="Weight decay if we apply some.") 109 | parser.add_argument("-anneal_w", type=float, default=1.0, help="Weight for the annealing function in RecAdam. Default 1.0.") 110 | parser.add_argument("-anneal_fun", type=str, default='sigmoid', choices=["sigmoid", "linear", 'constant'], help="the type of annealing function in RecAdam. Default sigmoid") 111 | parser.add_argument("-anneal_t0", type=int, default=1000, help="t0 for the annealing function in RecAdam.") 112 | parser.add_argument("-anneal_k", type=float, default=0.1, help="k for the annealing function in RecAdam.") 113 | parser.add_argument("-pretrain_cof", type=float, default=5000.0, help="Coefficient of the quadratic penalty in RecAdam. Default 5000.0.") 114 | parser.add_argument("-logging_Euclid_dist", action="store_true", help="Whether to log the Euclidean distance between the pretrained model and fine-tuning model") 115 | parser.add_argument("-max_steps", default=-1, type=int, help="If > 0: set total number of training steps to perform. Override num_train_epochs.") 116 | parser.add_argument("-model_type", type=str, default="layers") 117 | args = parser.parse_args() 118 | 119 | # initial logger 120 | init_logger(make_log_file_name(args)) 121 | logger.info(args) 122 | # set gpu 123 | os.environ["CUDA_VISIBLE_DEVICES"] = args.visible_gpu 124 | # set random seed 125 | fix_random_seed(args.random_seed) 126 | 127 | # loading data 128 | # it's faster to load data from pre_build data 129 | logger.info('starting to read dataloader') 130 | train_loader, valid_loader = load_dataloader(args) 131 | 132 | # initial model and optimizer 133 | logger.info('starting to build model') 134 | if not args.mtl: 135 | model, checkpoint = load_model(args) 136 | model.cuda() 137 | optim = build_optim(args, model, None, model) 138 | pretrained_model = BartForConditionalGeneration.from_pretrained('facebook/bart-base') 139 | if args.recadam: 140 | pretrained_model.cuda() 141 | optim = build_optim(args, model, None, pretrained_model) 142 | else: 143 | model_lm, model_cnn, checkpoint = load_model(args) 144 | model_lm.cuda() 145 | model_cnn.cuda() 146 | optim_lm = build_optim(args, model_lm, checkpoint) 147 | optim_cnn = build_optim(args, model_cnn, checkpoint) 148 | 149 | # training 150 | if args.mtl: 151 | assert args.data_name == "cnn_dm" or args.data_name == "xsum" 152 | tgtdomain_data = load_dataloader_for_domain_corpus(args) 153 | cnn_train_data = batch_generator(train_loader) 154 | cnn_valid_data = batch_generator(valid_loader) 155 | multitask_train(model_lm, model_cnn, cnn_train_data, cnn_valid_data, tgtdomain_data, optim_lm, optim_cnn, checkpoint, args) 156 | else: 157 | train(model, train_loader, valid_loader, optim, checkpoint, args, pretrained_model) 158 | -------------------------------------------------------------------------------- /src/others/optimizer.py: -------------------------------------------------------------------------------- 1 | from others.recadam import RecAdam, anneal_function 2 | import torch 3 | import torch.optim as optim 4 | 5 | def build_optim(args, model, checkpoint, pretrained_model=None): 6 | """ Build optimizer """ 7 | if args.recadam: 8 | print("Using RecAdam") 9 | no_decay = ["bias", "layer_norm.weight", "layernorm_embedding.weight"] 10 | optimizer_grouped_parameters = [ 11 | { 12 | "params": [p for n, p in model.named_parameters() if 13 | not any(nd in n for nd in no_decay) and args.model_type in n], 14 | "weight_decay": args.weight_decay, 15 | "anneal_w": args.anneal_w, 16 | "pretrain_params": [p_p for p_n, p_p in pretrained_model.named_parameters() if 17 | not any(nd in p_n for nd in no_decay) and args.model_type in p_n] 18 | }, 19 | { 20 | "params": [p for n, p in model.named_parameters() if 21 | not any(nd in n for nd in no_decay) and args.model_type not in n], 22 | "weight_decay": args.weight_decay, 23 | "anneal_w": 0.0, 24 | "pretrain_params": [p_p for p_n, p_p in pretrained_model.named_parameters() if 25 | not any(nd in p_n for nd in no_decay) and args.model_type not in p_n] 26 | }, 27 | { 28 | "params": [p for n, p in model.named_parameters() if 29 | any(nd in n for nd in no_decay) and args.model_type in n], 30 | "weight_decay": 0.0, 31 | "anneal_w": args.anneal_w, 32 | "pretrain_params": [p_p for p_n, p_p in pretrained_model.named_parameters() if 33 | any(nd in p_n for nd in no_decay) and args.model_type in p_n] 34 | }, 35 | { 36 | "params": [p for n, p in model.named_parameters() if 37 | any(nd in n for nd in no_decay) and args.model_type not in n], 38 | "weight_decay": 0.0, 39 | "anneal_w": 0.0, 40 | "pretrain_params": [p_p for p_n, p_p in pretrained_model.named_parameters() if 41 | any(nd in p_n for nd in no_decay) and args.model_type not in p_n] 42 | } 43 | ] 44 | optim = RecAdam(optimizer_grouped_parameters, lr=args.lr, eps=args.adam_epsilon, anneal_fun=args.anneal_fun, anneal_k=args.anneal_k, anneal_t0=args.anneal_t0, pretrain_cof=args.pretrain_cof) 45 | else: 46 | optim = Optimizer( 47 | args.optim, args.lr, args.max_grad_norm, 48 | beta1=args.beta1, beta2=args.beta2, 49 | decay_method=args.decay_method, 50 | warmup_steps=args.warmup_steps, model_size=args.enc_hidden_size) 51 | 52 | optim.set_parameters(list(model.named_parameters())) 53 | 54 | if args.train_from != '' and 'xsum' not in args.train_from: 55 | optim.optimizer.load_state_dict(checkpoint['optim']) 56 | if args.visible_gpu != '-1': 57 | for state in optim.optimizer.state.values(): 58 | for k, v in state.items(): 59 | if torch.is_tensor(v): 60 | state[k] = v.cuda() 61 | 62 | if (optim.method == 'adam') and (len(optim.optimizer.state) < 1): 63 | raise RuntimeError( 64 | "Error: loaded Adam optimizer from existing model" + 65 | " but optimizer state is empty") 66 | 67 | 68 | return optim 69 | 70 | class Optimizer(object): 71 | """ 72 | Controller class for optimization. Mostly a thin 73 | wrapper for `optim`, but also useful for implementing 74 | rate scheduling beyond what is currently available. 75 | Also implements necessary methods for training RNNs such 76 | as grad manipulations. 77 | 78 | Args: 79 | method (:obj:`str`): one of [sgd, adagrad, adadelta, adam] 80 | lr (float): learning rate 81 | lr_decay (float, optional): learning rate decay multiplier 82 | start_decay_steps (int, optional): step to start learning rate decay 83 | beta1, beta2 (float, optional): parameters for adam 84 | adagrad_accum (float, optional): initialization parameter for adagrad 85 | decay_method (str, option): custom decay options 86 | warmup_steps (int, option): parameter for `noam` decay 87 | model_size (int, option): parameter for `noam` decay 88 | 89 | We use the default parameters for Adam that are suggested by 90 | the original paper https://arxiv.org/pdf/1412.6980.pdf 91 | These values are also used by other established implementations, 92 | e.g. https://www.tensorflow.org/api_docs/python/tf/train/AdamOptimizer 93 | https://keras.io/optimizers/ 94 | Recently there are slightly different values used in the paper 95 | "Attention is all you need" 96 | https://arxiv.org/pdf/1706.03762.pdf, particularly the value beta2=0.98 97 | was used there however, beta2=0.999 is still arguably the more 98 | established value, so we use that here as well 99 | """ 100 | 101 | def __init__(self, method, learning_rate, max_grad_norm, 102 | lr_decay=1, start_decay_steps=None, decay_steps=None, 103 | beta1=0.9, beta2=0.999, 104 | adagrad_accum=0.0, 105 | decay_method=None, 106 | warmup_steps=4000, 107 | model_size=None): 108 | self.last_ppl = None 109 | self.learning_rate = learning_rate 110 | self.original_lr = learning_rate 111 | self.max_grad_norm = max_grad_norm 112 | self.method = method 113 | self.lr_decay = lr_decay 114 | self.start_decay_steps = start_decay_steps 115 | self.decay_steps = decay_steps 116 | self.start_decay = False 117 | self._step = 0 118 | self.betas = [beta1, beta2] 119 | self.adagrad_accum = adagrad_accum 120 | self.decay_method = decay_method 121 | self.warmup_steps = warmup_steps 122 | self.model_size = model_size 123 | 124 | def set_parameters(self, params): 125 | """ ? """ 126 | self.params = [] 127 | self.sparse_params = [] 128 | for k, p in params: 129 | if p.requires_grad: 130 | if self.method != 'sparseadam' or "embed" not in k: 131 | self.params.append(p) 132 | else: 133 | self.sparse_params.append(p) 134 | if self.method == 'sgd': 135 | self.optimizer = optim.SGD(self.params, lr=self.learning_rate) 136 | elif self.method == 'adagrad': 137 | self.optimizer = optim.Adagrad(self.params, lr=self.learning_rate) 138 | for group in self.optimizer.param_groups: 139 | for p in group['params']: 140 | self.optimizer.state[p]['sum'] = self.optimizer\ 141 | .state[p]['sum'].fill_(self.adagrad_accum) 142 | elif self.method == 'adadelta': 143 | self.optimizer = optim.Adadelta(self.params, lr=self.learning_rate) 144 | elif self.method == 'adam': 145 | self.optimizer = optim.Adam(self.params, lr=self.learning_rate, 146 | betas=self.betas, eps=1e-9) 147 | elif self.method == 'sparseadam': 148 | self.optimizer = MultipleOptimizer( 149 | [optim.Adam(self.params, lr=self.learning_rate, 150 | betas=self.betas, eps=1e-8), 151 | optim.SparseAdam(self.sparse_params, lr=self.learning_rate, 152 | betas=self.betas, eps=1e-8)]) 153 | else: 154 | raise RuntimeError("Invalid optim method: " + self.method) 155 | 156 | def _set_rate(self, learning_rate): 157 | self.learning_rate = learning_rate 158 | if self.method != 'sparseadam': 159 | self.optimizer.param_groups[0]['lr'] = self.learning_rate 160 | else: 161 | for op in self.optimizer.optimizers: 162 | op.param_groups[0]['lr'] = self.learning_rate 163 | 164 | def step(self): 165 | """Update the model parameters based on current gradients. 166 | 167 | Optionally, will employ gradient modification or update learning 168 | rate. 169 | """ 170 | self._step += 1 171 | 172 | # Decay method used in tensor2tensor. 173 | if self.decay_method == "noam": 174 | self._set_rate( 175 | self.original_lr * 176 | ( self.model_size ** -0.5*min(self._step ** (-0.5), 177 | self._step * self.warmup_steps**(-1.5)))) 178 | else: 179 | if ((self.start_decay_steps is not None) and ( 180 | self._step >= self.start_decay_steps)): 181 | self.start_decay = True 182 | if self.start_decay: 183 | if ((self._step - self.start_decay_steps) 184 | % self.decay_steps == 0): 185 | self.learning_rate = self.learning_rate * self.lr_decay 186 | 187 | if self.method != 'sparseadam': 188 | self.optimizer.param_groups[0]['lr'] = self.learning_rate 189 | 190 | if self.max_grad_norm: 191 | clip_grad_norm_(self.params, self.max_grad_norm) 192 | self.optimizer.step() -------------------------------------------------------------------------------- /src/trainer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from tqdm import tqdm 3 | import os 4 | from transformers import BartTokenizer, get_linear_schedule_with_warmup 5 | from others.logging import logger 6 | from others.utils import pad_sents, get_mask 7 | from cal_rouge import test_rouge, rouge_results_to_str 8 | from dapt_pretraining import text_infilling, sent_permutation, add_noise 9 | 10 | def train(model, training_data, validation_data, optimizer, checkpoint, args, pretrained_model): 11 | ''' Start training ''' 12 | if args.logging_Euclid_dist: 13 | t_total = len(training_data) // args.accumulation_steps * 10 14 | scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total) 15 | logger.info('Start training') 16 | iteration = 0 17 | if args.break_point_continue: 18 | iteration = checkpoint['iteration'] 19 | total_loss = 0 20 | F1 = 0 21 | for epoch_i in range(args.epoch): 22 | logger.info('[ Epoch : {}]'.format(epoch_i)) 23 | dist_sum, dist_num = 0.0, 0 24 | # training part 25 | model.train() 26 | for src_ids, decoder_ids, mask, label_ids in training_data: 27 | iteration += 1 28 | src_ids = src_ids.cuda() 29 | decoder_ids = decoder_ids.cuda() 30 | mask = mask.cuda() 31 | label_ids = label_ids.cuda() 32 | # forward 33 | # optimizer.optimizer.zero_grad() 34 | loss = model(input_ids=src_ids, attention_mask=mask, decoder_input_ids=decoder_ids, labels=label_ids)[0] 35 | total_loss += loss.item() 36 | loss = loss / args.accumulation_steps 37 | # backward 38 | loss.backward() 39 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip) 40 | # loss accumulation 41 | if (iteration+1) % args.accumulation_steps == 0: 42 | optimizer.step() 43 | if args.recadam: 44 | scheduler.step() 45 | model.zero_grad() 46 | 47 | if args.logging_Euclid_dist: 48 | dist = torch.sum(torch.abs(torch.cat( 49 | [p.view(-1) for n, p in model.named_parameters()]) - torch.cat( 50 | [p.view(-1) for n, p in pretrained_model.named_parameters()])) ** 2).item() 51 | 52 | dist_sum += dist 53 | dist_num += 1 54 | # write to log file 55 | if iteration % 20 == 0: 56 | if args.logging_Euclid_dist: 57 | logger.info("iteration: {} loss_per_word: {:4f} Euclid dist: {:.6f}".format(iteration, total_loss/20, dist_sum / dist_num)) 58 | else: 59 | logger.info("iteration: {} loss_per_word: {:4f} learning rate: {:4f} ".format(iteration, total_loss/20, optimizer.learning_rate)) 60 | total_loss = 0 61 | # save model 62 | if iteration % args.save_interval == 0 and iteration > args.start_to_save_iter: 63 | temp_F1 = evaluation(model, validation_data, args) 64 | model.train() 65 | if temp_F1 > F1: 66 | logger.info("saving model") 67 | if not os.path.exists(args.saving_path + args.data_name): 68 | os.makedirs(args.saving_path + args.data_name) 69 | model_name = make_file_name(args, iteration) 70 | # checkpoint = {'iteration': iteration, 'settings': args, 'optim': optimizer.optimizer.state_dict(), 'model': model.state_dict()} 71 | torch.save(model, model_name) 72 | F1 = temp_F1 73 | else: 74 | pass 75 | 76 | 77 | def multitask_train(model_lm, model_cnn, cnn_train_data, cnn_valid_data, tgtdomain_data, optimizer_lm, optimizer_cnn, checkpoint, args): 78 | tokenizer = BartTokenizer.from_pretrained('facebook/bart-base') 79 | ''' Start training ''' 80 | logger.info('Start multitask training') 81 | iteration = 0 82 | if args.train_from!= '': 83 | iteration = checkpoint['iteration'] 84 | cnn_loss = 0 85 | lm_loss = 0 86 | F1 = 0 87 | while iteration < args.max_iter: 88 | iteration += 1 89 | model_lm.train() 90 | model_cnn.train() 91 | 92 | ## cnn news summarization training part 93 | src_ids, decoder_ids, mask, label_ids = next(cnn_train_data) 94 | src_ids = src_ids.cuda() 95 | decoder_ids = decoder_ids.cuda() 96 | mask = mask.cuda() 97 | label_ids = label_ids.cuda() 98 | 99 | loss = model_cnn(input_ids=src_ids, attention_mask=mask, decoder_input_ids=decoder_ids, labels=label_ids)[0] 100 | cnn_loss += loss.item() 101 | loss = loss / args.accumulation_steps 102 | # backward 103 | loss.backward() 104 | # torch.nn.utils.clip_grad_norm_(model_cnn.parameters(), args.clip) 105 | 106 | ## denoising language modeling part 107 | sents = next(tgtdomain_data) 108 | tokenized_sents = [tokenizer.encode(sent, add_special_tokens=False) for sent in sents] 109 | decoder_ids = [[tokenizer.bos_token_id] + item for item in tokenized_sents] 110 | label_ids = [item + [tokenizer.eos_token_id] for item in tokenized_sents] 111 | 112 | noisy_text = add_noise(sents, args.mask_prob) 113 | inputs_ids = [tokenizer.encode(sent, add_special_tokens=False) for sent in noisy_text] 114 | 115 | # prepare data for training 116 | inputs_ids = torch.tensor(pad_sents(inputs_ids, pad_token=tokenizer.pad_token_id)[0]).cuda() 117 | mask = torch.tensor(get_mask(inputs_ids)).cuda() 118 | decoder_ids = torch.tensor(pad_sents(decoder_ids, pad_token=tokenizer.pad_token_id)[0]).cuda() 119 | label_ids = torch.tensor(pad_sents(label_ids, pad_token=-100)[0]).cuda() 120 | 121 | # optimize model 122 | loss = model_lm(input_ids=inputs_ids, attention_mask=mask, decoder_input_ids=decoder_ids, labels=label_ids)[0] 123 | lm_loss += loss.item() 124 | loss = loss / args.accumulation_steps 125 | loss.backward() 126 | # torch.nn.utils.clip_grad_norm_(model_lm.parameters(), args.clip) 127 | 128 | # loss accumulation 129 | if (iteration+1) % args.accumulation_steps == 0: 130 | optimizer_lm.step() 131 | optimizer_cnn.step() 132 | model_lm.zero_grad() 133 | model_cnn.zero_grad() 134 | # write to log file 135 | if iteration % 20 == 0: 136 | logger.info("iteration: {} loss_per_word: {:.6f} loss_lm: {:.6f} learning rate lm: {:.9f} learning rate cnn: {:.9f}".format(iteration, cnn_loss/20, lm_loss/20, optimizer_lm.learning_rate, optimizer_cnn.learning_rate)) 137 | cnn_loss = 0 138 | lm_loss = 0 139 | 140 | if iteration % 50000 == 0: 141 | # eval_F1 = evaluation(model, cnn_valid_data, args) 142 | # logger.info("Iteration: {}. F1 score: {:.4f}".format(iteration, eval_F1)) 143 | logger.info("saving model") 144 | if not os.path.exists(args.saving_path + args.data_name): 145 | os.makedirs(args.saving_path + args.data_name) 146 | model_name = make_file_name(args, iteration) 147 | # checkpoint_1 = {'iteration': iteration, 'settings': args, 'optim': optimizer_lm.optimizer.state_dict(), 'model_lm': model_lm.state_dict()} 148 | # checkpoint_2 = {'iteration': iteration, 'settings': args, 'optim': optimizer_cnn.optimizer.state_dict(), 'model': model_cnn.state_dict()} 149 | torch.save(model_lm, model_name[0]) 150 | torch.save(model_cnn, model_name[1]) 151 | 152 | 153 | def evaluation(model, validation_data, args): 154 | model.eval() 155 | valid_reference_path = './dataset/' + args.data_name + '/valid.target' 156 | valid_data = open(valid_reference_path,'r') 157 | valid_list = valid_data.readlines() 158 | valid_list = [i.strip('\n') for i in valid_list] 159 | # inference 160 | tokenizer = BartTokenizer.from_pretrained('facebook/bart-base') 161 | outputs = [] 162 | for src_ids, decoder_ids, mask, label_ids in tqdm(validation_data): 163 | src_ids = src_ids.cuda() 164 | summary_ids = model.generate(src_ids, num_beams=4, max_length=256, early_stopping=True) 165 | output = [tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in summary_ids] 166 | outputs += output 167 | # calculate rouge 168 | final_results = test_rouge(outputs, valid_list, args.process_num) 169 | R1_F1 = final_results["rouge_1_f_score"] * 100 170 | logger.info('[ Validation ]') 171 | logger.info(rouge_results_to_str(final_results)) 172 | return R1_F1 173 | 174 | def make_file_name(args, iteration): 175 | model_name = args.saving_path + "{}/{}_75%sample.chkpt".format(args.data_name,iteration,args.percentage) 176 | if args.pre_trained_lm != '': 177 | # model_name = args.saving_path + "{}/{}_{}%DAPT_REG_EPOCH_1.chkpt".format(args.data_name,iteration,args.percentage) 178 | model_name = args.saving_path + "{}/{}_{}%DAPT_reg_160000.chkpt".format(args.data_name,iteration,args.percentage) 179 | if args.pre_trained_src: 180 | model_name = args.saving_path + "{}/{}_{}%cnn_pretrain_400000.chkpt".format(args.data_name,iteration,args.percentage) 181 | # model_name = args.saving_path + "{}/{}_{}%_pre_trained_src.chkpt".format(args.data_name,iteration,args.percentage) 182 | if args.mtl: 183 | model_name_1 = args.saving_path + "{}/{}_mtl_pre_trained_lm.chkpt".format('social_media_short',iteration) 184 | model_name_2 = args.saving_path + "{}/{}_mtl_pre_trained_src.chkpt".format('social_media_short',iteration) 185 | model_name = [model_name_1, model_name_2] 186 | return model_name 187 | 188 | def evaluation_loss(model, validation_data, args): 189 | model.eval() 190 | total_loss = 0 191 | for src_ids, decoder_ids, mask, label_ids in tqdm(validation_data): 192 | src_ids = src_ids.cuda() 193 | decoder_ids = decoder_ids.cuda() 194 | mask = mask.cuda() 195 | label_ids = label_ids.cuda() 196 | loss = model(input_ids=src_ids, attention_mask=mask, decoder_input_ids=decoder_ids, labels=label_ids)[0] 197 | total_loss += loss.item() 198 | loss = None 199 | logger.info(total_loss) 200 | return total_loss -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # AdaptSum: Towards Low-Resource Domain Adaptation for Abstractive Summarization 2 | 3 | [![](https://img.shields.io/badge/python-3.6+-blue.svg)](https://www.python.org/downloads/) [![CC BY 4.0][cc-by-shield]][cc-by] 4 | 5 | 6 | 7 | 8 | [cc-by]: http://creativecommons.org/licenses/by/4.0/ 9 | [cc-by-shield]: https://img.shields.io/badge/License-CC%20BY%204.0-lightgrey.svg 10 | 11 | 12 | Paper accepted at the [NAACL-HLT 2021](https://2021.naacl.org): 13 | 14 | **[AdaptSum: Towards Low-Resource Domain Adaptation for Abstractive Summarization](https://arxiv.org/pdf/2103.11332)**, by **[Tiezheng Yu*](https://tysonyu.github.io/)**, **[Zihan Liu*](https://zliucr.github.io/)**, [Pascale Fung](https://pascale.home.ece.ust.hk). 15 | 16 | ## Abstract 17 | State-of-the-art abstractive summarization models generally rely on extensive labeled data, which lowers their generalization ability on domains where such data are not available. In this paper, we present a study of domain adaptation for the abstractive summarization task across six diverse target domains in a low-resource setting. Specifically, we investigate the second phase of pre-training on large-scale generative models under three different settings: 1) source domain pre-training; 2) domain-adaptive pre-training; and 3) task-adaptive pre-training. Experiments show that the effectiveness of pre-training is correlated with the similarity between the pre-training data and the target domain task. Moreover, we find that continuing pre-training could lead to the pre-trained model's catastrophic forgetting, and a learning method with less forgetting can alleviate this issue. Furthermore, results illustrate that a huge gap still exists between the low-resource and high-resource settings, which highlights the need for more advanced domain adaptation methods for the abstractive summarization task. 18 | 19 | ## Dataset 20 | We release the AdaptSum dataset, which contains the summarization datasets across six target domains as well as the corpora for SDPT, DAPT and TAPT. You can download AdaptSum from [Here](https://drive.google.com/drive/folders/1qdkavIQonTAepkJhGpo3TZpU4LUW44sp?usp=sharing). 21 | 22 | ## Preparation for running 23 | 1. Create a new folder named `dataset` at the root of this project 24 | 2. Download the data from google drive and then put it in the `dataset` folder 25 | 3. Create the conda environment 26 | ``` 27 | conda create -n adaptsum python=3.6 28 | ``` 29 | 4. Activate the conda environment 30 | ``` 31 | conda activate adaptsum 32 | ``` 33 | 5. Install pytorch. Please check your CUDA version before the installation and modify it accordingly, or you can refer to [pytorch website](https://pytorch.org) 34 | ``` 35 | conda install pytorch cudatoolkit=11.0 -c pytorch 36 | ``` 37 | 6. Install requirements 38 | ``` 39 | pip install -r requirements.txt 40 | ``` 41 | 7. Create a new folder named `logs` at the root of this project 42 | ## SDPT pretraining 43 | - We take `cnn_dm` as an example 44 | 1. Create a new folder named `SDPT_save` at the root of this project 45 | 2. Prepare dataloader: 46 | ``` 47 | python ./src/preprocessing.py -data_path=dataset/ \ 48 | -data_name=SDPT-cnn_dm \ 49 | -mode=train \ 50 | -batch_size=4 51 | ``` 52 | 3. Run `./scripts/sdpt_pretraining.sh`. You can add `-recadam` and `-logging_Euclid_dist` to use RecAdam. 53 | 54 | ## DAPT pretraining 55 | - We take `debate domain` as an example 56 | 1. Create a new folder named `DAPT_save` at the root of this project 57 | 2. Run `./scripts/dapt_pretraining.sh`. You can add `-recadam` and `-logging_Euclid_dist` to use RecAdam. 58 | 59 | ## TAPT pretraining 60 | - We take `debate domain` as an example 61 | 1. Create a new folder named `TAPT_save` at the root of this project 62 | 2. Run `./scripts/tapt_pretraining.sh`. You can add `-recadam` and `-logging_Euclid_dist` to use RecAdam. 63 | 64 | ## Fine-tuning 65 | - We take `debate domain` as an example 66 | 1. Create a new folder named `debate` at `logs` 67 | 68 | 2. Prepare dataloader: 69 | ``` 70 | python ./src/preprocessing.py -data_path=dataset/ \ 71 | -data_name=debate \ 72 | -mode=train \ 73 | -batch_size=4 74 | python ./src/preprocessing.py -data_path=dataset/ \ 75 | -data_name=debate \ 76 | -mode=valid \ 77 | -batch_size=4 78 | python ./src/preprocessing.py -data_path=dataset/ \ 79 | -data_name=debate \ 80 | -mode=test \ 81 | -batch_size=4 82 | ``` 83 | 84 | 3. Install `pyrouge` package (You can skip this if you have already installed `pyrouge`) 85 | - Step 1 : Install Pyrouge from source (not from pip) 86 | ``` 87 | git clone https://github.com/bheinzerling/pyrouge 88 | cd pyrouge 89 | pip install -e . 90 | ``` 91 | - Step 2 : Install official ROUGE script 92 | ``` 93 | git clone https://github.com/andersjo/pyrouge.git rouge 94 | ``` 95 | - Step 3 : Point Pyrouge to official rouge script (The path given to pyrouge should be absolute path !) 96 | ``` 97 | pyrouge_set_rouge_path ~/pyrouge/rouge/tools/ROUGE-1.5.5/ 98 | ``` 99 | - Step 4 : Install libxml parser 100 | As mentioned in this [issue](https://github.com/bheinzerling/pyrouge/issues/27), you need to install libxml parser 101 | ``` 102 | sudo apt-get install libxml-parser-perl 103 | ``` 104 | - Step 5 : Regenerate the Exceptions DB 105 | As mentioned in this [issue](https://github.com/bheinzerling/pyrouge/issues/8), you need to regenerate the Exceptions DB 106 | ``` 107 | cd rouge/tools/ROUGE-1.5.5/data 108 | rm WordNet-2.0.exc.db 109 | ./WordNet-2.0-Exceptions/buildExeptionDB.pl ./WordNet-2.0-Exceptions ./smart_common_words.txt ./WordNet-2.0.exc.db 110 | ``` 111 | - Step 6 : Run the tests 112 | ``` 113 | python -m pyrouge.test 114 | ``` 115 | 116 | 4. Run Finetuning 117 | - If you don't want to use any second phase of pre-training, run: 118 | ``` 119 | python ./src/run.py -visible_gpu=0 \ 120 | -data_name=debate \ 121 | -save_interval=100 \ 122 | -start_to_save_iter=3000 123 | ``` 124 | - If you want to use pretrained checkpoints from **SDPT**, run: 125 | ``` 126 | python ./src/run.py -visible_gpu=0 \ 127 | -data_name=debate \ 128 | -save_interval=100 \ 129 | -start_to_save_iter=3000 \ 130 | -pre_trained_src \ 131 | -train_from=YOUR_SAVED_CHECKPOINTS 132 | ``` 133 | - If you want to use pretrained checkpoints from **DAPT** or **TAPT**, run: 134 | ``` 135 | python ./src/run.py -visible_gpu=0 \ 136 | -data_name=debate \ 137 | -save_interval=100 \ 138 | -start_to_save_iter=3000 \ 139 | -pre_trained_lm=YOUR_SAVED_CHECKPOINTS 140 | ``` 141 | 142 | 5. Evaluate the performance 143 | 1) Make a folder named `inference` at `logs` 144 | 2) You can do inference by 145 | ``` 146 | python ./src/inference.py -visible_gpu=0 -train_from=YOUR_SAVED_CHECKPOINT 147 | ``` 148 | 3) You can calculate rouge scores by 149 | ``` 150 | python ./src/cal_roug.py -c=CANDIDATE_FILE -r=REFERENCE_FILE -p=NUMBER_OF_PROCESS 151 | ``` 152 | 153 | ## References 154 | If you use our benchmark or the code in this repo, please cite our paper. 155 | 156 |
157 | @inproceedings{Yu2021AdaptSum,
158 |   title={AdaptSum: Towards Low-Resource Domain Adaptation for Abstractive Summarization},
159 |   author={Tiezheng Yu and Zihan Liu and Pascale Fung},
160 |   journal={arXiv preprint arXiv:2103.11332},
161 |   year={2021}
162 | }
163 | 
164 | 165 | Also, please consider citing all the individual datasets in your paper. 166 | 167 | Dialog domain: 168 |
169 | @inproceedings{gliwa2019samsum,
170 |   title={SAMSum Corpus: A Human-annotated Dialogue Dataset for Abstractive Summarization},
171 |   author={Gliwa, Bogdan and Mochol, Iwona and Biesek, Maciej and Wawer, Aleksander},
172 |   booktitle={Proceedings of the 2nd Workshop on New Frontiers in Summarization},
173 |   pages={70--79},
174 |   year={2019}
175 | }
176 | 
177 | 178 | Email domain: 179 |
180 | @inproceedings{zhang2019email,
181 |   title={This Email Could Save Your Life: Introducing the Task of Email Subject Line Generation},
182 |   author={Zhang, Rui and Tetreault, Joel},
183 |   booktitle={Proceedings of the 57th Annual Meeting of the Association for Computational Linguistics},
184 |   pages={446--456},
185 |   year={2019}
186 | }
187 | 
188 | 189 | Movie and debate domains: 190 |
191 | @inproceedings{wang2016neural,
192 |   title={Neural Network-Based Abstract Generation for Opinions and Arguments},
193 |   author={Wang, Lu and Ling, Wang},
194 |   booktitle={Proceedings of the 2016 Conference of the North American Chapter of the Association for Computational Linguistics: Human Language Technologies},
195 |   pages={47--57},
196 |   year={2016}
197 | }
198 | 
199 | 200 | Social media domain: 201 |
202 | @inproceedings{kim2019abstractive,
203 |   title={Abstractive Summarization of Reddit Posts with Multi-level Memory Networks},
204 |   author={Kim, Byeongchang and Kim, Hyunwoo and Kim, Gunhee},
205 |   booktitle={Proceedings of the 2019 Conference of the North American Chapter of the Association for Computational Linguistics: Human Language Technologies, Volume 1 (Long and Short Papers)},
206 |   pages={2519--2531},
207 |   year={2019}
208 | }
209 | 
210 | 211 | Science domain: 212 |
213 | @inproceedings{yasunaga2019scisummnet,
214 |   title={Scisummnet: A large annotated corpus and content-impact models for scientific paper summarization with citation networks},
215 |   author={Yasunaga, Michihiro and Kasai, Jungo and Zhang, Rui and Fabbri, Alexander R and Li, Irene and Friedman, Dan and Radev, Dragomir R},
216 |   booktitle={Proceedings of the AAAI Conference on Artificial Intelligence},
217 |   volume={33},
218 |   pages={7386--7393},
219 |   year={2019}
220 | }
221 | 
222 | 223 | -------------------------------------------------------------------------------- /src/dapt_pretraining.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | from torch.utils.data import Dataset, DataLoader 4 | from transformers import BartForConditionalGeneration, BartTokenizer, get_linear_schedule_with_warmup 5 | from others.logging import logger 6 | from others.utils import pad_sents, get_mask 7 | from others.optimizer import build_optim 8 | from tqdm import tqdm 9 | import numpy as np 10 | import argparse 11 | import random 12 | import os 13 | from nltk.tokenize import sent_tokenize 14 | 15 | 16 | def text_infilling(sent, mask_probability=0.05, lamda=3): 17 | ''' 18 | inputs: 19 | sent: a sentence string 20 | mask_probability: probability for masking tokens 21 | lamda: lamda for poission distribution 22 | outputs: 23 | sent: a list of tokens with masked tokens 24 | ''' 25 | sent = sent.split() 26 | length = len(sent) 27 | mask_indices = (np.random.uniform(0, 1, length) < mask_probability) * 1 28 | span_list = np.random.poisson(lamda, length) # lamda for poission distribution 29 | nonzero_idx = np.nonzero(mask_indices)[0] 30 | for item in nonzero_idx: 31 | span = min(span_list[item], 5) # maximum mask 5 continuous tokens 32 | for i in range(span): 33 | if item+i >= length: 34 | continue 35 | mask_indices[item+i] = 1 36 | for i in range(length): 37 | if mask_indices[i] == 1: 38 | sent[i] = '' 39 | 40 | # merge the s to one 41 | final_sent = [] 42 | mask_flag = 0 43 | for word in sent: 44 | if word != '': 45 | mask_flag = 0 46 | final_sent.append(word) 47 | else: 48 | if mask_flag == 0: 49 | final_sent.append(word) 50 | mask_flag = 1 51 | return final_sent 52 | 53 | def sent_permutation(sent): 54 | ''' 55 | inputs: 56 | sent: a sentence string 57 | outputs: 58 | shuffle_sent: a string after sentence permutations 59 | ''' 60 | # split sentences based on '.' 61 | splits = sent_tokenize(sent) 62 | random.shuffle(splits) 63 | 64 | return " ".join(splits) 65 | 66 | 67 | def add_noise(sents, mask_probability): 68 | noisy_sent_list = [] 69 | for sent in sents: 70 | noisy_sent = sent_permutation(sent) 71 | noisy_sent = text_infilling(noisy_sent, mask_probability) 72 | 73 | noisy_sent = " ".join(noisy_sent) 74 | noisy_sent_list.append(noisy_sent) 75 | 76 | return noisy_sent_list 77 | 78 | 79 | class CorpusDataset(Dataset): 80 | def __init__(self, data_path, denoising_flag=False): 81 | self.data = [] 82 | with open(data_path, "r", ) as f: 83 | for i, line in enumerate(f): 84 | line = line.strip() 85 | if denoising_flag: 86 | line = "denoising: " + line 87 | self.data.append(line) # append a list of tokens each time 88 | 89 | def __len__(self): 90 | return len(self.data) 91 | 92 | def __getitem__(self, idx): 93 | return self.data[idx] 94 | 95 | 96 | class BartLMTrainer(object): 97 | def __init__(self, model, dataloader, tokenizer, args, pretrained_model=None): 98 | self.args = args 99 | self.model = model 100 | self.pretrained_model = pretrained_model 101 | self.optimizer = build_optim(args, model, None, pretrained_model) 102 | self.dataloader = dataloader 103 | self.tokenizer = tokenizer 104 | self.epoch = args.epoch 105 | self.mask_probability = args.mask_prob 106 | self.accumulation_steps = args.accum_step 107 | self.clip = args.clip 108 | self.domain = args.dm 109 | self.path = args.path 110 | if args.recadam: 111 | if args.max_steps > 0: 112 | t_total = args.max_steps 113 | self.epoch = args.max_steps // (len(self.dataloader) // self.accumulation_steps) + 1 114 | else: 115 | t_total = len(self.dataloader) // self.accumulation_steps * self.epoch 116 | self.scheduler = get_linear_schedule_with_warmup(self.optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total) 117 | 118 | def train(self): 119 | print('Start finetuning BART language model') 120 | iteration = 0 121 | for epoch_i in range(self.epoch): 122 | self.model.train() 123 | if self.pretrained_model is not None: 124 | self.pretrained_model.eval() 125 | print('[ Epoch : {}]'.format(epoch_i)) 126 | loss_list = [] 127 | dist_sum, dist_num = 0.0, 0 128 | pbar = tqdm(self.dataloader, total=len(self.dataloader)) 129 | for sents in pbar: 130 | sents = [self.shorten_sent(sent) for sent in sents] 131 | iteration += 1 132 | tokenized_sents = self.tokenize(sents) 133 | decoder_ids = [[self.tokenizer.bos_token_id] + item for item in tokenized_sents] 134 | label_ids = [item + [self.tokenizer.eos_token_id] for item in tokenized_sents] 135 | # print("before:") 136 | # print(sents[0]) 137 | # print("tokenized sents:") 138 | # print(tokenized_sents[0]) 139 | # sents: a list of sentence, each item inside is a string 140 | noisy_text = add_noise(sents, self.mask_probability) 141 | # noisy_text: a list of sentence, each item inside is a string 142 | # print("after:") 143 | # print(noisy_text[0]) 144 | inputs_ids = self.tokenize(noisy_text) 145 | # print("tokenized noisy text:") 146 | # print(inputs_ids[0]) 147 | 148 | # prepare data for training 149 | mask = torch.tensor(get_mask(inputs_ids, max_len=512)).cuda() 150 | inputs_ids = torch.tensor(pad_sents(inputs_ids, pad_token=self.tokenizer.pad_token_id, max_len=512)[0]).cuda() 151 | decoder_ids = torch.tensor(pad_sents(decoder_ids, pad_token=self.tokenizer.pad_token_id, max_len=512)[0]).cuda() 152 | label_ids = torch.tensor(pad_sents(label_ids, pad_token=-100, max_len=512)[0]).cuda() 153 | #optimize model 154 | loss = self.model(input_ids=inputs_ids, attention_mask=mask, decoder_input_ids=decoder_ids, labels=label_ids)[0] 155 | loss_list.append(loss.item()) 156 | loss = loss / self.accumulation_steps 157 | loss.backward() 158 | if self.args.logging_Euclid_dist: 159 | dist = torch.sum(torch.abs(torch.cat( 160 | [p.view(-1) for n, p in self.model.named_parameters()]) - torch.cat( 161 | [p.view(-1) for n, p in self.pretrained_model.named_parameters()])) ** 2).item() 162 | 163 | dist_sum += dist 164 | dist_num += 1 165 | 166 | if iteration % self.accumulation_steps == 0: 167 | if self.args.recadam: 168 | torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0) 169 | self.optimizer.step() 170 | if self.args.recadam: 171 | self.scheduler.step() 172 | self.model.zero_grad() 173 | loss_list = [np.mean(loss_list)] 174 | 175 | if self.args.logging_Euclid_dist: 176 | # pbar.set_description("(Epoch {}) LOSS: {:.6f} Euclid dist: {:.6f} LR: {:.6f}".format(epoch_i, np.mean(loss_list), dist_sum / dist_num, self.scheduler.get_last_lr()[0])) 177 | pbar.set_description("(Epoch {}) LOSS: {:.6f} Euclid dist: {:.6f}".format(epoch_i, np.mean(loss_list), dist_sum / dist_num)) 178 | else: 179 | pbar.set_description("(Epoch {}) LOSS: {:.6f} LearningRate: {:.10f}".format(epoch_i, np.mean(loss_list), self.optimizer.learning_rate)) 180 | if iteration % args.save_interval == 0: 181 | self.save_model(iteration) 182 | 183 | def shorten_sent(self, sent): 184 | split_sent = sent.split() 185 | if len(split_sent) > 400: 186 | sent = ' '.join(split_sent[:400]) 187 | return sent 188 | 189 | def tokenize(self, sents): 190 | tokenized_text = [self.tokenizer.encode(sent, add_special_tokens=False) for sent in sents] 191 | return tokenized_text 192 | 193 | def save_model(self, iter_num): 194 | print("saving model") 195 | saved_path = os.path.join('DAPT_save/{}_{}.chkpt'.format(args.dm, iter_num)) 196 | torch.save(self.model, saved_path) 197 | 198 | if __name__ == "__main__": 199 | # configuration 200 | parser = argparse.ArgumentParser() 201 | parser.add_argument('-visible_gpu', default='1', type=str) 202 | parser.add_argument('-bsz', type=int, default=4, help="batch size") 203 | parser.add_argument('-path', type=str, default="", help="data path") 204 | parser.add_argument('-epoch', type=int, default=10, help="epoch size") 205 | parser.add_argument('-mask_prob', type=float, default=0.15, help="mask probability") 206 | parser.add_argument('-dm', type=str, default="", help="domain name") 207 | parser.add_argument('-random_seed', type=int, default=0) 208 | parser.add_argument('-save_interval', default=10000, type=int) 209 | # optimizer configuration 210 | parser.add_argument('-lr', default=0.05, type=float) 211 | parser.add_argument('-optim', default='adam', type=str) 212 | parser.add_argument('-max_grad_norm', default=0, type=float) 213 | parser.add_argument('-beta1', default=0.9, type=float) 214 | parser.add_argument('-beta2', default=0.998, type=float) 215 | parser.add_argument('-warmup_steps', default=10000, type=int) 216 | parser.add_argument('-decay_method', default='noam', type=str) 217 | parser.add_argument('-enc_hidden_size', default=768, type=int) 218 | parser.add_argument('-clip', type=float, default=1.0, help="gradient clip") 219 | parser.add_argument('-accum_step', type=int, default=10, help="accumulation steps") 220 | parser.add_argument('-train_from', default='', type=str) 221 | # using RecAdam 222 | parser.add_argument("-adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.") 223 | parser.add_argument('-recadam', default=False, action='store_true') 224 | parser.add_argument("-weight_decay", default=0.0, type=float, help="Weight decay if we apply some.") 225 | parser.add_argument("-anneal_w", type=float, default=1.0, help="Weight for the annealing function in RecAdam. Default 1.0.") 226 | parser.add_argument("-anneal_fun", type=str, default='sigmoid', choices=["sigmoid", "linear", 'constant'], help="the type of annealing function in RecAdam. Default sigmoid") 227 | parser.add_argument("-anneal_t0", type=int, default=1000, help="t0 for the annealing function in RecAdam.") 228 | parser.add_argument("-anneal_k", type=float, default=0.1, help="k for the annealing function in RecAdam.") 229 | parser.add_argument("-pretrain_cof", type=float, default=5000.0, help="Coefficient of the quadratic penalty in RecAdam. Default 5000.0.") 230 | parser.add_argument("-logging_Euclid_dist", action="store_true", help="Whether to log the Euclidean distance between the pretrained model and fine-tuning model") 231 | parser.add_argument("-max_steps", default=-1, type=int, help="If > 0: set total number of training steps to perform. Override num_train_epochs.") 232 | parser.add_argument("-model_type", type=str, default="layers") 233 | 234 | args = parser.parse_args() 235 | 236 | # set random seed 237 | random.seed(args.random_seed) 238 | np.random.seed(args.random_seed) 239 | torch.manual_seed(args.random_seed) 240 | torch.cuda.manual_seed(args.random_seed) 241 | torch.backends.cudnn.deterministic = True 242 | 243 | # set gpu 244 | os.environ["CUDA_VISIBLE_DEVICES"] = args.visible_gpu 245 | 246 | print("Loading datasets ...") 247 | dataset = CorpusDataset(args.path) 248 | dataloader = DataLoader(dataset=dataset, batch_size=args.bsz, shuffle=True) 249 | 250 | if args.train_from: 251 | model = torch.load(args.train_from, map_location='cpu') 252 | else: 253 | model = BartForConditionalGeneration.from_pretrained('facebook/bart-base') 254 | model.cuda() 255 | 256 | tokenizer = BartTokenizer.from_pretrained('facebook/bart-base') 257 | 258 | if args.recadam: 259 | pretrained_model = BartForConditionalGeneration.from_pretrained('facebook/bart-base') 260 | pretrained_model.cuda() 261 | else: 262 | pretrained_model = None 263 | 264 | bart_lm_trainer = BartLMTrainer(model, dataloader, tokenizer, args, pretrained_model) 265 | 266 | bart_lm_trainer.train() -------------------------------------------------------------------------------- /src/tapt_pretraining.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | from torch.utils.data import Dataset, DataLoader 4 | from transformers import BartForConditionalGeneration, BartTokenizer, get_linear_schedule_with_warmup 5 | from others.logging import logger 6 | from others.utils import pad_sents, get_mask 7 | from others.optimizer import build_optim 8 | from tqdm import tqdm 9 | import numpy as np 10 | import argparse 11 | import random 12 | import os 13 | from nltk.tokenize import sent_tokenize 14 | 15 | 16 | def text_infilling(sent, mask_probability=0.05, lamda=3): 17 | ''' 18 | inputs: 19 | sent: a sentence string 20 | mask_probability: probability for masking tokens 21 | lamda: lamda for poission distribution 22 | outputs: 23 | sent: a list of tokens with masked tokens 24 | ''' 25 | sent = sent.split() 26 | length = len(sent) 27 | mask_indices = (np.random.uniform(0, 1, length) < mask_probability) * 1 28 | span_list = np.random.poisson(lamda, length) # lamda for poission distribution 29 | nonzero_idx = np.nonzero(mask_indices)[0] 30 | for item in nonzero_idx: 31 | span = min(span_list[item], 5) # maximum mask 5 continuous tokens 32 | for i in range(span): 33 | if item+i >= length: 34 | continue 35 | mask_indices[item+i] = 1 36 | for i in range(length): 37 | if mask_indices[i] == 1: 38 | sent[i] = '' 39 | 40 | # merge the s to one 41 | final_sent = [] 42 | mask_flag = 0 43 | for word in sent: 44 | if word != '': 45 | mask_flag = 0 46 | final_sent.append(word) 47 | else: 48 | if mask_flag == 0: 49 | final_sent.append(word) 50 | mask_flag = 1 51 | return final_sent 52 | 53 | def sent_permutation(sent): 54 | ''' 55 | inputs: 56 | sent: a sentence string 57 | outputs: 58 | shuffle_sent: a string after sentence permutations 59 | ''' 60 | # split sentences based on '.' 61 | splits = sent_tokenize(sent) 62 | random.shuffle(splits) 63 | 64 | return " ".join(splits) 65 | 66 | 67 | def add_noise(sents, mask_probability): 68 | noisy_sent_list = [] 69 | for sent in sents: 70 | noisy_sent = sent_permutation(sent) 71 | noisy_sent = text_infilling(noisy_sent, mask_probability) 72 | 73 | noisy_sent = " ".join(noisy_sent) 74 | noisy_sent_list.append(noisy_sent) 75 | 76 | return noisy_sent_list 77 | 78 | 79 | class CorpusDataset(Dataset): 80 | def __init__(self, data_path, denoising_flag=False): 81 | self.data = [] 82 | with open(data_path, "r", ) as f: 83 | for i, line in enumerate(f): 84 | line = line.strip() 85 | if denoising_flag: 86 | line = "denoising: " + line 87 | self.data.append(line) # append a list of tokens each time 88 | 89 | def __len__(self): 90 | return len(self.data) 91 | 92 | def __getitem__(self, idx): 93 | return self.data[idx] 94 | 95 | 96 | class BartLMTrainer(object): 97 | def __init__(self, model, dataloader, tokenizer, args, pretrained_model=None): 98 | self.args = args 99 | self.model = model 100 | self.pretrained_model = pretrained_model 101 | self.optimizer = build_optim(args, model, None, pretrained_model) 102 | self.dataloader = dataloader 103 | self.tokenizer = tokenizer 104 | self.epoch = args.epoch 105 | self.mask_probability = args.mask_prob 106 | self.accumulation_steps = args.accum_step 107 | self.clip = args.clip 108 | self.domain = args.dm 109 | self.path = args.path 110 | if args.recadam: 111 | if args.max_steps > 0: 112 | t_total = args.max_steps 113 | self.epoch = args.max_steps // (len(self.dataloader) // self.accumulation_steps) + 1 114 | else: 115 | t_total = len(self.dataloader) // self.accumulation_steps * self.epoch 116 | self.scheduler = get_linear_schedule_with_warmup(self.optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total) 117 | 118 | def train(self): 119 | print('Start finetuning BART language model') 120 | iteration = 0 121 | for epoch_i in range(self.epoch): 122 | self.model.train() 123 | if self.pretrained_model is not None: 124 | self.pretrained_model.eval() 125 | print('[ Epoch : {}]'.format(epoch_i)) 126 | loss_list = [] 127 | dist_sum, dist_num = 0.0, 0 128 | pbar = tqdm(self.dataloader, total=len(self.dataloader)) 129 | for sents in pbar: 130 | sents = [self.shorten_sent(sent) for sent in sents] 131 | iteration += 1 132 | tokenized_sents = self.tokenize(sents) 133 | decoder_ids = [[self.tokenizer.bos_token_id] + item for item in tokenized_sents] 134 | label_ids = [item + [self.tokenizer.eos_token_id] for item in tokenized_sents] 135 | # print("before:") 136 | # print(sents[0]) 137 | # print("tokenized sents:") 138 | # print(tokenized_sents[0]) 139 | # sents: a list of sentence, each item inside is a string 140 | noisy_text = add_noise(sents, self.mask_probability) 141 | # noisy_text: a list of sentence, each item inside is a string 142 | # print("after:") 143 | # print(noisy_text[0]) 144 | inputs_ids = self.tokenize(noisy_text) 145 | # print("tokenized noisy text:") 146 | # print(inputs_ids[0]) 147 | 148 | # prepare data for training 149 | mask = torch.tensor(get_mask(inputs_ids, max_len=512)).cuda() 150 | inputs_ids = torch.tensor(pad_sents(inputs_ids, pad_token=self.tokenizer.pad_token_id, max_len=512)[0]).cuda() 151 | decoder_ids = torch.tensor(pad_sents(decoder_ids, pad_token=self.tokenizer.pad_token_id, max_len=512)[0]).cuda() 152 | label_ids = torch.tensor(pad_sents(label_ids, pad_token=-100, max_len=512)[0]).cuda() 153 | #optimize model 154 | loss = self.model(input_ids=inputs_ids, attention_mask=mask, decoder_input_ids=decoder_ids, labels=label_ids)[0] 155 | loss_list.append(loss.item()) 156 | loss = loss / self.accumulation_steps 157 | loss.backward() 158 | if self.args.logging_Euclid_dist: 159 | dist = torch.sum(torch.abs(torch.cat( 160 | [p.view(-1) for n, p in self.model.named_parameters()]) - torch.cat( 161 | [p.view(-1) for n, p in self.pretrained_model.named_parameters()])) ** 2).item() 162 | 163 | dist_sum += dist 164 | dist_num += 1 165 | 166 | if iteration % self.accumulation_steps == 0: 167 | if self.args.recadam: 168 | torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0) 169 | self.optimizer.step() 170 | if self.args.recadam: 171 | self.scheduler.step() 172 | self.model.zero_grad() 173 | loss_list = [np.mean(loss_list)] 174 | 175 | if self.args.logging_Euclid_dist: 176 | # pbar.set_description("(Epoch {}) LOSS: {:.6f} Euclid dist: {:.6f} LR: {:.6f}".format(epoch_i, np.mean(loss_list), dist_sum / dist_num, self.scheduler.get_last_lr()[0])) 177 | pbar.set_description("(Epoch {}) LOSS: {:.6f} Euclid dist: {:.6f}".format(epoch_i, np.mean(loss_list), dist_sum / dist_num)) 178 | else: 179 | pbar.set_description("(Epoch {}) LOSS: {:.6f} LearningRate: {:.10f}".format(epoch_i, np.mean(loss_list), self.optimizer.learning_rate)) 180 | if iteration % args.save_interval == 0: 181 | self.save_model(iteration) 182 | 183 | def shorten_sent(self, sent): 184 | split_sent = sent.split() 185 | if len(split_sent) > 400: 186 | sent = ' '.join(split_sent[:400]) 187 | return sent 188 | 189 | def tokenize(self, sents): 190 | tokenized_text = [self.tokenizer.encode(sent, add_special_tokens=False) for sent in sents] 191 | return tokenized_text 192 | 193 | def save_model(self, iter_num): 194 | print("saving model") 195 | saved_path = os.path.join('TAPT_save/{}_{}.chkpt'.format(args.dm, iter_num)) 196 | torch.save(self.model, saved_path) 197 | 198 | if __name__ == "__main__": 199 | # configuration 200 | parser = argparse.ArgumentParser() 201 | parser.add_argument('-visible_gpu', default='1', type=str) 202 | parser.add_argument('-bsz', type=int, default=4, help="batch size") 203 | parser.add_argument('-path', type=str, default="", help="data path") 204 | parser.add_argument('-epoch', type=int, default=10, help="epoch size") 205 | parser.add_argument('-mask_prob', type=float, default=0.15, help="mask probability") 206 | parser.add_argument('-dm', type=str, default="", help="domain name") 207 | parser.add_argument('-random_seed', type=int, default=0) 208 | parser.add_argument('-save_interval', default=10000, type=int) 209 | # optimizer configuration 210 | parser.add_argument('-lr', default=0.05, type=float) 211 | parser.add_argument('-optim', default='adam', type=str) 212 | parser.add_argument('-max_grad_norm', default=0, type=float) 213 | parser.add_argument('-beta1', default=0.9, type=float) 214 | parser.add_argument('-beta2', default=0.998, type=float) 215 | parser.add_argument('-warmup_steps', default=10000, type=int) 216 | parser.add_argument('-decay_method', default='noam', type=str) 217 | parser.add_argument('-enc_hidden_size', default=768, type=int) 218 | parser.add_argument('-clip', type=float, default=1.0, help="gradient clip") 219 | parser.add_argument('-accum_step', type=int, default=10, help="accumulation steps") 220 | parser.add_argument('-train_from', default='', type=str) 221 | # using RecAdam 222 | parser.add_argument("-adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.") 223 | parser.add_argument('-recadam', default=False, action='store_true') 224 | parser.add_argument("-weight_decay", default=0.0, type=float, help="Weight decay if we apply some.") 225 | parser.add_argument("-anneal_w", type=float, default=1.0, help="Weight for the annealing function in RecAdam. Default 1.0.") 226 | parser.add_argument("-anneal_fun", type=str, default='sigmoid', choices=["sigmoid", "linear", 'constant'], help="the type of annealing function in RecAdam. Default sigmoid") 227 | parser.add_argument("-anneal_t0", type=int, default=1000, help="t0 for the annealing function in RecAdam.") 228 | parser.add_argument("-anneal_k", type=float, default=0.1, help="k for the annealing function in RecAdam.") 229 | parser.add_argument("-pretrain_cof", type=float, default=5000.0, help="Coefficient of the quadratic penalty in RecAdam. Default 5000.0.") 230 | parser.add_argument("-logging_Euclid_dist", action="store_true", help="Whether to log the Euclidean distance between the pretrained model and fine-tuning model") 231 | parser.add_argument("-max_steps", default=-1, type=int, help="If > 0: set total number of training steps to perform. Override num_train_epochs.") 232 | parser.add_argument("-model_type", type=str, default="layers") 233 | 234 | args = parser.parse_args() 235 | 236 | # set random seed 237 | random.seed(args.random_seed) 238 | np.random.seed(args.random_seed) 239 | torch.manual_seed(args.random_seed) 240 | torch.cuda.manual_seed(args.random_seed) 241 | torch.backends.cudnn.deterministic = True 242 | 243 | # set gpu 244 | os.environ["CUDA_VISIBLE_DEVICES"] = args.visible_gpu 245 | 246 | print("Loading datasets ...") 247 | dataset = CorpusDataset(args.path) 248 | dataloader = DataLoader(dataset=dataset, batch_size=args.bsz, shuffle=True) 249 | 250 | if args.train_from: 251 | model = torch.load(args.train_from, map_location='cpu') 252 | else: 253 | model = BartForConditionalGeneration.from_pretrained('facebook/bart-base') 254 | model.cuda() 255 | 256 | tokenizer = BartTokenizer.from_pretrained('facebook/bart-base') 257 | 258 | if args.recadam: 259 | pretrained_model = BartForConditionalGeneration.from_pretrained('facebook/bart-base') 260 | pretrained_model.cuda() 261 | else: 262 | pretrained_model = None 263 | 264 | bart_lm_trainer = BartLMTrainer(model, dataloader, tokenizer, args, pretrained_model) 265 | 266 | bart_lm_trainer.train() -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Attribution 4.0 International 2 | 3 | ======================================================================= 4 | 5 | Creative Commons Corporation ("Creative Commons") is not a law firm and 6 | does not provide legal services or legal advice. Distribution of 7 | Creative Commons public licenses does not create a lawyer-client or 8 | other relationship. Creative Commons makes its licenses and related 9 | information available on an "as-is" basis. Creative Commons gives no 10 | warranties regarding its licenses, any material licensed under their 11 | terms and conditions, or any related information. Creative Commons 12 | disclaims all liability for damages resulting from their use to the 13 | fullest extent possible. 14 | 15 | Using Creative Commons Public Licenses 16 | 17 | Creative Commons public licenses provide a standard set of terms and 18 | conditions that creators and other rights holders may use to share 19 | original works of authorship and other material subject to copyright 20 | and certain other rights specified in the public license below. The 21 | following considerations are for informational purposes only, are not 22 | exhaustive, and do not form part of our licenses. 23 | 24 | Considerations for licensors: Our public licenses are 25 | intended for use by those authorized to give the public 26 | permission to use material in ways otherwise restricted by 27 | copyright and certain other rights. Our licenses are 28 | irrevocable. Licensors should read and understand the terms 29 | and conditions of the license they choose before applying it. 30 | Licensors should also secure all rights necessary before 31 | applying our licenses so that the public can reuse the 32 | material as expected. Licensors should clearly mark any 33 | material not subject to the license. This includes other CC- 34 | licensed material, or material used under an exception or 35 | limitation to copyright. More considerations for licensors: 36 | wiki.creativecommons.org/Considerations_for_licensors 37 | 38 | Considerations for the public: By using one of our public 39 | licenses, a licensor grants the public permission to use the 40 | licensed material under specified terms and conditions. If 41 | the licensor's permission is not necessary for any reason--for 42 | example, because of any applicable exception or limitation to 43 | copyright--then that use is not regulated by the license. Our 44 | licenses grant only permissions under copyright and certain 45 | other rights that a licensor has authority to grant. Use of 46 | the licensed material may still be restricted for other 47 | reasons, including because others have copyright or other 48 | rights in the material. A licensor may make special requests, 49 | such as asking that all changes be marked or described. 50 | Although not required by our licenses, you are encouraged to 51 | respect those requests where reasonable. More_considerations 52 | for the public: 53 | wiki.creativecommons.org/Considerations_for_licensees 54 | 55 | ======================================================================= 56 | 57 | Creative Commons Attribution 4.0 International Public License 58 | 59 | By exercising the Licensed Rights (defined below), You accept and agree 60 | to be bound by the terms and conditions of this Creative Commons 61 | Attribution 4.0 International Public License ("Public License"). To the 62 | extent this Public License may be interpreted as a contract, You are 63 | granted the Licensed Rights in consideration of Your acceptance of 64 | these terms and conditions, and the Licensor grants You such rights in 65 | consideration of benefits the Licensor receives from making the 66 | Licensed Material available under these terms and conditions. 67 | 68 | 69 | Section 1 -- Definitions. 70 | 71 | a. Adapted Material means material subject to Copyright and Similar 72 | Rights that is derived from or based upon the Licensed Material 73 | and in which the Licensed Material is translated, altered, 74 | arranged, transformed, or otherwise modified in a manner requiring 75 | permission under the Copyright and Similar Rights held by the 76 | Licensor. For purposes of this Public License, where the Licensed 77 | Material is a musical work, performance, or sound recording, 78 | Adapted Material is always produced where the Licensed Material is 79 | synched in timed relation with a moving image. 80 | 81 | b. Adapter's License means the license You apply to Your Copyright 82 | and Similar Rights in Your contributions to Adapted Material in 83 | accordance with the terms and conditions of this Public License. 84 | 85 | c. Copyright and Similar Rights means copyright and/or similar rights 86 | closely related to copyright including, without limitation, 87 | performance, broadcast, sound recording, and Sui Generis Database 88 | Rights, without regard to how the rights are labeled or 89 | categorized. For purposes of this Public License, the rights 90 | specified in Section 2(b)(1)-(2) are not Copyright and Similar 91 | Rights. 92 | 93 | d. Effective Technological Measures means those measures that, in the 94 | absence of proper authority, may not be circumvented under laws 95 | fulfilling obligations under Article 11 of the WIPO Copyright 96 | Treaty adopted on December 20, 1996, and/or similar international 97 | agreements. 98 | 99 | e. Exceptions and Limitations means fair use, fair dealing, and/or 100 | any other exception or limitation to Copyright and Similar Rights 101 | that applies to Your use of the Licensed Material. 102 | 103 | f. Licensed Material means the artistic or literary work, database, 104 | or other material to which the Licensor applied this Public 105 | License. 106 | 107 | g. Licensed Rights means the rights granted to You subject to the 108 | terms and conditions of this Public License, which are limited to 109 | all Copyright and Similar Rights that apply to Your use of the 110 | Licensed Material and that the Licensor has authority to license. 111 | 112 | h. Licensor means the individual(s) or entity(ies) granting rights 113 | under this Public License. 114 | 115 | i. Share means to provide material to the public by any means or 116 | process that requires permission under the Licensed Rights, such 117 | as reproduction, public display, public performance, distribution, 118 | dissemination, communication, or importation, and to make material 119 | available to the public including in ways that members of the 120 | public may access the material from a place and at a time 121 | individually chosen by them. 122 | 123 | j. Sui Generis Database Rights means rights other than copyright 124 | resulting from Directive 96/9/EC of the European Parliament and of 125 | the Council of 11 March 1996 on the legal protection of databases, 126 | as amended and/or succeeded, as well as other essentially 127 | equivalent rights anywhere in the world. 128 | 129 | k. You means the individual or entity exercising the Licensed Rights 130 | under this Public License. Your has a corresponding meaning. 131 | 132 | 133 | Section 2 -- Scope. 134 | 135 | a. License grant. 136 | 137 | 1. Subject to the terms and conditions of this Public License, 138 | the Licensor hereby grants You a worldwide, royalty-free, 139 | non-sublicensable, non-exclusive, irrevocable license to 140 | exercise the Licensed Rights in the Licensed Material to: 141 | 142 | a. reproduce and Share the Licensed Material, in whole or 143 | in part; and 144 | 145 | b. produce, reproduce, and Share Adapted Material. 146 | 147 | 2. Exceptions and Limitations. For the avoidance of doubt, where 148 | Exceptions and Limitations apply to Your use, this Public 149 | License does not apply, and You do not need to comply with 150 | its terms and conditions. 151 | 152 | 3. Term. The term of this Public License is specified in Section 153 | 6(a). 154 | 155 | 4. Media and formats; technical modifications allowed. The 156 | Licensor authorizes You to exercise the Licensed Rights in 157 | all media and formats whether now known or hereafter created, 158 | and to make technical modifications necessary to do so. The 159 | Licensor waives and/or agrees not to assert any right or 160 | authority to forbid You from making technical modifications 161 | necessary to exercise the Licensed Rights, including 162 | technical modifications necessary to circumvent Effective 163 | Technological Measures. For purposes of this Public License, 164 | simply making modifications authorized by this Section 2(a) 165 | (4) never produces Adapted Material. 166 | 167 | 5. Downstream recipients. 168 | 169 | a. Offer from the Licensor -- Licensed Material. Every 170 | recipient of the Licensed Material automatically 171 | receives an offer from the Licensor to exercise the 172 | Licensed Rights under the terms and conditions of this 173 | Public License. 174 | 175 | b. No downstream restrictions. You may not offer or impose 176 | any additional or different terms or conditions on, or 177 | apply any Effective Technological Measures to, the 178 | Licensed Material if doing so restricts exercise of the 179 | Licensed Rights by any recipient of the Licensed 180 | Material. 181 | 182 | 6. No endorsement. Nothing in this Public License constitutes or 183 | may be construed as permission to assert or imply that You 184 | are, or that Your use of the Licensed Material is, connected 185 | with, or sponsored, endorsed, or granted official status by, 186 | the Licensor or others designated to receive attribution as 187 | provided in Section 3(a)(1)(A)(i). 188 | 189 | b. Other rights. 190 | 191 | 1. Moral rights, such as the right of integrity, are not 192 | licensed under this Public License, nor are publicity, 193 | privacy, and/or other similar personality rights; however, to 194 | the extent possible, the Licensor waives and/or agrees not to 195 | assert any such rights held by the Licensor to the limited 196 | extent necessary to allow You to exercise the Licensed 197 | Rights, but not otherwise. 198 | 199 | 2. Patent and trademark rights are not licensed under this 200 | Public License. 201 | 202 | 3. To the extent possible, the Licensor waives any right to 203 | collect royalties from You for the exercise of the Licensed 204 | Rights, whether directly or through a collecting society 205 | under any voluntary or waivable statutory or compulsory 206 | licensing scheme. In all other cases the Licensor expressly 207 | reserves any right to collect such royalties. 208 | 209 | 210 | Section 3 -- License Conditions. 211 | 212 | Your exercise of the Licensed Rights is expressly made subject to the 213 | following conditions. 214 | 215 | a. Attribution. 216 | 217 | 1. If You Share the Licensed Material (including in modified 218 | form), You must: 219 | 220 | a. retain the following if it is supplied by the Licensor 221 | with the Licensed Material: 222 | 223 | i. identification of the creator(s) of the Licensed 224 | Material and any others designated to receive 225 | attribution, in any reasonable manner requested by 226 | the Licensor (including by pseudonym if 227 | designated); 228 | 229 | ii. a copyright notice; 230 | 231 | iii. a notice that refers to this Public License; 232 | 233 | iv. a notice that refers to the disclaimer of 234 | warranties; 235 | 236 | v. a URI or hyperlink to the Licensed Material to the 237 | extent reasonably practicable; 238 | 239 | b. indicate if You modified the Licensed Material and 240 | retain an indication of any previous modifications; and 241 | 242 | c. indicate the Licensed Material is licensed under this 243 | Public License, and include the text of, or the URI or 244 | hyperlink to, this Public License. 245 | 246 | 2. You may satisfy the conditions in Section 3(a)(1) in any 247 | reasonable manner based on the medium, means, and context in 248 | which You Share the Licensed Material. For example, it may be 249 | reasonable to satisfy the conditions by providing a URI or 250 | hyperlink to a resource that includes the required 251 | information. 252 | 253 | 3. If requested by the Licensor, You must remove any of the 254 | information required by Section 3(a)(1)(A) to the extent 255 | reasonably practicable. 256 | 257 | 4. If You Share Adapted Material You produce, the Adapter's 258 | License You apply must not prevent recipients of the Adapted 259 | Material from complying with this Public License. 260 | 261 | 262 | Section 4 -- Sui Generis Database Rights. 263 | 264 | Where the Licensed Rights include Sui Generis Database Rights that 265 | apply to Your use of the Licensed Material: 266 | 267 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right 268 | to extract, reuse, reproduce, and Share all or a substantial 269 | portion of the contents of the database; 270 | 271 | b. if You include all or a substantial portion of the database 272 | contents in a database in which You have Sui Generis Database 273 | Rights, then the database in which You have Sui Generis Database 274 | Rights (but not its individual contents) is Adapted Material; and 275 | 276 | c. You must comply with the conditions in Section 3(a) if You Share 277 | all or a substantial portion of the contents of the database. 278 | 279 | For the avoidance of doubt, this Section 4 supplements and does not 280 | replace Your obligations under this Public License where the Licensed 281 | Rights include other Copyright and Similar Rights. 282 | 283 | 284 | Section 5 -- Disclaimer of Warranties and Limitation of Liability. 285 | 286 | a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE 287 | EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS 288 | AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF 289 | ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS, 290 | IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION, 291 | WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR 292 | PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS, 293 | ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT 294 | KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT 295 | ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU. 296 | 297 | b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE 298 | TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION, 299 | NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT, 300 | INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES, 301 | COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR 302 | USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN 303 | ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR 304 | DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR 305 | IN PART, THIS LIMITATION MAY NOT APPLY TO YOU. 306 | 307 | c. The disclaimer of warranties and limitation of liability provided 308 | above shall be interpreted in a manner that, to the extent 309 | possible, most closely approximates an absolute disclaimer and 310 | waiver of all liability. 311 | 312 | 313 | Section 6 -- Term and Termination. 314 | 315 | a. This Public License applies for the term of the Copyright and 316 | Similar Rights licensed here. However, if You fail to comply with 317 | this Public License, then Your rights under this Public License 318 | terminate automatically. 319 | 320 | b. Where Your right to use the Licensed Material has terminated under 321 | Section 6(a), it reinstates: 322 | 323 | 1. automatically as of the date the violation is cured, provided 324 | it is cured within 30 days of Your discovery of the 325 | violation; or 326 | 327 | 2. upon express reinstatement by the Licensor. 328 | 329 | For the avoidance of doubt, this Section 6(b) does not affect any 330 | right the Licensor may have to seek remedies for Your violations 331 | of this Public License. 332 | 333 | c. For the avoidance of doubt, the Licensor may also offer the 334 | Licensed Material under separate terms or conditions or stop 335 | distributing the Licensed Material at any time; however, doing so 336 | will not terminate this Public License. 337 | 338 | d. Sections 1, 5, 6, 7, and 8 survive termination of this Public 339 | License. 340 | 341 | 342 | Section 7 -- Other Terms and Conditions. 343 | 344 | a. The Licensor shall not be bound by any additional or different 345 | terms or conditions communicated by You unless expressly agreed. 346 | 347 | b. Any arrangements, understandings, or agreements regarding the 348 | Licensed Material not stated herein are separate from and 349 | independent of the terms and conditions of this Public License. 350 | 351 | 352 | Section 8 -- Interpretation. 353 | 354 | a. For the avoidance of doubt, this Public License does not, and 355 | shall not be interpreted to, reduce, limit, restrict, or impose 356 | conditions on any use of the Licensed Material that could lawfully 357 | be made without permission under this Public License. 358 | 359 | b. To the extent possible, if any provision of this Public License is 360 | deemed unenforceable, it shall be automatically reformed to the 361 | minimum extent necessary to make it enforceable. If the provision 362 | cannot be reformed, it shall be severed from this Public License 363 | without affecting the enforceability of the remaining terms and 364 | conditions. 365 | 366 | c. No term or condition of this Public License will be waived and no 367 | failure to comply consented to unless expressly agreed to by the 368 | Licensor. 369 | 370 | d. Nothing in this Public License constitutes or may be interpreted 371 | as a limitation upon, or waiver of, any privileges and immunities 372 | that apply to the Licensor or You, including from the legal 373 | processes of any jurisdiction or authority. 374 | 375 | 376 | ======================================================================= 377 | 378 | Creative Commons is not a party to its public 379 | licenses. Notwithstanding, Creative Commons may elect to apply one of 380 | its public licenses to material it publishes and in those instances 381 | will be considered the “Licensor.” The text of the Creative Commons 382 | public licenses is dedicated to the public domain under the CC0 Public 383 | Domain Dedication. Except for the limited purpose of indicating that 384 | material is shared under a Creative Commons public license or as 385 | otherwise permitted by the Creative Commons policies published at 386 | creativecommons.org/policies, Creative Commons does not authorize the 387 | use of the trademark "Creative Commons" or any other trademark or logo 388 | of Creative Commons without its prior written consent including, 389 | without limitation, in connection with any unauthorized modifications 390 | to any of its public licenses or any other arrangements, 391 | understandings, or agreements concerning use of licensed material. For 392 | the avoidance of doubt, this paragraph does not form part of the 393 | public licenses. 394 | 395 | Creative Commons may be contacted at creativecommons.org. 396 | -------------------------------------------------------------------------------- /src/others/my_pyrouge.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, unicode_literals, division 2 | 3 | import os 4 | import re 5 | import codecs 6 | import platform 7 | 8 | from subprocess import check_output 9 | from tempfile import mkdtemp 10 | from functools import partial 11 | 12 | try: 13 | from configparser import ConfigParser 14 | except ImportError: 15 | from ConfigParser import ConfigParser 16 | 17 | from pyrouge.utils import log 18 | from pyrouge.utils.file_utils import verify_dir 19 | 20 | 21 | class DirectoryProcessor: 22 | 23 | @staticmethod 24 | def process(input_dir, output_dir, function): 25 | """ 26 | Apply function to all files in input_dir and save the resulting ouput 27 | files in output_dir. 28 | 29 | """ 30 | if not os.path.exists(output_dir): 31 | os.makedirs(output_dir) 32 | logger = log.get_global_console_logger() 33 | logger.info("Processing files in {}.".format(input_dir)) 34 | input_file_names = os.listdir(input_dir) 35 | for input_file_name in input_file_names: 36 | input_file = os.path.join(input_dir, input_file_name) 37 | with codecs.open(input_file, "r", encoding="UTF-8") as f: 38 | input_string = f.read() 39 | output_string = function(input_string) 40 | output_file = os.path.join(output_dir, input_file_name) 41 | with codecs.open(output_file, "w", encoding="UTF-8") as f: 42 | f.write(output_string.lower()) 43 | logger.info("Saved processed files to {}.".format(output_dir)) 44 | 45 | 46 | class Rouge155(object): 47 | """ 48 | This is a wrapper for the ROUGE 1.5.5 summary evaluation package. 49 | This class is designed to simplify the evaluation process by: 50 | 51 | 1) Converting summaries into a format ROUGE understands. 52 | 2) Generating the ROUGE configuration file automatically based 53 | on filename patterns. 54 | 55 | This class can be used within Python like this: 56 | 57 | rouge = Rouge155() 58 | rouge.system_dir = 'test/systems' 59 | rouge.model_dir = 'test/models' 60 | 61 | # The system filename pattern should contain one group that 62 | # matches the document ID. 63 | rouge.system_filename_pattern = 'SL.P.10.R.11.SL062003-(\d+).html' 64 | 65 | # The model filename pattern has '#ID#' as a placeholder for the 66 | # document ID. If there are multiple model summaries, pyrouge 67 | # will use the provided regex to automatically match them with 68 | # the corresponding system summary. Here, [A-Z] matches 69 | # multiple model summaries for a given #ID#. 70 | rouge.model_filename_pattern = 'SL.P.10.R.[A-Z].SL062003-#ID#.html' 71 | 72 | rouge_output = rouge.evaluate() 73 | print(rouge_output) 74 | output_dict = rouge.output_to_dict(rouge_ouput) 75 | print(output_dict) 76 | -> {'rouge_1_f_score': 0.95652, 77 | 'rouge_1_f_score_cb': 0.95652, 78 | 'rouge_1_f_score_ce': 0.95652, 79 | 'rouge_1_precision': 0.95652, 80 | [...] 81 | 82 | 83 | To evaluate multiple systems: 84 | 85 | rouge = Rouge155() 86 | rouge.system_dir = '/PATH/TO/systems' 87 | rouge.model_dir = 'PATH/TO/models' 88 | for system_id in ['id1', 'id2', 'id3']: 89 | rouge.system_filename_pattern = \ 90 | 'SL.P/.10.R.{}.SL062003-(\d+).html'.format(system_id) 91 | rouge.model_filename_pattern = \ 92 | 'SL.P.10.R.[A-Z].SL062003-#ID#.html' 93 | rouge_output = rouge.evaluate(system_id) 94 | print(rouge_output) 95 | 96 | """ 97 | 98 | def __init__(self, rouge_dir=None, rouge_args=None): 99 | """ 100 | Create a Rouge155 object. 101 | 102 | rouge_dir: Directory containing Rouge-1.5.5.pl 103 | rouge_args: Arguments to pass through to ROUGE if you 104 | don't want to use the default pyrouge 105 | arguments. 106 | 107 | """ 108 | self.log = log.get_global_console_logger() 109 | self.__set_dir_properties() 110 | self._config_file = None 111 | self._settings_file = self.__get_config_path() 112 | self.__set_rouge_dir(rouge_dir) 113 | self.args = self.__clean_rouge_args(rouge_args) 114 | self._system_filename_pattern = None 115 | self._model_filename_pattern = None 116 | 117 | def save_home_dir(self): 118 | config = ConfigParser() 119 | section = 'pyrouge settings' 120 | config.add_section(section) 121 | config.set(section, 'home_dir', self._home_dir) 122 | with open(self._settings_file, 'w') as f: 123 | config.write(f) 124 | self.log.info("Set ROUGE home directory to {}.".format(self._home_dir)) 125 | 126 | @property 127 | def settings_file(self): 128 | """ 129 | Path of the setttings file, which stores the ROUGE home dir. 130 | 131 | """ 132 | return self._settings_file 133 | 134 | @property 135 | def bin_path(self): 136 | """ 137 | The full path of the ROUGE binary (although it's technically 138 | a script), i.e. rouge_home_dir/ROUGE-1.5.5.pl 139 | 140 | """ 141 | if self._bin_path is None: 142 | raise Exception( 143 | "ROUGE path not set. Please set the ROUGE home directory " 144 | "and ensure that ROUGE-1.5.5.pl exists in it.") 145 | return self._bin_path 146 | 147 | @property 148 | def system_filename_pattern(self): 149 | """ 150 | The regular expression pattern for matching system summary 151 | filenames. The regex string. 152 | 153 | E.g. "SL.P.10.R.11.SL062003-(\d+).html" will match the system 154 | filenames in the SPL2003/system folder of the ROUGE SPL example 155 | in the "sample-test" folder. 156 | 157 | Currently, there is no support for multiple systems. 158 | 159 | """ 160 | return self._system_filename_pattern 161 | 162 | @system_filename_pattern.setter 163 | def system_filename_pattern(self, pattern): 164 | self._system_filename_pattern = pattern 165 | 166 | @property 167 | def model_filename_pattern(self): 168 | """ 169 | The regular expression pattern for matching model summary 170 | filenames. The pattern needs to contain the string "#ID#", 171 | which is a placeholder for the document ID. 172 | 173 | E.g. "SL.P.10.R.[A-Z].SL062003-#ID#.html" will match the model 174 | filenames in the SPL2003/system folder of the ROUGE SPL 175 | example in the "sample-test" folder. 176 | 177 | "#ID#" is a placeholder for the document ID which has been 178 | matched by the "(\d+)" part of the system filename pattern. 179 | The different model summaries for a given document ID are 180 | matched by the "[A-Z]" part. 181 | 182 | """ 183 | return self._model_filename_pattern 184 | 185 | @model_filename_pattern.setter 186 | def model_filename_pattern(self, pattern): 187 | self._model_filename_pattern = pattern 188 | 189 | @property 190 | def config_file(self): 191 | return self._config_file 192 | 193 | @config_file.setter 194 | def config_file(self, path): 195 | config_dir, _ = os.path.split(path) 196 | verify_dir(config_dir, "configuration file") 197 | self._config_file = path 198 | 199 | def split_sentences(self): 200 | """ 201 | ROUGE requires texts split into sentences. In case the texts 202 | are not already split, this method can be used. 203 | 204 | """ 205 | from pyrouge.utils.sentence_splitter import PunktSentenceSplitter 206 | self.log.info("Splitting sentences.") 207 | ss = PunktSentenceSplitter() 208 | sent_split_to_string = lambda s: "\n".join(ss.split(s)) 209 | process_func = partial( 210 | DirectoryProcessor.process, function=sent_split_to_string) 211 | self.__process_summaries(process_func) 212 | 213 | @staticmethod 214 | def convert_summaries_to_rouge_format(input_dir, output_dir): 215 | """ 216 | Convert all files in input_dir into a format ROUGE understands 217 | and saves the files to output_dir. The input files are assumed 218 | to be plain text with one sentence per line. 219 | 220 | input_dir: Path of directory containing the input files. 221 | output_dir: Path of directory in which the converted files 222 | will be saved. 223 | 224 | """ 225 | DirectoryProcessor.process( 226 | input_dir, output_dir, Rouge155.convert_text_to_rouge_format) 227 | 228 | @staticmethod 229 | def convert_text_to_rouge_format(text, title="dummy title"): 230 | """ 231 | Convert a text to a format ROUGE understands. The text is 232 | assumed to contain one sentence per line. 233 | 234 | text: The text to convert, containg one sentence per line. 235 | title: Optional title for the text. The title will appear 236 | in the converted file, but doesn't seem to have 237 | any other relevance. 238 | 239 | Returns: The converted text as string. 240 | 241 | """ 242 | sentences = text.split("\n") 243 | sent_elems = [ 244 | "[{i}] " 245 | "{text}".format(i=i, text=sent) 246 | for i, sent in enumerate(sentences, start=1)] 247 | html = """ 248 | 249 | {title} 250 | 251 | 252 | {elems} 253 | 254 | """.format(title=title, elems="\n".join(sent_elems)) 255 | 256 | return html 257 | 258 | @staticmethod 259 | def write_config_static(system_dir, system_filename_pattern, 260 | model_dir, model_filename_pattern, 261 | config_file_path, system_id=None): 262 | """ 263 | Write the ROUGE configuration file, which is basically a list 264 | of system summary files and their corresponding model summary 265 | files. 266 | 267 | pyrouge uses regular expressions to automatically find the 268 | matching model summary files for a given system summary file 269 | (cf. docstrings for system_filename_pattern and 270 | model_filename_pattern). 271 | 272 | system_dir: Path of directory containing 273 | system summaries. 274 | system_filename_pattern: Regex string for matching 275 | system summary filenames. 276 | model_dir: Path of directory containing 277 | model summaries. 278 | model_filename_pattern: Regex string for matching model 279 | summary filenames. 280 | config_file_path: Path of the configuration file. 281 | system_id: Optional system ID string which 282 | will appear in the ROUGE output. 283 | 284 | """ 285 | system_filenames = [f for f in os.listdir(system_dir)] 286 | system_models_tuples = [] 287 | 288 | system_filename_pattern = re.compile(system_filename_pattern) 289 | for system_filename in sorted(system_filenames): 290 | match = system_filename_pattern.match(system_filename) 291 | if match: 292 | id = match.groups(0)[0] 293 | model_filenames = [model_filename_pattern.replace('#ID#',id)] 294 | # model_filenames = Rouge155.__get_model_filenames_for_id( 295 | # id, model_dir, model_filename_pattern) 296 | system_models_tuples.append( 297 | (system_filename, sorted(model_filenames))) 298 | if not system_models_tuples: 299 | raise Exception( 300 | "Did not find any files matching the pattern {} " 301 | "in the system summaries directory {}.".format( 302 | system_filename_pattern.pattern, system_dir)) 303 | 304 | with codecs.open(config_file_path, 'w', encoding='utf-8') as f: 305 | f.write('') 306 | for task_id, (system_filename, model_filenames) in enumerate( 307 | system_models_tuples, start=1): 308 | 309 | eval_string = Rouge155.__get_eval_string( 310 | task_id, system_id, 311 | system_dir, system_filename, 312 | model_dir, model_filenames) 313 | f.write(eval_string) 314 | f.write("") 315 | 316 | def write_config(self, config_file_path=None, system_id=None): 317 | """ 318 | Write the ROUGE configuration file, which is basically a list 319 | of system summary files and their matching model summary files. 320 | 321 | This is a non-static version of write_config_file_static(). 322 | 323 | config_file_path: Path of the configuration file. 324 | system_id: Optional system ID string which will 325 | appear in the ROUGE output. 326 | 327 | """ 328 | if not system_id: 329 | system_id = 1 330 | if (not config_file_path) or (not self._config_dir): 331 | self._config_dir = mkdtemp() 332 | config_filename = "rouge_conf.xml" 333 | else: 334 | config_dir, config_filename = os.path.split(config_file_path) 335 | verify_dir(config_dir, "configuration file") 336 | self._config_file = os.path.join(self._config_dir, config_filename) 337 | Rouge155.write_config_static( 338 | self._system_dir, self._system_filename_pattern, 339 | self._model_dir, self._model_filename_pattern, 340 | self._config_file, system_id) 341 | self.log.info( 342 | "Written ROUGE configuration to {}".format(self._config_file)) 343 | 344 | def evaluate(self, system_id=1, rouge_args=None): 345 | """ 346 | Run ROUGE to evaluate the system summaries in system_dir against 347 | the model summaries in model_dir. The summaries are assumed to 348 | be in the one-sentence-per-line HTML format ROUGE understands. 349 | 350 | system_id: Optional system ID which will be printed in 351 | ROUGE's output. 352 | 353 | Returns: Rouge output as string. 354 | 355 | """ 356 | self.write_config(system_id=system_id) 357 | options = self.__get_options(rouge_args) 358 | command = [self._bin_path] + options 359 | self.log.info( 360 | "Running ROUGE with command {}".format(" ".join(command))) 361 | rouge_output = check_output(command).decode("UTF-8") 362 | return rouge_output 363 | 364 | def convert_and_evaluate(self, system_id=1, 365 | split_sentences=False, rouge_args=None): 366 | """ 367 | Convert plain text summaries to ROUGE format and run ROUGE to 368 | evaluate the system summaries in system_dir against the model 369 | summaries in model_dir. Optionally split texts into sentences 370 | in case they aren't already. 371 | 372 | This is just a convenience method combining 373 | convert_summaries_to_rouge_format() and evaluate(). 374 | 375 | split_sentences: Optional argument specifying if 376 | sentences should be split. 377 | system_id: Optional system ID which will be printed 378 | in ROUGE's output. 379 | 380 | Returns: ROUGE output as string. 381 | 382 | """ 383 | if split_sentences: 384 | self.split_sentences() 385 | self.__write_summaries() 386 | rouge_output = self.evaluate(system_id, rouge_args) 387 | return rouge_output 388 | 389 | def output_to_dict(self, output): 390 | """ 391 | Convert the ROUGE output into python dictionary for further 392 | processing. 393 | 394 | """ 395 | #0 ROUGE-1 Average_R: 0.02632 (95%-conf.int. 0.02632 - 0.02632) 396 | pattern = re.compile( 397 | r"(\d+) (ROUGE-\S+) (Average_\w): (\d.\d+) " 398 | r"\(95%-conf.int. (\d.\d+) - (\d.\d+)\)") 399 | results = {} 400 | for line in output.split("\n"): 401 | match = pattern.match(line) 402 | if match: 403 | sys_id, rouge_type, measure, result, conf_begin, conf_end = \ 404 | match.groups() 405 | measure = { 406 | 'Average_R': 'recall', 407 | 'Average_P': 'precision', 408 | 'Average_F': 'f_score' 409 | }[measure] 410 | rouge_type = rouge_type.lower().replace("-", '_') 411 | key = "{}_{}".format(rouge_type, measure) 412 | results[key] = float(result) 413 | results["{}_cb".format(key)] = float(conf_begin) 414 | results["{}_ce".format(key)] = float(conf_end) 415 | return results 416 | 417 | ################################################################### 418 | # Private methods 419 | 420 | def __set_rouge_dir(self, home_dir=None): 421 | """ 422 | Verfify presence of ROUGE-1.5.5.pl and data folder, and set 423 | those paths. 424 | 425 | """ 426 | if not home_dir: 427 | self._home_dir = self.__get_rouge_home_dir_from_settings() 428 | else: 429 | self._home_dir = home_dir 430 | self.save_home_dir() 431 | self._bin_path = os.path.join(self._home_dir, 'ROUGE-1.5.5.pl') 432 | self.data_dir = os.path.join(self._home_dir, 'data') 433 | if not os.path.exists(self._bin_path): 434 | raise Exception( 435 | "ROUGE binary not found at {}. Please set the " 436 | "correct path by running pyrouge_set_rouge_path " 437 | "/path/to/rouge/home.".format(self._bin_path)) 438 | 439 | def __get_rouge_home_dir_from_settings(self): 440 | config = ConfigParser() 441 | with open(self._settings_file) as f: 442 | if hasattr(config, "read_file"): 443 | config.read_file(f) 444 | else: 445 | # use deprecated python 2.x method 446 | config.readfp(f) 447 | rouge_home_dir = config.get('pyrouge settings', 'home_dir') 448 | return rouge_home_dir 449 | 450 | @staticmethod 451 | def __get_eval_string( 452 | task_id, system_id, 453 | system_dir, system_filename, 454 | model_dir, model_filenames): 455 | """ 456 | ROUGE can evaluate several system summaries for a given text 457 | against several model summaries, i.e. there is an m-to-n 458 | relation between system and model summaries. The system 459 | summaries are listed in the tag and the model summaries 460 | in the tag. pyrouge currently only supports one system 461 | summary per text, i.e. it assumes a 1-to-n relation between 462 | system and model summaries. 463 | 464 | """ 465 | peer_elems = "

{name}

".format( 466 | id=system_id, name=system_filename) 467 | 468 | model_elems = ["{name}".format( 469 | id=chr(65 + i), name=name) 470 | for i, name in enumerate(model_filenames)] 471 | 472 | model_elems = "\n\t\t\t".join(model_elems) 473 | eval_string = """ 474 | 475 | {model_root} 476 | {peer_root} 477 | 478 | 479 | 480 | {peer_elems} 481 | 482 | 483 | {model_elems} 484 | 485 | 486 | """.format( 487 | task_id=task_id, 488 | model_root=model_dir, model_elems=model_elems, 489 | peer_root=system_dir, peer_elems=peer_elems) 490 | return eval_string 491 | 492 | def __process_summaries(self, process_func): 493 | """ 494 | Helper method that applies process_func to the files in the 495 | system and model folders and saves the resulting files to new 496 | system and model folders. 497 | 498 | """ 499 | temp_dir = mkdtemp() 500 | new_system_dir = os.path.join(temp_dir, "system") 501 | os.mkdir(new_system_dir) 502 | new_model_dir = os.path.join(temp_dir, "model") 503 | os.mkdir(new_model_dir) 504 | self.log.info( 505 | "Processing summaries. Saving system files to {} and " 506 | "model files to {}.".format(new_system_dir, new_model_dir)) 507 | process_func(self._system_dir, new_system_dir) 508 | process_func(self._model_dir, new_model_dir) 509 | self._system_dir = new_system_dir 510 | self._model_dir = new_model_dir 511 | 512 | def __write_summaries(self): 513 | self.log.info("Writing summaries.") 514 | self.__process_summaries(self.convert_summaries_to_rouge_format) 515 | 516 | @staticmethod 517 | def __get_model_filenames_for_id(id, model_dir, model_filenames_pattern): 518 | pattern = re.compile(model_filenames_pattern.replace('#ID#', id)) 519 | model_filenames = [ 520 | f for f in os.listdir(model_dir) if pattern.match(f)] 521 | if not model_filenames: 522 | raise Exception( 523 | "Could not find any model summaries for the system" 524 | " summary with ID {}. Specified model filename pattern was: " 525 | "{}".format(id, model_filenames_pattern)) 526 | return model_filenames 527 | 528 | def __get_options(self, rouge_args=None): 529 | """ 530 | Get supplied command line arguments for ROUGE or use default 531 | ones. 532 | 533 | """ 534 | if self.args: 535 | options = self.args.split() 536 | elif rouge_args: 537 | options = rouge_args.split() 538 | else: 539 | options = [ 540 | '-e', self._data_dir, 541 | '-c', 95, 542 | # '-2', 543 | # '-1', 544 | # '-U', 545 | '-m', 546 | # '-v', 547 | '-r', 1000, 548 | '-n', 2, 549 | # '-w', 1.2, 550 | '-a', 551 | ] 552 | options = list(map(str, options)) 553 | 554 | 555 | 556 | 557 | options = self.__add_config_option(options) 558 | return options 559 | 560 | def __create_dir_property(self, dir_name, docstring): 561 | """ 562 | Generate getter and setter for a directory property. 563 | 564 | """ 565 | property_name = "{}_dir".format(dir_name) 566 | private_name = "_" + property_name 567 | setattr(self, private_name, None) 568 | 569 | def fget(self): 570 | return getattr(self, private_name) 571 | 572 | def fset(self, path): 573 | verify_dir(path, dir_name) 574 | setattr(self, private_name, path) 575 | 576 | p = property(fget=fget, fset=fset, doc=docstring) 577 | setattr(self.__class__, property_name, p) 578 | 579 | def __set_dir_properties(self): 580 | """ 581 | Automatically generate the properties for directories. 582 | 583 | """ 584 | directories = [ 585 | ("home", "The ROUGE home directory."), 586 | ("data", "The path of the ROUGE 'data' directory."), 587 | ("system", "Path of the directory containing system summaries."), 588 | ("model", "Path of the directory containing model summaries."), 589 | ] 590 | for (dirname, docstring) in directories: 591 | self.__create_dir_property(dirname, docstring) 592 | 593 | def __clean_rouge_args(self, rouge_args): 594 | """ 595 | Remove enclosing quotation marks, if any. 596 | 597 | """ 598 | if not rouge_args: 599 | return 600 | quot_mark_pattern = re.compile('"(.+)"') 601 | match = quot_mark_pattern.match(rouge_args) 602 | if match: 603 | cleaned_args = match.group(1) 604 | return cleaned_args 605 | else: 606 | return rouge_args 607 | 608 | def __add_config_option(self, options): 609 | return options + [self._config_file] 610 | 611 | def __get_config_path(self): 612 | if platform.system() == "Windows": 613 | parent_dir = os.getenv("APPDATA") 614 | config_dir_name = "pyrouge" 615 | elif os.name == "posix": 616 | parent_dir = os.path.expanduser("~") 617 | config_dir_name = ".pyrouge" 618 | else: 619 | parent_dir = os.path.dirname(__file__) 620 | config_dir_name = "" 621 | config_dir = os.path.join(parent_dir, config_dir_name) 622 | if not os.path.exists(config_dir): 623 | os.makedirs(config_dir) 624 | return os.path.join(config_dir, 'settings.ini') 625 | 626 | 627 | if __name__ == "__main__": 628 | import argparse 629 | from utils.argparsers import rouge_path_parser 630 | 631 | parser = argparse.ArgumentParser(parents=[rouge_path_parser]) 632 | args = parser.parse_args() 633 | 634 | rouge = Rouge155(args.rouge_home) 635 | rouge.save_home_dir() --------------------------------------------------------------------------------