├── src
├── others
│ ├── __init__.py
│ ├── logging.py
│ ├── utils.py
│ ├── recadam.py
│ ├── optimizer.py
│ └── my_pyrouge.py
├── inference.py
├── cal_rouge.py
├── preprocessing.py
├── sdpt_pretraining.py
├── run.py
├── trainer.py
├── dapt_pretraining.py
└── tapt_pretraining.py
├── scripts
├── finetuning.sh
├── dapt_pretraining.sh
├── tapt_pretraining.sh
└── sdpt_pretraining.sh
├── requirements.txt
├── image
├── HKUST.jpg
└── pytorch-logo-dark.png
├── .gitignore
├── README.md
└── LICENSE
/src/others/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/scripts/finetuning.sh:
--------------------------------------------------------------------------------
1 | python ./src/run.py
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy
2 | tqdm
3 | transformers>=3.0.2
4 | nltk
--------------------------------------------------------------------------------
/image/HKUST.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TysonYu/AdaptSum/HEAD/image/HKUST.jpg
--------------------------------------------------------------------------------
/image/pytorch-logo-dark.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TysonYu/AdaptSum/HEAD/image/pytorch-logo-dark.png
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__/
2 | dataset
3 | savings
4 | .vscode
5 | processed_datasets/
6 | *.pt
7 | logs
8 | pyrouge
9 | save
--------------------------------------------------------------------------------
/scripts/dapt_pretraining.sh:
--------------------------------------------------------------------------------
1 | python ./src/dapt_pretraining.py -path=./dataset/debate/debateorg.txt \
2 | -dm=debate \
3 | -visible_gpu=0 \
4 | -save_interval=10000 \
5 | # -recadam \
6 | # -logging_Euclid_dist
--------------------------------------------------------------------------------
/scripts/tapt_pretraining.sh:
--------------------------------------------------------------------------------
1 | python ./src/tapt_pretraining.py -path=./dataset/debate/TAPT-data/train.source \
2 | -dm=debate \
3 | -visible_gpu=0 \
4 | -save_interval=100 \
5 | # -recadam \
6 | # -logging_Euclid_dist
--------------------------------------------------------------------------------
/scripts/sdpt_pretraining.sh:
--------------------------------------------------------------------------------
1 | python ./src/sdpt_pretraining.py -data_name=SDPT-cnn_dm \
2 | -visible_gpu=0 \
3 | -saving_path=SDPT_save \
4 | -start_to_save_iter=0 \
5 | -save_interval=10000 \
6 | # -recadam \
7 | # -logging_Euclid_dist \
--------------------------------------------------------------------------------
/src/others/logging.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 |
3 | import logging
4 |
5 | logger = logging.getLogger()
6 |
7 |
8 | def init_logger(log_file=None, log_file_level=logging.NOTSET):
9 | log_format = logging.Formatter("[%(asctime)s %(levelname)s] %(message)s")
10 | logger = logging.getLogger()
11 | logger.setLevel(logging.INFO)
12 |
13 | console_handler = logging.StreamHandler()
14 | console_handler.setFormatter(log_format)
15 | logger.handlers = [console_handler]
16 |
17 | if log_file and log_file != '':
18 | file_handler = logging.FileHandler(log_file)
19 | file_handler.setLevel(log_file_level)
20 | file_handler.setFormatter(log_format)
21 | logger.addHandler(file_handler)
22 |
23 | return logger
24 |
25 |
26 |
--------------------------------------------------------------------------------
/src/others/utils.py:
--------------------------------------------------------------------------------
1 |
2 | import os
3 | import pickle
4 | import numpy as np
5 | import torch.nn as nn
6 | import torch
7 | import math
8 | import random
9 | NEAR_INF = 1e20
10 | NEAR_INF_FP16 = 65504
11 |
12 | def save(toBeSaved, filename, mode='wb'):
13 | '''
14 | save data to pickle file
15 | '''
16 | dirname = os.path.dirname(filename)
17 | if not os.path.exists(dirname):
18 | os.makedirs(dirname)
19 | file = open(filename, mode)
20 | pickle.dump(toBeSaved, file, protocol=4) # protocol 4 allows large size object, it's the default since python 3.8
21 | file.close()
22 |
23 | def load(filename, mode='rb'):
24 | '''
25 | load pickle file
26 | '''
27 | file = open(filename, mode)
28 | loaded = pickle.load(file)
29 | file.close()
30 | return loaded
31 |
32 | def pad_sents(sents, pad_token=0, max_len=512):
33 | '''
34 | pad input to max length
35 | '''
36 | sents_padded = []
37 | lens = get_lens(sents)
38 | max_len = min(max(lens), max_len)
39 | sents_padded = []
40 | new_len = []
41 | for i, l in enumerate(lens):
42 | if l > max_len:
43 | l = max_len
44 | new_len.append(l)
45 | sents_padded.append(sents[i][:l] + [pad_token] * (max_len - l))
46 | return sents_padded, new_len
47 |
48 | def get_mask(sents, unmask_idx=1, mask_idx=0, max_len=512):
49 | '''
50 | make mask for padded input
51 | '''
52 | lens = get_lens(sents)
53 | max_len = min(max(lens), max_len)
54 | mask = []
55 | for l in lens:
56 | if l > max_len:
57 | l = max_len
58 | mask.append([unmask_idx] * l + [mask_idx] * (max_len - l))
59 | return mask
60 |
61 | def get_lens(sents):
62 | return [len(sent) for sent in sents]
63 |
64 | def get_max_len(sents):
65 | max_len = max([len(sent) for sent in sents])
66 | return max_len
67 |
68 | def fix_random_seed(random_seed):
69 | random.seed(random_seed)
70 | np.random.seed(random_seed)
71 | torch.manual_seed(random_seed)
72 | torch.cuda.manual_seed(random_seed)
73 | torch.backends.cudnn.deterministic = True
74 |
75 | # ----- for model -----------------------------------------------------------
76 |
77 | def count_parameters(model):
78 | return sum(p.numel() for p in model.parameters() if p.requires_grad)
79 |
80 | def initialize_weights(m):
81 | if hasattr(m, 'weight') and m.weight.dim() > 1:
82 | nn.init.xavier_uniform_(m.weight.data)
83 |
84 | def tile(x, count, dim=0):
85 | """
86 | Tiles x on dimension dim count times.
87 | """
88 | perm = list(range(len(x.size())))
89 | if dim != 0:
90 | perm[0], perm[dim] = perm[dim], perm[0]
91 | x = x.permute(perm).contiguous()
92 | out_size = list(x.size())
93 | out_size[0] *= count
94 | batch = x.size(0)
95 | x = x.view(batch, -1) \
96 | .transpose(0, 1) \
97 | .repeat(count, 1) \
98 | .transpose(0, 1) \
99 | .contiguous() \
100 | .view(*out_size)
101 | if dim != 0:
102 | x = x.permute(perm).contiguous()
103 | return x
104 |
105 | def neginf(dtype: torch.dtype) -> float:
106 | """
107 | Return a representable finite number near -inf for a dtype.
108 | """
109 | if dtype is torch.float16:
110 | return -NEAR_INF_FP16
111 | else:
112 | return -NEAR_INF
113 |
114 | def batch_generator(dataloader):
115 | while True:
116 | for data_sample in dataloader:
117 | yield data_sample
118 |
--------------------------------------------------------------------------------
/src/inference.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | import torch
4 | from tqdm import tqdm
5 | from transformers import BartTokenizer
6 | from transformers import BartForConditionalGeneration
7 | from others.logging import init_logger, logger
8 | from others.utils import load, count_parameters, initialize_weights, fix_random_seed
9 | from preprocessing import BartDataset, DataReader
10 | from others.optimizer import build_optim
11 | from trainer import train
12 |
13 | def tokenize(data, tokenizer):
14 | tokenized_text = [tokenizer.encode(i) for i in data]
15 | return tokenized_text
16 |
17 | if __name__ == '__main__':
18 | # for training
19 | parser = argparse.ArgumentParser()
20 | parser.add_argument('-visible_gpu', default='1', type=str)
21 | parser.add_argument('-log_file', default='./logs/inference/', type=str)
22 | parser.add_argument('-train_from', default='', type=str)
23 | parser.add_argument('-random_seed', type=int, default=199744)
24 | parser.add_argument('-lr', default=0.001, type=float)
25 | parser.add_argument('-max_grad_norm', default=0, type=float)
26 | parser.add_argument('-epoch', type=int, default=50)
27 | parser.add_argument('-saving_path', default='./save/', type=str)
28 | parser.add_argument('-data_name', default='debate', type=str)
29 | # for learning, optimizer
30 | parser.add_argument('-optim', default='adam', type=str)
31 | parser.add_argument('-beta1', default=0.9, type=float)
32 | parser.add_argument('-beta2', default=0.998, type=float)
33 | parser.add_argument('-warmup_steps', default=1000, type=int)
34 | parser.add_argument('-decay_method', default='noam', type=str)
35 | parser.add_argument('-enc_hidden_size', default=768, type=int)
36 | parser.add_argument('-clip', default=1.0, type=float)
37 | parser.add_argument('-accumulation_steps', default=10, type=int)
38 |
39 | args = parser.parse_args()
40 |
41 | # initial logger
42 | init_logger(args.log_file+args.data_name+'.log')
43 | logger.info(args)
44 |
45 | # set gpu
46 | os.environ["CUDA_VISIBLE_DEVICES"] = args.visible_gpu
47 |
48 | # set random seed
49 | fix_random_seed(args.random_seed)
50 |
51 | # loading data
52 | # it's faster to load data from pre_build data
53 | logger.info('starting to read dataloader')
54 | data_file_name = './dataset/' + args.data_name + '/testloader.pt'
55 | train_loader = load(data_file_name)
56 |
57 | # initial tokenizer
58 | tokenizer = BartTokenizer.from_pretrained('facebook/bart-base')
59 |
60 | # initial model
61 | logger.info('starting to build model')
62 | if args.train_from != '':
63 | logger.info("train from : {}".format(args.train_from))
64 | # checkpoint = torch.load(args.train_from, map_location='cpu')
65 | # model.load_state_dict(checkpoint['model'])
66 | model = torch.load(args.train_from)
67 | else:
68 | model = BartForConditionalGeneration.from_pretrained('facebook/bart-base')
69 | model.cuda()
70 | model.eval()
71 |
72 | # inference and save model
73 | save_file_name = './dataset/' + args.data_name + '/summaries'
74 | summaries = open(save_file_name, 'w')
75 | outputs = []
76 | for src_ids, decoder_ids, mask, label_ids in tqdm(train_loader):
77 | src_ids = src_ids.cuda()
78 | summary_ids = model.generate(src_ids, num_beams=4, max_length=256, early_stopping=True)
79 | output = [tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in summary_ids]
80 | outputs += output
81 | outputs = [i+'\n' for i in outputs]
82 | print(outputs[:4])
83 | summaries.writelines(outputs)
84 | summaries.close()
85 |
86 |
--------------------------------------------------------------------------------
/src/cal_rouge.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import time
4 | from multiprocessing import Pool
5 |
6 | import shutil
7 | import sys
8 | import codecs
9 | # import pyrouge
10 | from others.my_pyrouge import Rouge155
11 |
12 | def process(data):
13 | candidates, references, pool_id = data
14 | cnt = len(candidates)
15 | current_time = time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime())
16 | tmp_dir = "rouge-tmp-{}-{}".format(current_time,pool_id)
17 | if not os.path.isdir(tmp_dir):
18 | os.mkdir(tmp_dir)
19 | os.mkdir(tmp_dir + "/candidate")
20 | os.mkdir(tmp_dir + "/reference")
21 | try:
22 | for i in range(cnt):
23 | if len(references[i]) < 1:
24 | continue
25 | with open(tmp_dir + "/candidate/cand.{}.txt".format(i), "w",
26 | encoding="utf-8") as f:
27 | f.write(candidates[i])
28 | with open(tmp_dir + "/reference/ref.{}.txt".format(i), "w",
29 | encoding="utf-8") as f:
30 | f.write(references[i])
31 | r = Rouge155()
32 | r.model_dir = tmp_dir + "/reference/"
33 | r.system_dir = tmp_dir + "/candidate/"
34 | r.model_filename_pattern = 'ref.#ID#.txt'
35 | r.system_filename_pattern = r'cand.(\d+).txt'
36 | rouge_results = r.convert_and_evaluate()
37 | # print(rouge_results)
38 | results_dict = r.output_to_dict(rouge_results)
39 | finally:
40 | pass
41 | if os.path.isdir(tmp_dir):
42 | shutil.rmtree(tmp_dir)
43 | return results_dict
44 |
45 | def chunks(l, n):
46 | """Yield successive n-sized chunks from l."""
47 | for i in range(0, len(l), n):
48 | yield l[i:i + n]
49 |
50 | def test_rouge(cand, ref,num_processes):
51 | """Calculate ROUGE scores of sequences passed as an iterator
52 | e.g. a list of str, an open file, StringIO or even sys.stdin
53 | """
54 | candidates = [line.strip() for line in cand]
55 | references = [line.strip() for line in ref]
56 |
57 | # print(len(candidates))
58 | # print(len(references))
59 | assert len(candidates) == len(references)
60 | candidates_chunks = list(chunks(candidates, int(len(candidates)/num_processes)))
61 | references_chunks = list(chunks(references, int(len(references)/num_processes)))
62 | n_pool = len(candidates_chunks)
63 | arg_lst = []
64 | for i in range(n_pool):
65 | arg_lst.append((candidates_chunks[i],references_chunks[i],i))
66 | pool = Pool(n_pool)
67 | results = pool.map(process,arg_lst)
68 | final_results = {}
69 | for i,r in enumerate(results):
70 | for k in r:
71 | if(k not in final_results):
72 | final_results[k] = r[k]*len(candidates_chunks[i])
73 | else:
74 | final_results[k] += r[k] * len(candidates_chunks[i])
75 | for k in final_results:
76 | final_results[k] = final_results[k]/len(candidates)
77 | return final_results
78 |
79 | def rouge_results_to_str(results_dict):
80 | return ">> ROUGE-F(1/2/3/l): {:.2f}/{:.2f}/{:.2f}\nROUGE-R(1/2/3/l): {:.2f}/{:.2f}/{:.2f}\nROUGE-P(1/2/3/l): {:.2f}/{:.2f}/{:.2f}\n".format(
81 | results_dict["rouge_1_f_score"] * 100,
82 | results_dict["rouge_2_f_score"] * 100,
83 | # results_dict["rouge_3_f_score"] * 100,
84 | results_dict["rouge_l_f_score"] * 100,
85 |
86 | results_dict["rouge_1_recall"] * 100,
87 | results_dict["rouge_2_recall"] * 100,
88 | # results_dict["rouge_3_f_score"] * 100,
89 | results_dict["rouge_l_recall"] * 100,
90 |
91 | results_dict["rouge_1_precision"] * 100,
92 | results_dict["rouge_2_precision"] * 100,
93 | # results_dict["rouge_3_f_score"] * 100,
94 | results_dict["rouge_l_precision"] * 100
95 | )
96 |
97 |
98 | if __name__ == "__main__":
99 | # init_logger('test_rouge.log')
100 | parser = argparse.ArgumentParser()
101 | parser.add_argument('-c', type=str, default="../datasets/debate/test.source",
102 | help='candidate file')
103 | parser.add_argument('-r', type=str, default="../datasets/debate/test.target",
104 | help='reference file')
105 | parser.add_argument('-p', type=int, default=1,
106 | help='number of processes')
107 | args = parser.parse_args()
108 | print(args.c)
109 | print(args.r)
110 | print(args.p)
111 |
112 | # read file
113 | c_data = open(args.c)
114 | candidates = c_data.readlines()
115 | candidates = [line.strip('\n') for line in candidates]
116 | c_data = open(args.r)
117 | references = c_data.readlines()
118 | references = [line.strip('\n') for line in references]
119 |
120 | # calculate rouge
121 | results_dict = test_rouge(candidates, references, args.p)
122 | print(time.strftime('%H:%M:%S', time.localtime()))
123 | print(rouge_results_to_str(results_dict))
124 |
--------------------------------------------------------------------------------
/src/preprocessing.py:
--------------------------------------------------------------------------------
1 | from tqdm import tqdm
2 | import torch
3 | from transformers import BartTokenizer
4 | from torch.utils.data import Dataset, DataLoader
5 | import random
6 | import numpy as np
7 | import os
8 |
9 | from others.utils import pad_sents, save, get_mask, fix_random_seed
10 |
11 | class BartDataset(Dataset):
12 | '''
13 | Attributes:
14 | src: it's a list, each line is a sample for source text.
15 | tgt: it's a list, each line is a sample for target text.
16 | src_ids: it's a list, each line is a sample for source index after tokenized.
17 | tgt_ids: it's a list, each line is a sample for target index after tokenized.
18 | '''
19 | def __init__(self, tokenizer, multi_news_reader, args):
20 | self.tokenizer = tokenizer
21 | self.multi_news_reader = multi_news_reader
22 | self.src = multi_news_reader.data_src
23 | self.tgt = multi_news_reader.data_tgt
24 | self.src = [i.strip('\n') for i in self.src]
25 | self.tgt = [i.strip('\n') for i in self.tgt]
26 | self.src_ids = self.tokenize(self.src)
27 | self.tgt_ids = self.tokenize(self.tgt)
28 |
29 | def __len__(self):
30 | return len(self.src)
31 |
32 | def __getitem__(self, idx):
33 | return self.src_ids[idx], self.tgt_ids[idx]
34 |
35 | def tokenize(self, data):
36 | tokenized_text = [self.tokenizer.encode(i, add_special_tokens=False) for i in tqdm(data)]
37 | return tokenized_text
38 |
39 | def collate_fn(self, data):
40 | # rebuld the raw text and truncate to max length
41 | max_input_len = 1024
42 | max_output_len = 256
43 | raw_src = [pair[0] for pair in data]
44 | raw_tgt = [pair[1] for pair in data]
45 | raw_src = [i[:max_input_len-1] for i in raw_src]
46 | raw_tgt = [i[:max_output_len-1] for i in raw_tgt]
47 | src = []
48 | tgt = []
49 | # remove blank data
50 | for i in range(len(raw_src)):
51 | if (raw_src[i] != []) and (raw_tgt[i] != []):
52 | src.append(raw_src[i])
53 | tgt.append(raw_tgt[i])
54 | # make input mask
55 | mask = torch.tensor(get_mask(src, max_len=max_input_len))
56 | # make input ids
57 | src_ids = torch.tensor(pad_sents(src, 1, max_len=max_input_len)[0])
58 | # make output ids
59 | decoder_ids = [[0]+i for i in tgt]
60 | # make output labels
61 | label_ids = [i+[2] for i in tgt]
62 | decoder_ids = torch.tensor(pad_sents(decoder_ids, 1, max_len=max_output_len)[0])
63 | label_ids = torch.tensor(pad_sents(label_ids, -100, max_len=max_output_len)[0])
64 |
65 | return src_ids, decoder_ids, mask, label_ids
66 |
67 |
68 | class DataReader(object):
69 | '''
70 | Attributes:
71 | data_src: source text
72 | data_tgt: target text
73 | '''
74 | def __init__ (self, args):
75 | self.args = args
76 | self.raw_data = self.load_multinews_data(self.args)
77 | self.data_src = self.raw_data[0]
78 | self.data_tgt = self.raw_data[1]
79 |
80 | def file_reader(self, file_path):
81 | file = open(file_path, 'r')
82 | lines = file.readlines()
83 | return lines
84 |
85 | def load_multinews_data(self, args):
86 | train_src_path = args.data_path + args.data_name + '/' + args.mode + '.source'
87 | train_tgt_path = args.data_path + args.data_name + '/' + args.mode + '.target'
88 | train_src_lines = self.file_reader(train_src_path)
89 | train_tgt_lines = self.file_reader(train_tgt_path)
90 | return (train_src_lines, train_tgt_lines)
91 |
92 | def data_builder(args):
93 | save_path = args.data_path + args.data_name + '/' + args.mode + 'loader' + '.pt'
94 | data_reader = DataReader(args)
95 | tokenizer = BartTokenizer.from_pretrained('facebook/bart-base')
96 | train_set = BartDataset(tokenizer, data_reader, args)
97 | if args.mode == 'train':
98 | data_loader = DataLoader(dataset=train_set,
99 | batch_size=args.batch_size,
100 | shuffle=True,
101 | collate_fn=train_set.collate_fn)
102 | else:
103 | data_loader = DataLoader(dataset=train_set,
104 | batch_size=args.batch_size,
105 | shuffle=False,
106 | collate_fn=train_set.collate_fn)
107 | save(data_loader, save_path)
108 |
109 | if __name__ == '__main__':
110 | import argparse
111 | parser = argparse.ArgumentParser()
112 | parser.add_argument('-data_path', default='dataset/', type=str)
113 | parser.add_argument('-data_name', default='debate', type=str)
114 | parser.add_argument('-mode', default='train', type=str)
115 | parser.add_argument('-batch_size', default=4, type=int)
116 | parser.add_argument('-random_seed', type=int, default=0)
117 | args = parser.parse_args()
118 |
119 | # set random seed
120 | fix_random_seed(args.random_seed)
121 | data_builder(args)
122 |
123 |
--------------------------------------------------------------------------------
/src/others/recadam.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """RecAdam optimizer"""
16 |
17 | import logging
18 | import math
19 | import numpy as np
20 |
21 | import torch
22 | from torch.optim import Optimizer
23 |
24 |
25 | logger = logging.getLogger(__name__)
26 |
27 |
28 | def anneal_function(function, step, k, t0, weight):
29 | if function == 'sigmoid':
30 | return float(1 / (1 + np.exp(-k * (step - t0)))) * weight
31 | elif function == 'linear':
32 | return min(1, step / t0) * weight
33 | elif function == 'constant':
34 | return weight
35 | else:
36 | ValueError
37 |
38 |
39 | class RecAdam(Optimizer):
40 | """ Implementation of RecAdam optimizer, a variant of Adam optimizer.
41 | Parameters:
42 | lr (float): learning rate. Default 1e-3.
43 | betas (tuple of 2 floats): Adams beta parameters (b1, b2). Default: (0.9, 0.999)
44 | eps (float): Adams epsilon. Default: 1e-6
45 | weight_decay (float): Weight decay. Default: 0.0
46 | correct_bias (bool): can be set to False to avoid correcting bias in Adam (e.g. like in Bert TF repository). Default True.
47 | anneal_fun (str): a hyperparam for the anneal function, decide the function of the curve. Default 'sigmoid'.
48 | anneal_k (float): a hyperparam for the anneal function, decide the slop of the curve. Choice: [0.05, 0.1, 0.2, 0.5, 1]
49 | anneal_t0 (float): a hyperparam for the anneal function, decide the middle point of the curve. Choice: [100, 250, 500, 1000]
50 | anneal_w (float): a hyperparam for the anneal function, decide the scale of the curve. Default 1.0.
51 | pretrain_cof (float): the coefficient of the quadratic penalty. Default 5000.0.
52 | pretrain_params (list of tensors): the corresponding group of params in the pretrained model.
53 | """
54 |
55 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-6, weight_decay=0.0, correct_bias=True, anneal_fun='sigmoid', anneal_k=0, anneal_t0=0, anneal_w=1.0, pretrain_cof=5000.0, pretrain_params=None):
56 | if lr < 0.0:
57 | raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr))
58 | if not 0.0 <= betas[0] < 1.0:
59 | raise ValueError("Invalid beta parameter: {} - should be in [0.0, 1.0[".format(betas[0]))
60 | if not 0.0 <= betas[1] < 1.0:
61 | raise ValueError("Invalid beta parameter: {} - should be in [0.0, 1.0[".format(betas[1]))
62 | if not 0.0 <= eps:
63 | raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(eps))
64 | print("anneal_k: {:4f}, anneal_t0: {:4f}".format(anneal_k, anneal_t0))
65 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, correct_bias=correct_bias, anneal_fun=anneal_fun, anneal_k=anneal_k, anneal_t0=anneal_t0, anneal_w=anneal_w, pretrain_cof=pretrain_cof, pretrain_params=pretrain_params)
66 | super().__init__(params, defaults)
67 |
68 | def step(self, closure=None):
69 | """Performs a single optimization step.
70 | Arguments:
71 | closure (callable, optional): A closure that reevaluates the model
72 | and returns the loss.
73 | """
74 | loss = None
75 | if closure is not None:
76 | loss = closure()
77 | for group in self.param_groups:
78 | for p, pp in zip(group["params"], group["pretrain_params"]):
79 | if p.grad is None:
80 | continue
81 | grad = p.grad.data
82 | if grad.is_sparse:
83 | raise RuntimeError("Adam does not support sparse gradients, please consider SparseAdam instead")
84 |
85 | state = self.state[p]
86 |
87 | # State initialization
88 | if len(state) == 0:
89 | state["step"] = 0
90 | # Exponential moving average of gradient values
91 | state["exp_avg"] = torch.zeros_like(p.data)
92 | # Exponential moving average of squared gradient values
93 | state["exp_avg_sq"] = torch.zeros_like(p.data)
94 |
95 | exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
96 | beta1, beta2 = group["betas"]
97 |
98 | state["step"] += 1
99 |
100 | # Decay the first and second moment running average coefficient
101 | # In-place operations to update the averages at the same time
102 | exp_avg.mul_(beta1).add_(1.0 - beta1, grad)
103 | exp_avg_sq.mul_(beta2).addcmul_(1.0 - beta2, grad, grad)
104 | denom = exp_avg_sq.sqrt().add_(group["eps"])
105 |
106 | step_size = group["lr"]
107 | if group["correct_bias"]:
108 | bias_correction1 = 1.0 - beta1 ** state["step"]
109 | bias_correction2 = 1.0 - beta2 ** state["step"]
110 | step_size = step_size * math.sqrt(bias_correction2) / bias_correction1
111 |
112 | # With RecAdam method, the optimization objective is
113 | # Loss = lambda(t)*Loss_T + (1-lambda(t))*Loss_S
114 | # Loss = lambda(t)*Loss_T + (1-lambda(t))*\gamma/2*\sum((\theta_i-\theta_i^*)^2)
115 | if group['anneal_w'] > 0.0:
116 | # print("anneal_fun: {}, anneal_w: {:4f}, anneal_k: {:4f}, anneal_t0: {:4f}".format(group['anneal_fun'], group['anneal_w'], group['anneal_k'], group['anneal_t0']))
117 | # We calculate the lambda as the annealing function
118 | anneal_lambda = anneal_function(group['anneal_fun'], state["step"], group['anneal_k'], group['anneal_t0'], group['anneal_w'])
119 | assert anneal_lambda <= group['anneal_w']
120 | # print(anneal_lambda, step_size)
121 | # The loss of the target task is multiplied by lambda(t)
122 | p.data.addcdiv_(-step_size * anneal_lambda, exp_avg, denom)
123 | # Add the quadratic penalty to simulate the pretraining tasks
124 | p.data.add_(-group["lr"] * (group['anneal_w'] - anneal_lambda) * group["pretrain_cof"], p.data - pp.data)
125 | else:
126 | p.data.addcdiv_(-step_size, exp_avg, denom)
127 |
128 | # Just adding the square of the weights to the loss function is *not*
129 | # the correct way of using L2 regularization/weight decay with Adam,
130 | # since that will interact with the m and v parameters in strange ways.
131 | #
132 | # Instead we want to decay the weights in a manner that doesn't interact
133 | # with the m/v parameters. This is equivalent to adding the square
134 | # of the weights to the loss with plain (non-momentum) SGD.
135 | # Add weight decay at the end (fixed version)
136 | if group["weight_decay"] > 0.0:
137 | p.data.add_(-group["lr"] * group["weight_decay"], p.data)
138 |
139 | return loss
--------------------------------------------------------------------------------
/src/sdpt_pretraining.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | import torch
4 | from transformers import BartForConditionalGeneration, get_linear_schedule_with_warmup
5 | import numpy as np
6 | from preprocessing import BartDataset, DataReader
7 | from others.logging import init_logger, logger
8 | from others.utils import load, fix_random_seed
9 | from others.optimizer import build_optim
10 |
11 |
12 |
13 | def load_dataloader(args):
14 | train_file_name = './dataset/' + args.data_name + '/trainloader.pt'
15 | train_loader = load(train_file_name)
16 | logger.info('train loader has {} samples'.format(len(train_loader.dataset)))
17 | return train_loader
18 |
19 | def train(model, training_data, optimizer, checkpoint, args, pretrained_model):
20 | ''' Start training '''
21 | if args.logging_Euclid_dist:
22 | t_total = len(training_data) // args.accumulation_steps * 10
23 | scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total)
24 | logger.info('Start training')
25 | iteration = 0
26 | if args.break_point_continue:
27 | iteration = checkpoint['iteration']
28 | total_loss = 0
29 | F1 = 0
30 | for epoch_i in range(args.epoch):
31 | logger.info('[ Epoch : {}]'.format(epoch_i))
32 | dist_sum, dist_num = 0.0, 0
33 | # training part
34 | model.train()
35 | for src_ids, decoder_ids, mask, label_ids in training_data:
36 | iteration += 1
37 | src_ids = src_ids.cuda()
38 | decoder_ids = decoder_ids.cuda()
39 | mask = mask.cuda()
40 | label_ids = label_ids.cuda()
41 | # forward
42 | # optimizer.optimizer.zero_grad()
43 | loss = model(input_ids=src_ids, attention_mask=mask, decoder_input_ids=decoder_ids, labels=label_ids)[0]
44 | total_loss += loss.item()
45 | loss = loss / args.accumulation_steps
46 | # backward
47 | loss.backward()
48 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)
49 | # loss accumulation
50 | if (iteration+1) % args.accumulation_steps == 0:
51 | optimizer.step()
52 | if args.recadam:
53 | scheduler.step()
54 | model.zero_grad()
55 |
56 | if args.logging_Euclid_dist:
57 | dist = torch.sum(torch.abs(torch.cat(
58 | [p.view(-1) for n, p in model.named_parameters()]) - torch.cat(
59 | [p.view(-1) for n, p in pretrained_model.named_parameters()])) ** 2).item()
60 |
61 | dist_sum += dist
62 | dist_num += 1
63 | # write to log file
64 | if iteration % 20 == 0:
65 | if args.logging_Euclid_dist:
66 | logger.info("iteration: {} loss_per_word: {:4f} Euclid dist: {:.6f}".format(iteration, total_loss/20, dist_sum / dist_num))
67 | else:
68 | logger.info("iteration: {} loss_per_word: {:4f} learning rate: {:4f} ".format(iteration, total_loss/20, optimizer.learning_rate))
69 | total_loss = 0
70 | # save model
71 | if iteration % args.save_interval == 0 and iteration > args.start_to_save_iter:
72 | print('=====Saving checkpoint=====')
73 | model_name = args.saving_path + "/{}_{}.chkpt".format(args.data_name, iteration)
74 | torch.save(model, model_name)
75 | else:
76 | pass
77 |
78 | if __name__ == '__main__':
79 | # for training
80 | parser = argparse.ArgumentParser()
81 | parser.add_argument('-visible_gpu', default='1', type=str)
82 | parser.add_argument('-log_file', default='./logs/', type=str)
83 | parser.add_argument('-train_from', default='', type=str)
84 | parser.add_argument('-random_seed', type=int, default=0)
85 | parser.add_argument('-lr', default=0.05, type=float)
86 | parser.add_argument('-max_grad_norm', default=0, type=float)
87 | parser.add_argument('-epoch', type=int, default=50)
88 | parser.add_argument('-max_iter', type=int, default=800000)
89 | parser.add_argument('-saving_path', default='./save/', type=str)
90 | parser.add_argument('-data_name', default='debate', type=str)
91 | parser.add_argument('-minor_data', action='store_true')
92 | parser.add_argument('-pre_trained_lm', default='', type=str)
93 | parser.add_argument('-pre_trained_src', action='store_true')
94 | parser.add_argument('-break_point_continue', action='store_true')
95 | parser.add_argument('-corpus_path', type=str, default="", help="target domain corpus path")
96 | parser.add_argument('-mask_prob', type=float, default=0.15, help="mask probability")
97 | # for learning, optimizer
98 | parser.add_argument('-mtl', action='store_true', help='multitask learning')
99 | parser.add_argument('-optim', default='adam', type=str)
100 | parser.add_argument('-beta1', default=0.9, type=float)
101 | parser.add_argument('-beta2', default=0.998, type=float)
102 | parser.add_argument('-warmup_steps', default=1000, type=int)
103 | parser.add_argument('-decay_method', default='noam', type=str)
104 | parser.add_argument('-enc_hidden_size', default=768, type=int)
105 | parser.add_argument('-clip', default=1.0, type=float)
106 | parser.add_argument('-accumulation_steps', default=10, type=int)
107 | parser.add_argument('-bsz', default=4, type=int, help='batch size')
108 | # for evaluation
109 | parser.add_argument('-process_num', default=4, type=int)
110 | parser.add_argument('-start_to_save_iter', default=3000, type=int)
111 | parser.add_argument('-save_interval', default=10000, type=int)
112 | # using RecAdam
113 | parser.add_argument("-adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.")
114 | parser.add_argument('-recadam', default=False, action='store_true')
115 | parser.add_argument("-weight_decay", default=0.0, type=float, help="Weight decay if we apply some.")
116 | parser.add_argument("-anneal_w", type=float, default=1.0, help="Weight for the annealing function in RecAdam. Default 1.0.")
117 | parser.add_argument("-anneal_fun", type=str, default='sigmoid', choices=["sigmoid", "linear", 'constant'], help="the type of annealing function in RecAdam. Default sigmoid")
118 | parser.add_argument("-anneal_t0", type=int, default=1000, help="t0 for the annealing function in RecAdam.")
119 | parser.add_argument("-anneal_k", type=float, default=0.1, help="k for the annealing function in RecAdam.")
120 | parser.add_argument("-pretrain_cof", type=float, default=5000.0, help="Coefficient of the quadratic penalty in RecAdam. Default 5000.0.")
121 | parser.add_argument("-logging_Euclid_dist", action="store_true", help="Whether to log the Euclidean distance between the pretrained model and fine-tuning model")
122 | parser.add_argument("-max_steps", default=-1, type=int, help="If > 0: set total number of training steps to perform. Override num_train_epochs.")
123 | parser.add_argument("-model_type", type=str, default="layers")
124 | args = parser.parse_args()
125 |
126 | # initial logger
127 | if not os.path.exists(args.log_file + args.data_name):
128 | os.makedirs(args.log_file + args.data_name)
129 | init_logger(args.log_file + args.data_name + '/DAPT_pretraining.log')
130 | logger.info(args)
131 | # set gpu
132 | os.environ["CUDA_VISIBLE_DEVICES"] = args.visible_gpu
133 | # set random seed
134 | fix_random_seed(args.random_seed)
135 |
136 | # loading data
137 | # it's faster to load data from pre_build data
138 | logger.info('starting to read dataloader')
139 | train_loader = load_dataloader(args)
140 |
141 | # initial model and optimizer
142 | logger.info('starting to build model')
143 | model = BartForConditionalGeneration.from_pretrained('facebook/bart-base')
144 | model.cuda()
145 | checkpoint = None
146 | optim = build_optim(args, model, None, model)
147 | pretrained_model = BartForConditionalGeneration.from_pretrained('facebook/bart-base')
148 | if args.recadam:
149 | pretrained_model.cuda()
150 | optim = build_optim(args, model, None, pretrained_model)
151 |
152 | # training
153 | train(model, train_loader, optim, checkpoint, args, pretrained_model)
154 |
--------------------------------------------------------------------------------
/src/run.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | import torch
4 | from torch.utils.data import DataLoader
5 | from transformers import BartForConditionalGeneration
6 | from others.logging import init_logger, logger
7 | from others.utils import load, count_parameters, initialize_weights, batch_generator, fix_random_seed
8 | from preprocessing import BartDataset, DataReader
9 | from others.optimizer import build_optim
10 | from trainer import train, multitask_train
11 | from dapt_pretraining import CorpusDataset
12 | import random
13 | import numpy as np
14 |
15 | def make_log_file_name(args):
16 | if args.pre_trained_lm != '':
17 | log_file_name = args.log_file + args.data_name + '/train_' + args.pre_trained_lm.split('/')[-1][:-3] + '_' + args.percentage + '%_pretrain_lm.log'
18 | elif args.pre_trained_src:
19 | log_file_name = args.log_file + args.data_name + '/train_' + args.train_from.split('/')[-1][:-6] + '_' + args.percentage + '_' + '%_pretrain_src.log'
20 | else:
21 | log_file_name = args.log_file + args.data_name + '/train_' + args.percentage + '%.log'
22 | return log_file_name
23 |
24 | def load_dataloader(args):
25 | train_file_name = './dataset/' + args.data_name + '/trainloader.pt'
26 | train_loader = load(train_file_name)
27 | valid_file_name = './dataset/' + args.data_name + '/validloader.pt'
28 | valid_loader = load(valid_file_name)
29 | logger.info('train loader has {} samples'.format(len(train_loader.dataset)))
30 | return train_loader, valid_loader
31 |
32 | def load_dataloader_for_domain_corpus(args):
33 | corpus_dataset = CorpusDataset(args.corpus_path, denoising_flag=True)
34 | dataloader = DataLoader(dataset=corpus_dataset, batch_size=args.bsz, shuffle=True)
35 | data_generator = batch_generator(dataloader)
36 | return data_generator
37 |
38 | def load_model(args):
39 | model = BartForConditionalGeneration.from_pretrained('facebook/bart-base')
40 | if args.pre_trained_lm != '':
41 | model = torch.load(args.pre_trained_lm, map_location='cpu')
42 | # load from saved model
43 | if args.train_from != '':
44 | logger.info("train from : {}".format(args.train_from))
45 | if "mtl_pre_trained_lm" in args.train_from:
46 | checkpoint = torch.load(args.train_from, map_location='cpu')
47 | model.load_state_dict(checkpoint['model_lm'])
48 | elif "xsum" in args.train_from:
49 | checkpoint = None
50 | model = BartForConditionalGeneration.from_pretrained('VictorSanh/bart-base-finetuned-xsum')
51 | else:
52 | checkpoint = torch.load(args.train_from, map_location='cpu')
53 | model.load_state_dict(checkpoint['model'])
54 | if args.train_from == '':
55 | checkpoint = None
56 | if args.mtl:
57 | model_lm = BartForConditionalGeneration.from_pretrained('facebook/bart-base')
58 | if args.pre_trained_lm != '':
59 | model_lm = torch.load(args.pre_trained_lm, map_location='cpu')
60 | model_cnn = BartForConditionalGeneration.from_pretrained('facebook/bart-base')
61 | # shared part
62 | model_cnn.model.shared = model_lm.model.shared
63 | # encoder part
64 | model_cnn.model.encoder = model_lm.model.encoder
65 |
66 | print("dont share decoder!")
67 | model = None
68 | return model_lm, model_cnn, checkpoint
69 | return model, checkpoint
70 |
71 | if __name__ == '__main__':
72 | # for training
73 | parser = argparse.ArgumentParser()
74 | parser.add_argument('-visible_gpu', default='1', type=str)
75 | parser.add_argument('-log_file', default='./logs/', type=str)
76 | parser.add_argument('-train_from', default='', type=str)
77 | parser.add_argument('-random_seed', type=int, default=0)
78 | parser.add_argument('-lr', default=0.05, type=float)
79 | parser.add_argument('-max_grad_norm', default=0, type=float)
80 | parser.add_argument('-epoch', type=int, default=50)
81 | parser.add_argument('-max_iter', type=int, default=800000)
82 | parser.add_argument('-saving_path', default='./save/', type=str)
83 | parser.add_argument('-data_name', default='debate', type=str)
84 | parser.add_argument('-pre_trained_lm', default='', type=str)
85 | parser.add_argument('-pre_trained_src', action='store_true')
86 | parser.add_argument('-break_point_continue', action='store_true')
87 | parser.add_argument('-percentage', default='100', type=str)
88 | parser.add_argument('-corpus_path', type=str, default="", help="target domain corpus path")
89 | parser.add_argument('-mask_prob', type=float, default=0.15, help="mask probability")
90 | # for learning, optimizer
91 | parser.add_argument('-mtl', action='store_true', help='multitask learning')
92 | parser.add_argument('-optim', default='adam', type=str)
93 | parser.add_argument('-beta1', default=0.9, type=float)
94 | parser.add_argument('-beta2', default=0.998, type=float)
95 | parser.add_argument('-warmup_steps', default=1000, type=int)
96 | parser.add_argument('-decay_method', default='noam', type=str)
97 | parser.add_argument('-enc_hidden_size', default=768, type=int)
98 | parser.add_argument('-clip', default=1.0, type=float)
99 | parser.add_argument('-accumulation_steps', default=10, type=int)
100 | parser.add_argument('-bsz', default=4, type=int, help='batch size')
101 | # for evaluation
102 | parser.add_argument('-process_num', default=4, type=int)
103 | parser.add_argument('-save_interval', default=100, type=int)
104 | parser.add_argument('-start_to_save_iter', default=100, type=int)
105 | # using RecAdam
106 | parser.add_argument("-adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.")
107 | parser.add_argument('-recadam', default=False, action='store_true')
108 | parser.add_argument("-weight_decay", default=0.0, type=float, help="Weight decay if we apply some.")
109 | parser.add_argument("-anneal_w", type=float, default=1.0, help="Weight for the annealing function in RecAdam. Default 1.0.")
110 | parser.add_argument("-anneal_fun", type=str, default='sigmoid', choices=["sigmoid", "linear", 'constant'], help="the type of annealing function in RecAdam. Default sigmoid")
111 | parser.add_argument("-anneal_t0", type=int, default=1000, help="t0 for the annealing function in RecAdam.")
112 | parser.add_argument("-anneal_k", type=float, default=0.1, help="k for the annealing function in RecAdam.")
113 | parser.add_argument("-pretrain_cof", type=float, default=5000.0, help="Coefficient of the quadratic penalty in RecAdam. Default 5000.0.")
114 | parser.add_argument("-logging_Euclid_dist", action="store_true", help="Whether to log the Euclidean distance between the pretrained model and fine-tuning model")
115 | parser.add_argument("-max_steps", default=-1, type=int, help="If > 0: set total number of training steps to perform. Override num_train_epochs.")
116 | parser.add_argument("-model_type", type=str, default="layers")
117 | args = parser.parse_args()
118 |
119 | # initial logger
120 | init_logger(make_log_file_name(args))
121 | logger.info(args)
122 | # set gpu
123 | os.environ["CUDA_VISIBLE_DEVICES"] = args.visible_gpu
124 | # set random seed
125 | fix_random_seed(args.random_seed)
126 |
127 | # loading data
128 | # it's faster to load data from pre_build data
129 | logger.info('starting to read dataloader')
130 | train_loader, valid_loader = load_dataloader(args)
131 |
132 | # initial model and optimizer
133 | logger.info('starting to build model')
134 | if not args.mtl:
135 | model, checkpoint = load_model(args)
136 | model.cuda()
137 | optim = build_optim(args, model, None, model)
138 | pretrained_model = BartForConditionalGeneration.from_pretrained('facebook/bart-base')
139 | if args.recadam:
140 | pretrained_model.cuda()
141 | optim = build_optim(args, model, None, pretrained_model)
142 | else:
143 | model_lm, model_cnn, checkpoint = load_model(args)
144 | model_lm.cuda()
145 | model_cnn.cuda()
146 | optim_lm = build_optim(args, model_lm, checkpoint)
147 | optim_cnn = build_optim(args, model_cnn, checkpoint)
148 |
149 | # training
150 | if args.mtl:
151 | assert args.data_name == "cnn_dm" or args.data_name == "xsum"
152 | tgtdomain_data = load_dataloader_for_domain_corpus(args)
153 | cnn_train_data = batch_generator(train_loader)
154 | cnn_valid_data = batch_generator(valid_loader)
155 | multitask_train(model_lm, model_cnn, cnn_train_data, cnn_valid_data, tgtdomain_data, optim_lm, optim_cnn, checkpoint, args)
156 | else:
157 | train(model, train_loader, valid_loader, optim, checkpoint, args, pretrained_model)
158 |
--------------------------------------------------------------------------------
/src/others/optimizer.py:
--------------------------------------------------------------------------------
1 | from others.recadam import RecAdam, anneal_function
2 | import torch
3 | import torch.optim as optim
4 |
5 | def build_optim(args, model, checkpoint, pretrained_model=None):
6 | """ Build optimizer """
7 | if args.recadam:
8 | print("Using RecAdam")
9 | no_decay = ["bias", "layer_norm.weight", "layernorm_embedding.weight"]
10 | optimizer_grouped_parameters = [
11 | {
12 | "params": [p for n, p in model.named_parameters() if
13 | not any(nd in n for nd in no_decay) and args.model_type in n],
14 | "weight_decay": args.weight_decay,
15 | "anneal_w": args.anneal_w,
16 | "pretrain_params": [p_p for p_n, p_p in pretrained_model.named_parameters() if
17 | not any(nd in p_n for nd in no_decay) and args.model_type in p_n]
18 | },
19 | {
20 | "params": [p for n, p in model.named_parameters() if
21 | not any(nd in n for nd in no_decay) and args.model_type not in n],
22 | "weight_decay": args.weight_decay,
23 | "anneal_w": 0.0,
24 | "pretrain_params": [p_p for p_n, p_p in pretrained_model.named_parameters() if
25 | not any(nd in p_n for nd in no_decay) and args.model_type not in p_n]
26 | },
27 | {
28 | "params": [p for n, p in model.named_parameters() if
29 | any(nd in n for nd in no_decay) and args.model_type in n],
30 | "weight_decay": 0.0,
31 | "anneal_w": args.anneal_w,
32 | "pretrain_params": [p_p for p_n, p_p in pretrained_model.named_parameters() if
33 | any(nd in p_n for nd in no_decay) and args.model_type in p_n]
34 | },
35 | {
36 | "params": [p for n, p in model.named_parameters() if
37 | any(nd in n for nd in no_decay) and args.model_type not in n],
38 | "weight_decay": 0.0,
39 | "anneal_w": 0.0,
40 | "pretrain_params": [p_p for p_n, p_p in pretrained_model.named_parameters() if
41 | any(nd in p_n for nd in no_decay) and args.model_type not in p_n]
42 | }
43 | ]
44 | optim = RecAdam(optimizer_grouped_parameters, lr=args.lr, eps=args.adam_epsilon, anneal_fun=args.anneal_fun, anneal_k=args.anneal_k, anneal_t0=args.anneal_t0, pretrain_cof=args.pretrain_cof)
45 | else:
46 | optim = Optimizer(
47 | args.optim, args.lr, args.max_grad_norm,
48 | beta1=args.beta1, beta2=args.beta2,
49 | decay_method=args.decay_method,
50 | warmup_steps=args.warmup_steps, model_size=args.enc_hidden_size)
51 |
52 | optim.set_parameters(list(model.named_parameters()))
53 |
54 | if args.train_from != '' and 'xsum' not in args.train_from:
55 | optim.optimizer.load_state_dict(checkpoint['optim'])
56 | if args.visible_gpu != '-1':
57 | for state in optim.optimizer.state.values():
58 | for k, v in state.items():
59 | if torch.is_tensor(v):
60 | state[k] = v.cuda()
61 |
62 | if (optim.method == 'adam') and (len(optim.optimizer.state) < 1):
63 | raise RuntimeError(
64 | "Error: loaded Adam optimizer from existing model" +
65 | " but optimizer state is empty")
66 |
67 |
68 | return optim
69 |
70 | class Optimizer(object):
71 | """
72 | Controller class for optimization. Mostly a thin
73 | wrapper for `optim`, but also useful for implementing
74 | rate scheduling beyond what is currently available.
75 | Also implements necessary methods for training RNNs such
76 | as grad manipulations.
77 |
78 | Args:
79 | method (:obj:`str`): one of [sgd, adagrad, adadelta, adam]
80 | lr (float): learning rate
81 | lr_decay (float, optional): learning rate decay multiplier
82 | start_decay_steps (int, optional): step to start learning rate decay
83 | beta1, beta2 (float, optional): parameters for adam
84 | adagrad_accum (float, optional): initialization parameter for adagrad
85 | decay_method (str, option): custom decay options
86 | warmup_steps (int, option): parameter for `noam` decay
87 | model_size (int, option): parameter for `noam` decay
88 |
89 | We use the default parameters for Adam that are suggested by
90 | the original paper https://arxiv.org/pdf/1412.6980.pdf
91 | These values are also used by other established implementations,
92 | e.g. https://www.tensorflow.org/api_docs/python/tf/train/AdamOptimizer
93 | https://keras.io/optimizers/
94 | Recently there are slightly different values used in the paper
95 | "Attention is all you need"
96 | https://arxiv.org/pdf/1706.03762.pdf, particularly the value beta2=0.98
97 | was used there however, beta2=0.999 is still arguably the more
98 | established value, so we use that here as well
99 | """
100 |
101 | def __init__(self, method, learning_rate, max_grad_norm,
102 | lr_decay=1, start_decay_steps=None, decay_steps=None,
103 | beta1=0.9, beta2=0.999,
104 | adagrad_accum=0.0,
105 | decay_method=None,
106 | warmup_steps=4000,
107 | model_size=None):
108 | self.last_ppl = None
109 | self.learning_rate = learning_rate
110 | self.original_lr = learning_rate
111 | self.max_grad_norm = max_grad_norm
112 | self.method = method
113 | self.lr_decay = lr_decay
114 | self.start_decay_steps = start_decay_steps
115 | self.decay_steps = decay_steps
116 | self.start_decay = False
117 | self._step = 0
118 | self.betas = [beta1, beta2]
119 | self.adagrad_accum = adagrad_accum
120 | self.decay_method = decay_method
121 | self.warmup_steps = warmup_steps
122 | self.model_size = model_size
123 |
124 | def set_parameters(self, params):
125 | """ ? """
126 | self.params = []
127 | self.sparse_params = []
128 | for k, p in params:
129 | if p.requires_grad:
130 | if self.method != 'sparseadam' or "embed" not in k:
131 | self.params.append(p)
132 | else:
133 | self.sparse_params.append(p)
134 | if self.method == 'sgd':
135 | self.optimizer = optim.SGD(self.params, lr=self.learning_rate)
136 | elif self.method == 'adagrad':
137 | self.optimizer = optim.Adagrad(self.params, lr=self.learning_rate)
138 | for group in self.optimizer.param_groups:
139 | for p in group['params']:
140 | self.optimizer.state[p]['sum'] = self.optimizer\
141 | .state[p]['sum'].fill_(self.adagrad_accum)
142 | elif self.method == 'adadelta':
143 | self.optimizer = optim.Adadelta(self.params, lr=self.learning_rate)
144 | elif self.method == 'adam':
145 | self.optimizer = optim.Adam(self.params, lr=self.learning_rate,
146 | betas=self.betas, eps=1e-9)
147 | elif self.method == 'sparseadam':
148 | self.optimizer = MultipleOptimizer(
149 | [optim.Adam(self.params, lr=self.learning_rate,
150 | betas=self.betas, eps=1e-8),
151 | optim.SparseAdam(self.sparse_params, lr=self.learning_rate,
152 | betas=self.betas, eps=1e-8)])
153 | else:
154 | raise RuntimeError("Invalid optim method: " + self.method)
155 |
156 | def _set_rate(self, learning_rate):
157 | self.learning_rate = learning_rate
158 | if self.method != 'sparseadam':
159 | self.optimizer.param_groups[0]['lr'] = self.learning_rate
160 | else:
161 | for op in self.optimizer.optimizers:
162 | op.param_groups[0]['lr'] = self.learning_rate
163 |
164 | def step(self):
165 | """Update the model parameters based on current gradients.
166 |
167 | Optionally, will employ gradient modification or update learning
168 | rate.
169 | """
170 | self._step += 1
171 |
172 | # Decay method used in tensor2tensor.
173 | if self.decay_method == "noam":
174 | self._set_rate(
175 | self.original_lr *
176 | ( self.model_size ** -0.5*min(self._step ** (-0.5),
177 | self._step * self.warmup_steps**(-1.5))))
178 | else:
179 | if ((self.start_decay_steps is not None) and (
180 | self._step >= self.start_decay_steps)):
181 | self.start_decay = True
182 | if self.start_decay:
183 | if ((self._step - self.start_decay_steps)
184 | % self.decay_steps == 0):
185 | self.learning_rate = self.learning_rate * self.lr_decay
186 |
187 | if self.method != 'sparseadam':
188 | self.optimizer.param_groups[0]['lr'] = self.learning_rate
189 |
190 | if self.max_grad_norm:
191 | clip_grad_norm_(self.params, self.max_grad_norm)
192 | self.optimizer.step()
--------------------------------------------------------------------------------
/src/trainer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from tqdm import tqdm
3 | import os
4 | from transformers import BartTokenizer, get_linear_schedule_with_warmup
5 | from others.logging import logger
6 | from others.utils import pad_sents, get_mask
7 | from cal_rouge import test_rouge, rouge_results_to_str
8 | from dapt_pretraining import text_infilling, sent_permutation, add_noise
9 |
10 | def train(model, training_data, validation_data, optimizer, checkpoint, args, pretrained_model):
11 | ''' Start training '''
12 | if args.logging_Euclid_dist:
13 | t_total = len(training_data) // args.accumulation_steps * 10
14 | scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total)
15 | logger.info('Start training')
16 | iteration = 0
17 | if args.break_point_continue:
18 | iteration = checkpoint['iteration']
19 | total_loss = 0
20 | F1 = 0
21 | for epoch_i in range(args.epoch):
22 | logger.info('[ Epoch : {}]'.format(epoch_i))
23 | dist_sum, dist_num = 0.0, 0
24 | # training part
25 | model.train()
26 | for src_ids, decoder_ids, mask, label_ids in training_data:
27 | iteration += 1
28 | src_ids = src_ids.cuda()
29 | decoder_ids = decoder_ids.cuda()
30 | mask = mask.cuda()
31 | label_ids = label_ids.cuda()
32 | # forward
33 | # optimizer.optimizer.zero_grad()
34 | loss = model(input_ids=src_ids, attention_mask=mask, decoder_input_ids=decoder_ids, labels=label_ids)[0]
35 | total_loss += loss.item()
36 | loss = loss / args.accumulation_steps
37 | # backward
38 | loss.backward()
39 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)
40 | # loss accumulation
41 | if (iteration+1) % args.accumulation_steps == 0:
42 | optimizer.step()
43 | if args.recadam:
44 | scheduler.step()
45 | model.zero_grad()
46 |
47 | if args.logging_Euclid_dist:
48 | dist = torch.sum(torch.abs(torch.cat(
49 | [p.view(-1) for n, p in model.named_parameters()]) - torch.cat(
50 | [p.view(-1) for n, p in pretrained_model.named_parameters()])) ** 2).item()
51 |
52 | dist_sum += dist
53 | dist_num += 1
54 | # write to log file
55 | if iteration % 20 == 0:
56 | if args.logging_Euclid_dist:
57 | logger.info("iteration: {} loss_per_word: {:4f} Euclid dist: {:.6f}".format(iteration, total_loss/20, dist_sum / dist_num))
58 | else:
59 | logger.info("iteration: {} loss_per_word: {:4f} learning rate: {:4f} ".format(iteration, total_loss/20, optimizer.learning_rate))
60 | total_loss = 0
61 | # save model
62 | if iteration % args.save_interval == 0 and iteration > args.start_to_save_iter:
63 | temp_F1 = evaluation(model, validation_data, args)
64 | model.train()
65 | if temp_F1 > F1:
66 | logger.info("saving model")
67 | if not os.path.exists(args.saving_path + args.data_name):
68 | os.makedirs(args.saving_path + args.data_name)
69 | model_name = make_file_name(args, iteration)
70 | # checkpoint = {'iteration': iteration, 'settings': args, 'optim': optimizer.optimizer.state_dict(), 'model': model.state_dict()}
71 | torch.save(model, model_name)
72 | F1 = temp_F1
73 | else:
74 | pass
75 |
76 |
77 | def multitask_train(model_lm, model_cnn, cnn_train_data, cnn_valid_data, tgtdomain_data, optimizer_lm, optimizer_cnn, checkpoint, args):
78 | tokenizer = BartTokenizer.from_pretrained('facebook/bart-base')
79 | ''' Start training '''
80 | logger.info('Start multitask training')
81 | iteration = 0
82 | if args.train_from!= '':
83 | iteration = checkpoint['iteration']
84 | cnn_loss = 0
85 | lm_loss = 0
86 | F1 = 0
87 | while iteration < args.max_iter:
88 | iteration += 1
89 | model_lm.train()
90 | model_cnn.train()
91 |
92 | ## cnn news summarization training part
93 | src_ids, decoder_ids, mask, label_ids = next(cnn_train_data)
94 | src_ids = src_ids.cuda()
95 | decoder_ids = decoder_ids.cuda()
96 | mask = mask.cuda()
97 | label_ids = label_ids.cuda()
98 |
99 | loss = model_cnn(input_ids=src_ids, attention_mask=mask, decoder_input_ids=decoder_ids, labels=label_ids)[0]
100 | cnn_loss += loss.item()
101 | loss = loss / args.accumulation_steps
102 | # backward
103 | loss.backward()
104 | # torch.nn.utils.clip_grad_norm_(model_cnn.parameters(), args.clip)
105 |
106 | ## denoising language modeling part
107 | sents = next(tgtdomain_data)
108 | tokenized_sents = [tokenizer.encode(sent, add_special_tokens=False) for sent in sents]
109 | decoder_ids = [[tokenizer.bos_token_id] + item for item in tokenized_sents]
110 | label_ids = [item + [tokenizer.eos_token_id] for item in tokenized_sents]
111 |
112 | noisy_text = add_noise(sents, args.mask_prob)
113 | inputs_ids = [tokenizer.encode(sent, add_special_tokens=False) for sent in noisy_text]
114 |
115 | # prepare data for training
116 | inputs_ids = torch.tensor(pad_sents(inputs_ids, pad_token=tokenizer.pad_token_id)[0]).cuda()
117 | mask = torch.tensor(get_mask(inputs_ids)).cuda()
118 | decoder_ids = torch.tensor(pad_sents(decoder_ids, pad_token=tokenizer.pad_token_id)[0]).cuda()
119 | label_ids = torch.tensor(pad_sents(label_ids, pad_token=-100)[0]).cuda()
120 |
121 | # optimize model
122 | loss = model_lm(input_ids=inputs_ids, attention_mask=mask, decoder_input_ids=decoder_ids, labels=label_ids)[0]
123 | lm_loss += loss.item()
124 | loss = loss / args.accumulation_steps
125 | loss.backward()
126 | # torch.nn.utils.clip_grad_norm_(model_lm.parameters(), args.clip)
127 |
128 | # loss accumulation
129 | if (iteration+1) % args.accumulation_steps == 0:
130 | optimizer_lm.step()
131 | optimizer_cnn.step()
132 | model_lm.zero_grad()
133 | model_cnn.zero_grad()
134 | # write to log file
135 | if iteration % 20 == 0:
136 | logger.info("iteration: {} loss_per_word: {:.6f} loss_lm: {:.6f} learning rate lm: {:.9f} learning rate cnn: {:.9f}".format(iteration, cnn_loss/20, lm_loss/20, optimizer_lm.learning_rate, optimizer_cnn.learning_rate))
137 | cnn_loss = 0
138 | lm_loss = 0
139 |
140 | if iteration % 50000 == 0:
141 | # eval_F1 = evaluation(model, cnn_valid_data, args)
142 | # logger.info("Iteration: {}. F1 score: {:.4f}".format(iteration, eval_F1))
143 | logger.info("saving model")
144 | if not os.path.exists(args.saving_path + args.data_name):
145 | os.makedirs(args.saving_path + args.data_name)
146 | model_name = make_file_name(args, iteration)
147 | # checkpoint_1 = {'iteration': iteration, 'settings': args, 'optim': optimizer_lm.optimizer.state_dict(), 'model_lm': model_lm.state_dict()}
148 | # checkpoint_2 = {'iteration': iteration, 'settings': args, 'optim': optimizer_cnn.optimizer.state_dict(), 'model': model_cnn.state_dict()}
149 | torch.save(model_lm, model_name[0])
150 | torch.save(model_cnn, model_name[1])
151 |
152 |
153 | def evaluation(model, validation_data, args):
154 | model.eval()
155 | valid_reference_path = './dataset/' + args.data_name + '/valid.target'
156 | valid_data = open(valid_reference_path,'r')
157 | valid_list = valid_data.readlines()
158 | valid_list = [i.strip('\n') for i in valid_list]
159 | # inference
160 | tokenizer = BartTokenizer.from_pretrained('facebook/bart-base')
161 | outputs = []
162 | for src_ids, decoder_ids, mask, label_ids in tqdm(validation_data):
163 | src_ids = src_ids.cuda()
164 | summary_ids = model.generate(src_ids, num_beams=4, max_length=256, early_stopping=True)
165 | output = [tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in summary_ids]
166 | outputs += output
167 | # calculate rouge
168 | final_results = test_rouge(outputs, valid_list, args.process_num)
169 | R1_F1 = final_results["rouge_1_f_score"] * 100
170 | logger.info('[ Validation ]')
171 | logger.info(rouge_results_to_str(final_results))
172 | return R1_F1
173 |
174 | def make_file_name(args, iteration):
175 | model_name = args.saving_path + "{}/{}_75%sample.chkpt".format(args.data_name,iteration,args.percentage)
176 | if args.pre_trained_lm != '':
177 | # model_name = args.saving_path + "{}/{}_{}%DAPT_REG_EPOCH_1.chkpt".format(args.data_name,iteration,args.percentage)
178 | model_name = args.saving_path + "{}/{}_{}%DAPT_reg_160000.chkpt".format(args.data_name,iteration,args.percentage)
179 | if args.pre_trained_src:
180 | model_name = args.saving_path + "{}/{}_{}%cnn_pretrain_400000.chkpt".format(args.data_name,iteration,args.percentage)
181 | # model_name = args.saving_path + "{}/{}_{}%_pre_trained_src.chkpt".format(args.data_name,iteration,args.percentage)
182 | if args.mtl:
183 | model_name_1 = args.saving_path + "{}/{}_mtl_pre_trained_lm.chkpt".format('social_media_short',iteration)
184 | model_name_2 = args.saving_path + "{}/{}_mtl_pre_trained_src.chkpt".format('social_media_short',iteration)
185 | model_name = [model_name_1, model_name_2]
186 | return model_name
187 |
188 | def evaluation_loss(model, validation_data, args):
189 | model.eval()
190 | total_loss = 0
191 | for src_ids, decoder_ids, mask, label_ids in tqdm(validation_data):
192 | src_ids = src_ids.cuda()
193 | decoder_ids = decoder_ids.cuda()
194 | mask = mask.cuda()
195 | label_ids = label_ids.cuda()
196 | loss = model(input_ids=src_ids, attention_mask=mask, decoder_input_ids=decoder_ids, labels=label_ids)[0]
197 | total_loss += loss.item()
198 | loss = None
199 | logger.info(total_loss)
200 | return total_loss
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # AdaptSum: Towards Low-Resource Domain Adaptation for Abstractive Summarization
2 |
3 |
[](https://www.python.org/downloads/) [![CC BY 4.0][cc-by-shield]][cc-by]
4 |
5 |
6 |
7 |
8 | [cc-by]: http://creativecommons.org/licenses/by/4.0/
9 | [cc-by-shield]: https://img.shields.io/badge/License-CC%20BY%204.0-lightgrey.svg
10 |
11 |
12 | Paper accepted at the [NAACL-HLT 2021](https://2021.naacl.org):
13 |
14 | **[AdaptSum: Towards Low-Resource Domain Adaptation for Abstractive Summarization](https://arxiv.org/pdf/2103.11332)**, by **[Tiezheng Yu*](https://tysonyu.github.io/)**, **[Zihan Liu*](https://zliucr.github.io/)**, [Pascale Fung](https://pascale.home.ece.ust.hk).
15 |
16 | ## Abstract
17 | State-of-the-art abstractive summarization models generally rely on extensive labeled data, which lowers their generalization ability on domains where such data are not available. In this paper, we present a study of domain adaptation for the abstractive summarization task across six diverse target domains in a low-resource setting. Specifically, we investigate the second phase of pre-training on large-scale generative models under three different settings: 1) source domain pre-training; 2) domain-adaptive pre-training; and 3) task-adaptive pre-training. Experiments show that the effectiveness of pre-training is correlated with the similarity between the pre-training data and the target domain task. Moreover, we find that continuing pre-training could lead to the pre-trained model's catastrophic forgetting, and a learning method with less forgetting can alleviate this issue. Furthermore, results illustrate that a huge gap still exists between the low-resource and high-resource settings, which highlights the need for more advanced domain adaptation methods for the abstractive summarization task.
18 |
19 | ## Dataset
20 | We release the AdaptSum dataset, which contains the summarization datasets across six target domains as well as the corpora for SDPT, DAPT and TAPT. You can download AdaptSum from [Here](https://drive.google.com/drive/folders/1qdkavIQonTAepkJhGpo3TZpU4LUW44sp?usp=sharing).
21 |
22 | ## Preparation for running
23 | 1. Create a new folder named `dataset` at the root of this project
24 | 2. Download the data from google drive and then put it in the `dataset` folder
25 | 3. Create the conda environment
26 | ```
27 | conda create -n adaptsum python=3.6
28 | ```
29 | 4. Activate the conda environment
30 | ```
31 | conda activate adaptsum
32 | ```
33 | 5. Install pytorch. Please check your CUDA version before the installation and modify it accordingly, or you can refer to [pytorch website](https://pytorch.org)
34 | ```
35 | conda install pytorch cudatoolkit=11.0 -c pytorch
36 | ```
37 | 6. Install requirements
38 | ```
39 | pip install -r requirements.txt
40 | ```
41 | 7. Create a new folder named `logs` at the root of this project
42 | ## SDPT pretraining
43 | - We take `cnn_dm` as an example
44 | 1. Create a new folder named `SDPT_save` at the root of this project
45 | 2. Prepare dataloader:
46 | ```
47 | python ./src/preprocessing.py -data_path=dataset/ \
48 | -data_name=SDPT-cnn_dm \
49 | -mode=train \
50 | -batch_size=4
51 | ```
52 | 3. Run `./scripts/sdpt_pretraining.sh`. You can add `-recadam` and `-logging_Euclid_dist` to use RecAdam.
53 |
54 | ## DAPT pretraining
55 | - We take `debate domain` as an example
56 | 1. Create a new folder named `DAPT_save` at the root of this project
57 | 2. Run `./scripts/dapt_pretraining.sh`. You can add `-recadam` and `-logging_Euclid_dist` to use RecAdam.
58 |
59 | ## TAPT pretraining
60 | - We take `debate domain` as an example
61 | 1. Create a new folder named `TAPT_save` at the root of this project
62 | 2. Run `./scripts/tapt_pretraining.sh`. You can add `-recadam` and `-logging_Euclid_dist` to use RecAdam.
63 |
64 | ## Fine-tuning
65 | - We take `debate domain` as an example
66 | 1. Create a new folder named `debate` at `logs`
67 |
68 | 2. Prepare dataloader:
69 | ```
70 | python ./src/preprocessing.py -data_path=dataset/ \
71 | -data_name=debate \
72 | -mode=train \
73 | -batch_size=4
74 | python ./src/preprocessing.py -data_path=dataset/ \
75 | -data_name=debate \
76 | -mode=valid \
77 | -batch_size=4
78 | python ./src/preprocessing.py -data_path=dataset/ \
79 | -data_name=debate \
80 | -mode=test \
81 | -batch_size=4
82 | ```
83 |
84 | 3. Install `pyrouge` package (You can skip this if you have already installed `pyrouge`)
85 | - Step 1 : Install Pyrouge from source (not from pip)
86 | ```
87 | git clone https://github.com/bheinzerling/pyrouge
88 | cd pyrouge
89 | pip install -e .
90 | ```
91 | - Step 2 : Install official ROUGE script
92 | ```
93 | git clone https://github.com/andersjo/pyrouge.git rouge
94 | ```
95 | - Step 3 : Point Pyrouge to official rouge script (The path given to pyrouge should be absolute path !)
96 | ```
97 | pyrouge_set_rouge_path ~/pyrouge/rouge/tools/ROUGE-1.5.5/
98 | ```
99 | - Step 4 : Install libxml parser
100 | As mentioned in this [issue](https://github.com/bheinzerling/pyrouge/issues/27), you need to install libxml parser
101 | ```
102 | sudo apt-get install libxml-parser-perl
103 | ```
104 | - Step 5 : Regenerate the Exceptions DB
105 | As mentioned in this [issue](https://github.com/bheinzerling/pyrouge/issues/8), you need to regenerate the Exceptions DB
106 | ```
107 | cd rouge/tools/ROUGE-1.5.5/data
108 | rm WordNet-2.0.exc.db
109 | ./WordNet-2.0-Exceptions/buildExeptionDB.pl ./WordNet-2.0-Exceptions ./smart_common_words.txt ./WordNet-2.0.exc.db
110 | ```
111 | - Step 6 : Run the tests
112 | ```
113 | python -m pyrouge.test
114 | ```
115 |
116 | 4. Run Finetuning
117 | - If you don't want to use any second phase of pre-training, run:
118 | ```
119 | python ./src/run.py -visible_gpu=0 \
120 | -data_name=debate \
121 | -save_interval=100 \
122 | -start_to_save_iter=3000
123 | ```
124 | - If you want to use pretrained checkpoints from **SDPT**, run:
125 | ```
126 | python ./src/run.py -visible_gpu=0 \
127 | -data_name=debate \
128 | -save_interval=100 \
129 | -start_to_save_iter=3000 \
130 | -pre_trained_src \
131 | -train_from=YOUR_SAVED_CHECKPOINTS
132 | ```
133 | - If you want to use pretrained checkpoints from **DAPT** or **TAPT**, run:
134 | ```
135 | python ./src/run.py -visible_gpu=0 \
136 | -data_name=debate \
137 | -save_interval=100 \
138 | -start_to_save_iter=3000 \
139 | -pre_trained_lm=YOUR_SAVED_CHECKPOINTS
140 | ```
141 |
142 | 5. Evaluate the performance
143 | 1) Make a folder named `inference` at `logs`
144 | 2) You can do inference by
145 | ```
146 | python ./src/inference.py -visible_gpu=0 -train_from=YOUR_SAVED_CHECKPOINT
147 | ```
148 | 3) You can calculate rouge scores by
149 | ```
150 | python ./src/cal_roug.py -c=CANDIDATE_FILE -r=REFERENCE_FILE -p=NUMBER_OF_PROCESS
151 | ```
152 |
153 | ## References
154 | If you use our benchmark or the code in this repo, please cite our paper.
155 |
156 |
157 | @inproceedings{Yu2021AdaptSum,
158 | title={AdaptSum: Towards Low-Resource Domain Adaptation for Abstractive Summarization},
159 | author={Tiezheng Yu and Zihan Liu and Pascale Fung},
160 | journal={arXiv preprint arXiv:2103.11332},
161 | year={2021}
162 | }
163 |
164 |
165 | Also, please consider citing all the individual datasets in your paper.
166 |
167 | Dialog domain:
168 |
169 | @inproceedings{gliwa2019samsum,
170 | title={SAMSum Corpus: A Human-annotated Dialogue Dataset for Abstractive Summarization},
171 | author={Gliwa, Bogdan and Mochol, Iwona and Biesek, Maciej and Wawer, Aleksander},
172 | booktitle={Proceedings of the 2nd Workshop on New Frontiers in Summarization},
173 | pages={70--79},
174 | year={2019}
175 | }
176 |
177 |
178 | Email domain:
179 |
180 | @inproceedings{zhang2019email,
181 | title={This Email Could Save Your Life: Introducing the Task of Email Subject Line Generation},
182 | author={Zhang, Rui and Tetreault, Joel},
183 | booktitle={Proceedings of the 57th Annual Meeting of the Association for Computational Linguistics},
184 | pages={446--456},
185 | year={2019}
186 | }
187 |
188 |
189 | Movie and debate domains:
190 |
191 | @inproceedings{wang2016neural,
192 | title={Neural Network-Based Abstract Generation for Opinions and Arguments},
193 | author={Wang, Lu and Ling, Wang},
194 | booktitle={Proceedings of the 2016 Conference of the North American Chapter of the Association for Computational Linguistics: Human Language Technologies},
195 | pages={47--57},
196 | year={2016}
197 | }
198 |
199 |
200 | Social media domain:
201 |
202 | @inproceedings{kim2019abstractive,
203 | title={Abstractive Summarization of Reddit Posts with Multi-level Memory Networks},
204 | author={Kim, Byeongchang and Kim, Hyunwoo and Kim, Gunhee},
205 | booktitle={Proceedings of the 2019 Conference of the North American Chapter of the Association for Computational Linguistics: Human Language Technologies, Volume 1 (Long and Short Papers)},
206 | pages={2519--2531},
207 | year={2019}
208 | }
209 |
210 |
211 | Science domain:
212 |
213 | @inproceedings{yasunaga2019scisummnet,
214 | title={Scisummnet: A large annotated corpus and content-impact models for scientific paper summarization with citation networks},
215 | author={Yasunaga, Michihiro and Kasai, Jungo and Zhang, Rui and Fabbri, Alexander R and Li, Irene and Friedman, Dan and Radev, Dragomir R},
216 | booktitle={Proceedings of the AAAI Conference on Artificial Intelligence},
217 | volume={33},
218 | pages={7386--7393},
219 | year={2019}
220 | }
221 |
222 |
223 |
--------------------------------------------------------------------------------
/src/dapt_pretraining.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 | from torch.utils.data import Dataset, DataLoader
4 | from transformers import BartForConditionalGeneration, BartTokenizer, get_linear_schedule_with_warmup
5 | from others.logging import logger
6 | from others.utils import pad_sents, get_mask
7 | from others.optimizer import build_optim
8 | from tqdm import tqdm
9 | import numpy as np
10 | import argparse
11 | import random
12 | import os
13 | from nltk.tokenize import sent_tokenize
14 |
15 |
16 | def text_infilling(sent, mask_probability=0.05, lamda=3):
17 | '''
18 | inputs:
19 | sent: a sentence string
20 | mask_probability: probability for masking tokens
21 | lamda: lamda for poission distribution
22 | outputs:
23 | sent: a list of tokens with masked tokens
24 | '''
25 | sent = sent.split()
26 | length = len(sent)
27 | mask_indices = (np.random.uniform(0, 1, length) < mask_probability) * 1
28 | span_list = np.random.poisson(lamda, length) # lamda for poission distribution
29 | nonzero_idx = np.nonzero(mask_indices)[0]
30 | for item in nonzero_idx:
31 | span = min(span_list[item], 5) # maximum mask 5 continuous tokens
32 | for i in range(span):
33 | if item+i >= length:
34 | continue
35 | mask_indices[item+i] = 1
36 | for i in range(length):
37 | if mask_indices[i] == 1:
38 | sent[i] = '{name}
".format( 466 | id=system_id, name=system_filename) 467 | 468 | model_elems = ["